# -*- coding: utf-8 -*-

"""
/***************************************************************************
 Animove
                                 A QGIS plugin
 AniMove for QGIS
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2021-07-19
        copyright            : (C) 2021 by Matteo Ghetta (Faunalia)
        email                : matteo.ghetta@faunalia.eu
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""

__author__ = 'Matteo Ghetta (Faunalia)'
__date__ = '2021-07-19'
__copyright__ = '(C) 2021 by Matteo Ghetta (Faunalia)'

# This will get replaced with a git SHA1 when you do a git archive

__revision__ = '$Format:%H$'

from qgis.PyQt.QtCore import QCoreApplication, QVariant
from qgis.PyQt.QtGui import QIcon

from qgis.core import (QgsProcessing,
                       QgsProcessingAlgorithm,
                       QgsProcessingParameterField,
                       QgsProcessingParameterNumber,
                       QgsWkbTypes,
                       QgsFields,
                       QgsField,
                       QgsRasterLayer,
                       QgsDistanceArea,
                       QgsGeometry,
                       QgsFeature,
                       QgsProcessingParameterField,
                       QgsProcessingParameterBoolean,
                       QgsProcessingUtils,
                       QgsProcessingParameterEnum,
                       QgsProcessingParameterFeatureSource,
                       QgsProcessingParameterFeatureSink)

import numpy as np
from osgeo import gdal, osr

import datetime
import os
import processing

from statsmodels.nonparametric import kernel_density


class AnimoveKernelDensity(QgsProcessingAlgorithm):
    """
    This is an example algorithm that takes a vector layer and
    creates a new identical one.

    It is meant to be used as an example of how to create your own
    algorithms and explain methods and variables used to do it. An
    algorithm like this will be available in all elements, and there
    is not need for additional work.

    All Processing algorithms should extend the QgsProcessingAlgorithm
    class.
    """

    # Inputs
    INPUT = 'INPUT'
    FIELD = 'FIELD'
    PERCENT = 'PERCENT'
    RESOLUTION = 'RESOLUTION'
    ADD_RASTER_OUTPUTS = 'ADD_RASTER_OUTPUTS'
    BW_METHOD = 'BW_METHOD'
    BW_VALUE = 'BW_VALUE'

    # Bandwidth method name array
    BW_METHODS = {}
    BW_METHODS['Rule of thumb (default)'] = 'normal_reference'
    BW_METHODS["Cross validation maximum likelihood"] = 'cv_ml'
    BW_METHODS["Cross validation least squares"] = 'cv_ls'
    BW_METHODS["Rule of thumb (Scott)"] = 'scott'
    BW_METHODS["Rule of thumb (Silverman)"] = 'silverman'
    # BW_METHODS["Custom value"] = 'custom'


     # Output
    OUTPUT = 'OUTPUT'

    def initAlgorithm(self, config):

        # point layer
        self.addParameter(
            QgsProcessingParameterFeatureSource(
                name=self.INPUT,
                description=self.tr('Input layer'),
                types=[QgsProcessing.TypeVectorPoint]
            )
        )

        # field of point layer
        self.addParameter(
            QgsProcessingParameterField(
                name=self.FIELD,
                description=self.tr("Group fixes by"),
                parentLayerParameterName=self.INPUT
            )
        )

        # percentage
        self.addParameter(
            QgsProcessingParameterNumber(
                name=self.PERCENT,
                description=self.tr("Percentage of Utilization Distribution (UD)"),
                type=QgsProcessingParameterNumber.Integer,
                minValue=5,
                maxValue=100,
                defaultValue=95
            )
        )

        # resolution
        self.addParameter(
            QgsProcessingParameterNumber(
                name=self.RESOLUTION,
                description=self.tr("Output raster resolution"),
                minValue=1,
                defaultValue=5
            )
        )

        # bandwidth method
        self.addParameter(
            QgsProcessingParameterEnum(
                name=self.BW_METHOD,
                description=self.tr("Bandwidth method"),
                options=self.BW_METHODS.keys()
            )
        )

        # bandwidth value
        # self.addParameter(
        #     QgsProcessingParameterNumber(
        #         name=self.BW_VALUE,
        #         description=self.tr("Bandwidth value (only used  if 'Custom value' "
        #                 "bandwidth method selected)"),
        #         type=QgsProcessingParameterNumber.Double,
        #         minValue=0.0,
        #         defaultValue=0.2
        #     )
        # )

        # raster destination
        self.addParameter(
            QgsProcessingParameterBoolean(
                name=self.ADD_RASTER_OUTPUTS,
                description=self.tr("Add raster outputs to QGIS")
            )
        )

        # vector output
        self.addParameter(
            QgsProcessingParameterFeatureSink(
                name=self.OUTPUT,
                description=self.tr("Kernel Density Estimation")
            )
        )



    def processAlgorithm(self, parameters, context, feedback):
        
        # get the input layer
        inputLayer = self.parameterAsSource(
            parameters,
            self.INPUT,
            context
        )

        # get the field
        field = self.parameterAsFields(
            parameters,
            self.FIELD,
            context
        )
        # get the first and unique field from the list
        field = field[0]

        # get the percentage as integer
        perc = self.parameterAsInt(
            parameters,
            self.PERCENT,
            context
        )

        # get the resolution as integer
        resolution = self.parameterAsInt(
            parameters,
            self.RESOLUTION,
            context
        )

        # get the bandwidth as int
        bandwidth = self.parameterAsEnum(
            parameters,
            self.BW_METHOD,
            context
        )
        # get the real bandwidth name from the dict
        bandwidth = list(self.BW_METHODS.values())[bandwidth]

        # get the bandwidth value (optional and to use only if the custom method is chosen)
        bandwidth_value = self.parameterAsDouble(
            parameters,
            self.BW_VALUE,
            context
        )

        # se the bandwidth to the value ONLY if the custom name is chosen
        if bandwidth == 'custom':
            bandwidth = bandwidth_value
        
        # raster output boolean
        addRasterOutputs = self.parameterAsBool(
            parameters,
            self.ADD_RASTER_OUTPUTS,
            context
        )


        fields = QgsFields()
        fields.append(QgsField('ID', QVariant.String))
        fields.append(QgsField('Area', QVariant.Double))
        fields.append(QgsField('Perim', QVariant.Double))

        (sink, dest_id) = self.parameterAsSink(
            parameters,
            self.OUTPUT,
            context,
            fields,
            QgsWkbTypes.MultiLineString,
            inputLayer.sourceCrs()
        )


        # get the index of the field chosen
        fieldIndex = inputLayer.fields().lookupField(field)

        # get unique values of the field chosen
        uniqueValues = inputLayer.uniqueValues(fieldIndex)

        total = 100 / len(uniqueValues)
        n = 0

        raster_list = []

        for current, value in enumerate(uniqueValues):
            
            # Filter x,y points with desired field value (value)
            xPoints = []
            yPoints = []

            for feature in inputLayer.getFeatures():
                fieldValue = feature[field]
                if (fieldValue == value):
                    points = feature.geometry().asPoint()
                    xPoints.append(points.x())
                    yPoints.append(points.y())

            # Compute kernel (X, Y, Z)
            xmin = min(xPoints) - 0.5 * (max(xPoints) - min(xPoints))
            xmax = max(xPoints) + 0.5 * (max(xPoints) - min(xPoints))
            ymin = min(yPoints) - 0.5 * (max(yPoints) - min(yPoints))
            ymax = max(yPoints) + 0.5 * (max(yPoints) - min(yPoints))

            # X, Y form a meshgrid
            X, Y = np.mgrid[xmin:xmax:complex(resolution),
                            ymin:ymax:complex(resolution)]
            
            feedback.pushDebugInfo(f'X shape: {X.shape}')
            feedback.pushDebugInfo(f'Y shape: {Y.shape}')

            # Meshgrid in form of stacked array with all possible positions
            positions = np.vstack([X.ravel(), Y.ravel()])

            # Meshgrid with all the real positions
            values = np.vstack([xPoints, yPoints])

            feedback.pushDebugInfo(f'Positions shape : {positions.shape}')
            feedback.pushDebugInfo(f'Values shape : {values.shape}')

            kernel = kernel_density.KDEMultivariate(
                data=values,
                var_type='cc',
                bw=bandwidth
            )

            # Evaluate positions using kernel
            Z = np.reshape(kernel.pdf(positions).T, X.T.shape)


            feedback.pushDebugInfo(f'Bandwidth value for: {str(value)} : {str(kernel.bw)}')
            feedback.pushDebugInfo(f'Shape of evaluation transponse : : {str(Z.T.shape)}')

            raster_name = f'{inputLayer.sourceName()}_{perc}_{value}_{datetime.date.today()}'

            fileName = os.path.join(QgsProcessingUtils.tempFolder(), f'{raster_name}.tif')

            feedback.pushDebugInfo(f'Writing {fileName} to disc')

            self.to_geotiff(fileName, xmin, xmax, ymin, ymax, X, Y, Z, inputLayer.sourceCrs().srsid())
            
            # add to the dictionary the filename and the raster file calculated
            raster_list.append(fileName)

            feedback.pushDebugInfo('Creating contour lines')

            # Create contour lines (temporary .shp) from GeoTIFF
            param = {
                'INPUT':fileName,
                'BAND':1,
                'INTERVAL':10,
                'FIELD_NAME':'values',
                'CREATE_3D':False,
                'IGNORE_NODATA':False,
                'NODATA':None,
                'OFFSET':0,
                'EXTRA':'',
                'OUTPUT':'TEMPORARY_OUTPUT'
            }
            contour_layer_str = processing.run("gdal:contour", param, feedback=feedback, context=context)['OUTPUT']

            contour_layer = QgsProcessingUtils.mapLayerFromString(contour_layer_str, context)

            outGeom = []
            area = 0
            perim = 0
            measure = QgsDistanceArea()
            measure.setSourceCrs(inputLayer.sourceCrs(), context.transformContext())
            measure.setEllipsoid(context.ellipsoid())
            for feat in contour_layer.getFeatures():
                polyline = feat.geometry().asPolyline()
                polygon = QgsGeometry.fromPolygonXY([polyline])
                perim+=measure.measurePerimeter(polygon)
                area+=measure.measureArea(polygon)
                outGeom.append(polyline)
            

            feedback.pushDebugInfo('Writing polylines features')
            
            sink_feature = QgsFeature()

            attrs = []
            attrs.append(value)
            attrs.append(area)
            attrs.append(perim)

            sink_feature.setAttributes(attrs)
            sink_feature.setGeometry(QgsGeometry.fromMultiPolylineXY(outGeom))

            sink.addFeature(sink_feature)

            feedback.setProgress(int(current * total))
            

        # try to add the raster layer to the legend
        if addRasterOutputs:
            for item in raster_list:
                raster_layer = QgsProcessingUtils.mapLayerFromString(item, context, allowLoadingNewLayers=True)
                raster_layer.setCrs(inputLayer.sourceCrs())
                context.addLayerToLoadOnCompletion(
                    raster_layer.source(),
                    context.LayerDetails(
                        name=os.path.basename(item),
                        project=context.project()
                    )
                )
        

        return {self.OUTPUT: dest_id}


    def to_geotiff(self, fname, xmin, xmax, ymin, ymax, X, Y, Z, epsg):
        '''
        saves the kernel as a GEOTIFF image
        '''
        driver = gdal.GetDriverByName("GTiff")
        out = driver.Create(fname, len(X), len(Y), 1, gdal.GDT_Float64)

        # pixel sizes
        xps = (xmax - xmin) / float(len(X))
        yps = (ymax - ymin) / float(len(Y))
        out.SetGeoTransform((xmin, xps, 0, ymin, 0, yps))
        coord_system = osr.SpatialReference()
        coord_system.ImportFromEPSG(epsg)
        out.SetProjection(coord_system.ExportToWkt())

        Z = Z.clip(0) * 100.0 / Z.max()

        out.GetRasterBand(1).WriteArray(Z.T)


    def name(self):
        return 'Kernel Density Estimation'

    def displayName(self):
        return self.tr(self.name())

    def group(self):
        return self.tr(self.groupId())

    def groupId(self):
        return 'Tools'

    def tr(self, string):
        return QCoreApplication.translate('Processing', string)
    
    def icon(self):
        icon_path = os.path.join(
            os.path.dirname(__file__),
            'icons',
            'kernelDensity.png'
        )
        return QIcon(icon_path)

    def tags(self):
        return self.tr('kernel,density,animal').split(',')

    def createInstance(self):
        return AnimoveKernelDensity()
