import os
import time
import warnings
from typing import Dict, Any, Union, Optional, Tuple, List
import tempfile

import pyproj
from PyQt5.QtCore import QCoreApplication
from qgis._core import QgsVectorLayer, QgsRasterLayer, QgsFeatureRequest, Qgis, QgsCoordinateReferenceSystem, \
    QgsVectorFileWriter, QgsProject, QgsVectorLayerUtils, QgsApplication, QgsGeometry, QgsPointXY

from qgis import processing

from landsklim.lk.cache import qgis_project_cache
import landsklim.lk.cache as lkcache
from landsklim.lk.landsklim_constants import DATASET_COLUMN_X, DATASET_COLUMN_Y, DATASET_RESPONSE_VARIABLE
from landsklim.lk.logger import Log
from landsklim.lk.utils import LandsklimUtils
from landsklim.lk.map_layer import RasterLayer

try:
    import pandas as pd
except ImportError:
    Log.critical("pandas not available")

import numpy as np
from scipy.stats import pearsonr

from landsklim.lk.phase import IPhase
from landsklim.lk.regression_model import MultipleRegressionModel, SmoothingMode

# Prevent error when launching plugin for the first time on a Python installation without sklearn.
# Sklearn will be installed when instantiating the plugin
try:
    from sklearn.linear_model import LinearRegression
    from sklearn.metrics import mean_squared_error, r2_score
except ImportError:
    Log.critical("sklearn not available")


class PhaseMultipleRegression(IPhase):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self._residuals_std: float = 0
        dem: RasterLayer = kwargs["dem"]
        self._dem_top = dem.qgis_layer().extent().yMaximum()
        self._dem_left = dem.qgis_layer().extent().xMinimum()
        self._dem_bottom = dem.qgis_layer().extent().yMinimum()
        self._dem_right = dem.qgis_layer().extent().xMaximum()
        self._dem_resolution_x = dem.qgis_layer().rasterUnitsPerPixelX()
        self._dem_resolution_y = dem.qgis_layer().rasterUnitsPerPixelY()
        self._dem_height = (dem.qgis_layer().extent().yMaximum() - dem.qgis_layer().extent().yMinimum()) // dem.qgis_layer().rasterUnitsPerPixelY()
        self._dem_width = (dem.qgis_layer().extent().xMaximum() - dem.qgis_layer().extent().xMinimum()) // dem.qgis_layer().rasterUnitsPerPixelX()
        #self._dem_y_axis_direction: Qgis.CrsAxisDirection = dem.qgis_layer().crs().axisOrdering()[1]
        #self._dem_crs_authid: QgsCoordinateReferenceSystem = dem.qgis_layer().crs().authid()


    def update_parameters(self, **kwargs):
        super().update_parameters(**kwargs)

    @staticmethod
    def class_name() -> str:
        return "PHASE_MULTIPLE_REGRESSION"

    @staticmethod
    def name() -> str:
        return QCoreApplication.translate("Landsklim", "Multiple regression")

    def get_models_count(self) -> int:
        """
        :returns: Number of models
        :rtype: int
        """
        return len(self._model)

    def get_coefficients_array(self) -> np.ndarray:
        """
        Get regression coefficients
        # TODO: Test

        :returns: A numpy array with regression coefficients for each regressor, for each model.
        :rtype: np.ndarray
        """
        coefficients: List[List[float]] = []
        regressors = self.get_regressors_name()
        for model in self._model:  # type: MultipleRegressionModel
            model_coefficients_dict: Dict[str, float] = model.get_coefficients()
            model_coefficients: List[float] = [model_coefficients_dict[regressor] if regressor in model_coefficients_dict else np.nan for regressor in regressors]
            coefficients.append(model_coefficients)

        return np.array(coefficients)

    def get_labels(self) -> List[str]:
        labels: List[str] = []
        for model in self._model:  # type: MultipleRegressionModel
            for label in model.get_labels():  # type: str
                if label not in labels:
                    labels.append(label)
        return labels

    def get_regressors_name(self) -> List[str]:
        return self._dataset.columns.drop(DATASET_COLUMN_Y).drop(DATASET_COLUMN_X).drop(DATASET_RESPONSE_VARIABLE).tolist()

    def get_label_frequency(self, label_name: str) -> float:
        """
        Percentage of presence of the label on the regression model through all models

        :returns: Frequency of presence of the label
        :rtype: float
        """
        occurrences = 0
        for model in self._model:  # type: MultipleRegressionModel
            if label_name in model.get_labels():
                occurrences += 1
        return occurrences / len(self._model)

    def get_formula(self, unstandardized: bool) -> str:
        return self._model[0].get_formula(unstandardized) if len(self._model) == 1 else "Multiple models"

    def get_adjusted_r2(self) -> float:
        X, y = self.prepare_dataset()
        r2 = self.r2()
        model_features = len(self.get_labels())
        return LandsklimUtils.adjusted_r2(r2, n=len(y), p=model_features)

    def __get_cv_output(self, get_residuals: bool) -> np.ndarray:
        """
        Get predictions of the model trough CV

        :param get_residuals: If true, get residuals instead of predictions
        :type get_residuals: bool

        :returns: CV predictions or residuals
        :rtype: np.ndarray
        """

        if len(self._polygons.polygons_definition()) > 1:
            # Cross-validation - Local
            points = self._dataset
            points_polygons: np.ndarray = self.get_closest_polygon_of_points(points).astype(int)


            """# Method 1 : Building CV model each time __get_cv_output() is called
            # Same results as Method 2
            results = []
            for i, polygon in zip(points.index, points_polygons):
                Xcv, ycv = self._model[polygon].split_dataset()
                Xcv, ycv = Xcv[self._model[polygon].get_labels()], ycv
                X_test, y_test = Xcv.loc[[i], self._model[polygon].get_labels()], ycv.loc[i].values
                Xcvt = Xcv.drop(i)
                ycvt = ycv.drop(i)
                model_cv = LinearRegression()
                model_cv.fit(Xcvt, ycvt)
                y_hat = model_cv.predict(X_test).ravel()[0]
                if self._model[polygon]._predictors_are_integers:
                    y_hat = np.trunc(y_hat)
                results.append((y_test - ycv) if get_residuals else y_hat)
            return results"""

            """
            # Method 2 : CV Models built when creating the general model
            """
            results: List[float] = []
            # For each stations
            for i, polygon in zip(points.index, points_polygons):
                if not points.loc[[i]].isnull().values.any():
                    polygon_dataset: pd.DataFrame = self._filter_dataset(self._dataset, polygon)
                    dataset_index_station: int = polygon_dataset.index.get_loc(i)
                    # dataset_index_station: int = self._model[polygon].split_dataset()[0].index.get_loc(i)
                    if get_residuals:
                        results.append(self._model[polygon].get_residuals_cv()[dataset_index_station])# if not self._forced_models else i)
                    else:
                        results.append(self._model[polygon].get_predictions_cv()[dataset_index_station]) #if not self._forced_models else i)
            return np.array(results)
        else:
            # Cross-validation - Global
            return self._model[0].get_residuals_cv() if get_residuals else self._model[0].get_predictions_cv()

    def get_residuals_cv(self) -> np.ndarray:
        return self.__get_cv_output(get_residuals=True)

    def predict_cv(self) -> np.ndarray:
        return self.__get_cv_output(get_residuals=False)

    def get_rse(self) -> float:
        """
        Residual Standard Error
        """
        residuals = self.get_residuals_cv()  # self.get_residuals()
        n = len(residuals)
        model_features = len(self.get_labels())
        return LandsklimUtils.rse(residuals, n, model_features)

    def get_mse(self):
        """
        Mean Squared Error
        """
        X, y = self.prepare_dataset()
        """y_hat = self.predict(X)"""
        y_hat = self.predict_cv()
        n = len(X)
        return (1/n) * np.sum(np.power((y - y_hat), 2))

    def get_prediction_interval(self, x, alpha=0.05) -> float:
        """
        TODO
        """
        # z = scipy.stats.norm.ppf(1-(alpha/2))  # t critical value with infinite degrees of freedom
        pass

    def get_rmse(self) -> float:
        """
        Compute the root-mean-square error
        """
        X, y = self.prepare_dataset()
        y_hat = self.predict_cv()  # self.predict(X)
        model_features = len(self.get_labels())
        return LandsklimUtils.rmse(y, y_hat, model_features)

    def get_residuals_standard_deviation(self) -> float:
        """
        Compute model error by cross-validation
        # FIXME: Local analysis
        """
        return self._model[0].get_residuals_standard_deviation()

    def compute_each_variable_correlation(self, X, y):
        # TODO: Keep it ? Better to use Pearson correlation ?
        # WARNING: Only on the case of a global analysis
        for column in X.columns:  # type: str
            window = int(column.split('_')[-1])
            regressor_name = "_".join(column.split('_')[:-1])
            if regressor_name not in self._variables_correlation:
                self._variables_correlation[regressor_name] = {}
            corr_coef = pearsonr(X[column].values, y.values)[0]
            """X_regressor = X[column].values.reshape(-1, 1)
            model_regressor = LinearRegression()
            model_regressor.fit(X_regressor, y)
            corr_coef = 1 - (1 - model_regressor.score(X_regressor, y)) * (len(y) - 1) / (len(y) - X_regressor.shape[1] - 1)"""
            self._variables_correlation[regressor_name][window] = corr_coef

    def valid_interpolation_mask(self, points: "pd.DataFrame", extrapolation_margin: float) -> np.ndarray:
        """
        Get a prediction validity mask according to values of predictors and predicted variable

        :param points: List of predictors
        :type points: pd.DataFrame

        :param extrapolation_margin: Extrapolation margin.
                                     Accept predictions where predictors (and predictions) are in the range of values
                                     used to build model + a margin in percentage, specified by extrapolation_margin
        :type extrapolation_margin: float

        :returns: Prediction validity mask
        :rtype: np.ndarray[bool]
        """

        if len(self._polygons.polygons_definition()) > 1:
            mask = np.zeros(len(points), dtype=bool)
            points_polygons: np.ndarray = self._polygons.get_polygon_of_points(points)
            for i in range(len(self._polygons.polygons_definition())):
                idx = np.nonzero(points_polygons == i)[0]
                if len(idx) > 0:
                    points_of_polygon = points.iloc[idx]
                    mask[idx] = self._model[i].valid_interpolation_mask(points_of_polygon, extrapolation_margin)
        else:
            mask = self._model[0].valid_interpolation_mask(points, extrapolation_margin)

        return mask

    def get_closest_polygon_of_points(self, dataset: pd.DataFrame):
        nan_mask: np.ndarray = dataset.isna().any(axis=1).values
        points = dataset.dropna()[[DATASET_COLUMN_X, DATASET_COLUMN_Y]]

        dataset_polygons = np.empty(len(dataset))
        dataset_polygons.fill(np.nan)

        points[DATASET_COLUMN_X] = ((points[DATASET_COLUMN_X] - self._dem_left) / self._dem_resolution_x).astype(int)
        points[DATASET_COLUMN_Y] = ((-points[DATASET_COLUMN_Y] + self._dem_top) / self._dem_resolution_y).astype(int)

        polygons_center_of_gravity = self._polygons.polygons_centroids()

        distances_matrix: np.ndarray = LandsklimUtils.distance_matrix(points, polygons_center_of_gravity)

        closest_stations = np.sort(np.argpartition(distances_matrix, 2, axis=1)[:, :1], axis=-1).astype(int).ravel()

        dataset_polygons[~nan_mask] = closest_stations
        return dataset_polygons

    def construct_model(self, dataset: "pd.DataFrame"):
        """
        Construct multiple regression model.

        :param dataset: Input dataset for building the model
        :type dataset: pd.DataFrame
        """

        start = time.perf_counter()
        self._dataset = dataset
        self._model = []

        time_in, time_fs, time_md, time_cv = 0, 0, 0, 0

        for i, polygon in enumerate(self._polygons.polygons_definition()):
            sub_start = time.perf_counter()
            polygon_dataset: pd.DataFrame = self._filter_dataset(self._dataset, i)
            model = MultipleRegressionModel()
            if self._forced_models:
                model._dataset = model.remove_geographic_features(polygon_dataset)
                model.create_standardization_params(polygon_dataset)
                model.get_mins_and_maxs(polygon_dataset)
            else:
                time_min, time_mfs, time_mmd, time_mcv = model.construct_model(polygon_dataset)
                time_in += time_min
                time_fs += time_mfs
                time_md += time_mmd
                time_cv += time_mcv
            self._model.append(model)
            if i % 10 == 0:
                print("[phase_multiple_regression][construct model {0}] {1:.3f}s".format(i, time.perf_counter() - sub_start))

        print("[phase_multiple_regression][construct model] {0:.3f}s".format(time.perf_counter()-start))

        print("[Time] Init : {0:1f}s".format(time_in))
        print("[Time] Feature selection : {0:1f}s".format(time_fs))
        print("[Time] Models : {0:1f}s".format(time_md))
        print("[Time] CV : {0:1f}s".format(time_cv))

        X, y = self.prepare_dataset()
        X = self.remove_geographic_features(X)
        self.compute_each_variable_correlation(X, y)

        # Compute error by cross-val
        self._residuals_std = self.get_residuals_standard_deviation() if not self._forced_models else -1
        """if 'encai_5' in X.columns:
            LandsklimUtils.force_local_analysis_models(self, "local_analysis_24_juil.txt")"""

    def smooth_interpolation(self, interpolation: np.ndarray, extrapolation_margin: float) -> np.ndarray:
        """
        Smooth interpolation values out-of-bounds (based on extrapolation margin).
        """
        dataset = self._dataset.dropna()[DATASET_RESPONSE_VARIABLE].values
        dataset_min, dataset_max = dataset.min(), dataset.max()
        lower_bound = dataset_min - np.abs(dataset_max * extrapolation_margin)
        upper_bound = dataset_max + np.abs(dataset_max * extrapolation_margin)
        interpolation = np.copy(interpolation)
        interpolation[(interpolation < lower_bound)] = lower_bound - (np.log(np.abs(interpolation[(interpolation < lower_bound)] - lower_bound)) + 1)
        interpolation[(interpolation > upper_bound)] = upper_bound + (np.log(np.abs(interpolation[(interpolation > upper_bound)] - upper_bound)) + 1)
        return interpolation
    def indices_by_polygons(self, points_polygons) -> Dict[int, np.ndarray]:
        return pd.DataFrame(points_polygons, columns=['polygon']).groupby(by='polygon').groups

    def get_distances_to_centroids(self, points_of_polygon, connected_polygons) -> np.ndarray:
        centroids: np.ndarray = self._polygons.polygons_centroids().copy()
        centroids[:, 0] = self._dem_left + self._dem_resolution_x * centroids[:, 0]
        centroids[:, 1] = self._dem_top - self._dem_resolution_y * centroids[:, 1]

        distances_to_centroids: np.ndarray = np.empty((len(points_of_polygon), len(connected_polygons)))
        distances_to_centroids.fill(np.nan)
        for j, connected_polygon in enumerate(connected_polygons):  # type: int
            x = np.square(points_of_polygon[DATASET_COLUMN_X].values - centroids[connected_polygon, 0])
            y = np.square(points_of_polygon[DATASET_COLUMN_Y].values - centroids[connected_polygon, 1])
            dst = np.sqrt(x + y)
            # print("[distance {0}-{1}] : [{2:.2f}-{3:.2f}]".format(i, connected_polygon, dst.min(), dst.max()))
            distances_to_centroids[:, j] = dst

        # Ponderate results of each connected polygon with the distance of their centroid
        # invert_distances_to_centroids = np.nan_to_num(1/distances_to_centroids, nan=1)
        # Avoid "division by zero" warnings
        invert_distances_to_centroids = np.ones(distances_to_centroids.shape)
        invert_distances_to_centroids[distances_to_centroids != 0] = 1 / np.square(
            distances_to_centroids[distances_to_centroids != 0])

        invert_distances_to_centroids_sums: np.ndarray = invert_distances_to_centroids.sum(axis=1)
        # [:, np.newaxis] is equivalent to .reshape(-1, 1)
        invert_distances_to_centroids = (invert_distances_to_centroids / invert_distances_to_centroids_sums[:, np.newaxis])
        return invert_distances_to_centroids

    def predict(self, points: "pd.DataFrame", extrapolation_margin: Optional[float] = None, no_data: Union[int, float] = None) -> np.ndarray:
        if len(self._polygons.polygons_definition()) > 1:
            # Get polygon of each row
            time_polygons = time.perf_counter()

            if not lkcache.interpolation_cache_enabled():
                points_polygons: np.ndarray = self._polygons.get_polygon_of_points(points)
            else:
                if lkcache.polygon_points_cache() is None:
                    print("[cache not found]")
                    points_polygons: np.ndarray = self._polygons.get_polygon_of_points(points)
                    lkcache.update_polygon_points_cache(points_polygons)
                else:
                    print("[cache found]")
                    points_polygons: np.ndarray = lkcache.polygon_points_cache()
            print("[predict][part 1 : points polygons] {0:.3f}s".format(time.perf_counter() - time_polygons))
            results: np.ndarray = np.empty((len(points)))
            results.fill(np.nan if no_data is None else no_data)

            time_predict = time.perf_counter()
            idx_polygons = self.indices_by_polygons(points_polygons)
            for i in range(len(self._polygons.polygons_definition())):  # type: int
                # idx = np.nonzero(points_polygons == i)[0]
                if i in idx_polygons:
                    idx = idx_polygons[i]
                    # Get rows who belong to this polygon and make prediction through the model of this polygon
                    points_of_polygon = points.iloc[idx]
                    connected_polygons: List[int] = [i] + self._polygons.polygons_connectedness()[i]
                    results_connected_polygons: np.ndarray = np.empty((len(points_of_polygon), len(connected_polygons)))
                    results_connected_polygons.fill(np.nan)

                    # Compute predictions using regression model of each connected polygons
                    for j, connected_polygon in enumerate(connected_polygons):  # type: int
                        self._model[connected_polygon].set_smoothing_mode(SmoothingMode.Local)
                        results_connected_polygons[:, j] = self._model[connected_polygon].predict(points_of_polygon, extrapolation_margin, no_data)

                    """# Compute distance of each points to the centroid of each connected polygons
                    centroids: np.ndarray = self._polygons.polygons_centroids().copy()
                    centroids[:, 0] = self._dem_left + self._dem_resolution_x * centroids[:, 0]
                    centroids[:, 1] = self._dem_top - self._dem_resolution_y * centroids[:, 1]
                    distances_to_centroids: np.ndarray = np.empty((len(points_of_polygon), len(connected_polygons)))
                    distances_to_centroids.fill(np.nan)
                    for j, connected_polygon in enumerate(connected_polygons):  # type: int
                        x = np.square(points_of_polygon[DATASET_COLUMN_X].values - centroids[connected_polygon, 0])
                        y = np.square(points_of_polygon[DATASET_COLUMN_Y].values - centroids[connected_polygon, 1])
                        dst = np.sqrt(x + y)
                        # print("[distance {0}-{1}] : [{2:.2f}-{3:.2f}]".format(i, connected_polygon, dst.min(), dst.max()))
                        distances_to_centroids[:, j] = dst

                    # Ponderate results of each connected polygon with the distance of their centroid
                    # invert_distances_to_centroids = np.nan_to_num(1/distances_to_centroids, nan=1)
                    # Avoid "division by zero" warnings
                    invert_distances_to_centroids = np.ones(distances_to_centroids.shape)
                    invert_distances_to_centroids[distances_to_centroids != 0] = 1/np.square(distances_to_centroids[distances_to_centroids != 0])

                    invert_distances_to_centroids_sums: np.ndarray = invert_distances_to_centroids.sum(axis=1)
                    # [:, np.newaxis] is equivalent to .reshape(-1, 1)
                    invert_distances_to_centroids = (invert_distances_to_centroids / invert_distances_to_centroids_sums[:, np.newaxis])"""
                    invert_distances_to_centroids = self.get_distances_to_centroids(points_of_polygon, connected_polygons)

                    ponderated_results_polygons: np.ndarray = invert_distances_to_centroids * results_connected_polygons
                    ponderated_results: np.ndarray = ponderated_results_polygons.sum(axis=1)
                    results[idx] = ponderated_results
            print("[predict][part 2 : model prediction] {0:.3f}s".format(time.perf_counter() - time_predict))
            # 0.4.4 : Only predictors are smoothed (in function of extrapolation_margin). Interpolations values are not smoothed anymore
            """if extrapolation_margin is not None:
                results[results != no_data] = self.smooth_interpolation(results[results != no_data], extrapolation_margin)"""

            results = np.fix(results) if self._model[0]._predictors_are_integers else results
            return results
        else:
            # Only one polygon -> only one model
            no_data_mask: np.ndarray = points.isna().any(axis=1).values
            points_valid: pd.DataFrame = points.dropna()
            results: np.ndarray = np.empty(len(points))
            results.fill(no_data if no_data is not None else np.nan)
            results[~no_data_mask] = self._model[0].predict(points_valid, extrapolation_margin, no_data)
            # 0.4.4 : Only predictors are smoothed (in function of extrapolation_margin). Interpolations values are not smoothed anymore
            """if extrapolation_margin is not None:
                results[results != no_data] = self.smooth_interpolation(results[results != no_data], extrapolation_margin)"""
            return results


    def get_kwargs(self) -> Dict[str, Any]:
        return {}

    def to_json(self) -> Dict:
        state_dict: Dict = super().to_json()

        scale_means_list: List = []
        scale_std_list: List = []
        dataset_mins: List = []
        dataset_maxs: List = []
        for model in self._model:
            scale_means_list.append(model._scale_means)
            scale_std_list.append(model._scale_std)
            dataset_mins.append(model._dataset_mins)
            dataset_maxs.append(model._dataset_maxs)

        state_dict["scale_means"] = pd.concat(scale_means_list, axis=1).T
        state_dict["scale_std"] = pd.concat(scale_std_list, axis=1).T
        state_dict["dataset_mins"] = np.vstack(dataset_mins)
        state_dict["dataset_maxs"] = np.vstack(dataset_maxs)
        state_dict["scale_column_names"] = self._model[0]._scale_column_names  # _scale_column_names attr is the same of each models


        return state_dict
