import abc
import cProfile
import time
from typing import Union, Dict, Tuple, List
import os

from osgeo import gdal
from osgeo import osr
from qgis.PyQt.QtCore import QCoreApplication
from qgis._core import QgsProcessingParameterString, QgsProcessingParameterDefinition, Qgis, QgsDistanceArea, \
    QgsProject, QgsGeometry, QgsPointXY
from qgis.core import QgsProcessing, QgsProcessingAlgorithm, QgsProcessingException, QgsProcessingParameterRasterLayer, \
    QgsProcessingParameterNumber, QgsProcessingParameterRasterDestination, QgsRasterLayer, QgsProcessingParameterBoolean, QgsProcessingContext
from qgis import processing
import numpy as np
from scipy import ndimage
from scipy.ndimage import generic_filter
from scipy.signal import convolve2d, convolve
from abc import ABC

from landsklim.lk import environment
from landsklim.processing.landsklim_processing_algorithm import LandsklimProcessingAlgorithm
from landsklim.processing.dem_variables_algorithm import DEMVariablesAlgorithm


try:
    from landsklim.lib import lisdqsapi
except ImportError:
    pass


class LandsklimProcessingRegressorAlgorithm(LandsklimProcessingAlgorithm):
    """
    Abstract processing algorithm computing variables from raster according to specific algorithm
    Handle input/output parameters
    Handle specific utilities for algorithms (kernels computation, ...)
    """

    def __init__(self, output_rasters: Dict):
        super().__init__()
        self.__output_value: str = output_rasters['OUTPUT']

    @abc.abstractmethod
    def createInstance(self):
        raise NotImplementedError()

    @abc.abstractmethod
    def name(self) -> str:
        """
        Unique name of the algorithm
        """
        raise NotImplementedError()

    @abc.abstractmethod
    def displayName(self) -> str:
        """
        Displayed name of the algorithm
        """
        raise NotImplementedError()

    def group(self) -> str:
        return self.tr('Regressors', "LandsklimProcessingRegressorAlgorithm")

    def groupId(self) -> str:
        return 'regressors'

    @abc.abstractmethod
    def shortHelpString(self) -> str:
        raise NotImplementedError()

    @abc.abstractmethod
    def add_dependencies(self):
        raise NotImplementedError()

    def initAlgorithm(self, config=None):
        """
        Define inputs and outputs for the main input
        """
        self.addParameter(
            QgsProcessingParameterRasterLayer(
                'INPUT',
                self.tr('Input raster layer', "LandsklimProcessingRegressorAlgorithm")
            )
        )

        self.add_dependencies()

        self.addParameter(
            QgsProcessingParameterNumber(
                'INPUT_WINDOW',
                self.tr('Window', "LandsklimProcessingRegressorAlgorithm")
            )
        )

        self.addParameter(
            QgsProcessingParameterNumber(
                'INPUT_CUSTOM_NO_DATA',
                self.tr('Define NO_DATA value (override raster NO_DATA)', "LandsklimProcessingRegressorAlgorithm"),
                optional=True,
                defaultValue=None
            )
        )

        additional_variables_folder_parameter = QgsProcessingParameterString(
                'ADDITIONAL_VARIABLES_FOLDER',
                self.tr('Destination to keep intermediate variables', "LandsklimProcessingRegressorAlgorithm"), optional=True
            )
        additional_variables_folder_parameter.setFlags(additional_variables_folder_parameter.flags()
                                                       | QgsProcessingParameterDefinition.FlagHidden)
        self.addParameter(additional_variables_folder_parameter)

        self.addParameter(
            QgsProcessingParameterRasterDestination(
                'OUTPUT',
                self.tr(self.__output_value)  # TODO: Not translated
            )
        )

    def create_kernel(self, radius: int) -> np.array:
        """
        Create circular kernel used to compute the windowed average

        :param radius: Radius of the kernel
        :type radius: int

        :returns: The kernel
        :rtype: np.array
        """

        dist_inside_circle = radius/2.0
        # X and Y indices of each cells of the kernel
        x_indices, y_indices = np.mgrid[0:radius, 0:radius]
        # Center index of matrix
        center = (radius/2) - 0.5
        # Euclidean distance matrix to center
        matrix_dist_center = np.sqrt((x_indices - center) ** 2 + (y_indices - center) ** 2)
        # Keep indices inside circle
        k = np.zeros((radius, radius), dtype=int)
        k[matrix_dist_center < dist_inside_circle] = 1
        # Steps managements when n(kernel) > 50
        # Probably not useful now
        """no = 0
        c_sum = 0
        if k.sum() > 50:
            for i in range(radius):
                for j in range(radius):
                    if k[i, j] == 1:
                        if no % 2 == 0: # or (c_sum > 32):
                            k[i, j] = 0
                        else:
                            c_sum += 1
                        no += 1"""

        return k

    def windowed_average(self, a: np.array, kernel_size: int, no_data: Union[int, float]) -> np.array:
        """
        Compute windows average on a raster.

        :param a: Input array
        :type a: np.array

        :param kernel_size: Window radius
        :type kernel_size: int

        :param no_data: Value to consider as no data
        :type no_data: Union[int, float]

        :returns: Windowed average of the array
        :rtype: np.array
        """
        k = self.create_kernel(kernel_size)
        a_no_data = a.copy()
        a_no_data[a_no_data == no_data] = 0
        # Sum all neighbors for each cell. no_data is set to 0 so it doesn't affect result
        window_sum = ndimage.convolve(a_no_data, k, mode='constant')  # convolve2d(a_no_data, k, mode='same')

        window_count = self.get_neighboors_count_raster(a, k, no_data)

        return window_sum / window_count

    @abc.abstractmethod
    def processAlgorithm(self, parameters, context, feedback):
        raise NotImplementedError()

    def dem_variables(self, raster: np.array, no_data: Union[int, float], k: np.array, path: str) -> tuple:
        """
        Compute intermediate rasters needed to compute slopes and orientations using Landsklim C++ API

        :param raster: Input raster
        :type raster: np.array

        :param no_data: No data value
        :type no_data: Union[int, float]

        :param k: Convolution kernel
        :type k: np.array

        :param path: Folder where intermediate layers are stored
        :type path: str

        :param srs: Spatial reference object to georeference the raster
        :type srs: osr.SpatialReference

        :param geotransform: Geotransform object to georeference the raster
        :type geotransform: tuple

        :returns: Intermediate values used to compute regressors
        :rtype: tuple
        """

        cpp_implementation: bool = environment.USE_CPP

        if no_data is None:
            no_data = raster.max() + 1
        local_min_path = os.path.join(path, "local_min_{0}.npz".format(k.shape[0]))
        local_max_path = os.path.join(path, "local_max_{0}.npz".format(k.shape[0]))
        a0_path = os.path.join(path, "a0_{0}.npz".format(k.shape[0]))
        a1_path = os.path.join(path, "a1_{0}.npz".format(k.shape[0]))
        sdif_path = os.path.join(path, "sdif_{0}.npz".format(k.shape[0]))
        local_min, local_max, raster_a, sdif = None, None, None, None
        if not os.path.exists(local_min_path) or not os.path.exists(local_max_path) or not os.path.exists(a0_path) or not os.path.exists(a1_path) or not os.path.exists(sdif_path):
            time_intermediate_variables = time.perf_counter()
            if cpp_implementation:
                # Kernel must be converted to int32 as C int uses 4 octets
                local_min, local_max, raster_a, sdif = lisdqsapi.dem_variables(raster.astype(np.float64), float(no_data), k.astype(np.int32))
            else:
                print("[DEM Python implementation]")
                n = self.get_neighboors_count_raster(raster, k, no_data)
                dem_variables_algorithm = DEMVariablesAlgorithm()
                with cProfile.Profile() as pr:
                    local_min, local_max, raster_a, sdif = dem_variables_algorithm.compute(raster.astype(np.float64), float(no_data), k.astype(np.int32), n)
                    if environment.TEST_MODE:
                        pr.dump_stats('dem.prof')
            print("[intermediate_variables] {0:.3f}s".format(time.perf_counter() - time_intermediate_variables))

            print("Memory usage : {0}".format(self.get_current_memory_usage()))
            if len(path) > 0:
                if not os.path.exists(path):
                    os.makedirs(path)
                self.write_array_if_not_exists(local_min_path, local_min)
                self.write_array_if_not_exists(local_max_path, local_max)
                self.write_array_if_not_exists(a0_path, raster_a[:, 0].reshape(raster.shape))
                self.write_array_if_not_exists(a1_path, raster_a[:, 1].reshape(raster.shape))
                self.write_array_if_not_exists(sdif_path, sdif)
        else:
            local_min = self.load_numpy_array(local_min_path)
            local_max = self.load_numpy_array(local_max_path)
            a0 = self.load_numpy_array(a0_path)
            a1 = self.load_numpy_array(a1_path)
            raster_a = np.stack((a0.ravel(), a1.ravel()), axis=1)
            sdif = self.load_numpy_array(sdif_path)
            # Only to match GDAL datatype
            local_min = local_min.astype(np.float32).astype(np.float64)
            local_max = local_max.astype(np.float32).astype(np.float64)
            raster_a = raster_a.astype(np.float32).astype(np.float64)
            sdif = sdif.astype(np.float32).astype(np.float64)
        return local_min, local_max, raster_a, sdif

    def write_array_if_not_exists(self, path: str, array: np.ndarray):
        if not os.path.exists(path):
            self.write_numpy_array(path, array)

    def get_neighboors_count_raster(self, a: np.array, k: np.array, no_data: Union[int, float]):
        """
        Get neighbors for each cell excluding no data values

        :param a: Input raster
        :type a: np.array

        :param k: Kernel
        :type k: np.array

        :param no_data: No data value
        :type no_data: Union[int, float]

        :param zero_padding: If True, set padding as 0 so border are not counted. Otherwise, set padding as "same" and borders are counted
        :type zero_padding: bool

        :returns: Raster containing for each cell its neighbors count
        :rtype: np.array
        """
        a_no_data = a.copy()
        a_no_data[a_no_data == no_data] = 0

        # Count neighbors for each cell. no_data cells are set to 0 so these neighbors will be ignored when averaging
        mask = np.ones(a.shape, dtype=bool)
        mask[a == no_data] = 0
        window_count = convolve2d(mask, k, mode='same')
        # Set 1 for no_data cell only surrounded by no_data cells to avoid division by 0
        # (but the result will always be 0)
        # TODO: Could set np.nan instead
        window_count[window_count == 0] = 1
        return window_count

    def get_neighboors_count_raster_with_min_value(self, a: np.array, k: np.array, no_data: Union[int, float], min_value_allowed: float, debug):
        """
        Get neighbors for each cell excluding no data values

        :param a: Input raster
        :type a: np.array

        :param k: Kernel
        :type k: np.array

        :param no_data: No data value
        :type no_data: Union[int, float]

        :param zero_padding: If True, set padding as 0 so border are not counted. Otherwise, set padding as "same" and borders are counted
        :type zero_padding: bool

        :returns: Raster containing for each cell its neighbors count
        :rtype: np.array
        """

        # Count neighbors for each cell. no_data cells are set to 0 so these neighbors will be ignored when averaging
        mask = np.ones(a.shape, dtype=bool)
        mask[(a == no_data) | (a <= min_value_allowed)] = 0

        window_count = convolve2d(mask, k, mode='same')

        # Set 1 for no_data cell only surrounded by no_data cells to avoid division by 0
        # (but the result will always be 0)
        window_count[(window_count == 0)] = np.nan
        return window_count

    def get_pixel_size(self, input_raster: QgsRasterLayer, geotransform: Tuple) -> Tuple[float, float]:
        """
        Get pixel size of a raster, in meters

        :param input_raster: Source raster to compute pixel size from
        :type input_raster: QgsRasterLayer

        :param geotransform: Geotransform tuple
        :type geotransform: Tuple

        :returns: X size of pixel in meters and Y size of pixel in meters
        :rtype: Tuple[float, float]
        """
        pixel_size_x: float
        pixel_size_y: float
        if not input_raster.crs().isGeographic():
            pixel_size_x, pixel_size_y = geotransform[1], -geotransform[5]
        else:
            if self.qgis_version()[1] >= 30:
                unit = Qgis.DistanceUnit.Meters
            else:
                from qgis._core import QgsUnitTypes
                unit = QgsUnitTypes.DistanceMeters

            distance = QgsDistanceArea()
            crs = input_raster.crs()
            distance.setSourceCrs(crs, QgsProject.instance().transformContext())
            distance.setEllipsoid(crs.ellipsoidAcronym())

            points_x = QgsGeometry.fromPolylineXY(
                [
                    QgsPointXY(geotransform[0], geotransform[3]),
                    QgsPointXY(geotransform[0] + geotransform[1], geotransform[3])
                ]
            )
            points_y = QgsGeometry.fromPolylineXY(
                [
                    QgsPointXY(geotransform[0], geotransform[3]),
                    QgsPointXY(geotransform[0], geotransform[3] + geotransform[5])
                ]
            )
            pixel_size_x = distance.convertLengthMeasurement(distance.measureLength(points_x), unit)
            pixel_size_y = distance.convertLengthMeasurement(distance.measureLength(points_y), unit)

        return pixel_size_x, pixel_size_y

