import cProfile
import time
from typing import Union, Tuple, List, Optional

from PyQt5.QtWidgets import QComboBox
from osgeo import gdal
from osgeo import osr
from osgeo.osr import SpatialReference
from qgis.PyQt.QtCore import QCoreApplication
from qgis._core import QgsProcessingParameterEnum, QgsProcessingParameterDefinition, QgsProcessingParameterExtent, \
    QgsProcessingParameterVectorLayer, QgsRectangle, QgsVectorLayer, QgsProcessingParameterVectorDestination, \
    QgsProcessingParameterFeatureSink
from qgis.core import QgsProcessing, QgsProcessingAlgorithm, QgsProcessingException, QgsProcessingParameterRasterLayer, \
    QgsProcessingParameterNumber, QgsProcessingParameterRasterDestination, QgsRasterLayer, QgsProcessingParameterBoolean
from qgis import processing

import numpy as np
import pandas as pd

from landsklim.lk import environment
from landsklim.lk.landsklim_constants import DATASET_RESPONSE_VARIABLE
from landsklim.processing.algorithm_smoothing import SmoothingProcessingAlgorithm
from landsklim.processing.landsklim_processing_algorithm import LandsklimProcessingAlgorithm
from landsklim.lk.utils import LandsklimUtils
from landsklim.lk.map_layer import VectorLayer
import landsklim.lk.cache as lkcache


class InterpolationProcessingAlgorithm(LandsklimProcessingAlgorithm):
    """
    Processing algorithm computing altitude average from a DEM
    """
    INPUT_ANALYSIS = 'INPUT_ANALYSIS'
    INPUT_SITUATION = 'INPUT_SITUATION'
    INPUT_CUSTOM_NO_DATA = 'INPUT_CUSTOM_NO_DATA'
    INPUT_MIN_VALUE = 'INPUT_MIN_VALUE'
    INPUT_MAX_VALUE = 'INPUT_MAX_VALUE'
    INPUT_EXTENT_GRID = 'INPUT_EXTENT_GRID'
    INPUT_EXTENT_POINTS = 'INPUT_EXTENT_POINTS'
    INPUT_EXTRAPOLATION_MARGIN = 'INPUT_EXTRAPOLATION_MARGIN'
    INPUT_EXTRAPOLATION_MODE = 'INPUT_EXTRAPOLATION_MODE'
    INPUT_EXTRAPOLATION_VALUE = 'INPUT_EXTRAPOLATION_VALUE'
    INPUT_INTERPOLATION_TYPE = 'INPUT_INTERPOLATION_TYPE'
    OUTPUT_RASTER = 'OUTPUT_RASTER'
    OUTPUT_VECTOR = 'OUTPUT_VECTOR'
    OUTPUT_PHASE = 'OUTPUT_PHASE'

    INTERPOLATION_TYPES = ['Partial (phase 1)', 'Partial (phase 2)', 'Global', 'Auto regression on residuals']
    EXTRAPOLATION_TYPES = ['Smooth values', 'NO_DATA', 'Value']

    def __init__(self):
        super().__init__()

    def createInstance(self):
        return InterpolationProcessingAlgorithm()

    def name(self) -> str:
        """
        Unique name of the algorithm
        """
        return 'interpolation'

    def displayName(self) -> str:
        """
        Displayed name of the algorithm
        """
        return self.tr('Interpolation', "InterpolationProcessingAlgorithm")

    def group(self) -> str:
        return self.tr('Interpolation', "InterpolationProcessingAlgorithm")

    def groupId(self) -> str:
        return 'interpolation'

    def shortHelpString(self) -> str:
        return self.tr('Compute interpolation from a regression model', "InterpolationProcessingAlgorithm")

    def specific_parameters(self):
        raise NotImplementedError

    def parameters(self):
        from landsklim.processing.processing_parameter_analysis import QgsProcessingParameterAnalysis, QgsProcessingParameterSituation
        self.addParameter(
            QgsProcessingParameterAnalysis(
                self.INPUT_ANALYSIS,
                self.tr('Analysis', "InterpolationProcessingAlgorithm")
            )
        )

        self.addParameter(
            QgsProcessingParameterSituation(
                self.INPUT_SITUATION,
                self.tr('Situation', "InterpolationProcessingAlgorithm"),
                parent=self.INPUT_ANALYSIS
            )
        )

        self.addParameter(
            QgsProcessingParameterNumber(
                self.INPUT_MIN_VALUE,
                self.tr('Minimum possible value', "InterpolationProcessingAlgorithm"),
                QgsProcessingParameterNumber.Type.Double,
                optional=True
            )
        )

        self.addParameter(
            QgsProcessingParameterNumber(
                self.INPUT_MAX_VALUE,
                self.tr('Maximum possible value', "InterpolationProcessingAlgorithm"),
                QgsProcessingParameterNumber.Type.Double,
                optional=True
            )
        )

        self.addParameter(
            QgsProcessingParameterNumber(
                self.INPUT_EXTRAPOLATION_MARGIN,
                self.tr('Extrapolation margin as a percent (0.05 = 5%)', "InterpolationProcessingAlgorithm"),
                QgsProcessingParameterNumber.Type.Double,
                optional=True
            )
        )

        self.addParameter(
            QgsProcessingParameterEnum(
                self.INPUT_EXTRAPOLATION_MODE,
                self.tr('Extrapolation behaviour (if enabled)', "InterpolationProcessingAlgorithm"),
                self.EXTRAPOLATION_TYPES,
                defaultValue='Smooth values',
                allowMultiple=False,
                optional=False
            )
        )

        self.addParameter(
            QgsProcessingParameterNumber(
                self.INPUT_EXTRAPOLATION_VALUE,
                self.tr('Out-of-bound value (if extrapolation behaviour is "Value")', "InterpolationProcessingAlgorithm"),
                QgsProcessingParameterNumber.Type.Double,
                optional=True
            )
        )

        self.addParameter(
            QgsProcessingParameterEnum(
                self.INPUT_INTERPOLATION_TYPE,
                self.tr('Define which phases to include for the interpolation', "InterpolationProcessingAlgorithm"),
                self.INTERPOLATION_TYPES,
                defaultValue='Global',
                allowMultiple=False,
                optional=False
            )
        )

        self.addParameter(
            QgsProcessingParameterNumber(
                self.INPUT_CUSTOM_NO_DATA,
                self.tr('Custom NO_DATA', "InterpolationProcessingAlgorithm"),
                optional=True,
                defaultValue=None
            )
        )

    def initAlgorithm(self, config=None):
        """
        Define inputs and outputs for the main input
        """
        self.parameters()
        self.specific_parameters()

    """def compute_interpolation_old(self, analysis: "LandsklimAnalysis", situation: int, pixels: pd.DataFrame, extrapolation_margin: float, interpolation_type: int, extent: Union[QgsVectorLayer, QgsRectangle], target_shape: Tuple[int], no_data: Optional[Union[int, float]], to_raster: bool):
        print("[compute_interpolation]")
        interpolations: List[np.ndarray] = []
        phases: List["IPhase"] = analysis.get_phases(situation)
        if interpolation_type < 2:
            phases = [phases[interpolation_type]]

        for i, phase in enumerate(phases):  # type: IPhase
            time_start = time.perf_counter()
            interpolation: np.ndarray = phase.predict(pixels, no_data, extrapolation_margin, extent=extent)
            if to_raster:
                interpolation = interpolation.reshape(target_shape)
            time_end = time.perf_counter()
            print("[make_interpolation][Phase {2}] Took {0:.3f}s to predict {1} points".format(time_end - time_start, pixels.values.shape, i+1))
            interpolations.append(interpolation)

        interpolation: np.ndarray
        if len(interpolations) == 1:
            interpolation = interpolations[0]
        else:
            interpolation = interpolations[0] + interpolations[1]
            if no_data is not None:
                interpolation[((interpolations[0] == no_data) | (interpolations[1] == no_data))] = no_data
        return interpolation"""

    def set_interpolations_limit(self, interpolation: np.ndarray, min_value: Optional[float], max_value: Optional[float], no_data: Optional[Union[int, float]]) -> np.ndarray:
        """
        Limit a numpy array under min/max limits
        """
        res: np.ndarray = np.copy(interpolation)
        if min_value is not None:
            res[((res < min_value) & (res != no_data))] = min_value
        if max_value is not None:
            res[((res > max_value) & (res != no_data))] = max_value
        return res

    def valid_interpolation_mask(self, source_dataset: pd.DataFrame, predictions: np.ndarray, extrapolation_margin: Union[int, float]) -> np.ndarray:
        """
        Check interpolation validity (if the predicted value is inside the wanted range)

        :param source_dataset: Dataset used to build the model
        :type source_dataset: pd.DataFrame

        :param predictions: Interpolated points
        :type predictions: np.ndarray

        :param extrapolation_margin: Don't interpolate points where predictor values are too far away of predictors values used when building model.
                                     For example, don't interpolate temperature at the top of a mountain if the highest recoding station is located far below.
                                     Invalid points are replaced by no data value if defined, else ``np.nan``
        :type extrapolation_margin: Union[int, float]

        :returns: Mask defining valid and out of range interpolated points
        :rtype: np.ndarray
        """
        dataset_min: float = source_dataset[DATASET_RESPONSE_VARIABLE].min(skipna=True)
        dataset_max: float = source_dataset[DATASET_RESPONSE_VARIABLE].max(skipna=True)
        response_min_allowed: float = dataset_min - np.abs(dataset_min * extrapolation_margin)
        response_max_allowed: float = dataset_max + np.abs(dataset_max * extrapolation_margin)
        # response_min_allowed = source_dataset[DATASET_RESPONSE_VARIABLE].min(skipna=True) - (source_dataset[DATASET_RESPONSE_VARIABLE].min(skipna=True) * extrapolation_margin)
        # response_max_allowed = source_dataset[DATASET_RESPONSE_VARIABLE].max(skipna=True) * (1 + extrapolation_margin)
        mask: np.ndarray = ((predictions >= response_min_allowed) & (predictions <= response_max_allowed))
        return mask

    def compute_interpolation(self, analysis: "LandsklimAnalysis", situation: int, pixels: pd.DataFrame, minimum_value: Optional[float], maximum_value: Optional[float], extrapolation_margin: float, extrapolation_mode: int, extrapolation_value: float, interpolation_type: int, target_shape: Tuple[int, int], no_data: Optional[Union[int, float]], to_raster: bool):
        phase: "IPhase" = analysis.get_phases(situation)[interpolation_type] if interpolation_type < 3 else None

        time_start = time.perf_counter()
        interpolation: np.ndarray = phase.predict(pixels, extrapolation_margin, no_data)
        interpolation_no_data_mask = interpolation == no_data
        print("[make_interpolation] Took {0:.3f}s to make model prediction".format(time.perf_counter() - time_start))
        interpolation = self.set_interpolations_limit(interpolation, minimum_value, maximum_value, no_data)
        limit_to_extent = extrapolation_mode > 0
        if limit_to_extent and extrapolation_margin is not None:
            # The "validity" mask is the combinaison between validity based on predictors values (defined for each phase) and validity in regards with interpolated value (defined here)
            mask = phase.valid_interpolation_mask(pixels, extrapolation_margin) & self.valid_interpolation_mask(phase.get_dataset(), interpolation, extrapolation_margin)
            interpolation[~mask] = extrapolation_value
            interpolation[interpolation_no_data_mask] = no_data
        if to_raster:
            interpolation = interpolation.reshape(target_shape)
        time_end = time.perf_counter()
        print("[make_interpolation] Took {0:.3f}s to predict {1} points".format(time_end - time_start, pixels.shape))
        return interpolation

    def processAlgorithm(self, parameters, context, feedback):
        """
        Called when a processing algorithm is run
        """
        print("[interpolation] processAlgorithm")
        # Load input raster and its metadata
        analysis: "LandsklimAnalysis" = self.parameterAsAnalysis(parameters, self.INPUT_ANALYSIS, context)
        situation: int = self.parameterAsSituation(parameters, self.INPUT_SITUATION, context)

        minimum_value: Optional[float] = self.parameterAsDouble(parameters, self.INPUT_MIN_VALUE, context) if parameters[self.INPUT_MIN_VALUE] is not None else None
        maximum_value: Optional[float] = self.parameterAsDouble(parameters, self.INPUT_MAX_VALUE, context) if parameters[self.INPUT_MAX_VALUE] is not None else None

        extrapolation_margin: Optional[float] = self.parameterAsDouble(parameters, self.INPUT_EXTRAPOLATION_MARGIN, context) if parameters[self.INPUT_EXTRAPOLATION_MARGIN] is not None else None
        interpolation_type: int = self.parameterAsEnum(parameters, self.INPUT_INTERPOLATION_TYPE, context)
        extrapolation_mode: int = self.parameterAsEnum(parameters, self.INPUT_EXTRAPOLATION_MODE, context)
        extrapolation_value: float = self.parameterAsDouble(parameters, self.INPUT_EXTRAPOLATION_VALUE, context)

        # input_extent can't be None: if not defined, QgsRectangle(0, 0, 0, 0) is returned
        input_extent: QgsRectangle = self.parameterAsExtent(parameters, self.INPUT_EXTENT_GRID, context)
        input_points: Optional[QgsVectorLayer] = self.parameterAsVectorLayer(parameters, self.INPUT_EXTENT_POINTS, context)

        dem: QgsRasterLayer = analysis.get_dem().qgis_layer()
        raster_resolution_x: float = dem.rasterUnitsPerPixelX()
        raster_resolution_y: float = dem.rasterUnitsPerPixelY()

        # Correct given extent to match DEM resolution for clipping
        correction_x: float = input_extent.width() % raster_resolution_x
        input_extent.setXMaximum(input_extent.xMaximum() - correction_x)
        correction_y: float = input_extent.height() % raster_resolution_y
        input_extent.setYMaximum(input_extent.yMaximum() - correction_y)

        reference_raster: QgsRasterLayer = analysis.get_dem().qgis_layer()
        no_data, geotransform = self.get_raster_metadata(parameters, context, source_layer=reference_raster) if input_points is not None else analysis.get_dem().clip_metadata(input_extent)
        extrapolation_value = (no_data if no_data is not None else np.nan) if extrapolation_mode != 2 else extrapolation_value

        # Load other params
        out_srs: SpatialReference = self.get_spatial_reference(reference_raster)

        # Path of the layer is given. If a temporary layer is selected, layer is created in qgis temp dir
        out_path = self.parameterAsOutputLayer(parameters, self.OUTPUT_VECTOR, context) if self.OUTPUT_VECTOR in parameters else self.parameterAsOutputLayer(parameters, self.OUTPUT_RASTER, context)

        # pixels: pd.DataFrame = analysis.get_variables_dataset(input_extent) if input_points is None else analysis.get_variables_dataset_from_points(VectorLayer(input_points), response_variable_index=None)
        if not lkcache.interpolation_cache_enabled() or lkcache.variables_dataset_cache() is None:
            pixels: pd.DataFrame = analysis.get_variables_dataset(input_extent) if input_points is None else analysis.get_variables_dataset_from_points(VectorLayer(input_points), response_variable_index=None)
            lkcache.update_variables_dataset_cache(pixels)
        else:
            pixels: pd.DataFrame = lkcache.variables_dataset_cache()

        target_shape: Tuple[int] = LandsklimUtils.raster_to_array(reference_raster).shape if input_points is not None else analysis.get_dem().clip(input_extent).shape
        np_interpolation: np.ndarray = self.compute_interpolation(analysis, situation, pixels, minimum_value, maximum_value, extrapolation_margin, extrapolation_mode, extrapolation_value, interpolation_type, target_shape, no_data, to_raster=input_points is None)

        if input_points is None and not environment.TEST_MODE:
            smoothing_algorithm: SmoothingProcessingAlgorithm = SmoothingProcessingAlgorithm()
            np_interpolation_input = np_interpolation.copy()
            if extrapolation_mode == 2:
                np_interpolation_input[np_interpolation_input == extrapolation_value] = no_data
            np_output = smoothing_algorithm.smoothing(np_interpolation_input, 5, no_data)
            np_output[np_interpolation == no_data] = no_data
            if extrapolation_mode == 2:
                np_output[np_interpolation == extrapolation_value] = extrapolation_value
        else:
            np_output = np_interpolation

        # np_output[(np_output % 1) > 0.99999] = np_output[(np_output % 1) > 0.99999] + 1
        # np_output = np.fix(np_output)

        if input_points is None:
            self.write_raster(out_path, np_output, out_srs, geotransform, no_data)
        else:
            valid_points = pixels.index
            np_output_points = np.zeros(input_points.featureCount())
            np_output_points[:] = no_data
            np_output_points[valid_points] = np_output
            self.write_point_geopackage(out_path, input_points, np_output_points, analysis.get_situation_vector_field_name(situation), no_data)

        output = {self.OUTPUT_PHASE: np_interpolation}
        if self.OUTPUT_VECTOR in parameters:
            output.update({self.OUTPUT_VECTOR: out_path})
        else:
            output.update({self.OUTPUT_RASTER: out_path})
        return output


class InterpolationRasterProcessingAlgorithm(InterpolationProcessingAlgorithm):
    def createInstance(self):
        return InterpolationRasterProcessingAlgorithm()

    def name(self) -> str:
        """
        Unique name of the algorithm
        """
        return 'interpolationraster'

    def displayName(self) -> str:
        """
        Displayed name of the algorithm
        """
        return self.tr('Interpolation (to raster)')

    def specific_parameters(self):
        self.addParameter(
            QgsProcessingParameterExtent(
                self.INPUT_EXTENT_GRID,
                self.tr('Extent'),
                optional=True
            )
        )

        self.addParameter(
            QgsProcessingParameterRasterDestination(
                self.OUTPUT_RASTER,
                self.tr('Output')
            )
        )


class InterpolationVectorProcessingAlgorithm(InterpolationProcessingAlgorithm):
    def createInstance(self):
        return InterpolationVectorProcessingAlgorithm()

    def name(self) -> str:
        """
        Unique name of the algorithm
        """
        return 'interpolationvector'

    def displayName(self) -> str:
        """
        Displayed name of the algorithm
        """
        return self.tr('Interpolation (to points)')

    def specific_parameters(self):
        self.addParameter(
            QgsProcessingParameterVectorLayer(
                self.INPUT_EXTENT_POINTS,
                self.tr('On points'),
                [QgsProcessing.TypeVectorPoint]
            )
        )

        self.addParameter(
            QgsProcessingParameterVectorDestination(
                self.OUTPUT_VECTOR,
                self.tr('Output'),
                type=QgsProcessing.TypeVectorPoint
            )
        )
