# -*- coding: utf-8 -*-

"""
/***************************************************************************
 Landscape metrics
                                 A QGIS plugin
 Metrics at tile scale
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2022-01-22
        copyright            : (C) 2022 by Gabriel Marquès
        email                : gabriel.marques@toulouse-inp.fr
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""

__author__ = 'Gabriel Marquès'
__date__ = '2022-01-22'
__copyright__ = '(C) 2022 by Gabriel Marquès'

# This will get replaced with a git SHA1 when you do a git archive

__revision__ = '$Format:%H$'

from qgis.PyQt.QtCore import (QCoreApplication, 
                              QVariant)
from qgis.core import (QgsProcessing,
                       QgsProcessingUtils,
                       QgsProcessingAlgorithm,
                       QgsProcessingParameterFeatureSource,
                       QgsFeature,
                       QgsRectangle,
                       QgsProcessingParameterField,
                       QgsProcessingParameterEnum,
                       QgsFeatureRequest,
                       QgsVectorLayer,
                       NULL)
from qgis.PyQt.QtGui import QIcon
from hedge_tools import resources # Only need in hedge_tools.py normaly but just to keep track of import

import numpy as np

from hedge_tools.tools.vector import geometry as g
from hedge_tools.tools.vector import qgis_wrapper as qw
from hedge_tools.tools.vector import attribute_table as at
from hedge_tools.tools.vector import utils

class LandscapeMetricsAlgorithm(QgsProcessingAlgorithm):
    """
    Compute landscape metrics for each tile of a grid such as : 
    - Icon ;
    - Icoh ;
    - Density (%) ;
    - Icod ;
    - Total length.

    Parameters
    ----------
    GRID_LAYER : (QgisObject : QgsVectorLayer) : Polygon layer.
    NODE_LAYER : (QgisObject : QgsVectorLayer) : Point layer.
    METRICS : ite[int]
    NODE_TYPE : QgsField : Str
    FOREST_FIELD : QgsField : Double
    ARC_LAYER (QgisObject : QgsVectorLayer) : LineString layer.
    POLY_LAYER (QgisObject : QgsVectorLayer) : Polygon layer.

    Return
    ------
    OUTPUT_GRID (QgisObject : QgsVectorLayer) : Polygon layer.
    """

    # Constants used to refer to parameters and outputs. They will be
    # used when calling the algorithm from another algorithm, or when
    # calling from the QGIS console.
    GRID_LAYER = "GRID_LAYER"
    NODE_LAYER = "NODE_LAYER"
    METRICS = "METRICS"
    NODE_TYPE = "NODE_TYPE"
    FOREST_FIELD = "FOREST_FIELD"
    ARC_LAYER = "ARC_LAYER"
    POLY_LAYER = "POLY_LAYER"
    OUTPUT_GRID = "OUTPUT_GRID" 

    def initAlgorithm(self, config):
        """
        Here we define the inputs and output of the algorithm, along
        with some other properties.
        """
        self.addParameter(
            QgsProcessingParameterFeatureSource(
                self.GRID_LAYER,
                self.tr("Grid layer"),
                [QgsProcessing.TypeVectorPolygon],
                optional=False
            )
        )

        self.addParameter(
            QgsProcessingParameterFeatureSource(
                self.NODE_LAYER,
                self.tr("Node layer"),
                [QgsProcessing.TypeVectorPoint]
            )
        )

        # Which width to compute
        self.addParameter(
            QgsProcessingParameterEnum(
                name=self.METRICS,
                description=self.tr("Landscape metrics"),
                options=["Sum of O nodes", "Sum of L nodes", 
                         "Sum of T nodes", "Sum of X nodes",
                         "Sum of M nodes", "Icon", "Icoh (%)", 
                         "Density (ml/ha)", "Density (%)", "Icod", 
                         "Total length"],
                optional=False,
                allowMultiple=True
            )
        )

        # Node type
        self.addParameter(
            QgsProcessingParameterField(
                self.NODE_TYPE,
                self.tr("Node type"),
                type=QgsProcessingParameterField.String,
                parentLayerParameterName='NODE_LAYER'
            )
        )

        # Distance to forest field
        self.addParameter(
            QgsProcessingParameterField(
                self.FOREST_FIELD,
                self.tr("Distance to forest"),
                type=QgsProcessingParameterField.Numeric,
                parentLayerParameterName='NODE_LAYER',
                optional=True
            )
        )

        self.addParameter(
            QgsProcessingParameterFeatureSource(
                self.ARC_LAYER,
                self.tr("Arc layer"),
                [QgsProcessing.TypeVectorLine],
                optional=True
            )
        )

        self.addParameter(
            QgsProcessingParameterFeatureSource(
                self.POLY_LAYER,
                self.tr("Polygon layer"),
                [QgsProcessing.TypeVectorPolygon],
                optional=True
            )
        )

    def __init__(self):
        super().__init__()

        # key = index of the metrics, value : (name, type) of field to create
        self.field_map = {0: ("O_count", QVariant.Int),
                          1: ("L_count", QVariant.Int),
                          2: ("T_count", QVariant.Int),
                          3: ("X_count", QVariant.Int),
                          4: ("M_count", QVariant.Int),
                          5: ("Icon", QVariant.Int),
                          6: ("Icoh", QVariant.Double), 
                          7: ("Density", QVariant.Double),
                          8: ("Density_%", QVariant.Double),
                          9: ("Icod", QVariant.Double), 
                          10: ("Length", QVariant.Double)}

    def flags(self):
        """
        Algorithm manipulating project (toggle layer) or using external library are not thread safe
        See : https://api.qgis.org/api/classQgsProcessingAlgorithm.html#a6a8c21fab75e03f50f45e41a9d67dbc3a229dea6cedf8c59132196dee09d4f2f6
        """
        return super().flags() | QgsProcessingAlgorithm.FlagNoThreading
    
    def processAlgorithm(self, parameters, context, feedback):
        """
        Here is where the processing itself takes place.
        """

        grid_layer = self.parameterAsVectorLayer(parameters, self.GRID_LAYER, context)
        node_layer = self.parameterAsVectorLayer(parameters, self.NODE_LAYER, context)
        metrics = self.parameterAsEnums(parameters, self.METRICS, context)
        self.node_field = self.parameterAsFields(parameters, self.NODE_TYPE, context)[0]
        arc_layer = self.parameterAsVectorLayer(parameters, self.ARC_LAYER, context)       
        poly_layer = self.parameterAsVectorLayer(parameters, self.POLY_LAYER, context)
        if bool(parameters[self.FOREST_FIELD]):
            self.forest_field = self.parameterAsFields(parameters, self.FOREST_FIELD, context)[0]

        feedback.pushInfo("Starting processing")

        # Check for cancellation
        if feedback.isCanceled():
            return {}

        index_map = self.create_output_fields(metrics, grid_layer)

        attr_map = {}

        count = grid_layer.featureCount()

        for current, grid in enumerate(grid_layer.getFeatures()):
            attr_map[grid.id()] = self.compute_metrics(index_map, grid, 
                                                       node_layer, 
                                                       poly_layer, arc_layer)
            
            feedback.setProgress(int((current / count) * 100))
            # Check for cancellation
            if feedback.isCanceled():
                return {}

        grid_layer.dataProvider().changeAttributeValues(attr_map)

        return {"OUTPUT_GRID": parameters[self.GRID_LAYER]}
    
    def get_extent(self, grid):
        """
        From a grid feature or geoemtry return a QgsRectangle

        Parameters
        ---
        grid : QgsFeature or QgsGeometry:Polygon

        Return
        ---
        extent : QgsRectangle
        """
        if isinstance(grid, QgsFeature):
            grid = grid.geometry()
            
        # Grid are always created clockwise starting by upper left point
        polygon = grid.asPolygon()[0] # list of QgsPointXY
        upper_left = polygon[0]
        bottom_right = polygon[2]
        
        xmin = upper_left.x()
        ymax = upper_left.y()
        xmax = bottom_right.x()
        ymin = bottom_right.y()

        extent = QgsRectangle(xmin, ymin, xmax, ymax)

        return extent

    def create_output_fields(self, metrics, layer):
        """
        Create the metrics field wanted by user and 
        return a map of field position for each index

        Parameters
        ----------
        metrics : ite[str]
        layer : QgsVectorLayer:Polygon

        Return
        ------
        index_map : dict{metrics index: field index}
        
        TODO: Remove index_map and use create_fields output
        """
        fields = [self.field_map[key] for key in metrics]
        _ = at.create_fields(layer, fields)

        index_map = {k:layer.fields().indexFromName(name) 
                    for k,(name, _) in self.field_map.items() if k in metrics}

        return index_map

    def get_node_type_count(self, features):
        """
        From a list of features retrieve the count of each different node type
        and store it as a dict
        
        Parameters
        ----------
        features : QgsFeature
        
        Return
        ------
        occ : dict{node_type: count}
        """
        occ = {}
        occ["E"] = features.count("O")
        occ["L"] = features.count("L")
        occ["T"]  = features.count("T")
        occ["X"]  = features.count("X")
        occ["X-"]  = features.count("M")

        return occ
    
    def compute_icon(self, occ):
        """
        Compute icon index
        
        Parameters
        ----------
        occ : dict{node_type:count}
        
        Return
        ------
        icon : int
        """
        return occ["E"] + 2*occ["L"] + 6*occ["T"] + 12*occ["X"] + 20*occ["X-"]

    def compute_icoh(self, occ, b):
        """
        Compute icoh index
        
        Parameters
        ----------
        occ : dict{node_type:count}
        B : int : number of connexion between a forest and an hedge
        
        Return
        ------
        icoh : float
        """
        a = occ["L"] + 2*occ["T"] + 2*b + 3*occ["X"] + 4*occ["X-"]

        if a + occ["E"] != 0:
            return round((a/(a+occ["E"])) * 100, 2)
        else:
            return 0

    def compute_density(self, grid, features, percent=False):
        """
        Compute hedge density inside the tile
        
        Parameters
        ----------
        grid : QgsGeometry : Polygon or his corresponding QgsFeature
        features : QgsFeature : Feature of the hedge (arcs) inside the grid
        percent : boolean
            If percent is True then density will be computed using hedge polygons, 
            otherwise it'll use hedge arc
        Return
        ------
        density : float ml/ha or area of cell
        """
        if isinstance(grid, QgsFeature):
            grid = grid.geometry()
        # * 10000 to put the result in ml/ha
        if percent:
            return round(sum([f.geometry().area() for f in features]) / grid.area(), 2) * 100
        else:
            return round(sum([f.geometry().length() for f in features]) * (10000/grid.area()), 2)
        
    def compute_icod(self, icoh, density):
        """
        Compute icod index
        Usually used with 1 square km tile
        
        Parameters
        ----------
        icoh : float : consistency indicator
        density : density of hedge in a grid (ml/ha)
        
        Return
        ------
        icod : float
        """
        return round(density*icoh/100, 2)

    def compute_total_length(self, features):
        """
        Compute hedges total length inside a tile
        
        Parameters
        ----------
        features : QgsFeature : Feature of the hedge (arc) inside the grid
        
        Return
        ------
        length : float
        """
        return round(sum([f.geometry().length() for f in features]), 2)
    
    def compute_metrics(self, index_map, grid, node_layer, poly_layer=None, arc_layer=None):
        """
        Compute the metrics wanted by the users 
        and build the attr_map for the current feature

        Parameters
        ----------
        index_map : dict{metrics index: field index}
        grid : QgsFeature
        node_layer : QgsVectorLayer : Point
        arc_layer : QgsVectorLayer : LineString

        Return
        ------
        results : dict{field_index: value} : Attr map for the current feature
        """
        count, features = g.get_clementini(node_layer, grid.geometry())
        values = [f[self.node_field] for f in features]
        if 0 in index_map or 1 in index_map or 2 in index_map or 3 in index_map \
                or 4 in index_map or 5 in index_map or 6 in index_map:
            occ = self.get_node_type_count(values)
        if 6 in index_map or 9 in index_map:
            b = sum([f[self.forest_field] for f in features if f[self.forest_field] == 0])
        if 8 in index_map:
            extent = self.get_extent(grid)
            clipped = qw.extract_by_extent(extent, poly_layer, clip_by=True)
            polygons = [f for f in clipped.getFeatures()]
        if 7 in index_map or 9 in index_map or 10 in index_map:
            extent = self.get_extent(grid)
            clipped = qw.extract_by_extent(extent, arc_layer, clip_by=True)
            arcs = [f for f in clipped.getFeatures()]

        results = {}
        for k,v in index_map.items():
            if k == 0:
                results[v] = occ["E"]
            if k == 1:
                results[v] = occ["L"]
            if k == 2:
                results[v] = occ["T"]
            if k == 3:
                results[v] = occ["X"]
            if k == 4:
                results[v] = occ["X-"]    
            if k == 5:
                results[v] = self.compute_icon(occ)
            if k == 6: 
                icoh = self.compute_icoh(occ, b)
                results[v] = icoh
            if k == 7: 
                density = self.compute_density(grid, arcs)
                results[v] = round(density, 2)
            if k == 8: 
                density = self.compute_density(grid, polygons, percent=True)
                results[v] = round(density, 2)
            if k == 9: 
                if 6 not in index_map:
                    icoh = self.compute_icoh(occ, b)
                if 7 not in index_map:
                    density = self.compute_density(grid, arcs)
                results[v] = self.compute_icod(icoh, density)
            if k == 10: 
                results[v] = self.compute_total_length(arcs)

        return results
    
    def postProcessAlgorithm(self, context, feedback):
        """
        Tasks done when processAlgorithm is finished
        """
        utils.delete_processing_workspace()

        return {}
    
    def icon(self):
        """
        Should return a QIcon which is used for your provider inside
        the Processing toolbox.
        """
        return QIcon(":/plugins/hedge_tools/images/hedge_tools.png")
    
    def shortHelpString(self):
        """
        Returns a localised short help string for the algorithm.
        """
        return self.tr("Please note that the algorithm is currently freezing QGIS interface until completion. \n\
                        Compute landscape metrics for each tile of a grid layer \
                        The grid layer must be polygonal and can be created with the native qgis algorithm: create grid \n\
                        Landscape metrics can be: \n\
                        - Icon: Connectivity index, weighted sum of hedge intersections. \n\
                        - Icoh (%): Consistency index, percentage of hedge intersections \
                          with other hedges and forest/grove over all hedge extremities;\n\
                        - Density (ml/ha): Total length of hedges bring back to hectare;\n\
                        - Density (%): Density of hedge for each tile;\n\
                        - Icod: Density and structure of bocage index;\n\
                        - Total length: Sum of arc length inside each tile\n\
                        Some indices need additional data to be computed: \n\
                        - Icoh needs a field of distance to the forest from the 'Connection to a forest' tool;\n\
                        - Density (%) and icod needs a polygon layer of hedges;\n\
                        - Density (ml/ha) and total length needs an arc layer of hedges.")

    def name(self):
        """
        Returns the algorithm name, used for identifying the algorithm. This
        string should be fixed for the algorithm, and must not be localised.
        The name should be unique within each provider. Names should contain
        lowercase alphanumeric characters only and no spaces or other
        formatting characters.
        """
        return "landscapemetrics"

    def displayName(self):
        """
        Returns the translated algorithm name, which should be used for any
        user-visible display of the algorithm name.
        """
        return self.tr("Landscape metrics")

    def group(self):
        """
        Returns the name of the group this algorithm belongs to. This string
        should be localised.
        """
        return self.tr("7 - Landscape level: grid")

    def groupId(self):
        """
        Returns the unique ID of the group this algorithm belongs to. This
        string should be fixed for the algorithm, and must not be localised.
        The group id should be unique within each provider. Group id should
        contain lowercase alphanumeric characters only and no spaces or other
        formatting characters.
        """
        return 'hedgegrid'

    def tr(self, string):
        return QCoreApplication.translate('Processing', string)
    
    def checkParameterValues(self, parameters, context):
        if (6 in parameters[self.METRICS] or 9 in parameters[self.METRICS]) and bool(parameters[self.FOREST_FIELD]) is False:
            return (False, "Forest field is mandatory for icoh or icod computation.")
        if 8 in parameters[self.METRICS] and parameters[self.POLY_LAYER] is None:
            return (False, "Hedge polygon layer is mandatory for density (%) computation")
        if (7 in parameters[self.METRICS] or 9 in parameters[self.METRICS] or 10 in parameters[self.METRICS]) \
            and parameters[self.ARC_LAYER] is None:
            return (False, "Hedge arc layer is mandatory for density (ml/ha), icod and total length computation")
        
        return (True, '')

    def createInstance(self):
        return LandscapeMetricsAlgorithm()
