import gc
import os.path
import time
from typing import Union, Tuple, Callable, Optional
import ctypes

import pandas as pd
import sklearn
from osgeo import gdal
from osgeo import osr
from osgeo.osr import SpatialReference
from qgis.PyQt.QtCore import QCoreApplication
from qgis._core import QgsProcessingParameterNumber
from qgis.core import QgsProcessing, QgsProcessingAlgorithm, QgsProcessingException, QgsRasterLayer, \
    QgsProcessingParameterRasterLayer
from qgis import processing
import numpy as np
import math

from scipy import signal, ndimage
from scipy.interpolate import interp2d, RectBivariateSpline, RegularGridInterpolator, LinearNDInterpolator
from scipy.linalg import lstsq
from scipy.ndimage import binary_erosion, binary_dilation
from sklearn.linear_model import LinearRegression

from landsklim.lk.regressor_orientation import RegressorOrientation
from landsklim.processing.algorithm_altitude import AltitudeProcessingAlgorithm
from landsklim.processing.landsklim_processing_regressor_algorithm import LandsklimProcessingRegressorAlgorithm
from landsklim.lk.utils import LandsklimUtils
from landsklim.lk import environment


class AmplitudeProcessingAlgorithm(LandsklimProcessingRegressorAlgorithm):
    """
    Processing algorithm computing amplitude of the lines of crest or thalweg from a DEM
    """

    INPUT_CREST_THALWEG = 'INPUT_CREST_THALWEG'
    INPUT_CODE = 'INPUT_CODE'

    CODE_THALWEG = 30
    CODE_CREST = 10

    def __init__(self):
        super().__init__({'OUTPUT': 'Amplitude of the lines of crest or thalweg'})
        self._crest_thalweg: Optional[np.ndarray] = None
        self._test_case: bool = False

    def createInstance(self):
        return AmplitudeProcessingAlgorithm()

    def add_dependencies(self):
        """
        No dependencies
        """
        pass

    def initAlgorithm(self, config=None):
        super().initAlgorithm(config)

        self.addParameter(
            QgsProcessingParameterRasterLayer(
                self.INPUT_CREST_THALWEG,
                self.tr('Lines of crest and thalweg layer')
            )
        )

        self.addParameter(
            QgsProcessingParameterNumber(
                self.INPUT_CODE,
                self.tr('Code', "AmplitudeProcessingAlgorithm"),
                QgsProcessingParameterNumber.Type.Integer
            )
        )

    def name(self) -> str:
        """
        Unique name of the algorithm
        """
        return 'amplitude'

    def displayName(self) -> str:
        """
        Displayed name of the algorithm
        """
        return self.tr('Amplitude of the lines of crest or thalweg')

    def shortHelpString(self) -> str:
        return self.tr('Compute amplitude of the lines of crest or thalweg from a DEM')

    def processAlgorithm(self, parameters, context, feedback):
        """
        Called when a processing algorithm is run
        """
        # Load input raster and its metadata
        input_raster: QgsRasterLayer = self.parameterAsRasterLayer(parameters, 'INPUT', context)
        code: int = self.parameterAsInt(parameters, self.INPUT_CODE, context)

        if self._crest_thalweg is None:
            input_crest_thalweg: QgsRasterLayer = self.parameterAsRasterLayer(parameters, self.INPUT_CREST_THALWEG, context)
            self._crest_thalweg: np.ndarray = LandsklimUtils.raster_to_array(input_crest_thalweg)


        no_data, geotransform = self.get_raster_metadata(parameters, context, input_raster)
        input_window = self.parameterAsInt(parameters, 'INPUT_WINDOW', context)
        additional_variables_folder = self.parameterAsString(parameters, 'ADDITIONAL_VARIABLES_FOLDER', context).strip()
        out_srs: SpatialReference = self.get_spatial_reference(input_raster)

        self.out_srs, self.geotransform = out_srs, geotransform

        # Path of the layer is given. If a temporary layer is selected, layer is created in qgis temp dir
        out_ampl = self.parameterAsOutputLayer(parameters, 'OUTPUT', context)

        input_array: np.ndarray = LandsklimUtils.raster_to_array(input_raster)
        self._altitudes = input_array

        print("[input array]", input_window)
        self._altitudes = self.windowed_average(input_array, input_window, no_data)
        self._altitudes[input_array == no_data] = no_data
        if self._test_case: self._altitudes = np.trunc(self._altitudes)

        output_no_data = -9999
        np_ampl = self.compute_amplitude(self._altitudes, self._crest_thalweg, no_data, input_window, code, output_no_data)
        np_ampl[np.isnan(np_ampl)] = output_no_data

        self.write_raster(out_ampl, np_ampl, out_srs, geotransform, output_no_data)

        return {'OUTPUT': out_ampl}

    def compute_amplitude(self, input: np.ndarray, crest_thalweg: np.ndarray, no_data: Optional[Union[int, float]], kernel_size: int, code: int, output_no_data: Union[int, float]) -> np.ndarray:
        """
        Compute amplitude of crest or thalweg lines

        :param input: DEM
        :type input: np.ndarray

        :param crest_thalweg: Raster of crest and thalweg lines
        :type crest_thalweg: np.ndarray

        :param no_data: NO_DATA value
        :type no_data: Optional[Union[int, float]]

        :param kernel_size: Kernel size
        :type kernel_size: int

        :param code: Code who represents crests or thalwegs
        :type code: int

        :returns: Amplitude raster
        :rtype: np.ndarray
        """

        self.debug = (0, 0)

        print("[{0}], {1}".format(self.debug, self._altitudes[self.debug]))

        rows, cols = input.shape
        no_data_mask = input == no_data
        input_without_no_data = input.copy()
        input_without_no_data[no_data_mask] = np.nan
        res = np.zeros(input.shape)
        echant = int(np.log10(kernel_size * 50) * 2)  # int(np.log10(kernel_size * 50) * 8)
        echant_conv = int(np.log10(kernel_size * 50) * 2)
        """i2 = np.array([1397, 1414, 1414, 1416, 1417, 1411])
        j2 = np.array([849, 831, 866, 892, 893, 919])
        i = np.array([1388, 1388, 1388, 1388, 1388, 1388])
        j = np.array([876, 876, 876, 876, 876, 876])
        print("Test angles")
        print(self.angle_v(i2, j2, i, j))"""

        tot_time_idw = 0
        tot_time_rv = 0
        tot_time_tr = 0
        tot_time_regr = 0
        tot_time_lr = 0
        tot_time_cat_angles = 0

        tot_time_unique = 0
        tot_time_hist = 0
        tot_time_r = 0
        tot_time_angles = 0
        tot_time_init = 0
        tot_time_i = 0

        taille_f = kernel_size * 10
        if taille_f > 500:
            taille_f = 500

        nb_q = 360 // 30
        deb_i, deb_j = 0, 0

        si_collection = np.zeros((nb_q, taille_f+1), dtype=int)
        sj_collection = np.zeros((nb_q, taille_f+1), dtype=int)

        for q in range(nb_q):
            h = q * (360.0 / nb_q)
            si_q, sj_q = self.bal_sur_mat(h, taille_f)
            si_collection[q] = si_q
            sj_collection[q] = sj_q

        si_cumsum = si_collection.cumsum(axis=1)
        sj_cumsum = sj_collection.cumsum(axis=1)

        print("[Code]", (crest_thalweg == code).sum())
        print("[NoData]", (no_data_mask).sum())
        res[deb_i:rows, deb_j:cols] = -30000
        res[no_data_mask] = no_data
        res[crest_thalweg == code] = input[crest_thalweg == code]

        cpt_i = echant - 1
        cpt_j = echant - 1
        nk = 3
        crest_thalweg_code_mask = crest_thalweg == code
        if not self._test_case:
            crest_thalweg_code_mask = binary_dilation(crest_thalweg_code_mask)
            crest_thalweg_code_mask = binary_erosion(crest_thalweg_code_mask)
        qs = np.arange(nb_q)

        for i in range(deb_i, rows):
            if not i % 100:
                print("[row]", i)
            cpt_i += 1
            if cpt_i == echant:
                cpt_i = 0
                for j in range(deb_j, cols):
                    debug = environment.TEST_MODE and i == self.debug[0] and j == self.debug[1]
                    if res[i, j] == -30000:
                        cpt_j += 1
                        if cpt_j == echant:
                            cpt_j = 0

                            # === Etape 1 ==========
                            # Recherche des valeurs
                            # Pour le point (i, j), on cherche tous les points autour du rayon maximal qui sont soit un creux soit une bosse (selon le cas)
                            # Les points sont triés par ordre de distance
                            # ======================

                            time_rv = time.perf_counter()

                            #while rayon < taille_f and (nb_pt < nb_q - 1):
                            #    rayon += 1
                            time_r = time.perf_counter()
                            i_array, j_array, angles, time_init, time_angle, time_dst = self.search_values(taille_f, i, j, crest_thalweg_code_mask, debug)
                            tot_time_angles += time_angle
                            tot_time_init += time_init
                            tot_time_i += time_dst
                            tot_time_r += (time.perf_counter() - time_r)

                            time_cat_angles = time.perf_counter()
                            # Récupération de la catégorie de chaque angle (0 à 12)
                            cat_angles = self.cat_angles(angles, nb_q)
                            # fnd = (fnd) | (cat_angles_hist > 0)
                            # nb_pt += len(i_array)
                            tot_time_cat_angles += (time.perf_counter() - time_cat_angles)

                            # Création du tableau de données pour l'interpolation à venir (altitude, i, j, angle, catégorie d'angle)
                            z_n = np.zeros((len(i_array), 5))
                            z_n[:, 0] = input[i_array, j_array]
                            z_n[:, 1] = i_array
                            z_n[:, 2] = j_array
                            z_n[:, 3] = angles
                            z_n[:, 4] = cat_angles

                            # Il faut prendre le premier point de chaque catégorie d'angles.
                            # Chaque catégorie d'angle est représenté par le point le plus proche de (i, j)
                            time_un = time.perf_counter()
                            if debug:
                                with np.printoptions(threshold=np.inf, suppress=True):
                                    print("[z_n]")
                                    print("[input, i, j, angle, cat_angle]")
                                    print(z_n)
                            z_n = z_n[np.unique(z_n[:, 4], return_index=True)[1]]
                            tot_time_unique += (time.perf_counter() - time_un)

                            time_hist = time.perf_counter()
                            fnd = np.isin(qs, z_n[:, 4])
                            nb_pt = fnd.sum()
                            tot_time_hist += (time.perf_counter() - time_hist)

                            tot_time_rv += (time.perf_counter() - time_rv)

                            # =========== Etape 2 ========
                            # Complétion du tableau de données pour la régression (l'interpolation)
                            # car la recherche de valeurs ne permet pas de trouver des points dans toutes les directions
                            # ============================

                            # Pour avoir 12 samples (un par catégorie d'angle), on complète les catégories d'angles où l'on a trouvé aucun point lors de la recherche

                            time_tr = time.perf_counter()
                            # h1 : les catégories d'angles sans point trouvé
                            h1 = qs[~fnd]
                            nb_pt += len(h1)
                            h = h1 * (360.0 / nb_q)
                            # si et sj représentent le parcours à suivre à partir de (i, j) pour atteindre le point cible
                            #si, sj = si_collection[h1], sj_collection[h1]
                            sic, sjc = si_cumsum[h1], sj_cumsum[h1]
                            # vi et vj représentent le décalage par rapport à (i, j), étape par étape.
                            # On fait +i et +j pour connaître la position précise de chaque point sur la grille (et non par rapport à i et j)
                            vi = sic[:taille_f] + i
                            vj = sjc[:taille_f] + j

                            # Liste des points qui sont en dehors de l'étendue
                            out_of_bounds_mask = (vi < 0) | (vj < 0) | (vi >= rows) | (vj >= cols)
                            vi[vi < 0] = 0
                            vj[vj < 0] = 0
                            vi[vi >= rows] = rows - 1
                            vj[vj >= cols] = cols - 1

                            # Pour chaque ligne de notre fenêtrage, le premier indice de la colonne où l'on rencontre une valeur True (de NO_DATA ou OUT_OF_BOUNDS)
                            # (un np.where() row-wise)
                            out_points = (no_data_mask[vi, vj]) | (out_of_bounds_mask)
                            indices_nodata = np.argmax(out_points, axis=1)
                            # On met à -1 les lignes où l'on ne rencontre aucun NO_DATA ou OUT_OF_BOUNDS (car 0 par défaut donc confusion par la suite)
                            indices_nodata[(~out_points).all(axis=1)] = -1

                            # Nos points supplémentaires
                            _max = np.zeros(len(indices_nodata))
                            pi = np.zeros(len(indices_nodata))
                            pj = np.zeros(len(indices_nodata))

                            # Cas 1 : pour les points sans NO_DATA ([indices_nodata == -1])
                            # _max c'est l'altitude max rencontrée pour chaque ligne (en ne comptant pas la dernière case)
                            # pi et pj prennent la dernière position de vi et vj : l'arrivée du parcours (en ne comptant pas la dernière case)
                            cas_1 = indices_nodata == -1
                            _max[cas_1] = np.nanmax(input_without_no_data[vi[:, :-1], vj[:, :-1]][cas_1], axis=1)
                            pi[cas_1] = vi[cas_1, -1-1]
                            pj[cas_1] = vj[cas_1, -1-1]

                            # Cas 2 : pour les points avec du NO_DATA ([indices_nodata != -1])
                            # Dans LISDQS, la logique est d'arrêter le parcours au premier NO_DATA rencontré lors du parcours
                            cas_2 = indices_nodata != -1

                            # On créé donc "input_without_no_data2" (à renommer) qui contient le fenêtrage,
                            # mais pour chaque ligne, dès que l'on rencontre un np.nan, toutes les valeurs du parcours qui suit deviennent np.nan
                            # (comme ça ignoré lors de la recherche du max)
                            # Ex : [3, 4, 5, np.nan, 6, 7] devient [3, 4, 5, np.nan, np.nan, np.nan]
                            # La logique est ensuite la même que pour "Cas 1", mais sur les données prétraités en prenant le NO_DATA en considération
                            indices_nodata_with_nodata = indices_nodata[cas_2]
                            input_without_no_data2 = input_without_no_data[vi, vj].copy()
                            input_without_no_data2[np.isnan(input_without_no_data2.cumsum(axis=1))] = np.nan

                            _max[cas_2] = np.nanmax(input_without_no_data2[cas_2], axis=1)
                            indices_nodata_with_nodata = np.minimum(indices_nodata_with_nodata, vi.shape[1]-2)  # Pour ne pas déborder
                            pi[cas_2] = vi[cas_2, indices_nodata_with_nodata]  # +1
                            pj[cas_2] = vj[cas_2, indices_nodata_with_nodata]  # +1

                            if debug:
                                print("[CAS 2]")
                                print("[ind no data]")
                                print(indices_nodata)
                                print("[ind no data with no data]")
                                print(indices_nodata_with_nodata)
                                print("[vi]")
                                print(vi)
                                print("[vj]")
                                print(vj)
                                print("[pi]")
                                print(pi)
                                print("[pj]")
                                print(pj)
                                print("[max]")
                                print(_max)

                            z_n2 = np.column_stack((_max, pi, pj, h, h1))
                            z = np.vstack((z_n, z_n2))
                            tot_time_tr += (time.perf_counter() - time_tr)

                            # =========== Etape 3 ==========
                            # Régression et interpolation
                            # A partir du tableau de données (z), on interpole les valeurs de notre tableau de données pour trouver l'amplitude creux/bosse à la position (i, j)
                            # ==============================
                            time_regr = time.perf_counter()
                            if nb_pt > 4:
                                nb_pt += 1
                                time_lr = time.perf_counter()
                                if debug:
                                    np.set_printoptions(suppress=True)
                                    print(z)
                                res[i, j] = self.linear_regression(z[:nb_pt, [2, 1]], z[:nb_pt, 0], j, i)
                                tot_time_lr += (time.perf_counter() - time_lr)
                                if debug:
                                    print("[debug][cb]", crest_thalweg_code_mask[i, j])
                                if crest_thalweg_code_mask[i, j]:
                                    res[i, j] = max(input[i, j], res[i, j])
                            else:
                                res[i, j] = input[i, j]
                            tot_time_regr += (time.perf_counter() - time_regr)
                        else:
                            if debug:
                                print("Le point n'a pas été traité")

        if self._test_case:
            res = np.trunc(res)

        if environment.TEST_MODE:
            print("[rv] {0:.3f}s".format(tot_time_rv))
            print("\t[r] {0:.3f}s".format(tot_time_r))
            print("\t\t[init] {0:.3f}s".format(tot_time_init))
            print("\t\t[angl] {0:.3f}s".format(tot_time_angles))
            print("\t\t[dst] {0:.3f}s".format(tot_time_i))
            print("\t[cat] {0:.3f}s".format(tot_time_cat_angles))
            print("\t[un] {0:.3f}s".format(tot_time_unique))
            print("\t[hist] {0:.3f}s".format(tot_time_hist))
            print("[tr] {0:.3f}s".format(tot_time_tr))
            print("[idw] {0:.3f}s".format(tot_time_idw))
            print("[regr] {0:.3f}s".format(tot_time_regr))
            print("\t[lr] {0:.3f}s".format(tot_time_lr))

            print("[reg]", res[self.debug])

        # self.write_raster("outputtemp.tif", res, self.out_srs, self.geotransform, no_data)
        # res = LandsklimUtils.source_to_array("outputtemp.tif")

        time_2 = time.perf_counter()
        # Interpolation
        rayon = echant_conv + 3

        ck_size = rayon*2+1
        convolution_kernel = np.zeros((ck_size,ck_size))
        # # convolution_kernel[:-1, 1:] = 1
        # convolution_kernel[1:-1, 1:] = 1
        convolution_kernel[:-1, 1:] = 1
        debug_zoom = res[self.debug[0]-rayon:self.debug[0]+rayon,self.debug[1]-rayon+1:self.debug[1]+rayon+1]
        if environment.TEST_MODE:
            print("[int]", debug_zoom[debug_zoom != -30000])
        # print(res[dzoom_top:dzoom_bottom, dzoom_left:dzoom_right])

        convolution_kernel = convolution_kernel.T  # Convolution flip the kernel so prevent this by flipping the kernel here
        valid_neighbors = self.get_neighboors_count_raster_with_min_value(res, convolution_kernel, no_data, -30000, self.debug)
        valid_neighbors[res == no_data] = 0

        raster = res.copy()
        raster[(raster == no_data) | (raster == -30000)] = 0
        if environment.TEST_MODE:
            with np.printoptions(linewidth=np.inf):
                print("[raster]", raster[self.debug])
                print("[neighbors]", valid_neighbors[self.debug])
        res2 = np.trunc(ndimage.convolve(raster, convolution_kernel, mode='constant') // valid_neighbors)

        res2[res == no_data] = no_data

        # [Cas special rare] : Si, pour un point valide (pas NO_DATA) qui n'est pas un point de creux ou de bosse,
        # il n'a aucun voisinage (aucune régression a été calculée dans son voisinage)
        # alors au lieu de lui donner la valeur 0, on lui donne la valeur de l'altitude
        # pour éviter de bruiter le calcul de l'amplitude
        # nan nan 8956 0 0
        # print("[causes]", res2[self.debug], valid_neighbors[self.debug], self._altitudes[self.debug], self._crest_thalweg[self.debug], raster[self.debug])
        correction_zone = (np.isnan(res2)) & (self._altitudes != no_data) & (self._crest_thalweg == 0)
        if environment.TEST_MODE:
            print("[correction_zone]", np.nonzero(correction_zone))
        res2[correction_zone] = self._altitudes[correction_zone]

        res = res2.copy()
        print("[step 2] {0:.3f}s".format(time.perf_counter() - time_2))
        if environment.TEST_MODE:
            with np.printoptions(linewidth=np.inf):
                print("[int]", raster[self.debug[0]-rayon:self.debug[0]+rayon,self.debug[1]-rayon+1:self.debug[1]+rayon+1])
                print("[int]", res2[self.debug])

        time_3 = time.perf_counter()
        # Estimation de l'ampleur des creux et des bosses
        min_ = 30000
        res2 = input.copy()
        valid_mask = (input != no_data) & (res > -30000) & (res != no_data)
        if code == self.CODE_THALWEG:
            res2[valid_mask] = res[valid_mask] - input[valid_mask]
        if code == self.CODE_CREST:
            res2[valid_mask] = input[valid_mask] - res[valid_mask]
        min_ = res2[valid_mask].min()

        print("[step 3] {0:.3f}s".format(time.perf_counter() - time_3))
        time_4 = time.perf_counter()

        e, mi = 10, 0
        if self._test_case: min_ = np.trunc(min_)
        if min_ < 0:
            e = int(min_ * (-100.0))
        mi = np.log10(e)
        if self._test_case: mi = np.trunc(mi)

        if min_ < 0:
            res3 = res2.copy()
            res3[(res2 != no_data) & (res2 < 0)] = mi - (np.log10(res2[(res2 != no_data) & (res2 < 0)] * (-100)))
            res3[(res2 != no_data) & (res2 >= 0)] = res2[(res2 != no_data) & (res2 >= 0)] + mi - 1
            if self._test_case:
                res3 = np.trunc(res3)
            res2 = res3

        print("[step 4] {0:.3f}s".format(time.perf_counter() - time_4))

        time_5 = time.perf_counter()
        #Lissage
        rayon = 2
        ck_size = rayon*2+1

        convolution_kernel = np.ones((ck_size,ck_size))
        valid_neighbors = self.get_neighboors_count_raster(res2, convolution_kernel, no_data)
        raster = res2.copy()

        raster[raster == no_data] = 0
        res = ndimage.convolve(raster, convolution_kernel, mode='constant')  # signal.convolve(raster, convolution_kernel, mode='same')
        if self._test_case:
            res = np.trunc(res)
        res = res / valid_neighbors
        #res = (signal.convolve(raster, convolution_kernel, mode='same') / convolution_kernel.sum())
        #res = ((res * convolution_kernel.sum()) - ((convolution_kernel.sum() - valid_neighbors) * raster)) / valid_neighbors
        res[valid_neighbors == 0] = res2[valid_neighbors == 0]
        res[no_data_mask] = output_no_data
        if self._test_case:
            res = np.trunc(res)

        print("[step 5] {0:.3f}s".format(time.perf_counter() - time_5))
        if environment.TEST_MODE:
            print("[debug] [amplc] ", res[self.debug])

        return res

    def linear_regression(self, xy, z, xi, yi):
        # TODO : https://stackoverflow.com/questions/24744927/how-can-i-increase-speed-performance-with-scikit-learn-regression-and-pandas
        # TODO : https://stackoverflow.com/questions/75751631/sklearn-linear-regression-on-grouped-pandas-dataframe-without-aggregation
        # TODO : https://stackoverflow.com/questions/49895000/regression-by-group-in-python-pandas
        # tp1 = time.perf_counter()
        xy = np.concatenate((xy, np.ones((len(xy), 1))), axis=1)
        # tp1 = time.perf_counter() - tp1
        # tp2 = time.perf_counter()
        coefs = lstsq(xy, z, lapack_driver='gelsy', check_finite=False)[0]
        # tp2 = time.perf_counter() - tp2
        # tp = time.perf_counter()
        y_pred = np.dot(coefs, [xi, yi, 1])
        # tp = time.perf_counter() - tp
        # print("[tp] {0:.5f}s {0:.5f}s {0:.5f}s".format(tp1, tp2, tp))
        return y_pred

    def search_values(self, rayon: int, i: int, j: int, cb_code_mask: np.ndarray, debug: bool) -> Tuple[np.ndarray, np.ndarray, np.ndarray, int, int, int]:
        """
        Recherche les pixels de creux ou de bosse dans toutes les directions autour d'un point

        :param rayon: Search radius
        :type rayon: int

        :param i: Row of search point
        :type i: int

        :param j: Col of search point
        :type j: int

        :param cb_code_mask: Mask where interest pixels are True
        :type cb_code_mask: np.ndarray

        :returns: rows, cols and angles of points found and times of computation
        :rtype: Tuple[np.ndarray, np.ndarray, np.ndarray, int, int, int]
        """

        deb_i = max(0, i-rayon)
        deb_j = max(0, j-rayon)
        fin_i = i+rayon
        fin_j = j+rayon

        time_init = time.perf_counter()
        r, c = np.nonzero(cb_code_mask[deb_i:fin_i+1, deb_j:fin_j+1])
        r = r + deb_i
        c = c + deb_j

        if self._test_case:
            # TODO: Point non traité sur LISDQS (erreur)
            r, c = r[~((r == i) & (c == j+1))], c[~((r == i) & (c == j+1))]
        time_init = time.perf_counter() - time_init

        time_angle = time.perf_counter()
        angles = self.fast_angle(r, c, i, j) if not self._test_case else self.angle_lisdqs(r, c, i, j)
        time_angle = time.perf_counter() - time_angle

        time_dst = time.perf_counter()
        # distances = np.hstack((np.abs(r-i).reshape(-1, 1), np.abs(c-j).reshape(-1, 1))).max(axis=1)
        distances = np.maximum(np.abs(r-i), np.abs(c-j))
        # distances = np.sqrt(np.square(r - i) + np.square(c - j))
        # Trier par distances, puis par les lignes, puis par les colonnes
        sort_indices = np.lexsort((c, r, distances))

        if debug:
            print("[Rechercher_Valeurs]")
            print(np.hstack((r[sort_indices].reshape(-1, 1), c[sort_indices].reshape(-1, 1), distances[sort_indices].reshape(-1, 1))))
        time_dst = time.perf_counter() - time_dst

        return r[sort_indices].astype(int), c[sort_indices].astype(int), angles[sort_indices], time_init, time_angle, time_dst

    def cat_angles(self, angles: np.ndarray, nb_q: int) -> np.ndarray:
        cats = 360 // nb_q
        cat = angles // cats
        return cat #, np.histogram(cat, bins=nb_q, range=(0, nb_q))[0]

    def angle_lisdqs(self, i2: np.ndarray, j2: np.ndarray, i: Union[np.ndarray, int], j: Union[np.ndarray, int]):
        size = len(i2)
        az = np.zeros(size)
        al = np.zeros(size, dtype=int)
        in_ = np.zeros(size, dtype=int)
        cass = np.zeros(size, dtype=int)

        i2_is_i = i2 == i
        is2_below_i = i2 < i
        is2_above_i = i2 >= i
        j2_is_j = j2 == j
        j2_below_j = j2 < j
        j2_above_j = j2 > j

        cass[(i2_is_i)] = -1

        az[(i2_is_i) & (j2_below_j)] = 270

        az[(i2_is_i) & (j2 >= j)] = 90

        group_1 = (is2_below_i) & (j2_is_j)
        cass[group_1] = -1
        az[group_1] = 0

        group_2 = (is2_below_i) & (j2_below_j)
        al[group_2] = 270
        in_[group_2] = 0
        cass[group_2] = 4

        group_3 = (is2_below_i) & (j2_above_j)
        in_[group_3] = 91
        cass[group_3] = 1

        group_4 = (is2_above_i) & (cass > -1) & (j2_is_j)
        az[group_4] = 180
        cass[group_4] = -1

        group_5 = (is2_above_i) & (cass > -1) & (j2_below_j)
        cass[group_5] = 3
        in_[group_5] = 90
        al[group_5] = 181

        group_6 = (is2_above_i) & (cass > -1) & (j2_above_j)
        cass[group_6] = 2
        in_[group_6] = 0
        al[group_6] = 90

        cass_0 = cass > 0
        i_cass_0 = i[cass_0] if type(i) is np.ndarray else i
        j_cass_0 = j[cass_0] if type(j) is np.ndarray else j
        d1 = np.abs(i_cass_0 - i2[cass_0])
        d1[d1 == 0] = .0000000001
        d2 = np.abs(j_cass_0 - j2[cass_0])
        d2[d2 == 0] = .0000000001
        az[cass_0] = np.abs(in_[cass_0] - np.arctan(d1/d2) * (180.0/np.pi)) + al[cass_0]
        return np.trunc(az)

    def fast_angle(self, ya: np.ndarray, xa: np.ndarray, yb: np.ndarray, xb: np.ndarray):
        angl = np.arctan2(xa - xb, yb - ya)
        angl[angl < 0] = angl[angl < 0] + (2 * np.pi)
        return np.trunc(np.degrees(angl))

    def bal_sur_mat(self, z: int, dist1: int) -> Tuple[np.ndarray, np.ndarray]:
        a1 = np.pi / 180.0
        q = 0
        dist1 += 1
        si, sj = np.zeros((dist1), dtype=int), np.zeros((dist1), dtype=int)
        zz2 = np.zeros((1801), dtype=int)
        ia, ja = 0, 0
        if z < 90 or z > 270: ia = -1
        if z > 90 and z < 270: ia = 1
        if z > 0 and z < 180: ja = 1
        if z > 180: ja = -1
        k = -1
        zz1 = np.pi / 180.0
        while np.abs(zz1) < dist1 and k < (dist1+1):
            k += 1
            zz1 = k * np.tan(z * a1)
            if zz1 > -0.000001:
                zz1 = zz1 + 0.000001
            if zz1 > 32000:
                zz2[k] = 32000
            else:
                zz2[k] = np.abs(zz1)
        k = 0
        p = -1
        r = -1

        while k < (dist1-1):
            p += 1
            k += 1
            r += 1
            nk = zz2[p+1]-zz2[p]
            nkk = nk - r
            if nk == 0:
                q = 1
            if nk > 1:
                q = 3
            if nkk == 0 and nk > 1:
                q = 2
            if nkk == 1 :
                q = 2
            if q == 1:
                si[k] = ia
                r -= 1
            elif q == 2:
                si[k] = 1 * ia
                sj[k] = 1 * ja
                r = -1
            elif q == 3:
                sj[k] = 1 * ja
                p -= 1
        return si, sj
