import os
import time
from math import ceil
from typing import Dict, Any, Union, List, Optional

from PyQt5.QtCore import QCoreApplication
from pykrige import OrdinaryKriging
from qgis._core import QgsPoint, QgsVectorLayer, QgsFeature, QgsRectangle, QgsRasterLayer, QgsGeometry, QgsRaster, \
    QgsCoordinateTransform, QgsProject, QgsPointXY, QgsCoordinateReferenceSystem

from landsklim.lk.landsklim_constants import DATASET_COLUMN_X, DATASET_COLUMN_Y, DATASET_RESPONSE_VARIABLE
from landsklim.lk.landsklim_interpolation import LandsklimRectangle
from landsklim.lk.logger import Log
from landsklim.lk.utils import LandsklimUtils
from landsklim.lk.map_layer import RasterLayer, VectorLayer
from landsklim.processing.algorithm_kriging import KrigingProcessingAlgorithm
import numpy as np

try:
    import pandas as pd
except ImportError:
    Log.critical("pandas not available")

from landsklim.lk.phase import IPhase


class PhaseKriging(IPhase):
    """
    Kriging phase

    :param layer_path: polynomial degree
    :type layer_path: str

    :ivar Optional[np.ndarray] __cv_results: Store results of each point prediction using CV
    """

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self._layer_path: str = kwargs["layer_path"]
        self._kriging_mask: RasterLayer = kwargs["dem"]
        self.__cv_results: Optional[np.ndarray] = None
        self.__kriging_estimation_neighborhood: int = 5

    def update_parameters(self, **kwargs):
        super().update_parameters(**kwargs)
        if "kriging_estimation_neighborhood" in kwargs:
            self.__kriging_estimation_neighborhood: int = int(kwargs["kriging_estimation_neighborhood"])

    @staticmethod
    def class_name() -> str:
        return "PHASE_KRIGING"

    @staticmethod
    def name() -> str:
        return QCoreApplication.translate("Landsklim", "Kriging")

    def get_formula(self, unstandardized: bool) -> str:
        raise NotImplementedError

    def get_adjusted_r2(self) -> float:
        X, y = self.prepare_dataset()
        return LandsklimUtils.adjusted_r2(self.r2(), n=len(y), p=X.shape[1])

    def get_residuals_standard_deviation(self) -> float:
        return LandsklimUtils.unbiased_estimate_standard_deviation(self.get_residuals_cv())

    def compute_each_variable_correlation(self) -> float:
        raise NotImplementedError

    def get_residuals_cv(self):
        _, y = self.prepare_dataset()
        return y.values - self.predict_cv()
        # return self._model_residuals()

    def get_residuals_for_phase_2(self):
        return self.get_residuals_cv()

    def valid_interpolation_mask(self, points: "pd.DataFrame", extrapolation_margin: float) -> np.ndarray:
        """
        Get a prediction validity mask according to values of predictors and predicted variable

        :param points: List of predictors
        :type points: pd.DataFrame

        :param extrapolation_margin: Extrapolation margin.
                                     Accept predictions where predictors (and predictions) are in the range of values
                                     used to build model + a margin in percentage, specified by extrapolation_margin
        :type extrapolation_margin: float

        :returns: Prediction validity mask
        :rtype: np.ndarray[bool]
        """
        mask = np.ones(len(points), dtype=bool)

        return mask

    def predict_cv(self) -> np.ndarray:
        return self.__cv_results

    def construct_model(self, dataset: "pd.DataFrame"):
        """
        :param dataset: Dataset is a one-column dataset containing values of each station of krige
        :type dataset: pd.DataFrame
        """
        from qgis import processing
        Log.info("[construct_model] kriging")
        self._dataset = dataset
        data = dataset.dropna()
        station_values = data[DATASET_RESPONSE_VARIABLE].values.ravel()
        station_x = data[DATASET_COLUMN_X].values.ravel()
        station_y = data[DATASET_COLUMN_Y].values.ravel()

        layer_values: QgsVectorLayer = QgsVectorLayer("Point?crs=EPSG:4326&field=value:Double", "layer_residuals", "memory")
        crs: QgsCoordinateReferenceSystem = QgsCoordinateReferenceSystem()
        crs.fromOgcWmsCrs(self._crs_auth_id)
        layer_values.setCrs(crs)
        layer_values.startEditing()
        for (value, position_x, position_y) in zip(station_values, station_x, station_y):
            feat: QgsFeature = QgsFeature(layer_values.fields())
            position: QgsPoint = QgsPoint(position_x, position_y)
            feat.setGeometry(position)
            feat.setAttributes([float(value)])
            layer_values.dataProvider().addFeature(feat)
        layer_values.commitChanges()
        params = {
            KrigingProcessingAlgorithm.INPUT_POINTS_SHAPEFILE: layer_values,
            KrigingProcessingAlgorithm.INPUT_FIELD: 'value',
            KrigingProcessingAlgorithm.INPUT_MASK: self._kriging_mask.qgis_layer(),
            KrigingProcessingAlgorithm.INPUT_CV: True
            # KrigingProcessingAlgorithm.OUTPUT_PATH: self._layer_path
        }
        output = processing.run("landsklim:kriging", params)
        self.__cv_results: np.ndarray = output[KrigingProcessingAlgorithm.OUTPUT_CV]
        self._model: OrdinaryKriging = output[KrigingProcessingAlgorithm.OUTPUT_MODEL]

    def predict(self, points: "pd.DataFrame", extrapolation_margin: Optional[float] = None, no_data: Union[int, float] = None) -> np.ndarray:
        """
        Used solution : Add geographic position of points to predict (X, Y of the kriged CRS)

        - The best option for the architecture (no extra param and no needed reference to self._analysis)

        But :

        - Kriging will be recomputed each time 'predict' is called (which is not semantically wrong)
        - Geographic position must be added to points when predict() is called, and must be ignored on PhaseMultipleRegression when building model because there are not predictor

        The old solution : (with additional parameter [extent: Union[QgsVectorLayer, QgsRectangle] = None ])
        As the kriged layers were already computed, instead of remake predictions, we extract results from a specified extent
        So, "points" param is not used but "extent" is
        "extent" contains a VectorLayer or a rectangle defining where the results must be got.
        This avoids recomputing new kriging but contains three drawbacks :

        - Can't extrapolate kriging outside of DEM extent (it's okay because it loses its value statistically)
        - We must define another parameter ("extent"), only used on PhaseKriging and ignored on PhaseMultipleRegression which is very bad
        - We must reference "_analysis" to get the opened kriged layer on qgis (RasterLayer.clip() need an openable layer), which is non-sense because PhaseKriging should be able to get its data by itself."""

        """
        res: np.ndarray = LandsklimUtils.source_to_array(self._layer_path)
        kriged_layer: RasterLayer
        for situation in self._analysis.get_station_situations():
            if self._analysis.get_kriging_layer_path(situation) == self._layer_path:
                kriged_layer: RasterLayer = self._analysis.get_kriging_layer(situation)

        kriged_qgis_layer: QgsRasterLayer = kriged_layer.qgis_layer()
        transform = QgsCoordinateTransform(kriged_qgis_layer.crs(), self._analysis.get_dem().qgis_layer().crs(), QgsProject().instance())

        if isinstance(extent, QgsRectangle):
            res = kriged_layer.clip(extent)
        if isinstance(extent, QgsVectorLayer):
            datas = []
            for feat in extent.getFeatures():  # type: QgsFeature
                geometry: QgsGeometry = feat.geometry()
                data = kriged_qgis_layer.dataProvider().identify(transform.transform(geometry.asPoint()), QgsRaster.IdentifyFormatValue)
                datas.append(data.results()[1] if data.isValid() and 1 in data.results() else no_data)  # If no data is None ?
            res = np.array(datas)
        return res.ravel()
        """
        time_start = time.perf_counter()
        x = points[DATASET_COLUMN_X].values
        y = points[DATASET_COLUMN_Y].values
        z = KrigingProcessingAlgorithm.run_kriging(self._model, x, y)
        """lag_size = 131072
        lag_count = ceil(len(x) / lag_size)
        zs = []
        for i in range(lag_count):
            Log.info("[lag {0}/{1}]".format(i, lag_count))
            start_i = i*lag_size
            end_i = (i+1)*lag_size
            xi = x[start_i:end_i]  # Even if end_i > len(xi), no error, just subarray is smaller
            yi = y[start_i:end_i]
            # FIXME: When using n_closest_points, kriging estimation time is very huge (10 min vs 5 sec)
            #  (probably because n_closest_points is generally over 20)
            #  Currently, n_closest_points is disabled.
            #  With n_closest_points=[20, 25], range of kriged values are mostly the same than without,
            #  but with a visible nugget effect
            zi, _ = self._model.execute("points", xi, yi, backend="C")  #, n_closest_points=self.__kriging_estimation_neighborhood)
            zs.append(zi)
        z = np.concatenate(zs).ravel()"""

        # z, ss = self._model.execute("points", x, y, backend="C")
        # No data is retrieved based on dataset rows containing nan on predictors
        # No matter with, don't check for no data could be possible :
        # - Phase 2 will not be cropped on DEM extent, but no matter, Phase 1 and Phase 3 based on regression will crop the resulting raster
        if no_data is not None:
            points_values = np.any(points.isna().values, axis=1)
            z[np.nonzero(points_values)[0]] = no_data
        Log.info("[predict][kriging] {0:.3f}s".format(time.perf_counter() - time_start))
        return z

    def get_kwargs(self) -> Dict[str, Any]:
        return {}
