import os
from enum import Enum
from typing import List, Optional, Dict, Tuple, Union, Any
import time

import numpy as np
from qgis import processing
from qgis._core import QgsProject, QgsRectangle

from landsklim.lk.map_layer import MapLayerCollection, MapLayer, RasterLayer, VectorLayer
from landsklim.lk.utils import LandsklimUtils
from landsklim.processing.algorithm_interpolation import InterpolationProcessingAlgorithm


class LandsklimInterpolationType(Enum):
    """
    Define which phases to include in interpolation
    """
    PartialPhase1 = 0
    """
    Interpolation is based on the model described on the first phase
    """
    PartialPhase2 = 1
    """
    Interpolation is based on the model described on the second phase
    """
    Global = 2
    """
    Interpolation is based on the two models
    """
    AutoRegression = 3
    """
    WIP
    """

    def str(self):
        names = {
            LandsklimInterpolationType.PartialPhase1: "Phase1",
            LandsklimInterpolationType.PartialPhase2: "Phase2",
            LandsklimInterpolationType.Global: "Phase3",
            LandsklimInterpolationType.AutoRegression: "AutoReg"
        }
        return names[self]

class ExtrapolationMode(Enum):
    Smooth = 0
    """
    Pixels where values are beyond extrapolation limit are smoothed
    """
    NoData = 1
    """
    Put NO_DATA where values are beyond extrapolation limit
    """
    Value = 2
    """
    Put a specific value where values are beyond extrapolation limit
    """

class LandsklimRectangle:

    def __setstate__(self, state):
        # From 0.6.0, LISDQSRectangle was renamed LandsklimRectangle
        for k, v in state.items():
            setattr(self, LandsklimUtils.rename_attr("LISDQSRectangle", "LandsklimRectangle", k), v)

    def __init__(self, rectangle: QgsRectangle):
        self.__x_min = rectangle.xMinimum()
        self.__x_max = rectangle.xMaximum()
        self.__y_min = rectangle.yMinimum()
        self.__y_max = rectangle.yMaximum()

    def to_qgis_rectangle(self):
        return QgsRectangle(self.__x_min, self.__y_min, self.__x_max, self.__y_max, normalize=False)

    def to_json(self) -> Dict:
        return self.__dict__.copy()


class LandsklimInterpolation:
    """
    Represents an interpolation based on an analysis previously defined.

    :param name: Unique name of the interpolation
    :type name: str

    :param analysis: LandsklimAnalysis containing this interpolation
    :type analysis: LandsklimAnalysis

    :param min_value: The interpolated value can't be lower than this limit if not ``None``
    :type min_value: Optional[float]

    :param max_value: The interpolated value can't be higher than this limit if not ``None``
    :type max_value: Optional[float]

    :param extrapolation_mode: Extrapolation behavior.
        Values interpolated out-of-bounds are considered invalid.
    :type extrapolation_mode: ExtrapolationMode

    :param extrapolation_value: Value to use for pixels where interpolation values are out-of-bound
        (only used if extrapolation_mode is ExtrapolationMode.Value)
    :type extrapolation_value: Optional[float]

    :param extrapolation_margin: Accepted margin beyond maximum value recorded by a station.
        No extrapolation allowed if ``None``
    :type extrapolation_margin: float

    :param phases: Define which phases to include for the interpolation
    :type phases: List[LandsklimInterpolationType]

    :param on_grid: If true, interpolation is made for each point on a defined extent.
        If false, interpolation is computed for each points of a specified point shapefile
    :type on_grid: bool

    :ivar Optional[float] _min_value: The interpolated value can't be lower than this limit if not ``None``

    :ivar Optional[float] _max_value: The interpolated value can't be higher than this limit if not ``None``

    """

    def __init__(self, name: str, analysis: "LandsklimAnalysis", min_value: Optional[float], max_value: Optional[float], extrapolation_mode: ExtrapolationMode, extrapolation_value: Optional[float], extrapolation_margin: Optional[float], phases: List[LandsklimInterpolationType], on_grid: bool, extent: Union[VectorLayer, LandsklimRectangle]):
        self.__name: str = name
        self._analysis: "LandsklimAnalysis" = analysis
        self._min_value: Optional[float] = min_value
        self._max_value: Optional[float] = max_value
        self._extrapolation_mode: ExtrapolationMode = extrapolation_mode
        self._extrapolation_value: Optional[float] = extrapolation_value
        self._extrapolation_margin: Optional[float] = extrapolation_margin
        self._interpolation_types: List[LandsklimInterpolationType] = phases
        self._interpolation_on_grid: bool = on_grid
        self._extent: Union[VectorLayer, LandsklimRectangle] = extent
        self._layers: Dict[LandsklimInterpolationType, Dict[int, "RasterLayer"]] = {itype: {} for itype in self._interpolation_types}

    def __setstate__(self, state: Dict[str, Any]):
        for k, v in state.items():
            setattr(self, LandsklimUtils.rename_attr("LISDQSInterpolation", "LandsklimInterpolation", k), v)

    def get_situations_count(self) -> int:
        return len(self._analysis.get_station_situations())

    def get_name(self) -> str:
        return self.__name

    def get_interpolation_types(self) -> List[LandsklimInterpolationType]:
        return list(self._interpolation_types)

    def is_on_grid(self) -> bool:
        return self._interpolation_on_grid

    def get_minimum_value(self) -> Optional[float]:
        return self._min_value

    def get_maximum_value(self) -> Optional[float]:
        return self._max_value

    def get_paths(self, situation: int) -> List[str]:
        paths = []
        for phase in self._interpolation_types:
            ext = "tif" if self._interpolation_on_grid else "gpkg"
            situation_name: str = self._analysis.slugify(self._analysis.get_situation_name(situation)) if self._interpolation_on_grid else "interpolation"
            path = os.path.join(self._analysis.get_path(), self.get_name(), phase.str(), "{0}.{1}".format(situation_name, ext))
            paths.append(path)
        return paths

    def get_path(self):
        return os.path.join(self._analysis.get_path(), self.get_name())

    def get_layers(self, interpolation_type: LandsklimInterpolationType) -> Dict[int, "MapLayer"]:
        """
        Get for each situation its interpolation raster
        Note : when self._interpolation_on_grid is false, Map Layer is the same for each situation

        :returns: Dictionary mapping each situation with its interpolation layer
        :rtype: Dict[int, MapLayer]
        """
        return self._layers[interpolation_type]

    def propagate_to_composite(self, situation: int, interpolation_type: LandsklimInterpolationType, output: np.ndarray):
        # There needs to be a composite phase to create the cache
        if LandsklimInterpolationType.Global in self._interpolation_types:
            from landsklim.lk.phase_composite import PhaseComposite
            composite_phases = [p for p in self._analysis.get_phases(situation) if p.class_name() == PhaseComposite.class_name()]
            if len(composite_phases) == 1:  # There is a composite phase
                param_paths = {}
                # There is an interpolation, phase 1 : we obtain the name of the phase at index 0 and define the output
                if interpolation_type == LandsklimInterpolationType.PartialPhase1:
                    param_paths[self._analysis.get_phases(situation)[0].class_name()] = output
                # There is an interpolation, phase 2 : we obtain the name of the phase at index 1 and define the output
                if interpolation_type == LandsklimInterpolationType.PartialPhase2:
                    param_paths[self._analysis.get_phases(situation)[1].class_name()] = output
                if len(param_paths) > 0:
                    composite_phases[0].update_parameters(phase=param_paths)

    def make_interpolation(self, situation: int):
        """
        Call the algorithm to compute the interpolation according to the given situation

        :param situation: Situation to compute
        :type situation: int
        """
        raster_paths = self.get_paths(situation)

        for interpolation_type, raster_path in zip(self._interpolation_types, raster_paths):

            params = {
                InterpolationProcessingAlgorithm.INPUT_ANALYSIS: self._analysis,
                InterpolationProcessingAlgorithm.INPUT_SITUATION: situation,
                InterpolationProcessingAlgorithm.INPUT_MIN_VALUE: self._min_value,
                InterpolationProcessingAlgorithm.INPUT_MAX_VALUE: self._max_value,
                InterpolationProcessingAlgorithm.INPUT_EXTRAPOLATION_MARGIN: self._extrapolation_margin,
                InterpolationProcessingAlgorithm.INPUT_EXTRAPOLATION_MODE: self._extrapolation_mode.value,
                InterpolationProcessingAlgorithm.INPUT_EXTRAPOLATION_VALUE: self._extrapolation_value,
                InterpolationProcessingAlgorithm.INPUT_INTERPOLATION_TYPE: interpolation_type.value,
                InterpolationProcessingAlgorithm.INPUT_CUSTOM_NO_DATA: None,
            }

            if self._interpolation_on_grid:
                params[InterpolationProcessingAlgorithm.INPUT_EXTENT_GRID] = self._extent.to_qgis_rectangle()
                params[InterpolationProcessingAlgorithm.OUTPUT_RASTER] = raster_path
            else:
                params[InterpolationProcessingAlgorithm.INPUT_EXTENT_POINTS] = self._extent.qgis_layer()
                params[InterpolationProcessingAlgorithm.OUTPUT_VECTOR] = raster_path

            if not os.path.exists(os.path.dirname(raster_path)):
                os.makedirs(os.path.dirname(raster_path))

            processing_algo_name = "landsklim:interpolationraster" if self._interpolation_on_grid else "landsklim:interpolationvector"

            output = processing.run(processing_algo_name, params)

            phase_np = output[InterpolationProcessingAlgorithm.OUTPUT_PHASE]
            self.propagate_to_composite(situation, interpolation_type, phase_np)

    def to_string(self):
        return self.get_name()

    def to_json(self) -> Dict:
        state_dict: Dict = self.__dict__.copy()
        state_dict.pop("_analysis", None)

        # LandsklimInterpolationType can't be serialized as dictionary key
        state_dict_layers = {}
        for layer, value in state_dict["_layers"].items():
            state_dict_layers[str(layer)] = value
        state_dict["_layers"] = state_dict_layers

        return state_dict
