from typing import Union, List, Optional, Tuple, Dict, Any

from PyQt5.QtCore import QVariant
from qgis._core import QgsProcessingParameterEnum, QgsProcessingParameterVectorLayer, QgsProcessingParameterString, \
    QgsVectorLayer, QgsProject, QgsField, QgsFeature, QgsProcessingParameterField, Qgis, QgsFeedback
from qgis.core import QgsProcessing, QgsProcessingAlgorithm, QgsProcessingException, QgsProcessingParameterRasterLayer, \
    QgsProcessingParameterNumber, QgsProcessingParameterRasterDestination, QgsRasterLayer, QgsProcessingParameterBoolean
from qgis import processing
import numpy as np

from landsklim.lk.landsklim_analysis import LandsklimAnalysis, LandsklimAnalysisMode
from landsklim.lk.landsklim_configuration import LandsklimConfiguration
from landsklim.lk.landsklim_interpolation import LandsklimInterpolationType
from landsklim.lk.landsklim_project import LandsklimProject
from landsklim.lk.map_layer import MapLayer, VectorLayer, RasterLayer
from landsklim.lk.phase_kriging import PhaseKriging
from landsklim.lk.phase_multiple_regression import PhaseMultipleRegression
from landsklim.lk.regressor import Regressor
from landsklim.processing.landsklim_processing_tool_algorithm import LandsklimProcessingToolAlgorithm


class BatchLocalAnalysesAlgorithm(LandsklimProcessingToolAlgorithm):
    """
    Processing algorithm computing batch local analysis from a station file
    """

    INPUT_STATIONS = 'INPUT_STATIONS'
    INPUT_FIELDS = 'INPUT_FIELDS'
    INPUT_ANALYSES_PATTERN = 'INPUT_ANALYSES_PREFIXES'
    INPUT_NEIGHBORHOOD_SIZE = 'INPUT_NEIGHBORHOOD_SIZE'

    interpolation_types = [i for i in list(LandsklimInterpolationType)]
    interpolation_types_str = [i.str() for i in interpolation_types]

    def __init__(self):
        super().__init__()
        self.__p_step: float = 0

    def createInstance(self):
        return BatchLocalAnalysesAlgorithm()

    def initAlgorithm(self, config=None):
        self.addParameter(
            QgsProcessingParameterVectorLayer(
                self.INPUT_STATIONS,
                self.tr('Stations', 'BatchLocalAnalysesAlgorithm'),
                [QgsProcessing.TypeVectorPoint]
            )
        )

        self.addParameter(
            QgsProcessingParameterField(
                self.INPUT_FIELDS,
                self.tr('Situations', 'BatchLocalAnalysesAlgorithm'),
                parentLayerParameterName=self.INPUT_STATIONS,
                type=QgsProcessingParameterField.DataType.Numeric,
                #type=Qgis.QgsProcessingParameterFieldDataType.Numeric,
                allowMultiple=True,
                optional=False,
                defaultToAllFields=False
            )
        )

        self.addParameter(
            QgsProcessingParameterNumber(
                self.INPUT_NEIGHBORHOOD_SIZE,
                self.tr('Neighborhood size', 'BatchLocalAnalysesAlgorithm'),
                QgsProcessingParameterNumber.Type.Integer
            )
        )

        self.addParameter(
            QgsProcessingParameterString(
                self.INPUT_ANALYSES_PATTERN,
                self.tr('Pattern of analyses names', "BatchLocalAnalysesAlgorithm")
            )
        )

    def name(self) -> str:
        """
        Unique name of the algorithm
        """
        return 'batch_local_analyses'

    def displayName(self) -> str:
        """
        Displayed name of the algorithm
        """
        return self.tr('Batch of local analyses')

    def shortHelpString(self) -> str:
        return self.tr('Compute batch local analyses from stations, splitting analyses according to the availability of stations data from one day to the next\n\nWarning: Analyses generated from this algorithm will follow the pattern :\n- Phases 1 are multiple regression models\n- Phases 2 are kriging\n- All regressors available will be included on analyses')

    def check_layer_is_registered(self, lk_project: LandsklimProject, layer: QgsVectorLayer) -> Optional[VectorLayer]:
        """
        Returns True if the vector layer is already registered as stations in the Landsklim project
        """
        res: Optional[VectorLayer] = None
        for station in lk_project.get_stations_sources():  # type: VectorLayer
            if station.qgis_layer() == layer:
                res = station
        return res

    def check_no_names_conflict(self, lk_project: LandsklimProject, name_pattern: str, configurations_count: int) -> bool:
        """
        :returns: True if there is a name conflict, False otherwise
        :rtype: bool
        """
        res: bool = True
        for a in lk_project.get_analysis():  # type: LandsklimAnalysis
            if a.get_name() in self.get_analysis_names(name_pattern, configurations_count):
                res = False
        return res

    def get_stations_configurations(self, layer: QgsVectorLayer, fields: List[str], stations_no_data: Optional[float]) -> Tuple[List[List[int]], List[int]]:
        """
        Groups situations according to data availability

        :param layer: The vector layer
        :type layer: QgsVectorLayer

        :param fields: Fields of the layer selected by the user
        :type fields: List[str]

        :param stations_no_data: Optional value representing NO_DATA for stations

        :returns: List of grouped situations. Situations are represented by their index in the layer
        """

        fields_idx = []
        for i, f in enumerate(layer.fields()):  # type: int, QgsField
            if f.name() in fields:
                fields_idx.append(i)

        # [[]] * len(layer.fields()) reference the same list !!
        features_by_attributes = []
        for _ in fields_idx:
            features_by_attributes.append([])

        for i_f, feature in enumerate(layer.getFeatures()):  # type: QgsFeature
            attributes: List[Any] = np.array(feature.attributes())[fields_idx].tolist()
            for i_a, attr in enumerate(attributes):
                if not (attr is None or (isinstance(attr, (int, float)) and attr == stations_no_data) or (isinstance(attr, QVariant) and attr.isNull())):
                    features_by_attributes[i_a].append(i_f)

        sets = set()
        for att in features_by_attributes:
            sets.add(tuple(att))
        configs = list(sets)

        situations_groups = []
        groups_size = []
        for s in sets:
            situations_groups.append([])
            groups_size.append(len(s))
        for situation_idx, fa in enumerate(features_by_attributes):
            idx = configs.index(tuple(fa))
            situation: int = fields_idx[situation_idx]
            situations_groups[idx].append(situation)

        return situations_groups, groups_size

    def get_analysis_names(self, name_pattern: str, configurations_count: int) -> List[str]:
        return ["{0}conf{1}".format(name_pattern, cn+1) for cn in range(configurations_count)]

    def create_analyses(self, lk_prj: LandsklimProject, input_layer: VectorLayer, configurations: List[List[int]], configurations_size: List[int], name_pattern: str, neighborhood_size: int) -> List[LandsklimAnalysis]:
        analyses: List[LandsklimAnalysis] = []
        lk_config: LandsklimConfiguration = lk_prj.get_configurations()[0]
        a_names: List[str] = self.get_analysis_names(name_pattern, len(configurations_size))
        for i, (situations_config, config_size) in enumerate(zip(configurations, configurations_size)):  # type: int, Tuple[List[int], int]
            name: str = a_names[i]
            regressors: List[Regressor] = lk_prj.get_regressors()
            stations_no_data = lk_prj.get_stations_no_data()
            mode = LandsklimAnalysisMode.Local if config_size > neighborhood_size else LandsklimAnalysisMode.Global
            analysis = LandsklimAnalysis(name, lk_config, mode, neighborhood_size, input_layer, situations_config, False, PhaseMultipleRegression.class_name(), PhaseKriging.class_name(), regressors, stations_no_data)
            analyses.append(analysis)
        return analyses

    def on_polygons_failed(self, analysis: LandsklimAnalysis):
        raise RuntimeError("{0} : Polygons can't be created.".format(analysis.get_name()))

    def on_models_failed(self, analysis: LandsklimAnalysis, fields: str):
        raise RuntimeError("{0} : Regression models can't be created.".format(analysis.get_name()))

    def on_models_computed(self, feedback, text: str, i: int):
        feedback.setProgress(feedback.progress() + self.__p_step)

    def processAlgorithm(self, parameters, context, feedback: QgsFeedback):
        """
        Called when a processing algorithm is run
        """

        qgs_version_major, qgs_version_minor = self.qgis_version()

        input_layer: QgsVectorLayer = self.parameterAsVectorLayer(parameters, self.INPUT_STATIONS, context)
        if qgs_version_major == 3 and qgs_version_minor >= 32:
            input_fields: List[str] = self.parameterAsStrings(parameters, self.INPUT_FIELDS, context)
        else:
            input_fields: List[str] = self.parameterAsFields(parameters, self.INPUT_FIELDS, context)
        input_name_pattern: str = self.parameterAsString(parameters, self.INPUT_ANALYSES_PATTERN, context)
        input_neighborhood_size: int = self.parameterAsInt(parameters, self.INPUT_NEIGHBORHOOD_SIZE, context)

        from landsklim.landsklim import Landsklim
        lk_instance: Landsklim = Landsklim.instance()
        lk_prj: LandsklimProject = lk_instance.get_landsklim_project()
        lk_stations_no_data: Optional[float] = lk_prj.get_stations_no_data()

        if lk_prj is None:
            raise RuntimeError("A Landsklim project must be opened")

        registered_layer: Optional[VectorLayer] = self.check_layer_is_registered(lk_prj, input_layer)
        if registered_layer is None:
            raise RuntimeError("The selected layer must be registered as stations layer on the Landsklim project")

        configurations, groups_size = self.get_stations_configurations(input_layer, input_fields, lk_stations_no_data)

        if not self.check_no_names_conflict(lk_prj, input_name_pattern, len(configurations)):
            raise RuntimeError("Name conflict between analyses to create and existing analyses")

        analyses: List[LandsklimAnalysis] = self.create_analyses(lk_prj, registered_layer, configurations, groups_size, input_name_pattern, input_neighborhood_size)

        lk_config: LandsklimConfiguration = lk_prj.get_configurations()[0]

        feedback.setProgress(1)
        self.__p_step = 100/len(input_fields)
        for a in analyses:
            a.handle_on_polygons_failed(self.on_polygons_failed)
            a.handle_on_models_computation_fail(self.on_models_failed)
            a.handle_on_models_computation_step(lambda t, i: self.on_models_computed(feedback, t, i))
            a.create_polygons()

            if a.is_local():
                r: QgsRasterLayer = QgsRasterLayer(a.get_polygons_raster_path(), a.get_polygons_layer_displayed_name())
                QgsProject.instance().addMapLayer(r)
                getattr(a, '_polygons').set_polygon_path(a.get_polygons_raster_path())

            if feedback.isCanceled():
                break

            a.construct_datasets()
            a.construct_models()

            lk_config.add_analysis(a)

            if feedback.isCanceled():
                break

        return {}
