# -*- coding: utf-8 -*-

from qgis.PyQt.QtCore import QCoreApplication
from qgis.core import (QgsProcessing,
                       QgsMapLayer,
                       QgsGeometry,
                       QgsPoint,
                       QgsFeatureSink,
                       QgsProcessingException,
                       QgsProcessingAlgorithm,
                       QgsUnitTypes,
                       QgsProcessingParameterFeatureSource,
                       QgsProcessingParameterField,
                       QgsProcessingParameterDistance,
                       QgsProcessingParameterFeatureSink)
from qgis import processing


class BranchConnectProcessingAlgorithm(QgsProcessingAlgorithm):

    INPUT = 'INPUT'
    OUTPUT = 'OUTPUT'
    TOLERANCE = 'TOLERANCE'

    def tr(self, string):
        return QCoreApplication.translate('Processing', string)

    def createInstance(self):
        return BranchConnectProcessingAlgorithm()

    def name(self):
        return 'branchconnect'

    def displayName(self):
        return self.tr('Connexion des branches')

    def group(self):
        return self.tr('Scripts')

    def groupId(self):
        return 'scripts'

    def shortHelpString(self):
        return self.tr("Scinde puis reconstruit une multipolyligne pour en faire un arbre de ramification à branches uniques")

    def initAlgorithm(self, config=None):
        self.addParameter(
            QgsProcessingParameterFeatureSource(
                self.INPUT,
                self.tr("Multipolyligne source"),
                [QgsProcessing.TypeVectorLine]
            )
        )
        
        distance = QgsProcessingParameterDistance(
                self.TOLERANCE,
                self.tr("Tolérance"),
                defaultValue = 1
            )
        distance.setDefaultUnit(QgsUnitTypes.DistanceMeters)
        self.addParameter(distance)
        
        self.addParameter(
            QgsProcessingParameterFeatureSink(
                self.OUTPUT,
                self.tr('Branches connectées')
            )
        )

    def getMapLayer(self, v, context):
        if not isinstance(v, QgsMapLayer):
            v = context.getMapLayer(v)
        return v

    def processAlgorithm(self, parameters, context, feedback):
        source = self.parameterAsSource(
            parameters,
            self.INPUT,
            context
        )
        if source is None:
            raise QgsProcessingException(self.invalidSourceError(parameters, self.INPUT))

        (sink, dest_id) = self.parameterAsSink(
            parameters,
            self.OUTPUT,
            context,
            source.fields(),
            source.wkbType(),
            source.sourceCrs()
        )   

        parameters[self.INPUT] = self.getMapLayer(parameters[self.INPUT], context)

        feedback.pushInfo(f"Creating single lines")
        alg = processing.run("native:multiparttosingleparts", {
            'INPUT': parameters[self.INPUT],
            'OUTPUT': QgsProcessing.TEMPORARY_OUTPUT
        }, context=context, feedback=feedback)
    
        outLayer = self.getMapLayer(alg['OUTPUT'], context)
    
        feedback.pushInfo(f"Assembling branches")
        biefs = biefsCollection(outLayer, feedback, parameters[self.TOLERANCE])
        if not feedback.isCanceled():
            feedback.pushInfo(f"Exporting results")
            biefs.exportBranches(sink)

        return {self.OUTPUT: dest_id}


class biefUnit():
    def  __init__(self, feature):
        self.feature = feature
        self.geom = feature.geometry()
        self.setList()
    
    def setList(self):
        self.list = self.geom.asPolyline()
        
    def setFeature(self):
        self.feature.setGeometry(QgsGeometry.fromPolyline(map(QgsPoint, self.list)))
        self.geom = self.feature.geometry()
    
    def first(self):
        return self.list[0]
        
    def last(self):
        return self.list[len(self.list)-1]
        
    def extend(self, end, list):
        if end:
            nl = self.list
            nl.extend(list)
        else:
            nl = list
            nl.extend(self.list)
        self.list = nl
        self.setFeature()
        
    def id(self):
        return self.feature.attribute('id')

class biefsCollection():
    def  __init__(self, layer, feedback, tolerance=0):
        self.layer = layer
        self.feedback = feedback
        self.setTolerance(tolerance)
        self.makeUnits()
        self.agglomerate()

    def setTolerance(self, dist):
        self.tolerance = dist
    
    def makeUnits(self):
        self.units = []
        for f in self.layer.getFeatures():
            self.units.append(biefUnit(f))
        self.total = len(self.units)
        self.branches = []
    
    def matchPoint(self, p1, p2):
        d = p1.distance(p2)
        if d<=self.tolerance:
            return True
            
    def agglomerate(self):
        while len(self.units)>0:
            if self.feedback.isCanceled():
                break
            self.feedback.setProgress(int(100*(1-(len(self.units)/self.total))))
            b = self.units.pop()
            self.branches.append(b)
            self.connect()
            self.connect(True) 

    def currentBief(self):
        l = len(self.branches)
        if l>0:
            return self.branches[l-1]
    
    def connect(self, end=False):
        bief = self.currentBief()
        self.end = end
        if end==True:
            point = bief.last()
        else:
            point = bief.first()
        matches = []
        for i, b in enumerate(self.units):
            if self.feedback.isCanceled():
                break
            if self.matchPoint(point, b.first()):
                matches.append((i, False))
            if self.matchPoint(point, b.last()):
                matches.append((i, True))
        if len(matches)==1:
            self.melt(*matches[0])
    
    def melt(self, i, end):
        bief = self.currentBief()
        b = self.units.pop(i)
        list = b.list.copy()
        if self.end==end:
            list.reverse()
        bief.extend(self.end, list)
        self.connect(self.end)
    
    def exportBranches(self, sink):
        for b in self.branches:
            if self.feedback.isCanceled():
                break
            sink.addFeature(b.feature, QgsFeatureSink.FastInsert)



        