from datetime import datetime
import os
from typing import Optional, Any, List, Tuple, Union, Dict

import numpy as np
from osgeo import gdal
from osgeo.osr import SpatialReference
from qgis._core import QgsProcessingContext, QgsProcessingParameterEnum, QgsProcessing
from qgis.core import (QgsProcessingAlgorithm, QgsRasterLayer, QgsProcessingParameterMultipleLayers, Qgis,
                       QgsProcessingParameterString, QgsProcessingParameterFile)

from landsklim.processing.landsklim_processing_tool_algorithm import LandsklimProcessingToolAlgorithm


class AverageRastersByGroupsAlgorithm(LandsklimProcessingToolAlgorithm):

    INPUT_RASTERS = 'INPUT_RASTERS'
    INPUT_DATE_FORMAT = 'INPUT_DATE_FORMAT'
    INPUT_AGGREGATION_TYPE = 'INPUT_AGGREGATION_TYPE'
    OUTPUT_FOLDER = 'OUTPUT_FOLDER'

    def __init__(self):
        super().__init__()
        self.__driver = gdal.GetDriverByName('GTiff')

    def createInstance(self) -> Optional['QgsProcessingAlgorithm']:
        return AverageRastersByGroupsAlgorithm()

    def name(self):
        """
        Unique name of the algorithm
        """
        return "average_rasters_by_groups"

    def displayName(self) -> str:
        """
        Displayed name of the algorithm
        """
        return "Average rasters by groups"

    def shortHelpString(self) -> str:
        return self.tr('Compute the average of the rasters, grouped by month or by year')

    def initAlgorithm(self, configuration: dict[str, Any] = ...) -> None:
        if (self.qgis_version()[0] > 3) or (self.qgis_version()[0] == 3 and self.qgis_version()[1] >= 36):
            layer_type = Qgis.ProcessingSourceType.Raster
        else:
            layer_type = QgsProcessing.TypeRaster
        self.addParameter(
            QgsProcessingParameterMultipleLayers(
                self.INPUT_RASTERS,
                'Rasters',
                layerType=layer_type
            )
        )

        self.addParameter(
            QgsProcessingParameterString(
                self.INPUT_DATE_FORMAT,
                'Date format (%Y for years, %m for months, %d for days)'
            )
        )

        self.addParameter(
            QgsProcessingParameterEnum(
                self.INPUT_AGGREGATION_TYPE,
                'Aggregation type',
                ['Monthly', 'Yearly'],
                allowMultiple=False,
                defaultValue=0

            )
        )

        self.addParameter(
            QgsProcessingParameterFile(
                self.OUTPUT_FOLDER,
                'Output folder',
                behavior=QgsProcessingParameterFile.Folder
            )
        )

    def group_name(self, date_identifier: Tuple, agg_type: int):
        if agg_type == 0:
            return "{0}-{1}".format(date_identifier[0], date_identifier[1])
        if agg_type == 1:
            return "{0}".format(date_identifier)

    def processAlgorithm(self, parameters: dict[str, Any], context: 'QgsProcessingContext', feedback: Optional['QgsProcessingFeedback']) -> dict[str, Any]:
        layers: List[QgsRasterLayer] = self.parameterAsLayerList(parameters, self.INPUT_RASTERS, context)
        date_format: str = self.parameterAsString(parameters, self.INPUT_DATE_FORMAT, context)
        agg_type: int = self.parameterAsEnum(parameters, self.INPUT_AGGREGATION_TYPE, context)
        output_folder: str = self.parameterAsFile(parameters, self.OUTPUT_FOLDER, context)

        ref_layer: QgsRasterLayer = layers[0]
        np_ref_layer: np.ndarray = self.source_to_array(layers[0].source())
        no_data, geotransform = self.get_raster_metadata(parameters, context, source_layer=ref_layer)
        out_srs: SpatialReference = self.get_spatial_reference(ref_layer)

        date_groups: Dict[Tuple, List[QgsRasterLayer]] = self.make_groups(layers, date_format, agg_type)
        self.export_groups(date_groups, agg_type, output_folder, np_ref_layer, no_data, out_srs, geotransform)

        return {}

    def export_groups(self,
                      groups: Dict[Tuple, List[QgsRasterLayer]],
                      agg_type: int,
                      output_folder: str,
                      np_ref_layer: np.ndarray,
                      no_data: Optional[Union[int, float]],
                      out_srs: SpatialReference,
                      geotransform: Tuple):
        for date_identifier, rasters_group in groups.items():  # type: Tuple, List[QgsRasterLayer]
            avg: np.ndarray = self.rasters_average(rasters_group)
            avg[np_ref_layer == no_data] = no_data
            out_path: str = os.path.join(output_folder, "{0}.tif".format(self.group_name(date_identifier, agg_type)))
            self.write_raster(out_path, avg, out_srs, geotransform, no_data)

    def make_groups(self, layers: List[QgsRasterLayer], date_format: str, agg_type: int) -> Dict[Tuple, List[QgsRasterLayer]]:
        date_groups: Dict[Tuple, List[QgsRasterLayer]] = {}
        for raster in layers:  # type: QgsRasterLayer
            dt: datetime = datetime.strptime(raster.name(), date_format)
            self.add_to_date_groups(date_groups, agg_type, dt.year, dt.month, raster)
        return date_groups

    def add_to_date_groups(self, groups: Dict[Tuple, List[QgsRasterLayer]], agg_type: int, year: int, month: int, raster: QgsRasterLayer):
        date_identifier = (year, month) if agg_type == 0 else year
        if date_identifier not in groups:
            groups[date_identifier] = []
        groups[date_identifier].append(raster)
