import time
from typing import Union, Tuple, List, Optional
import numpy as np
import scipy
from osgeo.osr import SpatialReference

from qgis.core import QgsProcessingParameterString, QgsProcessingParameterRasterLayer, \
    QgsProcessingParameterRasterDestination, QgsRasterLayer, QgsProcessingParameterVectorLayer, QgsVectorLayer, \
    QgsRectangle, QgsProcessingParameterBoolean, QgsProcessingParameterNumber, QgsProcessingParameterEnum
from qgis import processing
from scipy.ndimage import label, generate_binary_structure

from landsklim.lk.utils import LandsklimUtils
from landsklim.processing.landsklim_processing_tool_algorithm import LandsklimProcessingToolAlgorithm


class InfluenceProcessingAlgorithm(LandsklimProcessingToolAlgorithm):
    """
    Processing algorithm computing influence raster
    """
    INPUT_OS = 'INPUT_OS'
    INPUT_CODES = 'INPUT_CODES'
    INPUT_CENTROIDS = 'INPUT_CENTROIDS'
    INPUT_QUANTITATIVE_RASTER = 'INPUT_QUANTITATIVE_RASTER'
    INPUT_ORIENTATIONS = 'INPUT_ORIENTATIONS'
    INPUT_MIN_SIZE = 'INPUT_MIN_SIZE'
    INPUT_MAX_SIZE = 'INPUT_MAX_SIZE'
    INPUT_DILATION_FACTOR = 'INPUT_DILATION_FACTOR'
    OUTPUT_RASTER = 'OUTPUT_RASTER'

    def __init__(self):
        super().__init__()
        """
        Work with integers
        """
        self.__int_mode: bool = False
        self.debug = (406, 444)


    def createInstance(self):
        return InfluenceProcessingAlgorithm()

    def name(self) -> str:
        """
        Unique name of the algorithm
        """
        return 'influence'

    def displayName(self) -> str:
        """
        Displayed name of the algorithm
        """
        return self.tr('Influence')

    def shortHelpString(self) -> str:
        return self.tr('Area of influence of an OS category')

    def initAlgorithm(self, config=None):
        """
        Define inputs and outputs for the main input
        """
        self.addParameter(
            QgsProcessingParameterRasterLayer(
                self.INPUT_OS,
                self.tr('Input OS')
            )
        )

        self.addParameter(
            QgsProcessingParameterString(
                self.INPUT_CODES,
                self.tr('Codes. Values are separated by a comma. Decimal mark is defined by a point')
            )
        )

        self.addParameter(
            QgsProcessingParameterBoolean(
                self.INPUT_CENTROIDS,
                self.tr('Centroids (warning : strongly slower if not checked)'),
                defaultValue=True
            )
        )

        """self.addParameter(
            QgsProcessingParameterRasterLayer(
                self.INPUT_QUANTITATIVE_RASTER,
                self.tr('Input quantitative raster'),
                optional = True
            )
        )"""

        self.addParameter(
            QgsProcessingParameterEnum(
                self.INPUT_ORIENTATIONS,
                self.tr('Orientations'),
                [
                    self.tr('Circulaire'),
                    self.tr('SO-NE'),
                    self.tr('O-E'),
                    self.tr('NO-SE'),
                    self.tr('N-S'),
                    self.tr('NE-SO'),
                    self.tr('E-O'),
                    self.tr('SE-NO'),
                    self.tr('S-N')
                ],
                allowMultiple=False,
                defaultValue=0
            )
        )

        self.addParameter(
            QgsProcessingParameterNumber(
                self.INPUT_MIN_SIZE,
                self.tr('Min size (pixels)'),
                QgsProcessingParameterNumber.Type.Integer,
                optional = True
            )
        )

        self.addParameter(
            QgsProcessingParameterNumber(
                self.INPUT_MAX_SIZE,
                self.tr('Max size (pixels)'),
                QgsProcessingParameterNumber.Type.Integer,
                optional = True
            )
        )

        self.addParameter(
            QgsProcessingParameterNumber(
                self.INPUT_DILATION_FACTOR,
                self.tr('Dilation factor'),
                QgsProcessingParameterNumber.Type.Integer,
                defaultValue = 1
            )
        )

        self.addParameter(
            QgsProcessingParameterRasterDestination(
                self.OUTPUT_RASTER,
                self.tr('Output raster')
            )
        )

    def parse_code(self, code: str) -> Union[int, float]:
        return float(code) if "." in code else int(code)

    def parse_codes(self, codes: str) -> List[Union[int, float]]:
        return [self.parse_code(code) for code in codes.replace(" ", "").split(",")]

    def angle(self, xa: int, ya: int, xb: int, yb: int):
        angl = np.arctan2(xb - xa, ya - yb)
        if angl < 0:
            angl = angl + (2*np.pi)
        return np.fix(np.degrees(angl))

    def angle_v(self, xa: int, ya: int, xb: np.ndarray, yb: np.ndarray):
        angl = np.arctan2(xb - xa, ya - yb)
        """angl = np.zeros(xb.size)
        for i in range(xb.size):
            angl[i] = np.arctan2(xb[i] - xa, ya - yb[i])"""
        angl[angl < 0] = angl[angl < 0] + (2*np.pi)
        return np.fix(np.degrees(angl))

    # TODO: Almost duplicated code
    def create_kernel(self, radius: int) -> np.array:
        """
        Create circular kernel used to compute the windowed average

        :param radius: Radius of the kernel
        :type radius: int

        :returns: The kernel
        :rtype: np.array
        """

        dist_inside_circle = (radius-1)/2.0
        # X and Y indices of each cells of the kernel
        x_indices, y_indices = np.mgrid[0:radius, 0:radius]
        # Center index of matrix
        center = (radius/2) - 0.5
        # Euclidean distance matrix to center
        matrix_dist_center = np.sqrt((x_indices - center) ** 2 + (y_indices - center) ** 2)
        # Keep indices inside circle
        k = np.zeros((radius, radius), dtype=int)
        k[matrix_dist_center < dist_inside_circle] = 1
        return k, matrix_dist_center

    def compute_influence(self, os_raster: np.ndarray, no_data: Optional[Union[int, float]], codes: List[Union[int, float]], use_centroids: bool, min_size: float, max_size: float, dilation_factor: int, orientation: int, output_no_data: Union[int, float]) -> np.ndarray:
        inversion = False
        rows, cols = os_raster.shape
        no_data_mask = os_raster == no_data
        os_raster_list = os_raster.tolist()
        mask_code = np.isin(os_raster, codes)
        time_total, time_search, time_total_0, time_total_1, time_total_2, time_total_2a, time_total_2aa, time_total_2ab, time_total_2ac, time_total_2b = 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
        time_total = time.perf_counter()
        time_label = time.perf_counter()
        """
        vtt : La valeur des pixels dont la valeur est dans 'codes' est la taille en pixels de la tâche (0 si la valeur du pixel n'est pas dans 'codes')
        """
        val_taille_tache, val_tt_trans, vtt, val_trans, val_dist_pond = np.zeros(os_raster.shape), np.zeros(os_raster.shape, dtype=bool), np.zeros(os_raster.shape), np.zeros(os_raster.shape), np.zeros(os_raster.shape)
        quad = [0, 23, 67, 112, 157, 202, 247, 292, 337]
        frequencies = 500
        labels, _ = label(mask_code, structure=generate_binary_structure(2, 2))
        unique, counts = np.unique(labels, return_counts=True)
        vtt = counts[labels]
        vtt[labels == 0] = 0
        time_label = time.perf_counter() - time_label

        time_bwhere = time.perf_counter()
        labels_mask_y, labels_mask_x = np.nonzero(labels.astype(bool))
        sublabels = labels[labels_mask_y, labels_mask_x]
        time_bwhere = time.perf_counter() - time_bwhere

        kernels = {}
        kernels_dist = {}

        count_passages = 0
        for i in range(rows):
            if i % 100 == 0:
                print("[row {0}]".format(i))
            for j in range(cols):
                debug = i == self.debug[0] and j == self.debug[1]
                conti = mask_code[i, j] and val_taille_tache[i, j] == 0
                if conti:
                    # Calcul taille taches
                    code = labels[i, j]
                    time_0 = time.perf_counter()
                    sublabels_code = sublabels == code
                    masky = labels_mask_y[sublabels_code]
                    maskx = labels_mask_x[sublabels_code]
                    cptT = vtt[i, j]
                    val_taille_tache[masky, maskx] = np.inf
                    pl, dl, pc, dc = np.min(masky), np.max(masky)+1, np.min(maskx), np.max(maskx)+1
                    time_total_0 += (time.perf_counter() - time_0)

                    time_search_pixel = time.perf_counter()
                    time_search += (time.perf_counter() - time_search_pixel)

                    # Calcul distance pondérée
                    if min_size < cptT < max_size:
                        time_1 = time.perf_counter()
                        count_passages += 1
                        cptT *= dilation_factor
                        cptTlog = np.sqrt(cptT + 1)
                        rayon = int(cptTlog)

                        # Calcul centre de gravité
                        if use_centroids:
                            # ymin, ymax, xmin, xmax = masky.min(), masky.max(), maskx.min(), maskx.max()
                            val_taille_tache_mask = val_taille_tache[pl:dl, pc:dc] == np.inf
                            vtty, vttx = np.nonzero(val_taille_tache_mask)
                            vtty += pl
                            vttx += pc

                            pl, pc = int(vtty.mean()), int(vttx.mean())
                            dl, dc = pl + 1, pc + 1

                            """grid_p, grid_pp = np.meshgrid(np.arange(pl, dl), np.arange(pc, dc), indexing="ij")

                            totL = grid_p[val_taille_tache_mask].sum()  #grid_p[val_taille_tache_mask].sum()
                            totC = grid_pp[val_taille_tache_mask].sum()  # grid_pp[val_taille_tache_mask].sum()
                            cpt = val_taille_tache_mask.sum()

                            cgl, cgc = (totL / cpt, totC / cpt) if cpt > 0 else (totL, totC)

                            pl, pc = int(cgl), int(cgc)
                            dl, dc = pl + 1, pc + 1
                            dl = rows if dl > rows else dl
                            dc = cols if dc > cols else dc"""

                            if count_passages % frequencies == 0:
                                print(pl, pc, dl, dc)

                        time_total_1 += (time.perf_counter() - time_1)

                        time_2 = time.perf_counter()
                        if rayon not in kernels:
                            kernels[rayon], kernels_dist[rayon] = self.create_kernel(rayon*2+1)
                            kernels[rayon] = kernels[rayon].astype(bool)
                        k_ = kernels[rayon].copy()
                        kdists_ = kernels_dist[rayon].copy()
                        no_data_mask_ = no_data_mask.copy()

                        """code_mask = ((val_taille_tache[pl:dl, pc:dc] == np.inf) | (use_centroids))
                        code_mask_y, code_mask_x = np.where(code_mask)
                        code_mask_y += pl
                        code_mask_x += pc
                        f1 = (code_mask_y - rayon).astype(int)
                        f2 = (code_mask_y + rayon + 1).astype(int)
                        f3 = (code_mask_x - rayon).astype(int)
                        f4 = (code_mask_x + rayon + 1).astype(int)
                        oob = (f1 < 0) | (f3 < 0) | (f2 > rows) | (f4 > cols)
                        f1, f3 = np.maximum(0, f1), np.maximum(0, f3)
                        f2, f4 = np.minimum(f2, rows), np.minimum(f4, cols)"""

                        p_counter = -1
                        for p in range(pl, dl):
                            for pp in range(pc, dc):
                                if val_taille_tache[p, pp] == np.inf or use_centroids:
                                    time_2a = time.perf_counter()
                                    p_counter += 1

                                    time_2aa = time.perf_counter()

                                    f1_, f2_, f3_, f4_ = int(p - rayon), int(p + rayon + 1), int(pp - rayon), int(pp + rayon + 1)
                                    oob = f1_ < 0 or f3_ < 0 or f2_ > rows or f4_ > cols
                                    f1, f3 = 0 if f1_ < 0 else f1_, 0 if f3_ < 0 else f3_
                                    f2, f4 = rows if f2_ > rows else f2_, cols if f4_ > cols else f4_

                                    # val_trans_view = val_trans[f1:f2, f3:f4]
                                    val_transport = np.zeros((f2-f1, f4-f3))
                                    time_total_2aa += (time.perf_counter() - time_2aa)

                                    time_2ab = time.perf_counter()

                                    k, kdists, no_data_mask = k_, kdists_, no_data_mask_
                                    if oob:
                                        k = k_[f1-f1_:f2-f1, f3-f3_:f4-f3]
                                        kdists = kdists_[f1-f1_:f2-f1, f3-f3_:f4-f3]
                                        no_data_view = no_data_mask_[f1-f1_:f2-f1, f3-f3_:f4-f3]
                                        val_transport = np.zeros_like(no_data_view)
                                    else:
                                        no_data_view = no_data_mask[f1:f2, f3:f4]

                                    val_trans_mask = k & (~no_data_view)
                                    time_total_2ab += (time.perf_counter() - time_2ab)

                                    time_2ac = time.perf_counter()
                                    # val_trans_view[val_trans_mask] = np.maximum(0, val_trans_view[val_trans_mask]) + rayon - kdists[val_trans_mask]

                                    val_transport[val_trans_mask] = rayon - kdists[val_trans_mask]
                                    #  val_transport = rayon - kdists[val_trans_mask]
                                    time_total_2ac += (time.perf_counter() - time_2ac)

                                    time_total_2a += (time.perf_counter() - time_2a)

                                    # Inversion
                                    """if inversion:
                                        moy = rayon / 2
                                        if i1 in range(f1, f2):
                                            for j1 in range(f3, f4):
                                                v = val_trans[i1, j1]
                                                if val_trans[i1, j1] > 0 and os_raster_list[i1][j1] != no_data:
                                                    val_trans[i1, j1] = moy + (moy - v) + 1

                                        g1, g2, g3, g4 = p - (rayon * 2), p + (rayon * 2) + 1, pp - (rayon * 2), pp + (rayon * 2) + 1
                                        g1, g3 = 0 if g1 < 0 else g1, 0 if g3 < 0 else g3
                                        g2, g4 = rows if g2 > rows else g2, cols if g4 > cols else g4
                                        for k1 in range(g1, g2):
                                            for kk1 in range(g3, g4):
                                                dist_i = (cgl-k1) * (cgl-k1)
                                                dist_j = (cgc-kk1) * (cgc-kk1)
                                                dist = np.sqrt(dist_i + dist_j)
                                                if dist < rayon:
                                                    v = val_trans[k1, kk1]
                                                    if v == 0 and os_raster[k1, kk1] != no_data:
                                                        dist -= rayon
                                                        val_trans[k1, kk1] = moy - (dist - moy)"""
                                    # Première méthode
                                    # [orientation]
                                    if orientation > 0:
                                        # val_trans_view = val_trans[f1:f2, f3:f4]
                                        # angl_mask = val_trans_view > 0
                                        angl_mask = val_transport > 0
                                        #  angl_mask = val_trans_mask[val_transport > 0]
                                        if not oob:
                                            f_x, f_y = np.meshgrid(np.arange(f1, f2), np.arange(f3, f4), indexing="ij")
                                        else:
                                            f_x, f_y = np.meshgrid(np.arange(f1-f1_, f2-f1), np.arange(f3-f3_, f4-f3), indexing="ij")
                                        # angls = np.zeros(val_trans_view.shape)
                                        angls = np.zeros(val_transport.shape)
                                        angls[angl_mask] = self.angle_v(pp, p, f_y[angl_mask], f_x[angl_mask])
                                        angl_mask = angl_mask & ((angls < quad[orientation]) | (angls > quad[orientation % 8 + 1]))
                                        # val_trans_view[angl_mask] = 0
                                        val_transport[angl_mask] = 0

                                    if inversion:
                                        pass
                                        """g1, g2, g3, g4 = i - (rayon * 2), i + (rayon * 2) + 1, j - (rayon * 2), j + (rayon * 2) + 1
                                        g1, g3 = 0 if g1 < 0 else g1, 0 if g3 < 0 else g3
                                        g2, g4 = rows if g2 > rows else g2, cols if g4 > cols else g4
                                        for p1 in range(g1, g2):
                                            for pp1 in range(g3, g4):
                                                val_dist_pond[p1, pp1] += val_trans[p1, pp1]
                                                val_trans[p1, pp1] = 0"""
                                    else:
                                        time_2b = time.perf_counter()
                                        if oob:
                                            val_dist_pond_view = val_dist_pond[f1 - f1_:f2 - f1, f3 - f3_:f4 - f3]
                                        else:
                                            val_dist_pond_view = val_dist_pond[f1:f2, f3:f4]
                                        # val_dist_pond_view += np.fix(val_trans[f1:f2, f3:f4]) if self.__int_mode else val_trans[f1:f2, f3:f4]
                                        val_dist_pond_view += np.fix(val_transport) if self.__int_mode else val_transport
                                        #  val_dist_pond[f1:f2, f3:f4][val_trans_mask] += np.fix(val_transport) if self.__int_mode else val_transport
                                        # val_trans[f1:f2, f3:f4] = 0
                                        time_total_2b += (time.perf_counter() - time_2b)
                        time_total_2 += (time.perf_counter() - time_2)


        max_value = val_dist_pond[os_raster != no_data].max()
        divis = (max_value / 30000) + 1
        val_dist_pond = np.fix(val_dist_pond/np.fix(divis)) if self.__int_mode else val_dist_pond/divis
        val_dist_pond[os_raster == no_data] = output_no_data
        time_total = time.perf_counter() - time_total
        print("[time] {0:.2f}s".format(time_total))
        print("\t[labels] {0:.2f}s".format(time_label))
        print("\t[search] {0:.2f}s".format(time_search))
        print("\t[0] {0:.2f}s".format(time_total_0))
        print("\t[1] {0:.2f}s".format(time_total_1))
        print("\t[2] {0:.2f}s".format(time_total_2))
        print("\t\t[2a] {0:.2f}s".format(time_total_2a))
        print("\t\t\t[2aa] {0:.2f}s".format(time_total_2aa))
        print("\t\t\t[2ab] {0:.2f}s".format(time_total_2ab))
        print("\t\t\t[2ac] {0:.2f}s".format(time_total_2ac))
        print("\t\t[2b] {0:.2f}s".format(time_total_2b))
        print("[count_passages]", count_passages)
        return val_dist_pond

    def processAlgorithm(self, parameters, context, feedback):
        """
        Called when a processing algorithm is run
        """
        print("[influence] processAlgorithm")
        # Load input raster and its metadata
        os_layer: QgsRasterLayer = self.parameterAsRasterLayer(parameters, self.INPUT_OS, context)
        os_raster: np.ndarray = LandsklimUtils.raster_to_array(os_layer)
        no_data, geotransform = self.get_raster_metadata(parameters, context, os_layer)
        out_srs: SpatialReference = self.get_spatial_reference(os_layer)
        # raster_layer: Optional[QgsRasterLayer] = self.parameterAsRasterLayer(parameters, self.INPUT_QUANTITATIVE_RASTER, context) if parameters[self.INPUT_QUANTITATIVE_RASTER] is not None else None
        # Get index of selected item
        orientation: int = self.parameterAsEnum(parameters, self.INPUT_ORIENTATIONS, context)
        values: str = self.parameterAsString(parameters, self.INPUT_CODES, context)
        use_centroids: bool = self.parameterAsBoolean(parameters, self.INPUT_CENTROIDS, context)
        min_size: float = self.parameterAsDouble(parameters, self.INPUT_MIN_SIZE, context) if parameters[self.INPUT_MIN_SIZE] is not None else -np.inf
        max_size: float = self.parameterAsDouble(parameters, self.INPUT_MAX_SIZE, context) if parameters[self.INPUT_MAX_SIZE] is not None else np.inf
        dilation_factor: int = self.parameterAsInt(parameters, self.INPUT_DILATION_FACTOR, context)

        # Path of the layer is given. If a temporary layer is selected, layer is created in qgis temp dir
        out_path = self.parameterAsOutputLayer(parameters, self.OUTPUT_RASTER, context)

        codes: List[Union[int, float]] = self.parse_codes(values)
        output_no_data = -9999
        influence_array: np.ndarray = self.compute_influence(os_raster, no_data, codes, use_centroids, min_size,
                                                             max_size, dilation_factor, orientation, output_no_data)

        influence_array[os_raster == no_data] = output_no_data
        self.write_raster(out_path, influence_array, out_srs, geotransform, output_no_data)

        return {self.OUTPUT_RASTER: out_path}
