import gc
import os.path
import time
from functools import partial
from multiprocessing import Pool
import multiprocessing as mp
from typing import Union, Tuple, Callable, Optional, List
import ctypes

from PIL import Image
from osgeo import gdal
from osgeo import osr
from osgeo.osr import SpatialReference
from qgis.PyQt.QtCore import QCoreApplication
from qgis.core import QgsProcessing, QgsProcessingAlgorithm, QgsProcessingException, QgsRasterLayer
from qgis import processing
import numpy as np
import math
from scipy.signal import convolve2d

from landsklim.lk.regressor_factory import RegressorFactory
from landsklim.lk.regressor_orientation import RegressorOrientation
from landsklim.processing.algorithm_ampl import AmplitudeProcessingAlgorithm
from landsklim.processing.landsklim_processing_regressor_algorithm import LandsklimProcessingRegressorAlgorithm
from landsklim.lk.utils import LandsklimUtils
from landsklim.lk import environment


class CrestThalwegProcessingAlgorithm(LandsklimProcessingRegressorAlgorithm):
    """
    Processing algorithm computing lines of crest and thalweg from a DEM (from its orientation regressor)
    """

    def __init__(self):
        super().__init__({'OUTPUT': 'Lines of crest and thalweg raster'})
        self._orientation_raster: Optional[np.ndarray] = None
        self._orientation_no_data: Optional[Union[int, float]] = None

    def createInstance(self):
        return CrestThalwegProcessingAlgorithm()

    def name(self) -> str:
        """
        Unique name of the algorithm
        """
        return 'crestthalweg'

    def displayName(self) -> str:
        """
        Displayed name of the algorithm
        """
        return self.tr('Crest and thalweg lines')

    def shortHelpString(self) -> str:
        return self.tr('Compute lines of crest and thalweg from DEM')

    def add_dependencies(self):
        """
        No dependencies
        """
        pass

    def set_orientation_regressor(self, source: str):
        self._orientation_raster = LandsklimUtils.source_to_array(source)
        source = gdal.Open(source)
        self._orientation_no_data = source.GetRasterBand(1).GetNoDataValue()
        source = None

    def processAlgorithm(self, parameters, context, feedback):
        """
        Called when a processing algorithm is run
        """
        # Load input raster and its metadata
        input_raster: QgsRasterLayer = self.parameterAsRasterLayer(parameters, 'INPUT', context)
        no_data, geotransform = self.get_raster_metadata(parameters, context, input_raster)
        input_window = self.parameterAsInt(parameters, 'INPUT_WINDOW', context)
        additional_variables_folder = self.parameterAsString(parameters, 'ADDITIONAL_VARIABLES_FOLDER', context).strip()
        out_srs: SpatialReference = self.get_spatial_reference(input_raster)

        # Path of the layer is given. If a temporary layer is selected, layer is created in qgis temp dir
        out_ct = self.parameterAsOutputLayer(parameters, 'OUTPUT', context)

        input_array: np.ndarray = LandsklimUtils.raster_to_array(input_raster)

        """if self._orientation_raster is None:
            regressor_orientation = RegressorFactory.get_regressor(RegressorOrientation.class_name(), input_window, 1)
            orientation_raster_path = regressor_orientation.get_path()
            if not os.path.exists(orientation_raster_path):
                raise RuntimeError("Orientation raster is not found")
            self._orientation_raster = LandsklimUtils.source_to_array(orientation_raster_path)
        np_ct = self.compute_crest_thalweg(input_array, no_data, np.trunc(self._orientation_raster).astype(int), self._orientation_no_data, input_window)"""

        np_ct = self.crest_thalweg_qgis(input_raster, input_window, input_array.shape)
        #np_ct[np.isnan(np_ct)] = no_data
        self.write_raster(out_ct, np_ct, out_srs, geotransform, no_data)

        return {'OUTPUT': out_ct}

    """def compute_crest_thalweg(self, input: np.array, no_data: Optional[Union[int, float]], orientations: np.ndarray, no_data_orientation: Optional[Union[int, float]], kernel_size: int) -> np.array:
        time_bm = time.perf_counter()
        self.__si, self.__sj = self.bal_sur_mat(kernel_size)
        no_data_mask: np.ndarray = input == no_data
        rows: int = input.shape[0]
        cols: int = input.shape[1]
        resultat = np.zeros((5, rows+2, cols+2))
        adiez = 180.0/np.pi
        for i in range(rows):
            print("[row]", i)
            for j in range(cols):
                if not no_data_mask[i, j]:
                    cpt_o = np.zeros(17)
                    rayon = 2
                    f1 = i - rayon
                    f2 = i + rayon + 1
                    f3 = j - rayon
                    f4 = j + rayon + 1
                    if f1 < 0: f1 = 0
                    if f2 > rows: f2 = rows
                    if f3 < 0: f3 = 0
                    if f4 > cols: f4 = cols

                    for k1 in range(f1, f2):
                        for k2 in range(f3, f4):
                            o = orientations[k1, k2]
                            if o != no_data_orientation:
                                o //= 22
                                if o < 17:
                                    cpt_o[o] += 1

                    mx = -1
                    rep = 0
                    for k in range(17):
                        if cpt_o[k] > mx:
                            mx = cpt_o[k]
                            rep = k

                    resultat[1, i, j] = rep
                else:
                    resultat[1, i, j] = no_data

        orientations = resultat[1, :, :]

        # Etape 1: identification des changements majeurs d'orientation des versants
        e = np.log10(kernel_size) * 3
        rayon = int(e)
        for i in range(rows):
            print("Etape 1 [row]", i)
            for j in range(cols):
                resultat[0, i, j] = no_data
                if not no_data_mask[i, j]:
                    resultat[0, i, j] = self.change_maj(i, j, rayon, no_data_mask, rows, cols, orientations)

        # Etape 2 : Calculs creux bosse

        resultat[3, :, :] = resultat[0, :, :]
        resultat[0, :, :] = 0

        l_test = kernel_size // 2
        rayon *= 2
        if rayon > l_test:
            rayon = l_test
        if rayon > 500:
            rayon = 500
        if rayon < 4:
            rayon = 4

        for i in range(rows):
            print("Etape 2 [row]", i)
            for j in range(cols):
                if not no_data_mask[i, j]:
                    if resultat[0, i, j] == 0:
                        resultat[0, i, j] = self.traitement_creux_bosse(i, j, adiez, rayon)
                        if resultat[0, i, j] == 100:
                            resultat[0, i, j] = 0
                else:
                    resultat[0, i, j] = no_data

        return resultat[3, :rows, :cols]

    def traitement_creux_bosse(self, i: int, j: int, adiez: float, rayon: int) -> float:
        return 20

    def change_maj(self, i: int, j: int, rayon: int, no_data_mask: np.ndarray, rows: int, cols: int, orientations: np.ndarray) -> np.ndarray:
        deb_i = i - rayon
        deb_j = j - rayon
        change_majeur = 0
        if deb_i < 0: deb_i = 0
        if deb_j < 0: deb_j = 0
        for k1 in range(8):
            h = (k1 * 3) - 2
            if h < 0: h += 24
            no = -1
            matr = np.zeros((rayon*2+1, rayon*2+1))
            cpt_dir = 0
            for k2 in range(3):
                h += 1
                if h > 23:
                    h = h - 24

                vi, vj = i, j
                k = -1 if k2 == 0 else 0
                while k < rayon:
                    k += 1
                    vi += self.__si[h, k]
                    vj += self.__sj[h, k]
                    if vi < 0 or vi >= rows or vj < 0 or vj >= cols:
                        k = rayon
                    else:
                        if not no_data_mask[vi, vj] and matr[vi - deb_i, vj - deb_j] == 0:
                            no += 1
                            o1 = orientations[i, j] * 22
                            o2 = orientations[vi, vj] * 22
                            ec_dir = abs(o1 - o2)
                            if ec_dir > 180:
                                ec_dir = 360 - ec_dir
                            if ec_dir > 25:
                                cpt_dir += 1

                            matr[vi - deb_i, vj - deb_j] = 1



        k = 30 if change_majeur > 2 else 0
        return k

    def bal_sur_mat(self, kernel_size: int) -> Tuple[np.ndarray, np.ndarray]:
        l_test = kernel_size * 10
        a4 = np.pi / 180.0
        si, sj = np.zeros((24, l_test+1)), np.zeros((24, l_test+1))
        z2 = np.zeros(1800)
        for h in range(24):
            z = h*15
            ia = 0
            if z < 90 or z > 270: ia = -1
            if z > 90 and z < 270: ia = 1
            if z > 0 and z < 180: ja = 1
            if z > 180: ja = -1
            k = -1
            z1 = 0
            while abs(int(z1)) < l_test and k < (l_test+1):
                k += 1
                z1 = k * np.tan(z * a4)
                if z1 > -0.000001: z1 = z1 + 0.00001
                if z1 > 32000: z2[k] = 32000
                else: z2[k] = abs(int(z1))
            k = 0
            p = -1
            r = -1

            while k < l_test:
                p += 1
                k += 1
                r += 1
                nk1 = z2[p+1]-z2[p]
                nkk = nk1 - r
                if nk1 == 0: q = 1
                if nk1 > 1: q = 3
                if nkk == 0 and nk1 > 1: q = 2
                if nkk == 1 : q = 2
                if q == 1:
                    si[h, k] = ia
                    r -= 1
                elif q == 2:
                    si[h, k] = 1 * ia
                    sj[h, k] = 1 * ja
                    r = -1
                elif q == 3:
                    sj[h, k] = 1 * ja
                    p -= 1
        return si, sj"""

    def crest_thalweg_qgis(self, dem: QgsRasterLayer, input_window: int, output_shape: List[int]):
        params = {
            'elevation': dem.source(),
            'search': input_window,
            'skip': 0,
            'flat': 1,
            'dist': 0,
            'forms': 'TEMPORARY_OUTPUT',
            '-m': False,
            '-e': False,
            'GRASS_REGION_PARAMETER': dem.source(),
            'GRASS_REGION_CELLSIZE_PARAMETER': 0,
            'GRASS_RASTER_FORMAT_OPT': '',
            'GRASS_RASTER_FORMAT_META': ''
        }
        output = processing.run("grass7:r.geomorphon", params)
        forms: np.ndarray = LandsklimUtils.source_to_array(output["forms"])
        # TODO: r.geomorphon can return a raster with a different shape because GRASS seems to handle only one pixel size for both horizontal and vertical axes.
        try:
            mode = Image.Resampling.NEAREST
        except AttributeError as _:
            mode = Image.NEAREST
        forms = np.array(Image.fromarray(forms.T).resize(output_shape, mode)).T

        mask_thalweg = forms == 9
        mask_crest = forms == 3

        mask = np.zeros(forms.shape)
        mask[mask_thalweg] = AmplitudeProcessingAlgorithm.CODE_THALWEG
        mask[mask_crest] = AmplitudeProcessingAlgorithm.CODE_CREST
        return mask