from typing import Dict, Any, Union, Optional, Tuple, List

import time
import numpy as np
from PyQt5.QtCore import QCoreApplication

from landsklim.lk.logger import Log
from landsklim.lk.utils import LandsklimUtils
from landsklim.lk.phase_multiple_regression import PhaseMultipleRegression

try:
    import pandas as pd
except ImportError:
    Log.critical("pandas not available")

from landsklim.lk.phase import IPhase


class PhaseComposite(IPhase):
    """
    Defines a phase consisting of combining predictions of two phases
    """

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self._phases: List[IPhase] = [kwargs["phase1"], kwargs["phase2"]]
        self._phases_raster_path: Dict[str, np.ndarray] = {}

    def update_parameters(self, **kwargs):
        super().update_parameters(**kwargs)
        if 'phase' in kwargs:
            for k, v in kwargs['phase'].items():
                self._phases_raster_path[k] = v

    @staticmethod
    def class_name() -> str:
        return "PHASE_COMPOSITE"

    @staticmethod
    def name() -> str:
        return QCoreApplication.translate("Landsklim", "Composite")

    def remove_cache(self):
        self._phases_raster_path = {}

    def get_formula(self, unstandardized: bool) -> str:
        raise NotImplementedError

    def get_adjusted_r2(self) -> float:
        X, y = self.prepare_dataset()
        # TODO: Not good to call phase.class_name()
        labels = [phase.get_labels() for phase in self._phases if phase.class_name() == PhaseMultipleRegression.class_name()][0]
        return LandsklimUtils.adjusted_r2(self.r2(), n=len(y), p=len(labels))

    def get_residuals_standard_deviation(self) -> float:
        return LandsklimUtils.unbiased_estimate_standard_deviation(self.get_residuals())

    def compute_each_variable_correlation(self, X, y):
        # TODO: Keep it ? Better to use Pearson correlation ?
        raise NotImplementedError

    def get_residuals_cv(self) -> np.ndarray:
        X, y = self.prepare_dataset()
        y_hat: np.ndarray = self._phases[0].predict_cv() + self._phases[1].predict_cv()
        return y.values - y_hat

    def predict_cv(self) -> np.ndarray:
        return self._phases[0].predict_cv() + self._phases[1].predict_cv()

    def get_rse(self) -> float:
        """
        Residual Standard Error
        """
        residuals = self.get_residuals_cv()
        n = len(residuals)
        labels = [phase.get_labels() for phase in self._phases if phase.class_name() == PhaseMultipleRegression.class_name()][0]
        model_features = len(labels)
        return LandsklimUtils.rse(residuals, n, model_features)

    def valid_interpolation_mask(self, points: "pd.DataFrame", extrapolation_margin: float) -> np.ndarray:
        """
        Get a prediction validity mask according to values of predictors and predicted variable

        :param points: List of predictors
        :type points: pd.DataFrame

        :param extrapolation_margin: Extrapolation margin.
                                     Accept predictions where predictors (and predictions) are in the range of values
                                     used to build model + a margin in percentage, specified by extrapolation_margin
        :type extrapolation_margin: float

        :returns: Prediction validity mask
        :rtype: np.ndarray[bool]
        """
        return self._phases[0].valid_interpolation_mask(points, extrapolation_margin) & self._phases[1].valid_interpolation_mask(points, extrapolation_margin)

    def construct_model(self, dataset: "pd.DataFrame"):
        self._dataset = dataset

    def predict(self, points: "pd.DataFrame", extrapolation_margin: Optional[float] = None, no_data: Union[int, float] = None) -> np.ndarray:
        """
        Prediction is the sum of the prediction of the two models
        """
        interpolations: List[np.ndarray] = []
        for i, phase in enumerate(self._phases):  # type: IPhase
            time_start = time.perf_counter()
            if phase.class_name() in self._phases_raster_path:
                interpolation: np.ndarray = self._phases_raster_path[phase.class_name()].ravel()  # LandsklimUtils.source_to_array(self._phases_raster_path[phase.class_name()]).ravel()
                del self._phases_raster_path[phase.class_name()]
            else:
                interpolation: np.ndarray = phase.predict(points, extrapolation_margin, no_data)
            time_end = time.perf_counter()
            Log.info("[predict][Phase {2}] Took {0:.3f}s to predict {1} points".format(time_end - time_start, points.shape, i+1))
            interpolations.append(interpolation)

        composite_interpolation = interpolations[0] + interpolations[1]
        if no_data is not None:
            composite_interpolation[((interpolations[0] == no_data) | (interpolations[1] == no_data))] = no_data
        return composite_interpolation

    def get_kwargs(self) -> Dict[str, Any]:
        return {}

    def to_json(self) -> Dict:
        dct = super().to_json()
        dct.pop('_phases')
        dct.pop('_phases_raster_path')
        return dct
