import gc
import os
from functools import partial
from multiprocessing import Pool
import multiprocessing as mp
from typing import Union, Tuple, Callable, List
import ctypes

from osgeo import gdal
from osgeo import osr
from osgeo.osr import SpatialReference
from qgis.core import QgsProcessing, QgsProcessingAlgorithm, QgsProcessingException, QgsProcessingParameterRasterLayer, \
    QgsProcessingParameterNumber, QgsProcessingParameterRasterDestination, QgsRasterLayer, QgsProcessingParameterBoolean
from qgis import processing
import numpy as np

from landsklim.processing.landsklim_processing_regressor_algorithm import LandsklimProcessingRegressorAlgorithm
from landsklim.lk.utils import LandsklimUtils
from landsklim.lk import environment


class SlopeProcessingAlgorithm(LandsklimProcessingRegressorAlgorithm):
    """
    Processing algorithm computing slope from a DEM
    """

    def __init__(self):
        super().__init__({'OUTPUT': 'Slope raster'})

    def createInstance(self):
        return SlopeProcessingAlgorithm()

    def name(self) -> str:
        """
        Unique name of the algorithm
        """
        return 'slope'

    def displayName(self) -> str:
        """
        Displayed name of the algorithm
        """
        return self.tr('Slope')

    def shortHelpString(self) -> str:
        return self.tr('Compute windowed slope from DEM')

    def add_dependencies(self):
        """
        No dependencies
        """
        pass

    def compute_slope(self, a: np.array, kernel_size: int, no_data: Union[int, float], raster_pixel_size: int, local_min: np.array, local_max: np.array) -> np.array:
        """
        Compute slope of a raster

        :param a: Input array
        :type a: np.ndarray

        :param kernel_size: Window radius
        :type kernel_size: int

        :param no_data: Value to consider as no data
        :type no_data: Union[int, float]

        :param raster_pixel_size: Pixel size according to the raster unit
        :type raster_pixel_size: int

        :param local_min: Raster containing local minimums
        :type local_min: np.ndarray

        :param local_max: Raster containing local maximums
        :type local_max: np.ndarray

        :returns: Slope for each cell of the raster
        :rtype: np.ndarray
        """
        altitude_unit = 1
        slopes = np.degrees(np.arctan(((local_max - local_min) / altitude_unit) / ((raster_pixel_size * kernel_size) / 2)))

        # Erase every computed data by no data where necessary
        slopes[a == no_data] = -9999
        return slopes

    def processAlgorithm(self, parameters, context, feedback):
        """
        Called when a processing algorithm is run
        """
        # Load input raster and its metadata
        input_raster: QgsRasterLayer = self.parameterAsRasterLayer(parameters, 'INPUT', context)
        no_data, geotransform = self.get_raster_metadata(parameters, context, input_raster)

        input_window = self.parameterAsInt(parameters, 'INPUT_WINDOW', context)
        additional_variables_folder = self.parameterAsString(parameters, 'ADDITIONAL_VARIABLES_FOLDER', context).strip()

        pixel_size_x, pixel_size_y = self.get_pixel_size(input_raster, geotransform)  # TODO: Only handle unique raster unit (pixel_size_y not used)

        out_srs: SpatialReference = self.get_spatial_reference(input_raster)

        # Path of the layer is given. If a temporary layer is selected, layer is created in qgis temp dir
        out_slope = self.parameterAsOutputLayer(parameters, 'OUTPUT', context)

        np_input = LandsklimUtils.raster_to_array(input_raster)
        np_output = np.copy(np_input)

        if not environment.USE_QGIS_IMPLEMENTATION:
            local_min, local_max, _, _ = self.dem_variables(np_input, no_data, self.create_kernel(input_window), additional_variables_folder)

            np_output = self.compute_slope(np_output, input_window, no_data, pixel_size_x, local_min, local_max)
            # driver = gdal.GetDriverByName('GTiff')
            # Slope
            self.write_raster(out_slope, np_output, out_srs, geotransform, -9999)
        else:
            self.qgis_slope(input_raster, input_window, out_slope)

        return {'OUTPUT': out_slope}

    def qgis_slope(self, raster: QgsRasterLayer, kernel_size: int, out_path_slope: str):
        """
        Use the QGIS/GDAL implementation of the slope algorithm

        :param raster: Input raster (DEM, smoothing according the kernel)
        :type raster: QgsRasterLayer

        :param kernel_size: Kernel size
        :type kernel_size: int

        :param out_path_slope: Destination file of the 'slope' output
        :type out_path_slope: str
        """

        source: str = raster.source()
        smoothed_source = processing.run("landsklim:altitude",
                                      {
                                          'INPUT': source,
                                          'INPUT_WINDOW': kernel_size,
                                          'INPUT_CUSTOM_NO_DATA': None,
                                          'OUTPUT': QgsProcessing.TEMPORARY_OUTPUT
                                      })['OUTPUT']

        ratio_vh = raster.rasterUnitsPerPixelY() / raster.rasterUnitsPerPixelX()

        processing.run("gdal:slope",
                        {
                            'INPUT': smoothed_source,
                            'BAND': 1,
                            'SCALE': ratio_vh,
                            'AS_PERCENT': False,
                            'COMPUTE_EDGES': False,
                            'ZEVENBERGEN': False,
                            'OPTIONS': '',
                            'EXTRA': '',
                            'OUTPUT': out_path_slope
                         }
                       )
