# -*- coding: utf-8 -*-

"""
/***************************************************************************
 MedianAxisVoronoi
                                 A QGIS plugin
 Create median axis of polygone features with voronoï method
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2022-01-22
        copyright            : (C) 2022 by Dynafor
        email                : gabriel.marques@toulouse-inp.fr
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""

__author__ = "Dynafor"
__date__ = "2022-01-22"
__copyright__ = "(C) 2022 by Dynafor"

# This will get replaced with a git SHA1 when you do a git archive

__revision__ = "$Format:%H$"

import processing
from hedge_tools import (
    resources,
)  # Only need in hedge_tools.py normaly but just to keep track of import
from hedge_tools.tools.classes import HedgeGraph as hG
from hedge_tools.tools.vector import attribute_table as at
from hedge_tools.tools.vector import geometry as g
from hedge_tools.tools.vector import utils
from qgis.core import (
    QgsFeatureRequest,
    QgsFeatureSink,
    QgsGeometry,
    QgsLineString,
    QgsPoint,
    QgsProcessing,
    QgsProcessingAlgorithm,
    QgsProcessingParameterFeatureSink,
    QgsProcessingParameterFeatureSource,
    QgsProcessingParameterNumber,
    QgsProcessingUtils,
    QgsVectorLayer,
)
from qgis.PyQt.QtCore import QCoreApplication, QVariant
from qgis.PyQt.QtGui import QIcon


class TopologicalArcAlgorithm(QgsProcessingAlgorithm):
    """
    Implementation of median axis computation form voronoi extraction.
    Pre process is inspired from open jump skeletonizer algorithm.

    Parameters
    ---
    INPUT (QgisObject : QgsVectorLayer) : Polygon : Layer input from users.
    MIN_WIDTH (float) : Minimal width of a polygon in the input layer
                        Value used for pre process (simplification/densification)
    DANGLE_LGTH (int) : Below this length the path is considered a dangle and deleted

    Return
    ---
    OUTPUT (QgisObject : QgsVectorLayer) : Linestring : Skeleton of the input layer features.
    ERROR (QgisObject : QgsVectorLayer) : Linestring : Potential error in the median axis.
    Usually disconnected graph caused by a too high min width resulting in a vertex density too low.
    """

    # Constants used to refer to parameters and outputs. They will be
    # used when calling the algorithm from another algorithm, or when
    # calling from the QGIS console.
    INPUT = "INPUT"
    MIN_WIDTH = "MIN_WIDTH"
    DANGLE_LGTH = "DANGLE_LGTH"
    OUTPUT = "OUTPUT"
    ERROR = "ERROR"

    def initAlgorithm(self, config=None):
        """
        Here we define the inputs and output of the algorithm, along
        with some other properties.
        """

        # We add the input vector features source. It can have any kind of
        # geometry.
        self.addParameter(
            QgsProcessingParameterFeatureSource(
                self.INPUT, self.tr("Polygon layer"), [QgsProcessing.TypeVectorPolygon]
            )
        )

        # Distance value parameter
        self.addParameter(
            QgsProcessingParameterNumber(
                name=self.MIN_WIDTH,
                description=self.tr("Densification (meters)"),
                type=QgsProcessingParameterNumber.Double,
                defaultValue=-1,
                optional=False,
                minValue=-1,
                maxValue=1000,
            )
        )

        # Distance value parameter
        self.addParameter(
            QgsProcessingParameterNumber(
                name=self.DANGLE_LGTH,
                description=self.tr("Dangle length (meters)"),
                type=QgsProcessingParameterNumber.Integer,
                defaultValue=30,
                optional=False,
                minValue=0,
                maxValue=1000,
            )
        )

        # We add a feature sink in which to store our processed features (this
        # usually takes the form of a newly created vector layer when the
        # algorithm is run in QGIS).
        self.addParameter(
            QgsProcessingParameterFeatureSink(self.OUTPUT, self.tr("Output layer"))
        )

        self.addParameter(
            QgsProcessingParameterFeatureSink(self.ERROR, self.tr("Error layer"))
        )

    def processAlgorithm(self, parameters, context, feedback):
        """
        Here is where the processing itself takes place.
        """
        polygon_layer = self.parameterAsVectorLayer(parameters, self.INPUT, context)
        min_width = self.parameterAsDouble(parameters, self.MIN_WIDTH, context)
        dangle_lgth = self.parameterAsInt(parameters, self.DANGLE_LGTH, context)

        idx = polygon_layer.fields().indexFromName("pid")
        if idx == -1:
            feedback.pushInfo(
                f"pid field will be created in {polygon_layer.name()}. \
                Please do not remove it if you want to use hedge tools"
            )
            idx = at.create_fields(polygon_layer, [("pid", QVariant.Int)])[0]
            attr_map = {f.id(): {idx: f.id()} for f in polygon_layer.getFeatures()}
            polygon_layer.dataProvider().changeAttributeValues(attr_map)

        count = polygon_layer.featureCount()

        feedback.setProgress(1)

        features = []
        errors = []
        unprocess = []

        id = 0
        for current, poly in enumerate(polygon_layer.getFeatures()):
            try:
                polygon_geometry = poly.geometry()
                polygon_geometry = g.prepare_geometry(polygon_geometry, min_width)
                skeleton = g.compute_skeleton(polygon_geometry)
                graph = hG.HedgeGraph(skeleton)
                graph.non_recursive_connected_component()
                graph.remove_fork(dangle_lgth)
                # We have to recreate graph object because "hole" in vid and eid create
                # ambiguity in topology in leaf correction
                # Edge or vertex creation return an id of max(id) + 1 but topology id
                # on vertex and edge store the first vacant place in the list of id
                temp_geom = graph.edges_to_lines()  # Default output_type is geometry
                temp_geom = g.remove_invalid_geom(
                    temp_geom
                )  # Empty geom can create a topological error
                graph = hG.HedgeGraph(temp_geom)
                graph.correct_leaf(poly.geometry())  # Pass original geometry

                # Update topology by instanciating again
                temp_geom = graph.edges_to_lines()
                temp_geom = g.remove_invalid_geom(temp_geom)
                graph = hG.HedgeGraph(temp_geom)

                # graph.delete_d3_too_close()
                temp_geom = self.handle_complex_vertices(graph)

                # Update topology by instanciating again
                graph = hG.HedgeGraph(temp_geom)
                graph.non_recursive_connected_component()
                results = graph.edges_to_polylines(output_type=True)
                if isinstance(results, tuple):  # There is error and results is a tuple
                    # features += results[0]
                    errors += results[1]
                    feats = results[0]
                else:
                    # features += results
                    feats = results
                # Add fid, poly id (pid), and edge id (eid) and set value
                feats = [
                    at.add_fields(
                        feature,
                        ["fid", "eid", "pid"],
                        [QVariant.Int, QVariant.Int, QVariant.Int],
                    )
                    for feature in feats
                ]
                features += [
                    at.set_fields_value(
                        feature,
                        ["fid", "eid", "pid"],
                        [id + i + 1, id + i + 1, poly["pid"]],
                    )
                    for i, feature in enumerate(feats)
                ]
                # Increment id value with len of feats for current poly
                id += len(feats)

                # del(temp)
                # Set progress
                feedback.setProgress(int((current / count) * 100))
                # Check for cancellation
                if feedback.isCanceled():
                    return {}
            except:
                unprocess.append(poly["pid"])

        # Load into layers
        graph_uri = QgsProcessingUtils.generateTempFilename("graph_layer.gpkg")
        graph_layer = utils.create_layer(
            polygon_layer, geom_type="LineString", data_provider="ogr", path=graph_uri
        )
        error_uri = QgsProcessingUtils.generateTempFilename("error_layer.gpkg")
        error_layer = utils.create_layer(
            polygon_layer, geom_type="LineString", data_provider="ogr", path=error_uri
        )
        # Create fields already present in feature. Fid is created by default
        _ = at.create_fields(
            graph_layer, [("eid", QVariant.Int), ("pid", QVariant.Int)]
        )
        graph_layer.dataProvider().addFeatures(features)
        error_layer.dataProvider().addFeatures(errors)

        graph_layer = self.smooth_complex_vertices(polygon_layer, graph_layer)
        # Compute Id_line --> eid ?
        # at.create_fields(graph_layer, [("eid", QVariant.Int)])
        # idx = graph_layer.fields().indexFromName("eid")
        # attr_map = {f.id(): {idx: f.id()} for f in graph_layer.getFeatures()}
        # graph_layer.dataProvider().changeAttributeValues(attr_map)

        if len(unprocess) != 0:
            feedback.pushWarning(
                f"Following polygons couldn't be processed (pid): \n\
                                {unprocess}"
            )

        (sink_output, output_id) = self.parameterAsSink(
            parameters,
            self.OUTPUT,
            context,
            graph_layer.fields(),
            graph_layer.wkbType(),
            graph_layer.sourceCrs(),
        )
        for feature in graph_layer.getFeatures():
            sink_output.addFeature(feature, QgsFeatureSink.FastInsert)

        (sink_error, error_id) = self.parameterAsSink(
            parameters,
            self.ERROR,
            context,
            error_layer.fields(),
            error_layer.wkbType(),
            error_layer.sourceCrs(),
        )
        for feature in error_layer.getFeatures():
            sink_error.addFeature(feature, QgsFeatureSink.FastInsert)

        return {self.OUTPUT: output_id, self.ERROR: error_id}

    def handle_complex_vertices(self, graph: hG.HedgeGraph) -> list[QgsGeometry]:
        """
        Iterate over graph object to snap close d3 vertices.
        It is necassary to create a new graph for each
        modification to preserve topology

        Parameters
        ----------

        graph: HedgeGraph
            Graph object to correct

        Returns
        -------
        temp_geom: list[QgsGeometry]
            List of line geometries representing the graph edges.
        """
        edge_count = 1
        curr_edge = 0
        i = 0

        # If it's not pruning anymore we stop
        while edge_count != curr_edge:
            i += 1

            edge_count = graph.edgeCount()

            graph.delete_d3_too_close()

            curr_edge = graph.edgeCount()

            if edge_count != curr_edge:
                temp_geom = graph.edges_to_lines()
                temp_geom = g.remove_invalid_geom(temp_geom)
                graph = hG.HedgeGraph(temp_geom)

            # Hard exit condition just in case
            if i == 500:
                break

        temp_geom = graph.edges_to_lines()
        temp_geom = g.remove_invalid_geom(temp_geom)

        return temp_geom

    def smooth_complex_vertices(
        self, polygon_layer: QgsVectorLayer, graph_layer: QgsVectorLayer
    ) -> QgsVectorLayer:
        """
        From the graph layer fetch complex (X+) nodes and smooth the neighbouring vertices to
        avoid creating too narrow angles. It is done abruptly by
        deleting previous vertex and checking if the line is still inside the polygon

        Parameters
        ----------
        polygon_layer:
        graph_layer:

        Returns
        -------
        graph_layer: Smoothed
        """
        alg_name = "hedgetools:topologicalnodes"
        params = {"INPUT": graph_layer, "OUTPUT": "TEMPORARY_OUTPUT"}
        node_layer = processing.run(alg_name, params)["OUTPUT"]

        # Step 1: Build lookup maps once
        vid_arc_map = {}  # vid -> list of arcs
        pid_polygon_map = {}  # pid -> polygon geometry

        for arc in graph_layer.getFeatures():
            vid1 = arc["vid_1"]
            vid2 = arc["vid_2"]
            pid = arc["pid"]

            for vid in (vid1, vid2):
                vid_arc_map.setdefault(vid, []).append(arc)

            if pid not in pid_polygon_map:
                # Fetch polygon only once per pid
                req_poly = QgsFeatureRequest().setFilterExpression(f"pid = {pid}")
                polygon = next(polygon_layer.getFeatures(req_poly), None)
                if polygon:
                    pid_polygon_map[pid] = polygon.geometry()

        # Step 2: Process high-degree nodes
        expression = "Degree >= 4"
        request = QgsFeatureRequest().setFilterExpression(expression)
        geom_map = {}

        for node in node_layer.getFeatures(request):
            node_vid = node["vid"]
            node_geom = node.geometry()

            arcs = vid_arc_map.get(node_vid, [])
            for arc in arcs:
                pid = arc["pid"]
                poly_geom = pid_polygon_map.get(pid)
                if not poly_geom:
                    continue

                polyline = arc.geometry().asPolyline()
                if len(polyline) <= 2:
                    continue

                first = QgsGeometry(QgsPoint(polyline[0]))
                deletion = polyline[1] if node_geom.isGeosEqual(first) else polyline[-2]

                # Remove point by value (NOT index)
                polyline = [pt for pt in polyline if pt != deletion]

                linestring = QgsLineString(polyline)
                geom = QgsGeometry().fromPolyline(linestring)

                if not geom.overlaps(poly_geom):
                    geom_map[arc.id()] = geom

        # Step 3: Apply changes
        graph_layer.dataProvider().changeGeometryValues(geom_map)
        at.delete_fields(graph_layer, ["vid_1", "vid_2"])

        return graph_layer

    def postProcessAlgorithm(self, context, feedback):
        """
        Tasks done when processAlgorithm is finished
        """
        utils.delete_processing_workspace()

        return {}

    def icon(self):
        """
        Should return a QIcon which is used for your provider inside
        the Processing toolbox.
        """
        return QIcon(":/plugins/hedge_tools/images/hedge_tools.png")

    def shortHelpString(self):
        """
        Returns a localised short help string for the algorithm.
        """
        return self.tr(
            "This algorithm computes the median axis of each \
                        polygon in a vector layer before to create a \
                        topological graph. \n \
                        A lower value of densification enables to obtain a \
                        smoother median axis but it can significantly slow \
                        down the process. The default value of -1 means that \
                        the algorithm will estimate the value to be used for \
                        each polygons. \n \
                        The dangle length value is used to trim the median \
                        axis at a junction. All the dangles lower than the \
                        specified value are removed to keep only two. \n \
                        The 'Eid' field is created and used as the unique \
                        identifier in the median axis layer. \n \
                        For polygons with narrow spot, the median axis could \
                        be discontinuous. In this case, the longest will be \
                        kept in the output layer and all the others will be \
                        put in the error layer. The user can reconnect some \
                        of them manually but the field values of the object \
                        in the output layer must be preserved for further \
                        processing in HedgeTools."
        )

    def name(self):
        """
        Returns the algorithm name, used for identifying the algorithm. This
        string should be fixed for the algorithm, and must not be localised.
        The name should be unique within each provider. Names should contain
        lowercase alphanumeric characters only and no spaces or other
        formatting characters.
        """
        return "topologicalarc"

    def displayName(self):
        """
        Returns the translated algorithm name, which should be used for any
        user-visible display of the algorithm name.
        """
        return self.tr("1 - Create topological arc")

    def group(self):
        """
        Returns the name of the group this algorithm belongs to. This string
        should be localised.
        """
        return self.tr("1 - Data preparation")

    def groupId(self):
        """
        Returns the unique ID of the group this algorithm belongs to. This
        string should be fixed for the algorithm, and must not be localised.
        The group id should be unique within each provider. Group id should
        contain lowercase alphanumeric characters only and no spaces or other
        formatting characters.
        """
        return "datapreparation"

    def tr(self, string):
        return QCoreApplication.translate("Processing", string)

    def createInstance(self):
        return TopologicalArcAlgorithm()
