# -*- coding: utf-8 -*-

"""
/***************************************************************************
 DiffForestHedge
                                 A QGIS plugin
 Allow to split forest from hedge and others tree area
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2022-01-11
        copyright            : (C) 2022 by Dynafor
        email                : gabriel.marques@toulouse-inp.fr
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""

__author__ = "Dynafor"
__date__ = "2022-01-11"
__copyright__ = "(C) 2022 by Dynafor"

# This will get replaced with a git SHA1 when you do a git archive

__revision__ = "$Format:%H$"

import cProfile
import io
import math
import os
import pstats
from pstats import SortKey

import processing
from hedge_tools import (
    resources,
)  # Only need in hedge_tools.py normaly but just to keep track of import
from hedge_tools.tools.vector import attribute_table as at
from hedge_tools.tools.vector import geometry as g
from hedge_tools.tools.vector import qgis_wrapper as qw
from hedge_tools.tools.vector import utils
from qgis.core import (
    QgsFeature,
    QgsFeatureRequest,
    QgsFeatureSink,
    QgsGeometry,
    QgsProcessing,
    QgsProcessingAlgorithm,
    QgsProcessingParameterFeatureSink,
    QgsProcessingParameterFeatureSource,
    QgsProcessingParameterNumber,
    QgsProcessingUtils,
    QgsVectorLayer,
    QgsWkbTypes,
)
from qgis.core.additions.edit import edit
from qgis.PyQt.QtCore import QCoreApplication, QVariant
from qgis.PyQt.QtGui import QIcon


class DiffForestHedgeAlgorithm(QgsProcessingAlgorithm):
    """
    Allow to categorize forest, hedges and others tree patchs such as grove
    from a classified (machine-learning output) forest layer

    Parameters
    ---
    INPUT (QgisObject : QgsVectorLayer) : Layer input from users.
    OPENING (int) : Value for erosion/dilation.
    FOREST_AREA (int) : Value threshold for forest classification.
    GROVE_AREA (int) : Value threshold for grove/lone tree classification.

    Return
    ---
    FOREST (str) : Forest polygons layer
    HEDGE (str) : Hedge polygons layer
    GROVE (str) : Grove polygons layer
    LONE_TREE (str) : Lone tree polygons layer

    TODO: change first step of post process to addFetaure to not overload buffer
    TODO : simplification in pre process seems to degrade results. But no simplification is  longer. Need to loop through feature to not overload memory
    """

    # Constants used to refer to parameters and outputs. They will be
    # used when calling the algorithm from another algorithm, or when
    # calling from the QGIS console.
    INPUT = "INPUT"
    OPENING = "0PENING"
    FOREST_AREA = "FOREST_AREA"
    GROVE_AREA = "GROVE_AREA"
    FOREST = "FOREST"
    HEDGE = "HEDGE"
    GROVE = "GROVE"
    LONE_TREE = "LONE_TREE"

    def initAlgorithm(self, config):
        """
        Here we define the inputs and output of the algorithm, along
        with some other properties.
        """

        # We add the input vector features source.
        self.addParameter(
            QgsProcessingParameterFeatureSource(
                self.INPUT,
                self.tr("Polygons vector layer"),
                [QgsProcessing.TypeVectorPolygon],
            )
        )

        # We add the buffer distance to smooth the border of the polygone network
        self.addParameter(
            QgsProcessingParameterNumber(
                name=self.OPENING,
                description=self.tr("Buffer size (meters)"),
                type=QgsProcessingParameterNumber.Integer,
                defaultValue=20,
                optional=False,
                minValue=0,
                maxValue=150,
            )
        )

        # We add the area threshold for the spatial selection
        self.addParameter(
            QgsProcessingParameterNumber(
                name=self.FOREST_AREA,
                description=self.tr("Forest area threshold (square meters)"),
                type=QgsProcessingParameterNumber.Integer,
                defaultValue=5000,
                optional=False,
                minValue=0,
                maxValue=10000,
            )
        )

        # We add the area threshold for the spatial selection
        self.addParameter(
            QgsProcessingParameterNumber(
                name=self.GROVE_AREA,
                description=self.tr(
                    "Grove area threshold (should be lower than forest area threshold) (square meters)"
                ),
                type=QgsProcessingParameterNumber.Integer,
                defaultValue=500,
                optional=False,
                minValue=0,
                maxValue=10000,
            )
        )

        # We add a feature sink in which to store our processed features (this
        # usually takes the form of a newly created vector layer when the
        # algorithm is run in QGIS).
        self.addParameter(
            QgsProcessingParameterFeatureSink(self.FOREST, self.tr("Forest"))
        )

        # We add a feature sink in which to store our processed features (this
        # usually takes the form of a newly created vector layer when the
        # algorithm is run in QGIS).
        self.addParameter(
            QgsProcessingParameterFeatureSink(self.HEDGE, self.tr("Linear element"))
        )

        # We add a feature sink in which to store our processed features (this
        # usually takes the form of a newly created vector layer when the
        # algorithm is run in QGIS).
        self.addParameter(
            QgsProcessingParameterFeatureSink(self.GROVE, self.tr("Grove"))
        )

        self.addParameter(
            QgsProcessingParameterFeatureSink(self.LONE_TREE, self.tr("Scattered tree"))
        )

    def __init__(self):
        super().__init__()
        self.CRS = None
        self.input_field_list = []

    def processAlgorithm(self, parameters, context, feedback):
        """
        Here is where the processing itself takes place.
        """
        self.context = context

        # Init
        input_layer = self.parameterAsSource(parameters, self.INPUT, self.context)
        opening = self.parameterAsInt(parameters, self.OPENING, self.context)
        forest_area = self.parameterAsInt(parameters, self.FOREST_AREA, self.context)
        grove_area = self.parameterAsInt(parameters, self.GROVE_AREA, self.context)

        alg_number = 15
        step_per_alg = int(100 / alg_number)
        step = 1

        feedback.pushInfo("Starting processing")
        feedback.setProgress(0)

        # Get fields of input layer
        self.input_field_list = input_layer.fields().names()
        self.input_field_list.append("ht_class")
        self.CRS = input_layer.sourceCrs()

        # Check for cancellation
        if feedback.isCanceled():
            return {}

        # Just in case
        input_layer = qw.multipart_to_singleparts(parameters["INPUT"])
        self.context.temporaryLayerStore().addMapLayer(input_layer)

        if "fid" in input_layer.fields().names():
            self.input_field_list.append("fid")

            idx = input_layer.fields().indexFromName("fid")
            success = utils.update_unique_constraint(feedback, input_layer, idx)
            if success[0] is False:
                return {}  # Feedback should already be pushed inside function
            elif isinstance(success[1], QgsVectorLayer):
                input_layer = success[1]
                self.context.temporaryLayerStore().addMapLayer(input_layer)

        simplify_layer = qw.simplification(input_layer, tolerance=1)
        self.context.temporaryLayerStore().addMapLayer(simplify_layer)

        forest_layer = qw.opening(simplify_layer, opening)
        self.context.temporaryLayerStore().addMapLayer(forest_layer)

        feedback.setProgress(step_per_alg * step)
        step += 1

        forest_layer = qw.remove_null_geometries(forest_layer)
        self.context.temporaryLayerStore().addMapLayer(forest_layer)

        feedback.setProgress(step_per_alg * step)
        step += 1

        forest_single = qw.multipart_to_singleparts(forest_layer)
        self.context.temporaryLayerStore().addMapLayer(forest_single)

        feedback.setProgress(step_per_alg * step)
        step += 1

        if feedback.isCanceled():
            return {}

        feedback.pushInfo("Slicing features")

        slice_layer = self.voronoi_polygon(input_layer)
        self.context.temporaryLayerStore().addMapLayer(slice_layer)

        feedback.setProgress(step_per_alg * step)
        step += 1

        if feedback.isCanceled():
            return {}

        feedback.pushInfo("Starting classification")

        no_forest = qw.extract_by_location(slice_layer, forest_single, 2)
        self.context.temporaryLayerStore().addMapLayer(no_forest)

        feedback.setProgress(step_per_alg * step)
        step += 1

        if no_forest.featureCount() != 0:
            forest_layer = qw.extract_by_location(slice_layer, forest_single, 0)
            self.context.temporaryLayerStore().addMapLayer(forest_layer)

            forest_single = qw.dissolve(forest_layer, separate_disjoint=True)
            self.context.temporaryLayerStore().addMapLayer(forest_single)

            feedback.setProgress(step_per_alg * step)
            step += 1

            expression = "$area < " + str(forest_area) + "and $area > 1"
            forest_false = qw.extract_by_expression(forest_single, expression)
            self.context.temporaryLayerStore().addMapLayer(forest_false)

            expression = "$area >= " + str(forest_area)
            forest_real = qw.extract_by_expression(forest_single, expression)
            self.context.temporaryLayerStore().addMapLayer(forest_real)

            no_forest_full = qw.merge_layers([no_forest, forest_false])
            self.context.temporaryLayerStore().addMapLayer(no_forest_full)

            no_forest_full = qw.dissolve(no_forest_full, separate_disjoint=True)
            self.context.temporaryLayerStore().addMapLayer(no_forest_full)

            feedback.setProgress(step_per_alg * step)
            step += 1

            no_forest_full = qw.field_calculator(no_forest_full, "pid", 1, "$id", 10, 0)
            self.context.temporaryLayerStore().addMapLayer(no_forest_full)

            feedback.setProgress(step_per_alg * step)
            step += 1
            if feedback.isCanceled():
                return {}

            feedback.pushInfo("Creating classification metrics")
            parameters_layer = self.compute_classification_parameters(no_forest_full)
            self.context.temporaryLayerStore().addMapLayer(parameters_layer)

            feedback.setProgress(step_per_alg * step)
            step += 1
            if feedback.isCanceled():
                return {}

            feedback.pushInfo("Attributing class")
            class_ahf = self.morphological_classification(
                parameters_layer, forest_area, grove_area
            )
            self.context.temporaryLayerStore().addMapLayer(class_ahf)

            feedback.setProgress(step_per_alg * step)
            step += 1

            class_ahf = qw.remove_holes(class_ahf, 0.5)
            self.context.temporaryLayerStore().addMapLayer(class_ahf)

            forest_real = qw.remove_holes(forest_real, 0.5)
            self.context.temporaryLayerStore().addMapLayer(forest_real)

            feedback.setProgress(step_per_alg * step)
            step += 1
            if feedback.isCanceled():
                return {}

            feedback.pushInfo("Starting post-processing")
            if forest_real.featureCount() != 0:
                forest_layer, class_ahf = self.remove_missclass_forest(
                    forest_real, class_ahf
                )
                self.context.temporaryLayerStore().addMapLayer(forest_layer)
                self.context.temporaryLayerStore().addMapLayer(class_ahf)

                forest_real_single, class_ahf = self.remove_missclass_near_forest(
                    forest_layer, class_ahf, opening
                )
                self.context.temporaryLayerStore().addMapLayer(forest_real_single)
                self.context.temporaryLayerStore().addMapLayer(class_ahf)

                classif_ahf = self.merge_forest_with_others(
                    forest_real_single, class_ahf, input_layer
                )
                self.context.temporaryLayerStore().addMapLayer(classif_ahf)
            else:
                classif_ahf = slice_layer
                classif_ahf = qw.field_calculator(classif_ahf, "fid", 1, "$id")
                classif_ahf = qw.field_calculator(classif_ahf, "pid", 1, "$id")
                classif_ahf = qw.field_calculator(
                    classif_ahf, "ht_class", 2, "'Forest'"
                )
                self.context.temporaryLayerStore().addMapLayer(classif_ahf)

        feedback.setProgress(step_per_alg * step)
        step += 1
        if feedback.isCanceled():
            return {}

        feedback.pushInfo("Correcting hedge/forest connection")

        feat = next(classif_ahf.getFeatures())
        if feat.id() == 0:
            classif_ahf = utils.create_layer(
                classif_ahf, copy_feat=True, copy_field=True
            )
            self.context.temporaryLayerStore().addMapLayer(classif_ahf)

        if "fid" in classif_ahf.fields().names() and "fid" not in self.input_field_list:
            self.input_field_list.append("fid")

        new_field_list = classif_ahf.fields().names()
        field_del_list = list(set(new_field_list) - set(self.input_field_list))
        classif_layer = qw.delete_column(classif_ahf, field_del_list)
        self.context.temporaryLayerStore().addMapLayer(classif_layer)

        forest_layer, hedge_layer, grove_layer, lone_layer = (
            self.split_layer_by_modality(classif_layer)
        )
        self.context.temporaryLayerStore().addMapLayer(forest_layer)
        self.context.temporaryLayerStore().addMapLayer(hedge_layer)
        self.context.temporaryLayerStore().addMapLayer(grove_layer)
        self.context.temporaryLayerStore().addMapLayer(lone_layer)

        feedback.setProgress(step_per_alg * step)
        step += 1

        hedge_layer, forest_layer = self.interface_correction(hedge_layer, forest_layer)
        self.context.temporaryLayerStore().addMapLayer(hedge_layer)
        self.context.temporaryLayerStore().addMapLayer(forest_layer)

        feedback.setProgress(step_per_alg * step)
        step += 1

        forest_layer = qw.field_calculator(forest_layer, "fid", 1, "$id")
        forest_layer = qw.field_calculator(forest_layer, "pid", 1, "$id")
        self.context.temporaryLayerStore().addMapLayer(forest_layer)

        hedge_layer = qw.field_calculator(hedge_layer, "fid", 1, "$id")
        hedge_layer = qw.field_calculator(hedge_layer, "pid", 1, "$id")
        self.context.temporaryLayerStore().addMapLayer(hedge_layer)

        grove_layer = qw.field_calculator(grove_layer, "fid", 1, "$id")
        grove_layer = qw.field_calculator(grove_layer, "pid", 1, "$id")
        self.context.temporaryLayerStore().addMapLayer(grove_layer)

        lone_layer = qw.field_calculator(lone_layer, "fid", 1, "$id")
        lone_layer = qw.field_calculator(lone_layer, "pid", 1, "$id")
        self.context.temporaryLayerStore().addMapLayer(lone_layer)

        feedback.pushInfo("Creating and formatting output")
        feedback.setProgress(step_per_alg * step)
        step += 1

        (sink_forest, forest_id) = self.parameterAsSink(
            parameters,
            self.FOREST,
            self.context,
            forest_layer.fields(),
            forest_layer.wkbType(),
            forest_layer.sourceCrs(),
        )
        for feature in forest_layer.getFeatures():
            sink_forest.addFeature(feature, QgsFeatureSink.FastInsert)

        (sink_hedge, hedge_id) = self.parameterAsSink(
            parameters,
            self.HEDGE,
            self.context,
            hedge_layer.fields(),
            hedge_layer.wkbType(),
            hedge_layer.sourceCrs(),
        )
        for feature in hedge_layer.getFeatures():
            sink_hedge.addFeature(feature, QgsFeatureSink.FastInsert)

        (sink_grove, grove_id) = self.parameterAsSink(
            parameters,
            self.GROVE,
            self.context,
            grove_layer.fields(),
            grove_layer.wkbType(),
            grove_layer.sourceCrs(),
        )
        for feature in grove_layer.getFeatures():
            sink_grove.addFeature(feature, QgsFeatureSink.FastInsert)

        (sink_lone, lone_id) = self.parameterAsSink(
            parameters,
            self.LONE_TREE,
            self.context,
            lone_layer.fields(),
            lone_layer.wkbType(),
            lone_layer.sourceCrs(),
        )
        for feature in lone_layer.getFeatures():
            sink_lone.addFeature(feature, QgsFeatureSink.FastInsert)

        return {
            "FOREST": forest_id,
            "HEDGE": hedge_id,
            "GROVE": grove_id,
            "LONE": lone_id,
        }

    def unpack_geometry_collection(self, gc, fid, retain=[QgsWkbTypes.PolygonGeometry]):
        """
        Unpack geometry in a geometry collection with respect
        to the type of geometry specified in retain argument

        Parameters
        ----------
        gc:
            Geometry collection to unpack
        fid:
            Unique identifier used for the feature
        retain:
            Type of geometry to retain

        Returns
        -------
        unpacked:
            Retained geometries
        fid:
            Updated fid
        """
        unpacked = []
        for geometry in gc:  # .asGeometryCollection():
            if geometry.type() in retain:
                feature = QgsFeature()
                feature.setGeometry(geometry)
                feature.setAttributes([fid])
                unpacked.append(feature)
                fid += 1

        return unpacked, fid

    def inner_voronoi_diagram(self, polygon_geometry, fid):
        """
        From a set of voronoi polygons, retain only the ones
        fully contained in the parent polygon

        Parameters
        ----------
        polygon_geometry:
            Polygon to compute voronoi from
        fid:
            Unique identifier

        Returns
        -------
        voronoi_list:
            List of features containing the strictly contained voronoi polygons
        fid:
            Incremented unique identifier
        """
        voronoi_list = []

        voronoi_geom = polygon_geometry.voronoiDiagram()
        engine = QgsGeometry.createGeometryEngine(polygon_geometry.constGet())
        engine.prepareGeometry()

        for g in voronoi_geom.asGeometryCollection():
            geom = QgsGeometry(engine.intersection(g.constGet()))
            gc = geom.asGeometryCollection()
            if (
                geom.type() == QgsWkbTypes.UnknownGeometry
                or geom.type() == QgsWkbTypes.MultiPolygon
                or len(gc) > 1
            ):
                unpacked, fid = self.unpack_geometry_collection(gc, fid)
                voronoi_list += unpacked
            elif geom.type() == QgsWkbTypes.PolygonGeometry:
                feat = QgsFeature()
                feat.setGeometry(geom)
                feat.setAttributes([fid])
                voronoi_list.append(feat)
                fid += 1

        return voronoi_list, fid

    def voronoi_polygon(self, poly_layer):
        """
        Compute the inner voronoi polygon for polygons

        Parameters
        ----------
        poly_layer:
            Input polygon layer

        Returns
        -------
        slice_layer:
            Output layer with voronoi polygons
        """
        slice_uri = QgsProcessingUtils.generateTempFilename("slice_layer.gpkg")
        slice_layer = utils.create_layer(
            poly_layer, data_provider="ogr", path=slice_uri
        )
        add_list = []

        fid = 1
        for feat in poly_layer.getFeatures():
            polygon = feat.geometry()
            features, fid = self.inner_voronoi_diagram(polygon, fid)
            add_list += features
        slice_layer.dataProvider().addFeatures(add_list)

        self.context.temporaryLayerStore().addMapLayer(slice_layer)

        return slice_layer

    def compute_direct_metrics(self, layer):
        """
        Compute metrics from a polygonal layer
        Metrics : perimeter and area

        Parameters
        ---
        layer : QgsVectorLayer : Polygon

        Return
        ---
        output : QgsVectorLayer : Polygon : New no_forest_layer with metrics
        """
        layer = qw.field_calculator(layer, "perimeter", 0, "$perimeter", 10, 2)
        output = qw.field_calculator(layer, "area_HT", 0, "$area", 10, 2)

        self.context.temporaryLayerStore().addMapLayer(output)

        return output

    def compute_elongation(self, layer):
        """
        Compute elongation parameter
        for morphological classification

        Parameters
        ---
        layer : QgsVectorLayer : Polygon

        Return
        ---
        output : QgsVectorLayer : Polygon
        """
        formula = "CASE \
                WHEN CH_area <> 0 AND perimeter <> 0 \
                AND BB_width <> 0 THEN BB_height / BB_width \
                END"
        output = qw.field_calculator(layer, "elongation", 0, formula, 2, 2)

        self.context.temporaryLayerStore().addMapLayer(output)

        return output

    def compute_convexity(self, layer):
        """
        Compute convexity parameter
        for morphological classification

        Parameters
        ---
        layer : QgsVectorLayer : Polygon

        Return
        ---
        output : QgsVectorLayer : Polygon
        """
        formula = "CASE \
                WHEN CH_area <> 0 AND perimeter <> 0 \
                AND BB_width <> 0 THEN area_HT / CH_area \
                END"
        output = qw.field_calculator(layer, "convexity", 0, formula, 2, 2)

        self.context.temporaryLayerStore().addMapLayer(output)

        return output

    def compute_compactness(self, layer):
        """
        Compute compactness parameter
        for morphological classification

        Parameters
        ---
        layer : QgsVectorLayer : Polygon

        Return
        ---
        output : QgsVectorLayer : Polygon
        """
        formula = "CASE \
                WHEN CH_area <> 0 AND perimeter <> 0 \
                AND BB_width <> 0 THEN 4 * pi() * (area_HT / perimeter^2) \
                END"
        output = qw.field_calculator(layer, "compactness", 0, formula, 2, 2)

        self.context.temporaryLayerStore().addMapLayer(output)

        return output

    def compute_morphological_metrics(self, layer):
        """
        Compute indirect metrics used in morphological classification

        Parameters
        ---
        layer : QgsVectorLayer : Polygon

        Return
        ---
        output : QgsVectorLayer : Polygon
        """
        # BB
        bounding_box = qw.create_ombb(layer)
        # Convexhull
        convex_hull = qw.create_convex_hull(layer)
        # Join metrics from convexhull and BB to our noForest layer
        # with pid field
        join_layer = [bounding_box, convex_hull]
        fk_fields = "pid"  # Will use it for both layer automatically
        join_field = [["width", "height"], ["area"]]
        join_field_prefix = ["BB_", "CH_"]

        at.table_join(
            layer, "pid", join_layer, fk_fields, join_field, join_field_prefix
        )

        self.input_field_list.append("pid")

        # Compute elongation, convexity, compactness
        layer = self.compute_elongation(layer)
        layer = self.compute_convexity(layer)
        output = self.compute_compactness(layer)

        self.context.temporaryLayerStore().addMapLayer(output)

        return output

    def compute_classification_parameters(self, layer):
        """
        First compute direct metrics from a layer,
        Then compute indirect metrics (convex hull, ombb).

        Intended use case is for a layer with no forest

        Parameters
        ---
        layer : QgsVectorLayer : Polygon

        Return
        ---
        output : QgsVectorLayer : Polygon : layer with metrics
        """
        metrics_layer = self.compute_direct_metrics(layer)
        output = self.compute_morphological_metrics(metrics_layer)

        self.context.temporaryLayerStore().addMapLayer(output)

        return output

    def morphological_classification(self, parameters_layer, forest_area, grove_area):
        """
        Perform morphological classification from a polygonal layer
        with morphological parameters from compute classification parameters

        Intended use case is for a layer with no forest

        Parameters
        ---
        parameters_layer : QgsVectorLayer : Polygon
        forest_area : int : Minimum size for a geometry to be a forest
        grove_area : int : Minimum size for a geometry to be a grove

        Return
        ---
        classification_layer : QgsVectorLayer : Polygon
        """
        formula = f"CASE \
                    WHEN convexity <= 0.7 AND compactness < 0.4 THEN 'Hedge' \
                    WHEN convexity <= 0.7 AND compactness >= 0.4 \
                        AND $area >= {forest_area} THEN 'Forest' \
                    WHEN convexity > 0.7 AND elongation > 2.5 THEN 'Hedge' \
                    WHEN convexity > 0.7 AND elongation <= 2.5 \
                        AND $area >= {grove_area} AND $area < {forest_area} THEN 'Grove' \
                    WHEN convexity > 0.7 AND elongation <= 2.5 \
                        AND $area < {grove_area} THEN 'Lone tree'\
                    ELSE 'Grove' \
                    END"
        classification_layer = qw.field_calculator(
            parameters_layer, "ht_class", 2, formula, 10
        )

        self.context.temporaryLayerStore().addMapLayer(classification_layer)

        return classification_layer

    def remove_missclass_forest(self, forest_layer, classification_layer):
        """
        Remove misclassified forest features contained in the classification layer.
        ...
        """
        forest_fields = forest_layer.fields()

        expression = "ht_class = 'Forest' or ht_class = 'Grove'"
        request = QgsFeatureRequest().setFilterExpression(expression)

        del_list = []
        add_list = []
        for forest in forest_layer.getFeatures():
            geom = forest.geometry()
            count, features = g.get_clementini(
                classification_layer, geom, request, "touches"
            )
            if count != 0:
                for feat in features:
                    del_list.append(feat.id())
                    feat.setFields(forest_fields)
                    add_list.append(feat)

        if forest_layer.featureCount() != 0:
            fid_idx = forest_layer.fields().indexFromName("fid")
            max_fid = max(forest_layer.uniqueValues(fid_idx))
        else:
            max_fid = 0

        for i, feat in enumerate(add_list):
            new_feat = QgsFeature()
            new_feat.setGeometry(feat.geometry())
            feat["fid"] = max_fid + i + 1

        forest_layer.dataProvider().addFeatures(add_list)
        classification_layer.dataProvider().deleteFeatures(del_list)
        forest_layer.updateExtents()

        # Add updated layers to the QGIS temporary layer store for visualization
        self.context.temporaryLayerStore().addMapLayer(forest_layer)
        self.context.temporaryLayerStore().addMapLayer(classification_layer)

        return forest_layer, classification_layer

    def remove_missclass_near_forest(
        self, forest_layer, classification_layer, distance
    ):
        """
        Reclassify features near forest polygons as forest if they are fully contained
        ...
        """
        forest_real_diss = qw.dissolve(forest_layer)
        forest_real_single = qw.multipart_to_singleparts(forest_real_diss)
        distance = distance + (distance / 2)
        forest_real_buffer = qw.buffer(forest_real_single, distance, True)

        buffer = next(forest_real_buffer.getFeatures())
        _, contains_list = g.get_clementini(
            classification_layer, buffer, predicate="contains"
        )
        contains_id_list = [feat.id() for feat in contains_list]

        forest = next(forest_real_diss.getFeatures())
        request = QgsFeatureRequest().setFilterFids(contains_id_list)

        _, inter_list = g.get_clementini(classification_layer, forest, request)
        inter_id_list = [feat.id() for feat in inter_list]
        request = QgsFeatureRequest().setFilterFids(inter_id_list)

        idx = classification_layer.fields().indexFromName("ht_class")
        for feat in classification_layer.getFeatures(request):
            with edit(classification_layer):
                classification_layer.changeAttributeValue(feat.id(), idx, "Forest")

        # Add temporary layers to self.context for debugging/visualization
        self.context.temporaryLayerStore().addMapLayer(forest_real_single)
        self.context.temporaryLayerStore().addMapLayer(classification_layer)

        return forest_real_single, classification_layer

    def merge_forest_with_others(self, forest_layer, classification_layer, input_layer):
        """
        Merge forest layer and classified layer (hedge, grove, lone tree) together
        Then dissolve the resulting layer by the classification field
        to regroup adjacent features of same class and explode it again.

        Parameters
        ---
        forest_layer : QgsVectorLayer : Polygon
        classifcation_layer : QgsVectorLayer : Polygon
        input_layer : QgsVectorLayer : Polygon
            Input layer of the algorithm with matching topology against sliced layers
            AKA "valid_input"

        Return
        ---
        classification_ahf : QgsVectorLayer : Polygon
            New classified layer including forest

        TODO : input_layer and clip Why ? Need to test again to remember
        """
        # Add class field to forest layer
        classif_forest = qw.field_calculator(
            forest_layer, "ht_class", 2, "'Forest'", 10
        )
        self.context.temporaryLayerStore().addMapLayer(classif_forest)

        del forest_layer

        # Fuse forest layer and "other" layer
        classif_ahf = qw.merge_layers([classification_layer, classif_forest])
        self.context.temporaryLayerStore().addMapLayer(classif_ahf)

        del classif_forest

        # Dissolve by class field to fuse touching forest
        classif_dissolve = qw.dissolve(classif_ahf, "ht_class")
        self.context.temporaryLayerStore().addMapLayer(classif_dissolve)

        classification_ahf = qw.multipart_to_singleparts(classif_dissolve)
        self.context.temporaryLayerStore().addMapLayer(classification_ahf)

        del classif_ahf

        # Forest and others patches differentiation
        # classification_ahf = qw.clip(classif_single, input_layer)
        # del(classif_single)

        return classification_ahf

    def split_layer_by_modality(self, classification_layer):
        """
        First compute pid to be coherent between splitted layer
        Will split the classification layer into each of his modality (hard written)
        Finally compute for each splitted layers the fid field

        Parameters
        ---
        classification_layer : QgsVectorLayer : Polygon

        Return
        ---
        forest_layer : QgsVectorLayer : Polygon
        hedge_layer : QgsVectorLayer : Polygon
        grove_layer : QgsVectorLayer : Polygon
        lone_layer : QgsVectorLayer : Polygon
        """
        # Compute pid again
        classification_layer = qw.field_calculator(
            classification_layer, "pid", 1, "$id"
        )
        self.context.temporaryLayerStore().addMapLayer(classification_layer)

        # Split each class in a layer
        exp = "ht_class is 'Forest'"
        req = QgsFeatureRequest().setFilterExpression(exp)
        forest_layer = utils.create_layer(
            classification_layer, copy_feat=True, request=req, copy_field=True
        )
        self.context.temporaryLayerStore().addMapLayer(forest_layer)

        exp = "ht_class is 'Hedge'"
        req = QgsFeatureRequest().setFilterExpression(exp)
        hedge_layer = utils.create_layer(
            classification_layer, copy_feat=True, request=req, copy_field=True
        )
        self.context.temporaryLayerStore().addMapLayer(hedge_layer)

        exp = "ht_class is 'Grove'"
        req = QgsFeatureRequest().setFilterExpression(exp)
        grove_layer = utils.create_layer(
            classification_layer, copy_feat=True, request=req, copy_field=True
        )
        self.context.temporaryLayerStore().addMapLayer(grove_layer)

        exp = "ht_class is 'Lone tree'"
        req = QgsFeatureRequest().setFilterExpression(exp)
        lone_layer = utils.create_layer(
            classification_layer, copy_feat=True, request=req, copy_field=True
        )
        self.context.temporaryLayerStore().addMapLayer(lone_layer)

        # Compute fid
        forest_layer = qw.field_calculator(forest_layer, "fid", 1, "$id")
        self.context.temporaryLayerStore().addMapLayer(forest_layer)

        hedge_layer = qw.field_calculator(hedge_layer, "fid", 1, "$id")
        self.context.temporaryLayerStore().addMapLayer(hedge_layer)

        grove_layer = qw.field_calculator(grove_layer, "fid", 1, "$id")
        self.context.temporaryLayerStore().addMapLayer(grove_layer)

        lone_layer = qw.field_calculator(lone_layer, "fid", 1, "$id")
        self.context.temporaryLayerStore().addMapLayer(lone_layer)

        return forest_layer, hedge_layer, grove_layer, lone_layer

    def get_correct_parts(self, splits, main_geom):
        """
        After a successfull splitGeometry, get results and input geometries
        and add them together.
        (splitGeometry store a part of split in splits and other part in input)
        It then keep all the parts except the main geom (based on max area).
        This allow to find if the corretc results is in input or splits.

        Intended use case is for keeping small cut polygons that will
        be used to correct junction between an hedge and a forest

        Parameters
        ---
        splits : ite[QgsGeometry] : Results of splitGeometry
        main_geom : QgsGeometry : Modified input of splitGeometry

        Results
        ---
        splits : ite[QgsGeometry] : "Correct" results for the intended use case
        """
        splits.append(main_geom)
        area = [g.area() for g in splits]
        max_area_idx = area.index(max(area))
        splits.pop(max_area_idx)

        return splits

    def trim_polygon(self, polygon, line):
        """
        Split a polygon with a line.
        From results keep all the geometries except the main one.
        For potential improvement see correct_interface method below.

        Parameters
        ---
        polygon : QgsGeometry : Polygon
        line : QgsGeometry : LineString

        Return
        ---
        success : Boolean : True if sucessfull
        splits : ite[QgsGeometry:Polygon]
        """

        poly_copy = QgsGeometry(polygon)
        success, splits, topo = poly_copy.splitGeometry(line.asPolyline(), True)
        if success == 0:  # good
            splits = self.get_correct_parts(splits, poly_copy)
            return True, splits
        else:
            return False, False

    def create_features(self, geometries):
        """
        From a list of geometries create a list of features
        This method exist just to hide a for loop

        Parameters
        ---
        geometries : ite[QgsGeometry]

        Return
        ---
        features : ite[QgsFeature]
        """
        features = []

        for geom in geometries:
            feat = QgsFeature()
            feat.setGeometry(geom)
            features.append(feat)

        return features

    def interface_correction_preprocess(self, polygon_layer_1, polygon_layer_2):
        """
        First get the corrected interface junction between two polygons.
        Then split the polygons by this new junction and retrieve the created parts

        We could improve the split by keeping also the main_goem in trim_polygon
        and modify input layer. Problem would be where a goem is cutted multiple time
        as modifying the underlying dataProvider while working on it would probably
        crash Qgis. TLDR : Need a workaround for that

        Parameters
        ---
        polygon_layer_1 :  QgsVectorLayer : Polygon
        polygon_layer_2 :  QgsVectorLayer : Polygon

        Return
        ---
        poly_p1 :  QgsVectorLayer : Polygon
        poly_p2 :  QgsVectorLayer : Polygon
        """
        poly_p1 = utils.create_layer(polygon_layer_1)
        poly_p2 = utils.create_layer(polygon_layer_1)
        features_p1 = []
        features_p2 = []

        for polygon_1 in polygon_layer_1.getFeatures():
            geom_1 = polygon_1.geometry()
            _, polygons_2 = g.get_clementini(polygon_layer_2, geom_1)
            for polygon_2 in polygons_2:
                geom_2 = polygon_2.geometry()
                success, lines = g.get_new_intersection(geom_1, geom_2)
                if success is True:
                    for line in lines:
                        success1, parts_1 = self.trim_polygon(geom_1, line)
                        success2, parts_2 = self.trim_polygon(geom_2, line)
                        if success1 is True:
                            features_p1 += self.create_features(parts_1)
                        if success2 is True:
                            features_p2 += self.create_features(parts_2)

        poly_p1.dataProvider().addFeatures(features_p1)
        poly_p2.dataProvider().addFeatures(features_p2)

        self.context.temporaryLayerStore().addMapLayer(poly_p1)
        self.context.temporaryLayerStore().addMapLayer(poly_p2)

        return poly_p1, poly_p2

    def interface_correction(self, layer_1, layer_2):
        """
        Modify the interface between 2 geometries of different class
        by drawing a straight line between the start and the end of the intersection.
        It then clip the geometries and modify the input geometries to have a straight interface

        Parameters
        ---
        layer_1 : QgsVectorLayer : Polygon
        layer_2 : QgsVectorLayer : Polygon

        Return
        ---
        output_1 : QgsVectorLayer : Polygon : layer_1 corrected
        output_2 : QgsVectorLayer : Polygon : layer_2 corrected
        """
        layer_1_split, layer_2_split = self.interface_correction_preprocess(
            layer_1, layer_2
        )

        # Repair
        layer_1_fix = qw.fix_geometries(layer_1_split)
        self.context.temporaryLayerStore().addMapLayer(layer_1_fix)
        layer_2_fix = qw.fix_geometries(layer_2_split)
        self.context.temporaryLayerStore().addMapLayer(layer_2_fix)

        del layer_1_split
        del layer_2_split

        regroup_layer = qw.merge_layers([layer_2_fix, layer_1_fix])
        self.context.temporaryLayerStore().addMapLayer(regroup_layer)

        # Difference
        layer_1_diff = qw.difference(layer_1, regroup_layer)
        self.context.temporaryLayerStore().addMapLayer(layer_1_diff)
        layer_2_diff = qw.difference(layer_2, regroup_layer)
        self.context.temporaryLayerStore().addMapLayer(layer_2_diff)

        del layer_1
        del layer_2
        del regroup_layer

        # Regroup
        layer_1_regroup = qw.merge_layers([layer_1_diff, layer_2_fix])
        self.context.temporaryLayerStore().addMapLayer(layer_1_regroup)
        layer_2_regroup = qw.merge_layers([layer_2_diff, layer_1_fix])
        self.context.temporaryLayerStore().addMapLayer(layer_2_regroup)

        del layer_1_diff
        del layer_2_fix
        del layer_2_diff
        del layer_1_fix

        # Dissolve
        if layer_1_regroup.featureCount() != 0:
            layer_1_diss = qw.dissolve(layer_1_regroup)
            self.context.temporaryLayerStore().addMapLayer(layer_1_diss)
            del layer_1_regroup
        else:
            layer_1_diss = layer_1_regroup
            self.context.temporaryLayerStore().addMapLayer(layer_1_diss)

        if layer_2_regroup.featureCount() != 0:
            layer_2_diss = qw.dissolve(layer_2_regroup)
            self.context.temporaryLayerStore().addMapLayer(layer_2_diss)
            del layer_2_regroup
        else:
            layer_2_diss = layer_2_regroup
            self.context.temporaryLayerStore().addMapLayer(layer_2_diss)

        # Multi to single
        output_1 = qw.multipart_to_singleparts(layer_1_diss)
        self.context.temporaryLayerStore().addMapLayer(output_1)
        output_2 = qw.multipart_to_singleparts(layer_2_diss)
        self.context.temporaryLayerStore().addMapLayer(output_2)

        del layer_1_diss
        del layer_2_diss

        return output_1, output_2

    def icon(self):
        """
        Should return a QIcon which is used for your provider inside
        the Processing toolbox.
        """
        return QIcon(":/plugins/hedge_tools/images/hedge_tools.png")

    def shortHelpString(self):
        """
        Returns a localised short help string for the algorithm.
        """
        return self.tr(
            "This algorithm classifies a tree cover layer into \
                        four categories: \
                        forest, grove, linear element, and scattered tree. \n \
                        The buffer parameter is used to disconnect \
                        aggregated forests and linear element using an \
                        erosion-dilatation process. This parameter should be \
                        close to the maximum width of the linear element \
                        (hedgerows). \n \
                        Thresholds related to forest and grove area are used \
                        to discriminate them based on their extent."
        )

    def name(self):
        """
        Returns the algorithm name, used for identifying the algorithm. This
        string should be fixed for the algorithm, and must not be localised.
        The name should be unique within each provider. Names should contain
        lowercase alphanumeric characters only and no spaces or other
        formatting characters.
        """
        return "categorizewoodedarea"

    def displayName(self):
        """
        Returns the translated algorithm name, which should be used for any
        user-visible display of the algorithm name.
        """
        return self.tr("5 - Categorize wooded area [bêta]")

    def group(self):
        """
        Returns the name of the group this algorithm belongs to. This string
        should be localised.
        """
        return self.tr("0 - Extraction [optional]")

    def groupId(self):
        """
        Returns the unique ID of the group this algorithm belongs to. This
        string should be fixed for the algorithm, and must not be localised.
        The group id should be unique within each provider. Group id should
        contain lowercase alphanumeric characters only and no spaces or other
        formatting characters.
        """
        return "extraction"

    def tr(self, string):
        return QCoreApplication.translate("Processing", string)

    def createInstance(self):
        return DiffForestHedgeAlgorithm()
