import time
from math import ceil
from typing import Union, Tuple, List, Optional

from PyQt5.QtCore import QCoreApplication
from osgeo import osr, gdal
from osgeo.osr import SpatialReference
from pykrige.ok import OrdinaryKriging

from qgis._core import QgsProcessingParameterEnum, QgsProcessingParameterDefinition, QgsProcessingParameterExtent, \
    QgsProcessingParameterVectorLayer, QgsRectangle, QgsVectorLayer, QgsProcessingParameterVectorDestination, \
    QgsProcessingParameterFeatureSink, QgsProcessingParameterField, QgsField, QgsGeometry, QgsFeature, QgsUnitTypes, \
    QgsCoordinateTransform, QgsProject
from qgis.core import QgsProcessing, QgsProcessingAlgorithm, QgsProcessingException, QgsProcessingParameterRasterLayer, \
    QgsProcessingParameterNumber, QgsProcessingParameterRasterDestination, QgsRasterLayer, QgsProcessingParameterBoolean

import numpy as np

from landsklim.lk.utils import LandsklimUtils
from landsklim.processing.landsklim_processing_algorithm import LandsklimProcessingAlgorithm


class KrigingProcessingAlgorithm(LandsklimProcessingAlgorithm):
    """
    Processing algorithm computing kriging from a vector layer
    """
    INPUT_MASK = 'INPUT_MASK'
    INPUT_POINTS_SHAPEFILE = 'INPUT_POINTS_SHAPEFILE'
    INPUT_FIELD = 'INPUT_FIELD'
    INPUT_CV = 'INPUT_CV'
    OUTPUT_PATH = 'OUTPUT_PATH'
    OUTPUT_MODEL = 'OUTPUT_MODEL'
    OUTPUT_CV = 'OUTPUT_CV'

    def createInstance(self):
        return KrigingProcessingAlgorithm()

    def name(self) -> str:
        """
        Unique name of the algorithm
        """
        return 'kriging'

    def displayName(self) -> str:
        """
        Displayed name of the algorithm
        """
        return QCoreApplication.translate("Landsklim", "Kriging")

    def group(self) -> str:
        return "Interpolation"

    def groupId(self) -> str:
        return 'interpolation'

    def shortHelpString(self) -> str:
        return QCoreApplication.translate("Landsklim", "Compute ordinary kriging")

    def parameters(self):
        self.addParameter(
            QgsProcessingParameterVectorLayer(
                self.INPUT_POINTS_SHAPEFILE,
                QCoreApplication.translate('Landsklim', 'Shapefile'),
                [QgsProcessing.TypeVectorPoint]
            )
        )

        self.addParameter(
            QgsProcessingParameterField(
                self.INPUT_FIELD,
                QCoreApplication.translate('Landsklim', 'Field to krige'),
                type=QgsProcessingParameterField.Numeric,
                parentLayerParameterName=self.INPUT_POINTS_SHAPEFILE
            )
        )

        self.addParameter(
            QgsProcessingParameterBoolean(
                self.INPUT_CV,
                QCoreApplication.translate('Landsklim', 'Enable cross-validation. If enabled, kriging model will be the same as if not, but a prediction for each point of the input layer will be given, excluding this point of the neighborhood during the estimation.')
            )
        )

        self.addParameter(
            QgsProcessingParameterRasterLayer(
                self.INPUT_MASK,
                QCoreApplication.translate('Landsklim', 'Mask raster')
            )
        )

        self.addParameter(
            QgsProcessingParameterRasterDestination(
                self.OUTPUT_PATH,
                QCoreApplication.translate('Landsklim','Output raster'),
                optional=True
            )
        )

    def initAlgorithm(self, config=None):
        """
        Define inputs and outputs for the main input
        """
        self.parameters()

    def max_distance_between_two_points(self, layer_source: QgsVectorLayer, transform: QgsCoordinateTransform) -> float:
        """
        Get the maximum distance possible between two points on the vector layer

        :param layer_source: Layer containing points
        :type layer_source: QgsVectorLayer

        :returns: Maximum distance found between two points
        :rtype: float
        """

        layer_source_geo_coord = []
        for feature in layer_source.getFeatures():
            feature_geo = transform.transform(feature.geometry().asPoint())
            layer_source_geo_coord.append(feature_geo)

        max_distance = np.nan
        for f1geo in layer_source_geo_coord:
            for f2geo in layer_source_geo_coord:
                f_dist = f1geo.distance(f2geo)  # f1.geometry().distance(f2.geometry())
                max_distance = f_dist if (f_dist > max_distance or max_distance is np.nan) else max_distance
        return max_distance

    def average_nearest_neighbors(self, layer_source: QgsVectorLayer, transform: QgsCoordinateTransform) -> float:
        """
        Get the average distance between each point and its nearest neighbour

        :param layer_source: Layer containing points
        :type layer_source: QgsVectorLayer

        :returns: Average distance between each point and its nearest neighbour
        :rtype: float
        """

        nearest_neighbors = []
        layer_source_geo_coord = []
        for feature in layer_source.getFeatures():
            feature_geo = transform.transform(feature.geometry().asPoint())
            layer_source_geo_coord.append(feature_geo)

        for i, f1geo in enumerate(layer_source_geo_coord):
            nearest_neighbor = np.nan
            for j, f2geo in enumerate(layer_source_geo_coord):
                if i != j:
                    f_dist = f1geo.distance(f2geo)  # f1.geometry().distance(f2.geometry())
                    nearest_neighbor = f_dist if (f_dist < nearest_neighbor or nearest_neighbor is np.nan) else nearest_neighbor
            nearest_neighbors.append(nearest_neighbor)
        return np.array(nearest_neighbors).mean()

    """def run_kriging_base(self, kriging: OrdinaryKriging, output_shape: Tuple[int, int], layer_mask: QgsRasterLayer, resolution_x: float, resolution_y: float, no_data_mask: np.ndarray):
        backend = "C"
        gridx = np.linspace(layer_mask.extent().xMinimum(), layer_mask.extent().xMaximum() - resolution_x, output_shape[1])  # TODO: Geographic coordinates ?
        gridy = np.linspace(layer_mask.extent().yMaximum(), layer_mask.extent().yMinimum() + resolution_y, output_shape[0])  # TODO: Geographic coordinates ?

        output = np.zeros(output_shape)
        # Split computations to avoid allocating too many memory at once
        start_time = time.perf_counter()
        for i, y in enumerate(gridy):
            if i % 100 == 0 and i > 0:
                top_time = time.perf_counter() - start_time
                time_by_unit = top_time / i
                remaining_time = (time_by_unit * len(gridy)) - top_time
                print("[row] {0}/{1}. Remaining : {2:.2f}s".format(i, len(gridy), remaining_time))
            zy, ssy = kriging.execute("grid", gridx, [y], backend=backend)
            output[i, :] = zy
        output[no_data_mask] = 0
        print("[kriging] Took {0:.2f}s".format(time.perf_counter() - start_time))
        return output

    def run_kriging_batch_grid(self, kriging: OrdinaryKriging, output_shape: Tuple[int, int], layer_mask: QgsRasterLayer, resolution_x: float, resolution_y: float, no_data_mask: np.ndarray):
        backend = "vectorized"
        gridx = np.linspace(layer_mask.extent().xMinimum(), layer_mask.extent().xMaximum() - resolution_x,
                            output_shape[1])  # TODO: Geographic coordinates ?
        gridy = np.linspace(layer_mask.extent().yMaximum(), layer_mask.extent().yMinimum() + resolution_y,
                            output_shape[0])  # TODO: Geographic coordinates ?

        output = np.zeros(output_shape)
        # Split computations to avoid allocating too many memory at once
        batch_size = 1024  # [rows]
        batch_count = ceil(len(gridy) / batch_size)
        start_time = time.perf_counter()
        for i in range(batch_count):
            start_i = i * batch_size
            end_i = (i + 1) * batch_size
            y_rows = gridy[start_i:end_i]
            if i % 10 == 0 and i > 0:
                top_time = time.perf_counter() - start_time
                time_by_unit = top_time / i
                remaining_time = (time_by_unit * batch_count) - top_time
                print("[row] {0}-{1}/{2} Remaining : {3:.2f}s".format(start_i, end_i, len(gridy), remaining_time))

            zy, ssy = kriging.execute("grid", gridx, y_rows, backend=backend)
            output[start_i:end_i, :] = zy
        print("[kriging] Took {0:.2f}s".format(time.perf_counter() - start_time))
        return output"""

    """def run_kriging(self, kriging: OrdinaryKriging, output_shape: Tuple[int, int], layer_mask: QgsRasterLayer, resolution_x: float, resolution_y: float, no_data_mask: np.ndarray):
        backend = "C"  # "C"
        start_time = time.perf_counter()
        gridx = np.linspace(layer_mask.extent().xMinimum(), layer_mask.extent().xMaximum() - resolution_x, output_shape[1])  # TODO: Geographic coordinates ?
        gridy = np.linspace(layer_mask.extent().yMaximum(), layer_mask.extent().yMinimum() + resolution_y, output_shape[0])  # TODO: Geographic coordinates ?
        grid_x, grid_y = np.meshgrid(gridx, gridy)
        xpts = grid_x.flatten()
        ypts = grid_y.flatten()
        zy, _ = kriging.execute("grid", gridx, gridy, backend=backend)
        output = zy.reshape(output_shape)
        print("[kriging] Took {0:.2f}s".format(time.perf_counter() - start_time))
        return output"""

    @staticmethod
    def run_kriging(kriging_model: OrdinaryKriging, xpts: np.ndarray, ypts: np.ndarray) -> np.ndarray:
        """
        Execute kriging on points

        :param kriging_model: Kriging model to use
        :type kriging_model: OrdinaryKriging

        :param xpts: List of x-coordinates
        :type xpts: np.ndarray

        :param ypts: List of y-coordinates
        :type ypts: np.ndarray

        :returns: Kriged values
        :rtype: np.ndarray
        """
        output: np.ndarray = np.zeros(len(xpts))
        variogram_points: int = len(kriging_model.Z)
        order_of_magnitude: int = ceil(np.log2(variogram_points)) - 9  # when the number of stations become big, reduce batch_size accordingly
        if order_of_magnitude < 0:
            order_of_magnitude = 0
        batch_size = 524288 // (2 ** order_of_magnitude)  # len(xpts)
        batch_count = ceil(len(xpts) / batch_size)
        backend = "C"  # "C" with "n_closest_points=15" is faster
        n_closest_points = None
        # Arbitrary value. For big rasters with many points, kriging time become huge.
        # To reduce time, kriging for each pixel is limited by its 15 closest neighbors
        # if raster_size * variogram_points > 4 000 000 000 (arbitrary)
        # n_closest_points = 15 because [over this value], closest points computation time is [encore plus] time expensive
        if len(xpts) * variogram_points > 4000000000:
            n_closest_points = 15

        start_time = time.perf_counter()
        # Split computations to avoid allocating too many memory at once
        for i in range(batch_count):
            start_i = i * batch_size
            end_i = (i + 1) * batch_size
            if i > 0:
                top_time = time.perf_counter() - start_time
                time_by_unit = top_time / i
                remaining_time = (time_by_unit * batch_count) - top_time
                print("[kriging][batch] {0}/{1} Remaining : {2:.2f}s".format(i, batch_count, remaining_time))
            zy, _ = kriging_model.execute("points", xpts[start_i:end_i], ypts[start_i:end_i], backend=backend,
                                      n_closest_points=n_closest_points)
            output[start_i:end_i] = zy
        return output

    def __run_kriging(self, kriging: OrdinaryKriging, output_shape: Tuple[int, int], layer_mask: QgsRasterLayer, resolution_x: float, resolution_y: float, no_data_mask: np.ndarray) -> np.ndarray:
        gridx = np.linspace(layer_mask.extent().xMinimum(), layer_mask.extent().xMaximum() - resolution_x, output_shape[1])  # TODO: Geographic coordinates ?
        gridy = np.linspace(layer_mask.extent().yMaximum(), layer_mask.extent().yMinimum() + resolution_y, output_shape[0])  # TODO: Geographic coordinates ?
        grid_x, grid_y = np.meshgrid(gridx, gridy)
        no_data_mask_flatten = no_data_mask.flatten()
        xpts = grid_x.flatten()[~no_data_mask_flatten]
        ypts = grid_y.flatten()[~no_data_mask_flatten]
        output_array = np.zeros(len(no_data_mask_flatten))  # no_data value is filled after

        start_time = time.perf_counter()
        output = self.run_kriging(kriging, xpts, ypts)
        output_array[~no_data_mask_flatten] = output
        print("[kriging] Took {0:.2f}s".format(time.perf_counter() - start_time))
        return output_array.reshape(output_shape)

    def ordinary_kriging(self, layer_source: QgsVectorLayer, data_field_index: int, layer_mask: QgsRasterLayer, execute_kriging: bool, with_cv: bool, no_data: Optional[Union[int, float]]) -> Tuple[np.ndarray, OrdinaryKriging, np.ndarray]:
        """
        Compute ordinary kriging using PyKrige

        :param layer_source: Point layer to krige
        :type layer_source: QgsVectorLayer

        :param data_field_index: Field index on the point layer where data is located
        :type data_field_index: int

        :param layer_mask: Extent for the kriged raster
        :type layer_mask: QgsRasterLayer

        :param execute_kriging: Krige the layer mask if True. If False, only compute the kriging model
        :type execute_kriging: bool

        :param with_cv: Enable CV
        :type with_cv: bool

        :param no_data:
        :type no_data: NO_DATA value

        :returns:
            - Kriged array
            - Kriging model
            - Predictions of each point of the input layer
        :rtype: Tuple[np.ndarray, OrdinaryKriging, np.ndarray]
        """

        sourceCrs = layer_source.crs()
        destCrs = layer_mask.crs()
        tr: QgsCoordinateTransform = QgsCoordinateTransform(sourceCrs, destCrs, QgsProject.instance())

        np_mask = LandsklimUtils.raster_to_array(layer_mask)
        points_value = [f.attributes()[data_field_index] for f in layer_source.getFeatures()]
        lag_distance = self.average_nearest_neighbors(layer_source, tr)
        lags = ceil((self.max_distance_between_two_points(layer_source, tr)) / lag_distance)

        points_x: List[float]
        points_y: List[float]

        # If the CRS of these two layers are not the same,
        # we must convert the position of each points to the destination CRS
        if layer_mask.crs() != layer_source.crs():
            points_x = [tr.transform(f.geometry().asPoint()).x() for f in layer_source.getFeatures()]
            points_y = [tr.transform(f.geometry().asPoint()).y() for f in layer_source.getFeatures()]
        else:
            points_x = [f.geometry().asPoint().x() for f in layer_source.getFeatures()]
            points_y = [f.geometry().asPoint().y() for f in layer_source.getFeatures()]

        resolution_x: float = layer_mask.rasterUnitsPerPixelX()
        resolution_y: float = layer_mask.rasterUnitsPerPixelY()
        coordinates_type = 'geographic' if layer_mask.crs().mapUnits() == QgsUnitTypes.DistanceDegrees else 'euclidean'
        """
        # For models defining a range
        variogram_parameters = {'range': self.max_distance_between_two_points(layer_source)/2}"""

        # Using a 'power' variogram allows kriging to match better LISDQS kriging
        # 'exponent' parameter of fitted variogram is close to 0.3 (stronger than a squareroot) : proxy of distance² for a linear variogram

        variogram_model = 'exponential'  # 'exponential'
        # WIP: If changes here, change also on the 'with_cv' block
        kriging: OrdinaryKriging = OrdinaryKriging(x=points_x, y=points_y, z=points_value, nlags=lags, weight=True, variogram_model=variogram_model, coordinates_type=coordinates_type, enable_statistics=False, verbose=False)
        z: np.ndarray = np.zeros(np_mask.shape)

        if execute_kriging:
            no_data_mask = np_mask == no_data
            z = self.__run_kriging(kriging, np_mask.shape, layer_mask, resolution_x, resolution_y, no_data_mask)

        cv_results = np.zeros(len(points_value))
        if with_cv:
            for i in range(len(points_value)):
                if i % 100 == 0:
                    print("[kriging][cross-validation]", i)
                points_x_cv = points_x[:i] + points_x[i + 1:]
                points_y_cv = points_y[:i] + points_y[i + 1:]
                points_value_cv = points_value[:i] + points_value[i + 1:]
                point_x_cv = points_x[i]
                point_y_cv = points_y[i]
                kriging_cv: OrdinaryKriging = OrdinaryKriging(x=points_x_cv, y=points_y_cv, z=points_value_cv, nlags=lags, weight=True, variogram_model=variogram_model, coordinates_type=coordinates_type, enable_statistics=False, verbose=False)  # This step is the longer
                n_closest_points = 5
                z_cv, _ = kriging_cv.execute("points", [point_x_cv], [point_y_cv], backend="C", n_closest_points=n_closest_points)
                cv_results[i] = z_cv[0]

        return z, kriging, cv_results

    def variogram(self, m: List[float], d):
        # Work in progress
        slope = float(m[0])
        nugget = float(m[1])
        return slope * (d*d) + nugget

    def _build_variogram(self, x: np.ndarray, y: np.ndarray, z: np.ndarray):
        # Work in progress
        pass

    def processAlgorithm(self, parameters, context, feedback):
        """
        Called when a processing algorithm is run
        """
        input_points: QgsVectorLayer = self.parameterAsVectorLayer(parameters, self.INPUT_POINTS_SHAPEFILE, context)
        # Field name is got as a unique string (only one field is expected)
        input_field: str = self.parameterAsString(parameters, self.INPUT_FIELD, context)
        input_fields = [field.name() for field in input_points.fields()]
        field_index = input_fields.index(input_field)
        input_cv: bool = self.parameterAsBool(parameters, self.INPUT_CV, context)

        input_mask: QgsRasterLayer = self.parameterAsRasterLayer(parameters, self.INPUT_MASK, context)
        np_mask = LandsklimUtils.raster_to_array(input_mask)
        no_data, geotransform = self.get_raster_metadata(parameters, context, input_mask)

        out_srs: SpatialReference = self.get_spatial_reference(input_mask)

        out_path: Optional[str] = self.parameterAsOutputLayer(parameters, self.OUTPUT_PATH, context) if self.OUTPUT_PATH in parameters else None

        raster_output, kriging_model, cv = self.ordinary_kriging(input_points, field_index, input_mask, execute_kriging=out_path is not None, with_cv=input_cv, no_data=no_data)
        raster_output[np_mask == no_data] = no_data
        output = {self.OUTPUT_MODEL: kriging_model}
        print("[out_path]", out_path)
        if out_path is not None:
            self.write_raster(out_path, raster_output, out_srs, geotransform, no_data)
            output[self.OUTPUT_PATH] = out_path
        if input_cv:
            output[self.OUTPUT_CV] = cv

        return output