import abc
from typing import List, Optional, Dict, Any, Union, Tuple

from qgis._core import QgsRectangle, QgsVectorLayer, QgsFeature, QgsPoint, QgsRasterLayer, QgsCoordinateReferenceSystem

from landsklim.lk.landsklim_constants import DATASET_RESPONSE_VARIABLE, DATASET_COLUMN_X, DATASET_COLUMN_Y
import numpy as np

from landsklim.lk.landsklim_interpolation import LandsklimRectangle
from landsklim.lk.logger import Log
from landsklim.lk.utils import LandsklimUtils
from landsklim.lk.polygons_definition import PolygonsDefinition
from landsklim.lk.map_layer import VectorLayer, RasterLayer
from landsklim.processing.algorithm_moran_i import MoranIProcessingAlgorithm

# Prevent error when launching plugin for the first time on a Python installation without pandas.
# Pandas will be installed when instantiating the plugin
try:
    import pandas as pd
except ImportError:
    Log.critical("pandas not available")


class IPhase(metaclass=abc.ABCMeta):
    """
    Represents a phase during regression

    :ivar Optional[pd.DataFrame] _dataset: Dataset used for building model

    :ivar bool _forced_models: For test performances purpose. If ``True``, model are forced.
        Avoid computing models before setting forced coefficients
    """

    def __init__(self, **kwargs):
        self._model = None
        self._dataset: Optional["pd.DataFrame"] = None
        self._variables_correlation: Dict[str, Dict[int, float]] = {}

        self._polygons: Optional[PolygonsDefinition] = None
        self._polygons_path: Optional[str] = None

        self._crs_auth_id: str = kwargs["crs"]

        self._forced_models: bool = False

    def update_parameters(self, **kwargs):
        """
        Allows the IPhase owner to update phase meta-parameters
        """
        # No IPhase meta-parameters can be updated once initialized
        pass

    def get_polygons_path(self) -> str:
        return self._polygons.polygons_path()

    def set_polygons(self, polygons_definition: PolygonsDefinition):
        """
        Set polygons data used to build local models

        :param polygons_definition: Polygons definition
        :type polygons_definition: PolygonsDefinition
        """
        self._polygons = polygons_definition

    def set_polygon_path(self, polygon_path: str):
        """
        Set polygons raster path used to build local models

        :param polygon_path: The path
        :type polygon_path: str
        """
        self._polygons_path = polygon_path

    def get_variables_correlation(self) -> Dict[str, Dict[int, float]]:
        return self._variables_correlation

    def get_model(self):
        return self._model

    def get_dataset(self) -> Optional["pd.DataFrame"]:
        return self._dataset

    def to_string(self) -> str:
        return self.class_name()

    def _filter_dataset(self, dataset: pd.DataFrame, polygon: int) -> pd.DataFrame:
        """
        Select stations of a polygon

        :param dataset: Base dataset
        :type dataset: pd.DataFrame

        :param polygon: Polygon's stations to select
        :type polygon: int

        :returns: Same DataFrame but only with the selected rows
        """
        valid_keys = dataset.index.intersection(self._polygons.polygons_definition()[polygon])
        return dataset.loc[valid_keys]  # dataset.iloc[self._polygons.polygons_definition()[polygon]]

    def _prepare_dataset(self, dataset: pd.DataFrame, drop_na: bool = True) -> Tuple[pd.DataFrame, pd.DataFrame]:
        if drop_na:
            dataset = dataset.dropna()
        X, y = dataset.drop(DATASET_RESPONSE_VARIABLE, axis=1), dataset[DATASET_RESPONSE_VARIABLE]
        return X, y

    def remove_geographic_features(self, dataset: pd.DataFrame) -> pd.DataFrame:
        return dataset.drop([DATASET_COLUMN_X, DATASET_COLUMN_Y], axis=1)

    def prepare_dataset(self) -> Tuple[pd.DataFrame, pd.DataFrame]:
        """
        Get dataset, with NaN rows removed and only the selected features.
        Features and response variable are split.

        :returns: Features dataset and response variable serie
        :rtype: Tuple[pd.DataFrame, pd.DataFrame]
        """
        return self._prepare_dataset(self._dataset, drop_na=True)

    """def valid_interpolation_mask(self, points: "pd.DataFrame", predictions: np.ndarray, extrapolation_margin: float) -> np.ndarray:
        
        Get a prediction validity mask according to values of predictors and predicted variable

        :param points: List of predictors
        :type points: pd.DataFrame

        :param predictions: List of predicted values
        :type predictions: np.ndarray[float]

        :param extrapolation_margin: Extrapolation margin.
                                     Accept predictions where predictors (and predictions) are in the range of values
                                     used to build model + a margin in percentage, specified by extrapolation_margin
        :type extrapolation_margin: float

        :returns: Prediction validity mask
        :rtype: np.ndarray[bool]
        
        mask = np.ones(len(points), dtype=bool)
        for column in self._labels:
            min_allowed = self._dataset[column].min(skipna=True) - (self._dataset[column].min(skipna=True) * extrapolation_margin)
            max_allowed = self._dataset[column].max(skipna=True) * (1 + extrapolation_margin)
            mask = mask & ((points[column] >= min_allowed) & (points[column] <= max_allowed)).values

        response_min_allowed = self._dataset[DATASET_RESPONSE_VARIABLE].min(skipna=True) - (self._dataset[DATASET_RESPONSE_VARIABLE].min(skipna=True) * extrapolation_margin)
        response_max_allowed = self._dataset[DATASET_RESPONSE_VARIABLE].max(skipna=True) * (1 + extrapolation_margin)
        mask = mask & ((predictions >= response_min_allowed) & (predictions <= response_max_allowed))

        return mask"""

    def _model_residuals(self) -> np.ndarray:
        """
        Get residuals of the dataset on the retained model
        """
        X, y = self.prepare_dataset()
        y_hat: np.ndarray = self.predict(X)
        return y.values - y_hat

    def get_residuals(self) -> np.ndarray:
        """
        Get residuals of the dataset on the retained model
        """
        return self._model_residuals() if not self._forced_models else self.prepare_dataset()[1].values.ravel()

    def get_residuals_for_phase_2(self):
        return self.get_residuals()

    @abc.abstractmethod
    def get_residuals_cv(self) -> np.ndarray:
        """
        Get residuals of the dataset through cross-validation

        :returns: Residuals
        :rtype: np.ndarray
        """
        raise NotImplementedError

    def remove_na(self, points: "pd.DataFrame") -> "pd.DataFrame":
        return points.fillna(points.mean())

    @abc.abstractmethod
    def valid_interpolation_mask(self, points: "pd.DataFrame", extrapolation_margin: float) -> np.ndarray:
        """
        Get a prediction validity mask according to values of predictors and predicted variable

        :param points: List of predictors
        :type points: pd.DataFrame

        :param predictions: List of predicted values
        :type predictions: np.ndarray[float]

        :param extrapolation_margin: Extrapolation margin.
                                     Accept predictions where predictors (and predictions) are in the range of values
                                     used to build model + a margin in percentage, specified by extrapolation_margin
        :type extrapolation_margin: float

        :returns: Prediction validity mask
        :rtype: np.ndarray[bool]
        """
        raise NotImplementedError

    @abc.abstractmethod
    def get_formula(self, unstandardized: bool) -> str:
        raise NotImplementedError

    @abc.abstractmethod
    def predict_cv(self) -> np.ndarray:
        raise NotImplementedError

    def r2(self) -> float:
        """
        Get R2 of the model, after cross-validation

        :returns: R-squared
        :rtype: float
        """
        X, y = self.prepare_dataset()
        y_hat = self.predict_cv()  # self.predict(X)
        return LandsklimUtils.r2(y, y_hat)

    @abc.abstractmethod
    def get_adjusted_r2(self) -> float:
        raise NotImplementedError

    @abc.abstractmethod
    def compute_each_variable_correlation(self) -> float:
        raise NotImplementedError

    @abc.abstractmethod
    def get_residuals_standard_deviation(self) -> float:
        raise NotImplementedError

    def get_residuals_autocorrelation(self) -> float:
        """
        Get the Moran's Index of the residuals (CV) of the regression.

        :returns: Moran's I of the residuals (CV) of the regression
        :rtype: float
        """
        from qgis import processing
        residuals: np.ndarray = self.get_residuals_cv()
        stations_positions_x, stations_positions_y = self._dataset[DATASET_COLUMN_X].values, self._dataset[DATASET_COLUMN_Y].values

        layer_residuals: QgsVectorLayer = QgsVectorLayer("Point?crs=EPSG:4326&field=value:Double", "layer_residuals", "memory")
        crs: QgsCoordinateReferenceSystem = QgsCoordinateReferenceSystem()
        crs.fromOgcWmsCrs(self._crs_auth_id)
        layer_residuals.setCrs(crs)

        layer_residuals.startEditing()
        for (residual, position_x, position_y) in zip(residuals, stations_positions_x, stations_positions_y):
            feat: QgsFeature = QgsFeature(layer_residuals.fields())
            position: QgsPoint = QgsPoint(position_x, position_y)
            feat.setGeometry(position)
            feat.setAttributes([float(residual)])
            layer_residuals.dataProvider().addFeature(feat)
        layer_residuals.commitChanges()

        params = {
            MoranIProcessingAlgorithm.INPUT_POINTS_SHAPEFILE: layer_residuals,
            MoranIProcessingAlgorithm.INPUT_FIELD: 'value'
        }
        output = processing.run("landsklim:morani", params)
        # QgsProject.instance().addMapLayer(layer_residuals)
        del layer_residuals
        return output[MoranIProcessingAlgorithm.OUTPUT]

    @staticmethod
    @abc.abstractmethod
    def class_name() -> str:
        """
        ID name of the phase
        """
        raise NotImplementedError

    @staticmethod
    @abc.abstractmethod
    def name() -> str:
        """
        Human-readable name of the phase
        """
        raise NotImplementedError

    @abc.abstractmethod
    def get_kwargs(self) -> Dict[str, Any]:
        """
        Create list of args to be used by the factory to create a copy of this object
        """
        raise NotImplementedError

    @abc.abstractmethod
    def construct_model(self, dataset: "pd.DataFrame"):
        raise NotImplementedError

    @abc.abstractmethod
    def predict(self, points: "pd.DataFrame", extrapolation_margin: Optional[float] = None, no_data: Union[int, float] = None) -> np.ndarray:
        """
        Predict
        :param points: Samples to predict as pandas DataFrame
        :type points: pd.DataFrame

        :param extrapolation_margin: Doesn't smooth predictors values when under known values
        :type extrapolation_margin: float

        :param no_data: Prediction value when sample contains NaN. If None (default), NaN values are filled with mean values of the dataset
        :type no_data: Union[int, float]

        :returns: Array of predictions
        :rtype: np.ndarray
        """
        raise NotImplementedError

    def to_json(self) -> Dict:
        state_dict: Dict = self.__dict__.copy()
        state_dict.pop('_polygons')
        return state_dict
