import time
from typing import Union, Tuple, List, Optional

from qgis.core import QgsProcessingParameterString, QgsProcessingParameterRasterLayer, \
    QgsProcessingParameterRasterDestination, QgsRasterLayer, QgsProcessingParameterVectorLayer, QgsVectorLayer, \
    QgsRectangle
from qgis import processing

from landsklim.processing.landsklim_processing_tool_algorithm import LandsklimProcessingToolAlgorithm


class DistanceRasterProcessingAlgorithm(LandsklimProcessingToolAlgorithm):
    """
    Processing algorithm computing distance raster from raster values
    """
    INPUT_RASTER = 'INPUT_RASTER'
    INPUT_CODES = 'INPUT_CODES'
    OUTPUT_RASTER = 'OUTPUT_RASTER'

    def createInstance(self):
        return DistanceRasterProcessingAlgorithm()

    def name(self) -> str:
        """
        Unique name of the algorithm
        """
        return 'distanceraster'

    def displayName(self) -> str:
        """
        Displayed name of the algorithm
        """
        return self.tr('Distance (raster)')

    def shortHelpString(self) -> str:
        return self.tr('Compute distance raster to a list of values of a raster')

    def initAlgorithm(self, config=None):
        """
        Define inputs and outputs for the main input
        """
        self.addParameter(
            QgsProcessingParameterRasterLayer(
                self.INPUT_RASTER,
                self.tr('Input raster')
            )
        )

        self.addParameter(
            QgsProcessingParameterString(
                self.INPUT_CODES,
                self.tr('List to values to compute distance from. Values are separated by a comma. Decimal mark is defined by a point')
            )
        )

        self.addParameter(
            QgsProcessingParameterRasterDestination(
                self.OUTPUT_RASTER,
                self.tr('Output raster')
            )
        )

    def parse_code(self, code: str) -> Union[int, float]:
        return float(code) if "." in code else int(code)

    def parse_codes(self, codes: str) -> List[Union[int, float]]:
        return [self.parse_code(code) for code in codes.replace(" ", "").split(",")]

    def compute_distance(self, raster: QgsRasterLayer, codes: List[Union[int, float]], out_path: str):
        """
        compute_distance consist of calling gdal:proximity processing algorithm
        """
        params = {
            'INPUT': raster.source(),
            'BAND': 1,
            'VALUES': ", ".join([str(code) for code in codes]),
            'UNITS': 1,  # 0 : 'GEO'
            'OUTPUT': out_path
        }
        processing.run("gdal:proximity", params)

    def processAlgorithm(self, parameters, context, feedback):
        """
        Called when a processing algorithm is run
        """
        # Load input raster and its metadata
        raster_layer: QgsRasterLayer = self.parameterAsRasterLayer(parameters, self.INPUT_RASTER, context)
        values: str = self.parameterAsString(parameters, self.INPUT_CODES, context)

        # Path of the layer is given. If a temporary layer is selected, layer is created in qgis temp dir
        out_path = self.parameterAsOutputLayer(parameters, self.OUTPUT_RASTER, context)

        codes: List[Union[int, float]] = self.parse_codes(values)
        self.compute_distance(raster_layer, codes, out_path)

        return {self.OUTPUT_RASTER: out_path}


class DistanceVectorProcessingAlgorithm(LandsklimProcessingToolAlgorithm):
    """
    Processing algorithm computing distance raster from features
    """
    INPUT_RASTER = 'INPUT_RASTER'
    INPUT_VECTOR = 'INPUT_VECTOR'
    OUTPUT_RASTER = 'OUTPUT_RASTER'

    def createInstance(self):
        return DistanceVectorProcessingAlgorithm()

    def name(self) -> str:
        """
        Unique name of the algorithm
        """
        return 'distancevector'

    def displayName(self) -> str:
        """
        Displayed name of the algorithm
        """
        return self.tr('Distance (from features)')

    def shortHelpString(self) -> str:
        return self.tr('Compute raster distance to features')

    def initAlgorithm(self, config=None):
        """
        Define inputs and outputs for the main input
        """

        self.addParameter(
            QgsProcessingParameterVectorLayer(
                self.INPUT_VECTOR,
                self.tr('Input vector')
            )
        )

        self.addParameter(
            QgsProcessingParameterRasterLayer(
                self.INPUT_RASTER,
                self.tr('Raster extent')
            )
        )

        self.addParameter(
            QgsProcessingParameterRasterDestination(
                self.OUTPUT_RASTER,
                self.tr('Output raster')
            )
        )

    def compute_distance(self, mask: QgsRasterLayer, vector: QgsVectorLayer, out_path: str):
        """
        Distance is computed in two steps.

        - The vector layer is rasterized. Extent is given by the raster mask passed as an argument.
            Value of 1 is given to pixels on which features were present.
        - gdal:proximity is called to build the distance map from the rasterized vector layer.
            gdal write the result to the destination raster passed as an argument, so once this is done, we don't have to do additional treatments.

        :param mask: Extent of target raster
        :type mask: QgsRasterLayer

        :param vector: Vector layer to compute distance from features
        :type vector: QgsVectorLayer

        :param out_path: Raster destination filepath
        :type out_path: str
        """
        if not vector.hasSpatialIndex():
            vector.dataProvider().createSpatialIndex()

        raster_resolution_x: float = mask.rasterUnitsPerPixelX()
        raster_resolution_y: float = mask.rasterUnitsPerPixelY()
        extent: QgsRectangle = mask.extent()

        rasterize_params = {'INPUT': vector,
                            'FIELD': '',
                            'BURN': 1,
                            'UNITS': 1,
                            'WIDTH': raster_resolution_x,
                            'HEIGHT': raster_resolution_y,
                            'EXTENT': extent,
                            'OUTPUT': 'TEMPORARY_OUTPUT'}
        rasterize_output = processing.run("gdal:rasterize", rasterize_params)

        proximity_params = {
            'INPUT': rasterize_output['OUTPUT'],
            'BAND': 1,
            'VALUES': "1",
            'UNITS': 1,  # 0 : 'GEO'
            'OUTPUT': out_path
        }
        processing.run("gdal:proximity", proximity_params)

    def processAlgorithm(self, parameters, context, feedback):
        """
        Called when a processing algorithm is run
        """
        # Load input layers
        mask_layer: QgsRasterLayer = self.parameterAsRasterLayer(parameters, self.INPUT_RASTER, context)
        vector_layer: QgsVectorLayer = self.parameterAsVectorLayer(parameters, self.INPUT_VECTOR, context)

        # Path of the layer is given. If a temporary layer is selected, layer is created in qgis temp dir
        out_path = self.parameterAsOutputLayer(parameters, self.OUTPUT_RASTER, context)

        self.compute_distance(mask_layer, vector_layer, out_path)

        return {self.OUTPUT_RASTER: out_path}
