from typing import Union

from osgeo import gdal
from osgeo import osr
from osgeo.osr import SpatialReference
from qgis.PyQt.QtCore import QCoreApplication
from qgis.core import QgsProcessing, QgsProcessingAlgorithm, QgsProcessingException, QgsProcessingParameterRasterLayer, \
    QgsProcessingParameterNumber, QgsProcessingParameterRasterDestination, QgsRasterLayer, QgsProcessingParameterBoolean
from qgis import processing
import numpy as np
from scipy import ndimage
from scipy.signal import convolve2d

from landsklim.lk.utils import LandsklimUtils
from landsklim.processing.landsklim_processing_regressor_algorithm import LandsklimProcessingRegressorAlgorithm


class SmoothingProcessingAlgorithm(LandsklimProcessingRegressorAlgorithm):
    """
    Processing algorithm smoothing raster
    """

    INPUT = 'INPUT'
    INPUT_WINDOWS = "INPUT_WINDOW"
    OUTPUT = "OUTPUT"

    def __init__(self):
        super().__init__({self.OUTPUT: 'Smooth raster'})

    def createInstance(self):
        return SmoothingProcessingAlgorithm()

    def name(self) -> str:
        """
        Unique name of the algorithm
        """
        return 'smoothing'

    def displayName(self) -> str:
        """
        Displayed name of the algorithm
        """
        return self.tr('Smooth')

    def shortHelpString(self) -> str:
        return self.tr('Smooth raster')

    def add_dependencies(self):
        """
        No dependencies
        """
        pass

    def smoothing(self, a: np.array, kernel_size: int, no_data: Union[int, float]) -> np.array:
        """
        TODO: Merge with algorith_altitude ?
        Smooth array with a square kernel

        :param a: Input array
        :type a: np.array

        :param kernel_size: Window radius
        :type kernel_size: int

        :param no_data: Value to consider as no data
        :type no_data: Union[int, float]

        :returns: Windowed average of the array
        :rtype: np.array
        """
        diameter = (kernel_size*2)+1
        k = np.ones((diameter, diameter))

        # Left offset on the LISDQS code ?
        k = np.zeros((diameter, diameter))
        k[:-1, 1:] = 1
        k = k.T  # Convolution flip the kernel so prevent this by flipping the kernel here

        a_no_data = a.copy()
        a_no_data[a_no_data == no_data] = 0
        # Sum all neighbors for each cell. no_data is set to 0 so it doesn't affect result
        window_sum = ndimage.convolve(a_no_data, k, mode='constant')  # convolve2d(a_no_data, k, mode='same')

        window_count = self.get_neighboors_count_raster(a, k, no_data)
        return window_sum / window_count

    def processAlgorithm(self, parameters, context, feedback):
        """
        Called when a processing algorithm is run
        """
        # Load input raster and its metadata
        input_raster: QgsRasterLayer = self.parameterAsRasterLayer(parameters, self.INPUT, context)
        no_data, geotransform = self.get_raster_metadata(parameters, context, input_raster)

        # Load other params
        input_window = self.parameterAsInt(parameters, self.INPUT_WINDOWS, context)

        out_srs: SpatialReference = self.get_spatial_reference(input_raster)

        # Path of the layer is given. If a temporary layer is selected, layer is created in qgis temp dir
        out_path = self.parameterAsOutputLayer(parameters, self.OUTPUT, context)

        np_input = LandsklimUtils.raster_to_array(input_raster)
        np_output = np.copy(np_input)

        # Compute windowed average
        np_output = self.smoothing(np_output, input_window, no_data)

        # np_output = np.fix(np_output)

        # Erase every computed data by no data where necessary
        np_output[np_input == no_data] = no_data
        self.write_raster(out_path, np_output, out_srs, geotransform, no_data)

        return {self.OUTPUT: out_path}
