from typing import List

from osgeo.osr import SpatialReference
from qgis._core import QgsProcessingParameterEnum
from qgis.core import QgsProcessingParameterRasterDestination, QgsRasterLayer
import numpy as np

from landsklim.lk.landsklim_interpolation import LandsklimInterpolationType
from landsklim.lk.map_layer import MapLayer
from landsklim.processing.landsklim_processing_tool_algorithm import LandsklimProcessingToolAlgorithm


class AverageInterpolationsProcessingAlgorithm(LandsklimProcessingToolAlgorithm):
    """
    Processing algorithm computing the average of the interpolations of an analysis
    """

    INPUT_ANALYSIS = 'INPUT_ANALYSIS'
    INPUT_INTERPOLATION = 'INPUT_INTERPOLATION'
    INPUT_INTERPOLATION_TYPE = 'INPUT_INTERPOLATION_TYPE'
    OUTPUT_RASTER = 'OUTPUT_RASTER'

    interpolation_types = [i for i in list(LandsklimInterpolationType)]
    interpolation_types_str = [i.str() for i in interpolation_types]

    def __init__(self):
        super().__init__()

    def createInstance(self):
        return AverageInterpolationsProcessingAlgorithm()

    def initAlgorithm(self, config=None):
        from landsklim.processing.processing_parameter_analysis import QgsProcessingParameterAnalysis, QgsProcessingParameterInterpolation
        self.addParameter(
            QgsProcessingParameterAnalysis(
                self.INPUT_ANALYSIS,
                self.tr('Analysis', "AverageInterpolationsProcessingAlgorithm")
            )
        )

        self.addParameter(
            QgsProcessingParameterInterpolation(
                self.INPUT_INTERPOLATION,
                self.tr('Interpolation', "AverageInterpolationsProcessingAlgorithm"),
                parent=self.INPUT_ANALYSIS
            )
        )

        self.addParameter(
            QgsProcessingParameterEnum(
                self.INPUT_INTERPOLATION_TYPE,
                self.tr('Interpolation type', 'AverageInterpolationsProcessingAlgorithm'),
                self.interpolation_types_str,
                allowMultiple=False,
                defaultValue=0
            )
        )

        self.addParameter(
            QgsProcessingParameterRasterDestination(
                self.OUTPUT_RASTER,
                self.tr('Output')
            )
        )

    def name(self) -> str:
        """
        Unique name of the algorithm
        """
        return 'average_interpolations'

    def displayName(self) -> str:
        """
        Displayed name of the algorithm
        """
        return self.tr('Average of interpolations')

    def shortHelpString(self) -> str:
        return self.tr('Compute the average of the interpolations of an analysis')

    def processAlgorithm(self, parameters, context, feedback):
        """
        Called when a processing algorithm is run
        """
        interpolation: "LandsklimInterpolation" = self.parameterAsInterpolation(parameters, self.INPUT_INTERPOLATION, context)
        interpolation_type: LandsklimInterpolationType = self.interpolation_types[self.parameterAsEnum(parameters, self.INPUT_INTERPOLATION_TYPE, context)]

        # Path of the layer is given. If a temporary layer is selected, layer is created in qgis temp dir
        out_path = self.parameterAsOutputLayer(parameters, self.OUTPUT_RASTER, context)

        if interpolation_type in interpolation.get_interpolation_types():
            map_layers: List[MapLayer] = list(interpolation.get_layers(interpolation_type).values())
            layers: List[QgsRasterLayer] = [ml.qgis_layer() for ml in map_layers]
            if len(layers) > 0:
                ref_layer: QgsRasterLayer = layers[0]
                np_ref_layer: np.ndarray = self.source_to_array(ref_layer.source())

                no_data, geotransform = self.get_raster_metadata(parameters, context, source_layer=ref_layer)
                out_srs: SpatialReference = self.get_spatial_reference(ref_layer)

                np_mean: np.ndarray = self.rasters_average(layers)
                np_mean[np_ref_layer == no_data] = no_data
                self.write_raster(out_path, np_mean, out_srs, geotransform, no_data)
            else:
                raise RuntimeError("No layer found in the interpolation {0}".format(interpolation.get_name()))
        else:
            raise RuntimeError("{0} has not been calculated in the interpolation {1}".format(interpolation_type.str(), interpolation.get_name()))

        return {self.OUTPUT_RASTER: out_path}
