# -*- coding: utf-8 -*-

"""
/***************************************************************************
 FaunaliaToolkit
                                 A QGIS plugin
 Faunalia Spatial Analysis Toolkit
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2022-03-30
        copyright            : (C) 2022 by Matteo Ghetta (Faunalia)
        email                : matteo.ghetta@faunalia.eu
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""

__author__ = 'Matteo Ghetta (Faunalia)'
__date__ = '2022-03-30'
__copyright__ = '(C) 2022 by Matteo Ghetta (Faunalia)'

# This will get replaced with a git SHA1 when you do a git archive

__revision__ = '$Format:%H$'

import math

from faunalia_toolkit.__about__ import DIR_PLUGIN_ROOT

from qgis.PyQt.QtGui import QIcon
from qgis.PyQt.QtCore import QCoreApplication, QVariant
from qgis.core import (QgsProcessing,
                       QgsFeatureSink,
                       QgsProcessingAlgorithm,
                       QgsProcessingParameterField,
                       QgsStatisticalSummary,
                       QgsStringStatisticalSummary,
                       QgsDateTimeStatisticalSummary,
                       QgsFeatureRequest,
                       QgsGeometry,
                       QgsFields,
                       QgsField,
                       QgsFeature,
                       QgsSpatialIndex,
                       QgsProcessingUtils,
                       NULL,
                       QgsProcessingParameterFeatureSource,
                       QgsProcessingParameterEnum,
                       QgsProcessingParameterFeatureSink)


class PointStatisticsWithinPolygon(QgsProcessingAlgorithm):

    POLYGON = 'POLYGON'
    POINT = 'POINT'
    JOIN_FIELDS = 'JOIN_FIELDS'
    SUMMARIES = 'SUMMARIES'
    OUTPUT = 'OUTPUT'

    def initAlgorithm(self, config):
        """
        Here we define the inputs and output of the algorithm, along
        with some other properties.
        """

        # We add the input vector features source. It can have any kind of
        # geometry.
        self.addParameter(
            QgsProcessingParameterFeatureSource(
                self.POLYGON,
                self.tr('Polygons'),
                [QgsProcessing.TypeVectorPolygon]
            )
        )

        self.addParameter(
            QgsProcessingParameterFeatureSource(
                self.POINT,
                self.tr('Points'),
                [QgsProcessing.TypeVectorPoint]
            )
        )

        self.addParameter(
            QgsProcessingParameterField(
                self.JOIN_FIELDS,
                self.tr('Field to calculate the statistics on (leave empty to use all fields)'),
                parentLayerParameterName=self.POINT,
                allowMultiple=True,
                optional=True
            )
        )

        self.statistics = [
            ('count', self.tr('count')),
            ('unique', self.tr('unique')),
            ('min', self.tr('min')),
            ('max', self.tr('max')),
            ('range', self.tr('range')),
            ('sum', self.tr('sum')),
            ('mean', self.tr('mean')),
            ('median', self.tr('median')),
            ('stddev', self.tr('stddev')),
            ('minority', self.tr('minority')),
            ('majority', self.tr('majority')),
            ('q1', self.tr('q1')),
            ('q3', self.tr('q3')),
            ('iqr', self.tr('iqr')),
            ('empty', self.tr('empty')),
            ('filled', self.tr('filled')),
            ('min_length', self.tr('min_length')),
            ('max_length', self.tr('max_length')),
            ('mean_length', self.tr('mean_length'))
        ]


        self.addParameter(
            QgsProcessingParameterEnum(
                self.SUMMARIES,
                self.tr('Statistics (leave empty to use all available)'),
                options=[p[1] for p in self.statistics],
                allowMultiple=True,
                defaultValue=[0, 2, 3, 5, 6],
                optional=True
            )
        )

        self.addParameter(
            QgsProcessingParameterFeatureSink(
                self.OUTPUT,
                self.tr('Polygon with statistics')
            )
        )

    def processAlgorithm(self, parameters, context, feedback):
        """
        Here is where the processing itself takes place.
        """
        polygon = self.parameterAsSource(
            parameters,
            self.POLYGON,
            context
        )

        point = self.parameterAsSource(
            parameters,
            self.POINT,
            context
        )

        join_fields = self.parameterAsFields(
            parameters,
            self.JOIN_FIELDS,
            context
        )

        summaries = [self.statistics[i][0] for i in sorted(self.parameterAsEnums(parameters, self.SUMMARIES, context))]

        if not summaries:
            # none selected, so use all
            summaries = [s[0] for s in self.statistics]

        source_fields = polygon.fields()
        fields_to_join = QgsFields()
        join_field_indexes = []
        if not join_fields:
            # no fields selected, use all
            join_fields = [point.fields().at(i).name() for i in range(len(point.fields()))]

        def addFieldKeepType(original, stat):
            """
            Adds a field to the output, keeping the same data type as the original
            """
            field = QgsField(original)
            field.setName(field.name() + '_' + stat)
            fields_to_join.append(field)

        def addField(original, stat, type):
            """
            Adds a field to the output, with a specified type
            """
            field = QgsField(original)
            field.setName(field.name() + '_' + stat)
            field.setType(type)
            if type == QVariant.Double:
                field.setLength(20)
                field.setPrecision(6)
            fields_to_join.append(field)

        numeric_fields = (
            ('count', QVariant.Int, 'count'),
            ('unique', QVariant.Int, 'variety'),
            ('min', QVariant.Double, 'min'),
            ('max', QVariant.Double, 'max'),
            ('range', QVariant.Double, 'range'),
            ('sum', QVariant.Double, 'sum'),
            ('mean', QVariant.Double, 'mean'),
            ('median', QVariant.Double, 'median'),
            ('stddev', QVariant.Double, 'stDev'),
            ('minority', QVariant.Double, 'minority'),
            ('majority', QVariant.Double, 'majority'),
            ('q1', QVariant.Double, 'firstQuartile'),
            ('q3', QVariant.Double, 'thirdQuartile'),
            ('iqr', QVariant.Double, 'interQuartileRange')
        )

        datetime_fields = (
            ('count', QVariant.Int, 'count'),
            ('unique', QVariant.Int, 'countDistinct'),
            ('empty', QVariant.Int, 'countMissing'),
            ('filled', QVariant.Int),
            ('min', None),
            ('max', None)
        )

        string_fields = (
            ('count', QVariant.Int, 'count'),
            ('unique', QVariant.Int, 'countDistinct'),
            ('empty', QVariant.Int, 'countMissing'),
            ('filled', QVariant.Int),
            ('min', None, 'min'),
            ('max', None, 'max'),
            ('min_length', QVariant.Int, 'minLength'),
            ('max_length', QVariant.Int, 'maxLength'),
            ('mean_length', QVariant.Double, 'meanLength')
        )

        field_types = []
        for f in join_fields:
            idx = point.fields().lookupField(f)
            if idx >= 0:
                join_field_indexes.append(idx)

                join_field = point.fields().at(idx)
                if join_field.isNumeric():
                    field_types.append('numeric')
                    field_list = numeric_fields
                elif join_field.type() in (QVariant.Date, QVariant.Time, QVariant.DateTime):
                    field_types.append('datetime')
                    field_list = datetime_fields
                else:
                    field_types.append('string')
                    field_list = string_fields

                for f in field_list:
                    if f[0] in summaries:
                        if f[1] is not None:
                            addField(join_field, f[0], f[1])
                        else:
                            addFieldKeepType(join_field, f[0])

        out_fields = QgsProcessingUtils.combineFields(source_fields, fields_to_join)

        (sink, dest_id) = self.parameterAsSink(
            parameters,
            self.OUTPUT,
            context,
            out_fields,
            polygon.wkbType(),
            polygon.sourceCrs()
        )

        # prepare the QgsFeatureRequest used in the for loop
        index_request = QgsFeatureRequest()
        index_request.setDestinationCrs(polygon.sourceCrs(), context.transformContext())

        # create the spatialindex for the points together with the QgsFeatureRequest
        index = QgsSpatialIndex()
        index.addFeatures(point.getFeatures(index_request))

        total = 100 / polygon.featureCount() if polygon.featureCount() else 0

        # loop into the polygons
        for current, f in enumerate(polygon.getFeatures()):

            if feedback.isCanceled():
                break

            values = []

            # create and prepare the QgsGeometryEngine
            engine = QgsGeometry.createGeometryEngine(f.geometry().constGet())
            engine.prepareGeometry()

            # use the spatial index utility to speedup the script
            intersecting_ids = index.intersects(f.geometry().boundingBox())
            index_request.setFilterFids(intersecting_ids)

            for test_feat in point.getFeatures(index_request):
                if feedback.isCanceled():
                    break

                join_attributes = []
                for a in join_field_indexes:
                    join_attributes.append(test_feat[a])

                point_geometry = test_feat.geometry()

                # filter with the QgsGeometryEngine
                if engine.contains(point_geometry.constGet()):

                    # append to the list the values of the field chosen
                    values.append(join_attributes)

            feedback.setProgress(int(current * total))

            attrs = f.attributes()
            for i in range(len(join_field_indexes)):
                attribute_values = [v[i] for v in values]
                field_type = field_types[i]
                if field_type == 'numeric':
                    stat = QgsStatisticalSummary()
                    for v in attribute_values:
                        stat.addVariant(v)
                    stat.finalize()
                    for s in numeric_fields:
                        if s[0] in summaries:
                            val = getattr(stat, s[2])()
                            attrs.append(val if not math.isnan(val) else NULL)
                elif field_type == 'datetime':
                    stat = QgsDateTimeStatisticalSummary()
                    stat.calculate(attribute_values)
                    for s in datetime_fields:
                        if s[0] in summaries:
                            if s[0] == 'filled':
                                attrs.append(stat.count() - stat.countMissing())
                            elif s[0] == 'min':
                                attrs.append(stat.statistic(QgsDateTimeStatisticalSummary.Min))
                            elif s[0] == 'max':
                                attrs.append(stat.statistic(QgsDateTimeStatisticalSummary.Max))
                            else:
                                attrs.append(getattr(stat, s[2])())
                else:
                    stat = QgsStringStatisticalSummary()
                    for v in attribute_values:
                        if v == NULL:
                            stat.addString('')
                        else:
                            stat.addString(str(v))
                    stat.finalize()
                    for s in string_fields:
                        if s[0] in summaries:
                            if s[0] == 'filled':
                                attrs.append(stat.count() - stat.countMissing())
                            else:
                                attrs.append(getattr(stat, s[2])())

            feature = QgsFeature()
            feature.setAttributes(attrs)
            feature.setGeometry(f.geometry())

            sink.addFeature(feature, QgsFeatureSink.FastInsert)

        return {self.OUTPUT: dest_id}

    def name(self):
        return 'point_statistics_within_polygon'

    def displayName(self):
        return self.tr('Point Statistics Within Polygon')

    def group(self):
        return self.tr(self.groupId())

    def groupId(self):
        return 'Vector Analysis'

    def icon(self):
        return QIcon(str(DIR_PLUGIN_ROOT / "resources/images/mAlgorithmSumPoints.svg"))

    def tags(self):
        return self.tr('analysis,polygon,statistics,within').split(',')

    def shortHelpString(self):
        help_string = '''
        Given a point and a polygon layer, the algorithm takes only the points within the polygon and calculates the statistics of the chosen fields. If not fields are selected, all fields are taken into account.

        The statistics are added as fields to the output polygon layer (e.g. field_statistic)

        The statistics strongly depends on the field type (numeric, text, date) and are the <i>standard</i> statistics of QGIS (min, max, mean, median, count, count missing, count distinct, variety, sum, range, iqr, st dev, minority, majority, first quartile, third quartileinter quartile range).
        '''

        return help_string

    def helpUrl(self):
        return 'https://faunalia.gitlab.io/faunalia-toolkit/usage/algorithms.html#point-statistics-within-polygon'

    def tr(self, string):
        return QCoreApplication.translate('Processing', string)

    def createInstance(self):
        return PointStatisticsWithinPolygon()
