import time
from typing import Union

from osgeo import gdal
from osgeo import osr
from osgeo.osr import SpatialReference
from qgis.PyQt.QtCore import QCoreApplication
from qgis.core import QgsProcessing, QgsProcessingAlgorithm, QgsProcessingException, QgsProcessingParameterRasterLayer, \
    QgsProcessingParameterNumber, QgsProcessingParameterRasterDestination, QgsRasterLayer, QgsProcessingParameterBoolean
from qgis import processing
import numpy as np
import math
from scipy.signal import convolve2d
from scipy.ndimage import maximum_filter, minimum_filter, convolve, generic_filter

from landsklim.processing.landsklim_processing_regressor_algorithm import LandsklimProcessingRegressorAlgorithm
from landsklim.lk.utils import LandsklimUtils
from landsklim.lk import environment


class RoughnessProcessingAlgorithm(LandsklimProcessingRegressorAlgorithm):
    """
    Processing algorithm computing altitude average from a DEM
    """

    def __init__(self):
        super().__init__({'OUTPUT': 'Roughness raster'})

    def createInstance(self):
        return RoughnessProcessingAlgorithm()

    def name(self) -> str:
        """
        Unique name of the algorithm
        """
        return 'roughness'

    def displayName(self) -> str:
        """
        Displayed name of the algorithm
        """
        return self.tr('Roughness')

    def shortHelpString(self) -> str:
        return self.tr('Compute windowed roughness from DEM')

    def add_dependencies(self):
        """
        No dependencies
        """
        pass

    def compute_roughness(self, raster: np.array, kernel_size: int, no_data: Union[int, float], sdif: np.array):
        k = self.create_kernel(kernel_size)
        no = self.get_neighboors_count_raster(raster, k, no_data)
        roughness = np.sqrt((sdif * 10) / no)
        roughness[no < 3] = 0
        return roughness

    def processAlgorithm(self, parameters, context, feedback):
        """
        Called when a processing algorithm is run
        """
        # Load input raster and its metadata
        input_raster: QgsRasterLayer = self.parameterAsRasterLayer(parameters, 'INPUT', context)
        no_data, geotransform = self.get_raster_metadata(parameters, context, input_raster)

        # Load other params
        input_window = self.parameterAsInt(parameters, 'INPUT_WINDOW', context)
        additional_variables_folder = self.parameterAsString(parameters, 'ADDITIONAL_VARIABLES_FOLDER', context).strip()

        out_srs: SpatialReference = self.get_spatial_reference(input_raster)

        # Path of the layer is given. If a temporary layer is selected, layer is created in qgis temp dir
        out_path = self.parameterAsOutputLayer(parameters, 'OUTPUT', context)

        np_input = LandsklimUtils.raster_to_array(input_raster)
        np_output = np.copy(np_input)

        if not environment.USE_QGIS_IMPLEMENTATION:
            _, _, _, sdif = self.dem_variables(np_input, no_data, self.create_kernel(input_window), additional_variables_folder)

            # Compute windowed average
            np_output = self.compute_roughness(np_output, input_window, no_data, sdif)
            # Erase every computed data by no data where necessary
            output_no_data = -9999
            np_output[np_input == no_data] = output_no_data

            self.write_raster(out_path, np_output, out_srs, geotransform, output_no_data)
        else:
            print("[Call QGIS implementation]")
            self.qgis_roughness(input_raster, input_window, out_path)

        return {'OUTPUT': out_path}

    def qgis_roughness(self, raster: QgsRasterLayer, kernel_size: int, out_path: str):
        """
        Use the QGIS/GDAL implementation of the roughness algorithm (Terrain Roughness Index)

        :param raster: Input raster (DEM, smoothing according the kernel)
        :type raster: QgsRasterLayer

        :param kernel_size: Kernel size
        :type kernel_size: int

        :param out_path: Destination file of the output
        :type out_path: str
        """

        source: str = raster.source()
        smoothed_source = processing.run("landsklim:altitude",
                                      {
                                          'INPUT': source,
                                          'INPUT_WINDOW': kernel_size,
                                          'INPUT_CUSTOM_NO_DATA': None,
                                          'OUTPUT': QgsProcessing.TEMPORARY_OUTPUT
                                      })['OUTPUT']

        processing.run("gdal:roughness",
                        {
                            'INPUT': smoothed_source,
                            'BAND': 1,
                            'COMPUTE_EDGES': False,
                            'OPTIONS': '',
                            'OUTPUT': out_path
                         }
                       )  # NO_DATA is set as -9999
