import gc
import os.path
from functools import partial
from multiprocessing import Pool
import multiprocessing as mp
from typing import Union, Tuple, Callable
import ctypes

from osgeo import gdal
from osgeo import osr
from osgeo.osr import SpatialReference
from qgis.PyQt.QtCore import QCoreApplication
from qgis._core import QgsProcessingParameterRasterLayer
from qgis.core import QgsProcessing, QgsProcessingAlgorithm, QgsProcessingException, QgsRasterLayer
from qgis import processing
import numpy as np
import math

from scipy.signal import convolve2d

from landsklim.processing.landsklim_processing_regressor_algorithm import LandsklimProcessingRegressorAlgorithm
from landsklim.lk.utils import LandsklimUtils
from landsklim.lk import environment


class OrientationProcessingAlgorithm(LandsklimProcessingRegressorAlgorithm):
    """
    Processing algorithm computing orientation from a DEM
    """

    INPUT_SLOPE = 'INPUT_SLOPE'

    def __init__(self):
        super().__init__({'OUTPUT': 'Orientation raster'})

    def createInstance(self):
        return OrientationProcessingAlgorithm()

    def name(self) -> str:
        """
        Unique name of the algorithm
        """
        return 'orientation'

    def displayName(self) -> str:
        """
        Displayed name of the algorithm
        """
        return self.tr('Orientation')

    def shortHelpString(self) -> str:
        return self.tr('Compute windowed orientation from DEM')

    def add_dependencies(self):
        self.addParameter(
            QgsProcessingParameterRasterLayer(
                OrientationProcessingAlgorithm.INPUT_SLOPE,
                self.tr('Slope', "LandsklimProcessingRegressorAlgorithm")
            )
        )

    def compute_orientation(self, a: np.array, kernel_size: int, no_data: Union[int, float], raster_pixel_size: int, raster_a: np.array, slopes: np.ndarray) -> np.array:
        """
        Compute orientation of a raster

        :param a: Input array
        :type a: np.ndarray

        :param kernel_size: Window radius
        :type kernel_size: int

        :param no_data: Value to consider as no data
        :type no_data: Union[int, float]

        :param raster_pixel_size: Pixel size according to the raster unit
        :type raster_pixel_size: int

        :param local_min: Raster containing local minimums
        :type local_min: np.ndarray

        :param local_max: Raster containing local maximums
        :type local_max: np.ndarray

        :param raster_a: 2-band raster containing a values
        :type raster_a: np.ndarray

        :param slopes: Slopes
        :type slopes: np.ndarray

        :returns: Orientation values for each cell of the raster
        :rtype: np.ndarray
        """
        print("[Compute orientation] ", kernel_size)
        k = self.create_kernel(kernel_size)

        # Warning : C++ implementation leads to raster_a[:, :] = 1.0000000 when Python implementation leads to raster_a[:, :] having 0.99999999999999, making big changes when casting to int
        pti = np.degrees(np.arctan(((np.abs(raster_a[:, 0] * 10).astype(int)) / raster_pixel_size)/10))  # double parsing precision is different from C++
        ptj = np.degrees(np.arctan(((np.abs(raster_a[:, 1] * 10).astype(int)) / raster_pixel_size)/10))  # double parsing precision is different from C++
        pti[pti < 0.001] = 0.001
        ptj[ptj < 0.001] = 0.001
        ra = pti - ptj
        raa = np.zeros(pti.shape).astype(int)
        raa[ra > 0] = 1
        raa[~(ra > 0)] = 2
        a1 = np.zeros(pti.shape).astype(int)
        a2 = np.zeros(pti.shape).astype(int)
        a1[raster_a[:, 0] > 0] = 1
        a1[~(raster_a[:, 0] > 0)] = 2
        a2[raster_a[:, 1] > 0] = 1
        a2[~(raster_a[:, 1] > 0)] = 2

        orientations = np.zeros(pti.shape)

        orientations[(a1 == 0) & (a2 < 0)] = 90
        orientations[(a1 == 0) & ~(a2 < 0)] = 270

        mask = (a1 == 1) & (a2 == 1) & (raa == 1)
        orientations[mask] = 315+(((90.0/pti[mask]) * ra[mask])/2)
        orientations[(a1 == 1) & (a2 == 1) & (raa == 2)] = 315+(((90.0/ptj[(a1 == 1) & (a2 == 1) & (raa == 2)]) * ra[(a1 == 1) & (a2 == 1) & (raa == 2)])/2)

        orientations[(a1 == 1) & (a2 == 2) & (raa == 1)] = 45-(((90.0/pti[(a1 == 1) & (a2 == 2) & (raa == 1)]) * ra[(a1 == 1) & (a2 == 2) & (raa == 1)])/2)
        orientations[(a1 == 1) & (a2 == 2) & (raa == 2)] = 45-(((90.0/ptj[(a1 == 1) & (a2 == 2) & (raa == 2)]) * ra[(a1 == 1) & (a2 == 2) & (raa == 2)])/2)

        orientations[(a1 == 1) & (ptj < 0.002)] = 0

        orientations[(a1 == 2) & (a2 == 1) & (raa == 1)] = 225-(((90.0/pti[(a1 == 2) & (a2 == 1) & (raa == 1)]) * ra[(a1 == 2) & (a2 == 1) & (raa == 1)])/2)
        orientations[(a1 == 2) & (a2 == 1) & (raa == 2)] = 225-(((90.0/ptj[(a1 == 2) & (a2 == 1) & (raa == 2)]) * ra[(a1 == 2) & (a2 == 1) & (raa == 2)])/2)

        orientations[(a1 == 2) & (a2 == 2) & (raa == 1)] = 135+(((90.0/pti[(a1 == 2) & (a2 == 2) & (raa == 1)]) * ra[(a1 == 2) & (a2 == 2) & (raa == 1)])/2)
        orientations[(a1 == 2) & (a2 == 2) & (raa == 2)] = 135+(((90.0/ptj[(a1 == 2) & (a2 == 2) & (raa == 2)]) * ra[(a1 == 2) & (a2 == 2) & (raa == 2)])/2)

        orientations[(a1 == 2) & (ptj < 0.002)] = 180

        orientations[slopes.ravel() == 0] = 180

        orientations = orientations.reshape(a.shape)

        mask = np.ones(a.shape, dtype=bool)
        mask[a == no_data] = 0
        window_count = convolve2d(mask, k, mode='same')
        orientations[window_count < 3] = 0

        # Erase every computed data by no data where necessary
        orientations[a == no_data] = -9999

        return orientations

    def processAlgorithm(self, parameters, context, feedback):
        """
        Called when a processing algorithm is run
        """
        # Load input raster and its metadata
        input_raster: QgsRasterLayer = self.parameterAsRasterLayer(parameters, 'INPUT', context)
        input_slope: QgsRasterLayer = self.parameterAsRasterLayer(parameters, OrientationProcessingAlgorithm.INPUT_SLOPE, context)

        no_data, geotransform = self.get_raster_metadata(parameters, context, input_raster)
        pixel_size_x, pixel_size_y = self.get_pixel_size(input_raster, geotransform)  # TODO: Only handle unique raster unit (pixel_size_y not used)
        input_window = self.parameterAsInt(parameters, 'INPUT_WINDOW', context)
        additional_variables_folder = self.parameterAsString(parameters, 'ADDITIONAL_VARIABLES_FOLDER', context).strip()

        out_srs: SpatialReference = self.get_spatial_reference(input_raster)

        # Path of the layer is given. If a temporary layer is selected, layer is created in qgis temp dir
        out_orientation = self.parameterAsOutputLayer(parameters, 'OUTPUT', context)

        np_input = LandsklimUtils.raster_to_array(input_raster)

        np_slopes = LandsklimUtils.raster_to_array(input_slope)

        """from landsklim.lk.regressor_factory import RegressorFactory
        from landsklim.lk.regressor_slope import RegressorSlope
        regressor_slope = RegressorFactory.get_regressor(RegressorSlope.class_name(), input_window, 1)
        slope_raster_path = regressor_slope.get_path()
        if not os.path.exists(slope_raster_path):
            raise RuntimeError("Slope raster is not found")
        np_slopes = LandsklimUtils.source_to_array(slope_raster_path)"""
        np_slopes[np_slopes == no_data] = np.nan  # no_data is the same for each regressors

        if not environment.USE_QGIS_IMPLEMENTATION:
            _, _, raster_a, _ = self.dem_variables(np_input, no_data, self.create_kernel(input_window), additional_variables_folder)
            np_orientations = self.compute_orientation(np_input, input_window, no_data, pixel_size_x, raster_a, np_slopes)
            self.write_raster(out_orientation, np_orientations, out_srs, geotransform, -9999)
        else:
            self.qgis_orientation(input_raster, input_window, out_orientation)

        return {'OUTPUT': out_orientation}

    def qgis_orientation(self, raster: QgsRasterLayer, kernel_size: int, out_path: str):
        """
        Use the QGIS/GDAL implementation of the orientation algorithm (gdal:aspect)

        :param raster: Input raster (DEM, smoothing according the kernel)
        :type raster: QgsRasterLayer

        :param kernel_size: Kernel size
        :type kernel_size: int

        :param out_path: Destination file of the output
        :type out_path: str
        """

        source: str = raster.source()
        smoothed_source = processing.run("landsklim:altitude",
                                      {
                                          'INPUT': source,
                                          'INPUT_WINDOW': kernel_size,
                                          'INPUT_CUSTOM_NO_DATA': None,
                                          'OUTPUT': QgsProcessing.TEMPORARY_OUTPUT
                                      })['OUTPUT']

        processing.run("gdal:aspect",
                        {
                            'INPUT': smoothed_source,
                            'BAND': 1,
                            'TRIG_ANGLE': False,
                            'ZERO_FLAT': False,
                            'COMPUTE_EDGES': False,
                            'ZEVENBERGEN': False,
                            'OPTIONS': '',
                            'EXTRA': '',
                            'OUTPUT': out_path
                         }
                       )  # NO_DATA is set as -9999
