# -*- coding: utf-8 -*-

"""
Rast_ZonalStatistics.py
***************************************************************************
*                                                                         *
*   This program is free software; you can redistribute it and/or modify  *
*   it under the terms of the GNU General Public License as published by  *
*   the Free Software Foundation; either version 2 of the License, or     *
*   (at your option) any later version.                                   *
*                                                                         *
***************************************************************************
"""
__author__ = 'Leandro França'
__date__ = '2023-06-04'
__copyright__ = '(C) 2023, Leandro França'

from qgis.PyQt.QtCore import QVariant
from qgis.core import (QgsProcessing,
                       QgsFeatureSink,
                       QgsWkbTypes,
                       QgsFields,
                       QgsField,
                       QgsFeature,
                       QgsPointXY,
                       QgsGeometry,
                       QgsProject,
                       QgsProcessingException,
                       QgsProcessingAlgorithm,
                       QgsProcessingParameterString,
                       QgsProcessingParameterBand,
                       QgsProcessingParameterBoolean,
                       QgsProcessingParameterCrs,
                       QgsProcessingParameterEnum,
                       QgsFeatureRequest,
                       QgsExpression,
                       QgsProcessingParameterFeatureSource,
                       QgsProcessingParameterFeatureSink,
                       QgsProcessingParameterFileDestination,
                       QgsProcessingParameterMultipleLayers,
                       QgsProcessingParameterRasterLayer,
                       QgsProcessingParameterRasterDestination,
                       QgsApplication,
                       QgsProject,
                       QgsRasterLayer,
                       QgsCoordinateTransform,
                       QgsCoordinateReferenceSystem)

from math import floor, ceil
from osgeo import osr, gdal_array, gdal #https://gdal.org/python/
from matplotlib import path
import numpy as np
from lftools.geocapt.imgs import Imgs
from lftools.translations.translate import translate
import os
from qgis.PyQt.QtGui import QIcon

class ZonalStatistics(QgsProcessingAlgorithm):

    LOC = QgsApplication.locale()[:2]

    def tr(self, *string):
        return translate(string, self.LOC)

    def createInstance(self):
        return ZonalStatistics()

    def name(self):
        return 'zonalstatistics'

    def displayName(self):
        return self.tr('Zonal Statistics', 'Estatísticas zonais')

    def group(self):
        return self.tr('Raster')

    def groupId(self):
        return 'raster'

    def tags(self):
        return 'GeoOne,estatísticas,statistics,zonal,zonais,amostra,sample,mean,average,std,bands,values'.split(',')

    def icon(self):
        return QIcon(os.path.join(os.path.dirname(os.path.dirname(__file__)), 'images/raster.png'))

    txt_en = '''This algorithm calculates statistics for the bands of a raster layer, categorized by zones defined in a polygon type vector layer.
The values of the raster cells where the pixel center is exactly inside the polygon are considered in the statistics.'''
    txt_pt = '''Este algoritmo calcula estatísticas para as bandas de uma camada raster, categorizados por zonas definidas em camada vetorial do tipo polígono.
Os valores das células do raster onde o centro do pixel se encontra exatamente dentro do polígonos são considerados nas estatísticas.'''
    figure = 'images/tutorial/raster_zonalstatistics.jpg'

    def shortHelpString(self):
        social_BW = Imgs().social_BW
        footer = '''<div align="center">
                      <img src="'''+ os.path.join(os.path.dirname(os.path.dirname(__file__)), self.figure) +'''">
                      </div>
                      <div align="right">
                      <p align="right">
                      <b>'''+self.tr('Author: Leandro Franca', 'Autor: Leandro França')+'''</b>
                      </p>'''+ social_BW + '''</div>
                    </div>'''
        return self.tr(self.txt_en, self.txt_pt) + footer

    INPUT ='INPUT'
    BAND = 'BAND'
    POLYGONS = 'POLYGONS'
    PREFIX = 'PREFIX'
    STATISTICS = 'STATISTICS'
    OUTPUT = 'OUTPUT'
    STATS = ['count','sum','mean','median', 'std', 'min', 'max']

    def initAlgorithm(self, config=None):
        # INPUT
        self.addParameter(
            QgsProcessingParameterRasterLayer(
                self.INPUT,
                self.tr('Input Raster', 'Raster de Entrada'),
                [QgsProcessing.TypeRaster]
            )
        )

        self.addParameter(
            QgsProcessingParameterBand(
                self.BAND,
                self.tr('Band', 'Banda'),
                parentLayerParameterName=self.INPUT,
                optional = True
            )
        )

        self.addParameter(
            QgsProcessingParameterFeatureSource(
                self.POLYGONS,
                self.tr('Polygons', 'Polígonos'),
                [QgsProcessing.TypeVectorPolygon]
            )
        )

        self.addParameter(
            QgsProcessingParameterString(
                self.PREFIX,
                self.tr('Output column prefix', 'Prefixo da coluna de saída'),
                defaultValue = self.tr('stat_', 'estat_')
            )
        )

        self.addParameter(
            QgsProcessingParameterEnum(
                self.STATISTICS,
                self.tr('Statistics', 'Estatísticas'),
				options = self.STATS,
                allowMultiple = True,
                defaultValue = [0,2,4]
            )
        )

        # OUTPUT
        self.addParameter(
            QgsProcessingParameterFeatureSink(
                self.OUTPUT,
                self.tr('Zonal statistics', 'Estatísticas zonais')
            )
        )

    def processAlgorithm(self, parameters, context, feedback):

        RasterIN = self.parameterAsRasterLayer(
            parameters,
            self.INPUT,
            context
        )
        if RasterIN is None:
            raise QgsProcessingException(self.invalidSourceError(parameters, self.INPUT))
        RasterIN = RasterIN.dataProvider().dataSourceUri()

        layer = self.parameterAsSource(
            parameters,
            self.POLYGONS,
            context
        )
        if layer is None:
            raise QgsProcessingException(self.invalidSourceError(parameters, self.POLYGONS))

        stats = self.parameterAsEnums(
            parameters,
            self.STATISTICS,
            context
        )

        n_banda = self.parameterAsInt(
            parameters,
            self.BAND,
            context
        )

        prefixo = self.parameterAsString(
            parameters,
            self.PREFIX,
            context
        )

        # Abrir Raster layer como array
        feedback.pushInfo(self.tr('Opening raster...', 'Abrindo raster...'))
        image = gdal.Open(RasterIN)
        prj = image.GetProjection()
        n_bands = image.RasterCount
        Pixel_Nulo = image.GetRasterBand(1).GetNoDataValue()
        if Pixel_Nulo == None:
            Pixel_Nulo = 0
        cols = image.RasterXSize
        rows = image.RasterYSize
        # Origem e resolucao da imagem
        ulx, xres, xskew, uly, yskew, yres  = image.GetGeoTransform()
        origem = (ulx, uly)
        resol_X = abs(xres)
        resol_Y = abs(yres)
        lrx = ulx + (cols * xres)
        lry = uly + (rows * yres)
        bbox = [ulx, lrx, lry, uly]

        # Transformação de coordenadas
        crsSrc = layer.sourceCrs()
        crsDest = QgsCoordinateReferenceSystem(prj)
        if crsSrc != crsDest:
            transf_SRC = True
            coordTransf = QgsCoordinateTransform(crsSrc, crsDest, QgsProject.instance())
        else:
            transf_SRC = False

        # Amostra de Raster por banda e poligono
        dic = {}
        if n_banda == 0: # Calculo para todas as bandas
            n_bands = np.array(range(n_bands))+1
        else: # Calculo para uma banda específica
            n_bands = [n_banda]

        total = 100.0/(len(n_bands)*layer.featureCount())
        cont = 0
        feedback.pushInfo(self.tr('Calculating zonal statistics...', 'Calculando estatísticas zonais...'))
        for k in n_bands:
            banda = image.GetRasterBand(int(k)).ReadAsArray()
            for feat in layer.getFeatures():
                geom = feat.geometry()
                if transf_SRC:
                    geom.transform(coordTransf)
                
                if geom.isMultipart():
                    polygons = geom.asMultiPolygon()   # -> [ [ring0, ring1, ...], [ring0, ...], ... ]
                else:
                    polygons = [geom.asPolygon()]      # -> [ [ring0, ring1, ...] ]

                valores = []

                for poly_rings in polygons:
                    if not poly_rings:
                        continue

                    outer_ring = poly_rings[0]
                    inner_rings = poly_rings[1:]  # buracos

                    # --- bbox baseado no anel externo ---
                    lin_min = 1e18
                    col_min = 1e18
                    lin_max = -1e18
                    col_max = -1e18
                    caminho_outer = []

                    for ponto in outer_ring:
                        linha = (origem[1] - ponto.y()) / resol_Y
                        coluna = (ponto.x() - origem[0]) / resol_X

                        lin_min = min(lin_min, linha)
                        lin_max = max(lin_max, linha)
                        col_min = min(col_min, coluna)
                        col_max = max(col_max, coluna)

                        caminho_outer.append((linha, coluna))

                    # Converter bbox para índices inteiros e CLAMP no raster
                    lin_min = int(np.floor(lin_min))
                    lin_max = int(np.floor(lin_max))
                    col_min = int(np.floor(col_min))
                    col_max = int(np.floor(col_max))

                    # clamp
                    lin_min = max(0, lin_min)
                    col_min = max(0, col_min)
                    lin_max = min(rows - 1, lin_max)
                    col_max = min(cols - 1, col_max)

                    # se bbox ficou inválido, pula
                    if lin_max < lin_min or col_max < col_min:
                        continue

                    nx = lin_max - lin_min + 1
                    ny = col_max - col_min + 1

                    lin = np.arange(lin_min, lin_max + 1)
                    col = np.arange(col_min, col_max + 1)
                    COL, LIN = np.meshgrid(col, lin)  # LIN/COL com shape (nx, ny)

                    # Máscara do anel externo
                    p_outer = path.Path(caminho_outer)
                    pts = np.column_stack(((LIN + 0.5).ravel(), (COL + 0.5).ravel()))
                    mask_outer = p_outer.contains_points(pts).reshape((nx, ny))

                    # Subtrair buracos
                    if inner_rings:
                        mask_holes = np.zeros((nx, ny), dtype=bool)
                        for hole_ring in inner_rings:
                            if not hole_ring:
                                continue
                            caminho_hole = []
                            for ponto in hole_ring:
                                linha = (origem[1] - ponto.y()) / resol_Y
                                coluna = (ponto.x() - origem[0]) / resol_X
                                caminho_hole.append((linha, coluna))
                            p_hole = path.Path(caminho_hole)
                            mask_holes |= p_hole.contains_points(pts).reshape((nx, ny))

                        mask = mask_outer & (~mask_holes)
                    else:
                        mask = mask_outer

                    # Recorte de banda
                    recorte_img = banda[lin_min:lin_max+1, col_min:col_max+1]

                    # Coletar valores (vetorizado)
                    vals = recorte_img[mask]

                    # tratar nodata (se nodata for NaN)
                    if Pixel_Nulo is not None and isinstance(Pixel_Nulo, float) and np.isnan(Pixel_Nulo):
                        vals = vals[~np.isnan(vals)]
                    else:
                        vals = vals[vals != Pixel_Nulo]

                    # acumular
                    if vals.size:
                        valores.extend(vals.astype(float).tolist())

                    print(feat['name'], len(valores))
                
                # Calcular estatísticas da lista de valores
                valores = np.array(valores)
                lista_stats = []
                for st in stats:
                    if self.STATS[st] == 'count':
                        lista_stats += [int(len(valores))]
                    if self.STATS[st] == 'sum':
                        lista_stats += [float(valores.sum())]
                    if self.STATS[st] == 'mean':
                        lista_stats += [float(valores.mean())]
                    if self.STATS[st] == 'median':
                        lista_stats += [float(np.median(valores))]
                    if self.STATS[st] == 'std':
                        lista_stats += [float(valores.std())]
                    if self.STATS[st] == 'min':
                        lista_stats += [float(valores.min())]
                    if self.STATS[st] == 'max':
                        lista_stats += [float(valores.max())]
                if feat.id() not in dic:
                    dic[feat.id()] = {(k):lista_stats}
                else:
                    dic[feat.id()][k] = lista_stats

                cont += 1
                feedback.setProgress(int((cont) * total))
                if feedback.isCanceled():
                    break

        # Criar polígono de saída
        feedback.pushInfo(self.tr('Creating layer with results...', 'Criando camada com resultados...'))
        Fields = layer.fields()
        for band in n_bands:
            for st in stats:
                Fields.append(QgsField(prefixo + self.tr('band{}_'.format(band), 'banda{}_'.format(band)) + self.STATS[st], QVariant.Double))

        (sink, dest_id) = self.parameterAsSink(
            parameters,
            self.OUTPUT,
            context,
            Fields,
            layer.wkbType(),
            layer.sourceCrs()
        )
        if sink is None:
            raise QgsProcessingException(self.invalidSinkError(parameters, self.OUTPUT))

        # Exportar resultados
        for feat in layer.getFeatures():
            att = feat.attributes()
            for band in n_bands:
                att += dic[feat.id()][band]
            feature = QgsFeature(Fields)
            feature.setAttributes(att)
            feature.setGeometry(feat.geometry())
            sink.addFeature(feature, QgsFeatureSink.FastInsert)

        feedback.pushInfo(self.tr('Operation completed successfully!', 'Operação finalizada com sucesso!'))
        feedback.pushInfo(self.tr('Leandro Franca - Cartographic Engineer', 'Leandro França - Eng Cart'))

        return {self.OUTPUT: dest_id}
