# -*- coding: utf-8 -*-

from qgis.PyQt import QtCore, QtGui, QtWidgets
from qgis.gui import QgsMapTool, QgsRubberBand
from qgis.core import QgsRectangle, QgsVectorLayer, QgsRasterLayer, QgsPointCloudLayer, QgsLayerTreeLayer, QgsProject, QgsCoordinateTransform, QgsPointXY, QgsWkbTypes, QgsCoordinateReferenceSystem, QgsApplication,  QgsTask

import os, json, traceback
from functools import partial

TEXTS = {
    'pluginToolTip': "Select tiles to load",
}

class Picker(QgsMapTool):

    TEXTS = {
        'pluginToolTip': "Select tiles to load",
    }
    
    def __init__(self, plugin):
        canvas = plugin.iface.mapCanvas()
        super().__init__(canvas)
        self.plugin = plugin
        self.iface = plugin.iface
        self.selecting = False
        self.param = Param()
        self.layers = None
        self.files = None
        self.rubberBand = QgsRubberBand(canvas, QgsWkbTypes.GeometryType.PolygonGeometry)
        self.rubberBand.setStrokeColor(QtCore.Qt.GlobalColor.darkGray)
        self.rubberBand.setLineStyle(QtCore.Qt.PenStyle.DashLine)
        self.rubberBand.setWidth(2)
        self.lts = []
        self.tm = QgsApplication.taskManager()
        self.map = QgsProject.instance()
        self.iface.mapCanvas().renderComplete.connect(self.clear)
        self.active = None
    
    def unload(self):
        try: self.rubberBand.reset(QgsWkbTypes.GeometryType.PolygonGeometry)
        except: pass
        for lt in list(self.lts):
            lt.unload()
            del lt
    
    def clear(self):
        for lt in list(self.lts):
            if len(lt.layers)==0:
                lt.unload()
                del lt

    def getInLoad(self):
        l = []
        for lt in list(self.lts):
            l.extend(list(lt.layers.keys()))
        return l
    
    def activate(self):
        super().activate()
        self.canvas().setCursor(QtGui.QCursor(QtCore.Qt.CursorShape.CrossCursor))

    def deactivate(self):
        super().deactivate()
        self.canvas().setCursor(QtGui.QCursor(QtCore.Qt.CursorShape.ArrowCursor))
        self.deactivated.emit()

    def loadFromSelected(self):
        self.layers = self.getLayers(force=True)
        LoadTiles(self, unselect=False)
        self.layers = None
    
    def canvasPressEvent(self, e):
        self.layers = self.getLayers()
        self.files = self.getFiles()
        self.fpp = e.pixelPoint()
        self.fmp = self.getPoint(e)
        self.selecting = True
        self.selectFeatures(e)

    def canvasReleaseEvent(self, e):
        self.selecting = False  
        self.rubberBand.reset(QgsWkbTypes.GeometryType.PolygonGeometry)
        mp = self.getPoint(e)
        id = self.iface.mapCanvas().mapSettings().destinationCrs().authid()
        r = QgsRectangle(mp[id],self.fmp[id])
        if r.area()>self.param.get('maxZone')*1000000: 
            msg = self.plugin.tr("Too big area")
            self.iface.messageBar().pushInfo("", msg)
            return 
        LoadTiles(self, points=(self.fmp,mp))
        self.layers = None
        self.files = None

    def canvasMoveEvent(self, e):
        if self.selecting:
            self.showRect(self.rubberBand, self.fmp[self.iface.mapCanvas().mapSettings().destinationCrs().authid()], e.mapPoint())
            self.selectFeatures(e)
    
    def selectFeatures(self, e):
        mp = self.getPoint(e)
        r = {}
        for l in self.getLayers():
            id = l.crs().authid()
            if id not in r: r[id] = QgsRectangle(mp[id],self.fmp[id])
            l.selectByRect(r[id])
    
    def showRect(self, rb, p1, p2):
        rb.reset(QgsWkbTypes.GeometryType.PolygonGeometry)
        rb.addPoint(QgsPointXY(p1.x(), p1.y()), False)
        rb.addPoint(QgsPointXY(p1.x(), p2.y()), False)
        rb.addPoint(QgsPointXY(p2.x(), p2.y()), False)
        rb.addPoint(QgsPointXY(p2.x(), p1.y()), True)
        rb.show()
    
    def getPoint(self, e):
        pts = {}
        pt = e.mapPoint()
        crs = self.iface.mapCanvas().mapSettings().destinationCrs()
        pts[crs.authid()] = pt
        for l in self.getLayers():
            if l.crs().authid()!=crs.authid():
                tr = QgsCoordinateTransform(crs, l.crs(), QgsProject.instance())
                pts[l.crs().authid()] = tr.transform(pt)
        
        for fc,d in self.getFiles().items():
            id = d['data']['crs']
            if id!=crs.authid():
                tr = QgsCoordinateTransform(crs, QgsCoordinateReferenceSystem(id), QgsProject.instance())
                pts[id] = tr.transform(pt)
                
        return pts       
                
    
    def getLayers(self, force=False):
        if self.layers is not None: return self.layers
        layers = []
        if self.param.get('mode')=='index' or force:
            if self.param.get('targetLayers')=='active': layers = [self.iface.activeLayer()]
            if self.param.get('targetLayers')=='all' or force: layers = [l for n,l in QgsProject.instance().mapLayers().items()]
        return [l for l in layers if isinstance(l, QgsVectorLayer)]
    
    def getFiles(self):
        if self.files is not None: return self.files
        files = {}
        if self.param.get('mode')=='coords':
            for fc in self.param.get('fromCoords'):
                try: data = json.load(open(os.path.join(self.plugin.tileFolder, fc)))
                except: continue
                files[fc] = {'data':data, 'list':[]}
        return files
        

class LoadTiles(QtCore.QObject):
    def __init__(self, picker, points=None, unselect=True):
        super().__init__()
        self.picker = picker
        self.canvas = picker.canvas()
        self.unselect = unselect
        self.rbs = []
        self.picker.lts.append(self)
        self.loads = [l.source() for n,l in picker.map.mapLayers().items()]
        self.nopaths = []
        self.layers = {}
        self.loads.extend(picker.getInLoad())
        self.noReload = self.picker.param.get('noReload')
        if points: self.calculTiles(points[0], points[1])
        self.loadSelected()
        self.launch()

    def unload(self):
        self.picker.lts.remove(self)
        try: self.clearRubber()
        except: pass
        
    def addRubber(self, p1, p2):
        rb = QgsRubberBand(self.canvas, QgsWkbTypes.GeometryType.PolygonGeometry)
        self.rbs.append(rb)
        rb.setStrokeColor(QtCore.Qt.GlobalColor.darkGray)
        rb.setLineStyle(QtCore.Qt.PenStyle.DotLine)
        rb.setWidth(1)
        self.picker.showRect(rb, p1, p2)
        return rb
        
    def delRubber(self, rb):
        try:
            self.rbs.remove(rb)
            rb.reset(QgsWkbTypes.GeometryType.PolygonGeometry)
            del rb
        except: pass
    
    def clearRubber(self):
        for rb in list(self.rbs):
            self.delRubber(rb)
 
    
    def calculTiles(self, fmp, lmp):
        dec = {
            'upper': (1,2),
            'lower': (0,1),
            'left': (0,1),
            'right': (1,2),
        }
        crs = self.canvas.mapSettings().destinationCrs()
        for fc,d in self.picker.getFiles().items():
            id = TileParam.get('crs', d['data'])
            crs2 = QgsCoordinateReferenceSystem(id)
            tr = None
            if id!=crs.authid():
                tr = QgsCoordinateTransform(crs2, crs, QgsProject.instance())
            size = TileParam.get('size', d['data'])
            aff = int(size/TileParam.get('prec', d['data']))
            corner = TileParam.get('corner', d['data'])
            parts = TileParam.get('parts', d['data'])
            folder = TileParam.get('folder', d['data'])
            minx = int(min(fmp[id].x(), lmp[id].x())/size+dec[corner[1]][0])*aff
            maxx = int(max(fmp[id].x(), lmp[id].x())/size+dec[corner[1]][1])*aff
            miny = int(min(fmp[id].y(), lmp[id].y())/size+dec[corner[0]][0])*aff
            maxy = int(max(fmp[id].y(), lmp[id].y())/size+dec[corner[0]][1])*aff
            for x in range(minx, maxx, aff):
                for y in range(miny, maxy, aff):
                    vx = self.formatCoords(x, TileParam.get('digits', d['data']))
                    vy = self.formatCoords(y, TileParam.get('digits', d['data']))
                    if TileParam.get('order', d['data'])=='xy': v = (vx, vy)
                    else: v = (vy, vx)
                    file = f"{parts[0]}{v[0]}{parts[1]}{v[1]}{parts[2]}"
                    p1 = QgsPointXY((x/aff-dec[corner[1]][0])*size, (y/aff-dec[corner[0]][0])*size)
                    p2 = QgsPointXY((x/aff-dec[corner[1]][0]+1)*size, (y/aff-dec[corner[0]][0]+1)*size)
                    if tr is not None:
                        p1 = tr.transform(p1)
                        p2 = tr.transform(p2)
                    rb = self.addRubber(p1, p2)
                    self.loadPath(os.path.join(folder, file), crs2, rb)

    def formatCoords(self, v, n):
        v = f"{v}"
        while len(v)<n:
            v = f"0{v}"
        return v

    def loadSelected(self):
        if self.picker.param.get('mode')=='index':
            self.active = self.picker.iface.activeLayer()
        for l in self.picker.getLayers():
            atts =l.fields().names().copy()
            att = None
            for f in l.getSelectedFeatures():
                layer = None
                if att is not None:layer = self.loadFromAttribute(l, f, att)
                for a in atts:
                    if layer: continue
                    layer = self.loadFromAttribute(l, f, a)
                    if layer: att = a
            if self.unselect: l.removeSelection()
    
    def loadFromAttribute(self, l, f, a):
        path = str(f.attribute(a))
        if not os.path.exists(path): 
            s = l.source().split('|')[0]
            d,n = os.path.split(s)
            path = os.path.join(d,path)
        if not os.path.exists(path): return   
        bbox = f.geometry().boundingBox()
        p1 = QgsPointXY(bbox.xMinimum(),bbox.yMinimum())
        p2 = QgsPointXY(bbox.xMaximum(),bbox.yMaximum())
        rb = self.addRubber(p1, p2)
        return self.loadPath(path, l.crs(), rb)
    
    def loadPath(self, path, crs, rb=None):
        if self.noReload and path in self.loads: 
            self.delRubber(rb)
            return True
        self.layers[path] = {'crs':crs, 'rb':rb, 'in':False}
        return True
 
    def checkExist(self):
        self.nopaths = []
        for path,v in list(self.layers.items()):
            if not os.path.exists(path): 
                del self.layers[path]
                self.nopaths.append(v['rb'])
 
    def launch(self):
        t = checkTask(self)
        t.taskCompleted.connect(self.endLaunch)
        self.picker.tm.addTask(t)
        
    def endLaunch(self):
        for rb in self.nopaths:
            self.delRubber(rb)
        
        if len(self.layers)<1:
            msg = self.picker.plugin.tr("No tile found, check your config")
            self.picker.iface.messageBar().pushInfo("", msg)
        
        rep = True
        if len(self.layers)>self.picker.param.get('maxTiles'):
            qm = QtWidgets.QMessageBox
            resp = qm.question(None,'', self.picker.plugin.tr("You're about to load")+f" {len(self.layers)} "+self.picker.plugin.tr("tiles"), qm.StandardButton.Yes | qm.StandardButton.No)
            if resp!=qm.StandardButton.Yes:
                rep = False
        if rep:
            for path,v in list(self.layers.items()):
                n,e = os.path.splitext(path)
                if e.lower() in ('.las','.laz'):
                    layer = self.loadLayer(path, v['crs'])
                    self.endLoadLayer(path, layer=layer, rb=v['rb'])
                else:
                    t = loadTask(self, path, v['crs'])
                    t.taskCompleted.connect(partial(self.endLoadLayer, path, task=t, rb=v['rb']))
                    self.picker.tm.addTask(t)
        else:
            self.clearRubber()
            self.layers = {}
    
    def loadLayer(self, path, crs):
        layer = None
        try:
            b,n = os.path.split(path)
            n,e = os.path.splitext(n)
            if e.lower() in ('.las','.laz'):
                pdal = 'pdal'
                nn,ne = os.path.splitext(n)
                if ne.lower()=='.copc': 
                    n = nn
                    pdal = 'copc'
                options = QgsPointCloudLayer.LayerOptions()
                options.skipCrsValidation = True
                # options.skipStatisticsCalculation = True
                layer = QgsPointCloudLayer(path, n, pdal, options)
            else:
                options = QgsRasterLayer.LayerOptions()
                options.skipCrsValidation = True
                layer = QgsRasterLayer(path, n, 'gdal', options)
            try: layer.setCrs(crs)
            except: 
                # print(traceback.format_exc())
                pass
        except: 
            # print(traceback.format_exc())
            pass
        return layer

    def endLoadLayer(self, path, layer=None, task=None, rb=None):    
        
        if layer is None:
            try: layer = task.layer
            except: pass
        if layer is None: 
            self.delRubber(rb)
            del self.layers[path]
            return

        # if isinstance(layer, QgsPointCloudLayer):
            # print('state', layer.statisticsCalculationState())
            # layer.statisticsCalculationStateChanged.connect(partial(self.pCstate, layer))
            # print(layer.dataProvider().metadataStatistics())
            # pass
            
        
        if self.picker.param.get('groupActive'):
            self.picker.map.addMapLayer(layer, False)
            group = self.makeGroup(self.picker.param.get('groupName'))
            l = group.addLayer(layer)
            l.setExpanded(False)
        else:
            l = self.picker.map.addMapLayer(layer)

        if self.picker.param.get('mode')=='index':
            try: self.picker.iface.setActiveLayer(self.active)
            except: pass
        del self.layers[path]

    # def pCstate(self, layer):    
        # print('end', layer, layer.statisticsCalculationState())

    def makeGroup(self, name):
        group = self.picker.map.layerTreeRoot().findGroup(name)
        if not group:
            group = self.picker.map.layerTreeRoot().addGroup(name)
        return group

class Param:
    
    DEFAULT = {
        'targetLayers': "active",
        'noReload': True,
        'fromCoords': [],
        'group': "TilePicker",
        'mode': "index",
        'maxTiles': 20,
        'maxZone': 1000,
        'groupActive': True,
        'groupName': "TilePick",
    }

    def __init__(self):
        self.plugin = self.getPlugin()
        self.getSettings()
        
    def get(self, key):
        val = self.readIni(key)
        d = self.DEFAULT.get(key)
        if val is None: val = d
        else:
            try:
                if isinstance(d,float):
                    val = float(val)
                if isinstance(d, int) and not isinstance(d, bool): 
                    val = int(val)
                if isinstance(d, bool):
                    if val in ('True', 'true', 1) : val = True
                    else: val = False
            except: 
                val = d
                self.set(key, val)
        return val
        
    def set(self, key, val):
        self.writeIni(key, val)
    
    def getPlugin(self):
        path = os.path.dirname(__file__).replace("\\", "/")
        tt = path.split("/")
        doss = tt[len(tt)-1]
        return doss
    
    def getSettings(self):
        self.s = QtCore.QSettings() 
        self.s.beginGroup(self.plugin)
    
    def readIni(self, key):
        return self.s.value(key)

    def writeIni(self, key, value):
        self.s.setValue(key, value)


class TileParam:
    DEFAULT = {
        'alias': '',
        'folder': '',
        'crs': 'EPSG:2154',
        'corner': ('upper','left'),
        'size': 1000,
        'prec': 1000,
        'digits': 4,
        'order': 'xy',
        'parts': ('','',''),
    }
    
    def get(key, data):
        val = data.get(key)
        d = TileParam.DEFAULT.get(key)
        if val is None: val = d
        else:
            try:
                if isinstance(d,float):
                    val = float(val)
                if isinstance(d, int) and not isinstance(d, bool): 
                    val = int(val)
                if isinstance(d, bool):
                    if v in ('True', 'true', 1) : val = True
                    else: val = False
            except: 
                val = d
        return val
        
 
class checkTask(QgsTask):  
    def __init__(self, picker):
        description = "Calcul des tuiles disponibles sur le périmètre"
        super().__init__(description, QgsTask.CanCancel)
        self.picker = picker
        self.error = None

    def run(self):
        try: self.picker.checkExist()
        except: self.error = traceback.format_exc()
        return True

    def finished(self, result):
        if self.error is not None:
            print(self.error)
    
    def cancel(self):
        super().cancel()
 
 
class loadTask(QgsTask):  
    def __init__(self, picker, path, crs):
        description = "Chargement de la tuile"
        super().__init__(description, QgsTask.CanCancel)
        self.picker = picker
        self.error = None
        self.path = path
        self.crs = crs

    def run(self):
        try: self.layer = self.picker.loadLayer(self.path, self.crs)
        except: self.error = traceback.format_exc()
        return True

    def finished(self, result):
        if self.error is not None:
            print(self.error)
    
    def cancel(self):
        super().cancel()