# -*- coding: utf-8 -*-

"""
***************************************************************************
*                                                                         *
*   This program is free software; you can redistribute it and/or modify  *
*   it under the terms of the GNU General Public License as published by  *
*   the Free Software Foundation; either version 2 of the License, or     *
*   (at your option) any later version.                                   *
*                                                                         *
***************************************************************************
"""
from qgis.PyQt.QtCore import QCoreApplication
from qgis.PyQt.QtWidgets import QMessageBox
from qgis.core import (
    QgsProcessingLayerPostProcessorInterface,
    QgsProcessing,
    QgsFeatureSink,
    QgsVectorLayer,
    QgsProcessingException,
    QgsProcessingAlgorithm,
    QgsProcessingParameterNumber,
    QgsProcessingParameterVectorDestination,
    QgsProcessingParameterFeatureSource,
    QgsProcessingUtils,
    QgsProcessingParameterBoolean,
    QgsSnappingConfig,
    QgsProject,
    QgsTolerance
)

from qgis import processing
import os, inspect
import requests


def set_snapping_options():
    proj = QgsProject.instance()        
    config = proj.snappingConfig()
    proj.setAvoidIntersectionsMode(QgsProject.AvoidIntersectionsMode(1))
    proj.setTopologicalEditing(True)
    config.setMode(QgsSnappingConfig.ActiveLayer)
    config.setType(QgsSnappingConfig.Vertex)
    config.setTolerance(10)
    config.setUnits(QgsTolerance.Pixels)
    config.setEnabled(True)
    proj.setSnappingConfig(config)
    return proj


class Renamer (QgsProcessingLayerPostProcessorInterface):
    def __init__(self, layer_name):
        self.name = layer_name
        super().__init__()
        
    def postProcessLayer(self, layer, context, feedback):
        layer.setName(self.name)


class DownloadWorldcover(QgsProcessingAlgorithm):
    """
    This scripts extracts a roughness map from polygons 
    and saves it as a WAsP .map file
    """

    # Constants used to refer to parameters and outputs. They will be
    # used when calling the algorithm from another algorithm, or when
    # calling from the QGIS console.
    dest_id = None  # Save a reference to the output layer id
    INPUT = "INPUT"
    SAVE_ON_DISK = "SAVE_ON_DISK"
    VALUE = "roughness_length_outside"
    TABLE= "TABLE"
    OUTPUT = "OUTPUT"
    
    def tr(self, string):
        """
        Returns a translatable string with the self.tr() function.
        """
        return QCoreApplication.translate("Processing", string)

    def createInstance(self):
        return DownloadWorldcover()

    def name(self):
        """
        Returns the algorithm name, used for identifying the algorithm. This
        string should be fixed for the algorithm, and must not be localised.
        The name should be unique within each provider. Names should contain
        lowercase alphanumeric characters only and no spaces or other
        formatting characters.
        """
        return "download_worldcover"

    def displayName(self):
        """
        Returns the translated algorithm name, which should be used for any
        user-visible display of the algorithm name.
        """
        return self.tr("Get WorldCover landcover polygons")

    def group(self):
        """
        Returns the name of the group this algorithm belongs to. This string
        should be localised.
        """
        return self.tr("WAsP scripts")

    def groupId(self):
        """
        Returns the unique ID of the group this algorithm belongs to. This
        string should be fixed for the algorithm, and must not be localised.
        The group id should be unique within each provider. Group id should
        contain lowercase alphanumeric characters only and no spaces or other
        formatting characters.
        """
        return "wasp_scripts"

    def shortHelpString(self):
        """
        Returns a localised short helper string for the algorithm. This string
        should provide a basic description about what the algorithm does and the
        parameters and outputs associated with it..
        """
        return self.tr(
            """Extracts landcover polygons from Worldcover data

            The input is a 'Bounding box' layer with a single Polygon where we want to retrieve the data. The projection from this clipping layer is used for the output file. Raster data that are inside this area are downloaded from <a href='https://doi.org/10.5281/zenodo.7254221'>here</a> and then converted to polygons. Note that this relies on externally hosted data and might therefore be slow.
            
            The returned vector layer contains 'id', 'z0' and 'd' attributes which represents a polygon with certain landcover 'id' with a roughness length 'z0' and displacement height 'd', respectively. The lookup table that is used for tying a ID to roughness is found <a href='https://doi.org/10.5194/wes-6-1379-2021'>here</a>."""
        )

    def initAlgorithm(self, config=None):
        """
        Here we define the inputs and output of the algorithm, along
        with some other properties.
        """

        # We add the input vector features source. It can have any kind of
        # geometry.
        # self.addParameter(
        #    QgsProcessingParameterExtent(
        #        self.EXTENT,
        #        self.tr('Extent layer')
        #    )
        # )
      
        self.addParameter(
            QgsProcessingParameterFeatureSource(
                self.INPUT,
                self.tr("Clipping layer"),
                [QgsProcessing.TypeVectorPolygon],
            )
        )
     
       
        self.addParameter(
            QgsProcessingParameterVectorDestination(
                self.OUTPUT, self.tr("Output layer")
            )
        )

        self.addParameter(
            QgsProcessingParameterNumber(
                name=self.VALUE,
                description=self.tr("Resampling resolution"),
                type=QgsProcessingParameterNumber.Integer,
                defaultValue=100,
                optional=False,
                minValue=50,
                maxValue=500,
            )
        )

        self.addParameter(
            QgsProcessingParameterBoolean(
                name=self.SAVE_ON_DISK,
                description=self.tr("Save downloaded files on disk"),
                defaultValue=True,
                optional=False,
            )
        )
    def processAlgorithm(self, parameters, context, feedback):
        """
        Here is where the processing itself takes place.
        """

        # Retrieve the feature source and sink. The 'dest_id' variable is used
        # to uniquely identify the feature sink, and must be included in the
        # dictionary returned by the processAlgorithm function.
        source = self.parameterAsSource(parameters, self.INPUT, context)

        # If source was not found, throw an exception to indicate that the algorithm
        # encountered a fatal error. The exception text can be any string, but in this
        # case we use the pre-built invalidSourceError method to return a standard
        # helper text for when a source cannot be evaluated
        if source is None:
            raise QgsProcessingException(
                self.invalidSourceError(parameters, self.INPUT)
            )
        
        # read the resampling resolution
        resampling_res = self.parameterAsInt(parameters, self.VALUE, context)

        # Send some information to the user
        feedback.pushInfo("CRS is {}".format(source.sourceCrs().authid()))
        clipping_layer = self.parameterAsVectorLayer(parameters, self.INPUT, context)

        if clipping_layer.crs().isGeographic():
            raise QgsProcessingException("WAsP files cannot be saved in geographic coordinates! Please choose a projection with a metric coordinate system!")

        nlines = clipping_layer.featureCount()
        if nlines > 1:
            raise QgsProcessingException(
                f"The clipping layer must consist of a single feature only to avoid downloading too many files! Yours contains {nlines}. Make sure you have selected a bounding box layer."
            )
        
        clipping_layer_features = clipping_layer.getFeatures()
        for f in clipping_layer_features:
            geom = f.geometry()
            size = geom.area()
            if size > 3.6E9:
                raise QgsProcessingException(
                    "The clipping layer must be smaller than 60x60 km!"
                )
            
        # we slightly extend the mask layer so we have enough sample even in corner pixels
        # when we resample to a more coarse resolution
        buffered_layer = processing.run(
        "native:buffer",
        {
            "DISSOLVE": False,
            "DISTANCE": resampling_res*2,
            "END_CAP_STYLE": 2,
            "INPUT": clipping_layer,
            "JOIN_STYLE": 0,
            "MITER_LIMIT": 2,
            "OUTPUT": "TEMPORARY_OUTPUT",
            "SEGMENTS": 5,
        },
        context=context,
        feedback=feedback,
        )["OUTPUT"]
        
        script_dir = os.path.split(inspect.getfile(inspect.currentframe()))[0]
                
        # load worldcover grid
        feedback.pushInfo('Opening Worldcover tiles...')
        path = os.path.join(script_dir,"esa_worldcover_2020_grid.geojson")
        feedback.pushInfo(path)
        tiles = QgsVectorLayer(path, "Vector_Layer", "ogr")
        
        feedback.pushInfo('Reprojecting mask layer to Worldcover projection...')
        clipping_layer_reproj = processing.run("native:reprojectlayer", {
        'INPUT': buffered_layer,
        'OUTPUT': 'TEMPORARY_OUTPUT',
        'TARGET_CRS' : tiles.crs()
        }, context=context, feedback=feedback)['OUTPUT']
        
        feedback.pushInfo('Finding relevant tiles from Worldcover...')
        selected_tiles = processing.run("native:intersection", {
        'INPUT': tiles,
        'OVERLAY': clipping_layer_reproj,
        'OUTPUT': 'TEMPORARY_OUTPUT',
        }, context=context, feedback=feedback)['OUTPUT']
        
        # where to store the download worldcover tif files
        # it can be nice to avoid redownload if you using the same area
        # but if you want to save disk space you could not save them
        save_on_disk = self.parameterAsBool(parameters, self.SAVE_ON_DISK, context)

        file_dir = os.path.join(script_dir, "WorldCoverData")
        if not os.path.exists(file_dir):
            os.mkdir(file_dir)
        feedback.pushInfo(f'Will save files in {file_dir}')
        
        s3_url_prefix = "https://esa-worldcover.s3.eu-central-1.amazonaws.com"
        features = selected_tiles.getFeatures()
        raster_layers = []
        # loop over the tiles that are covered by our mask layer
        for current, feat in enumerate(features):
            # Stop the algorithm if cancel button has been clicked
            if feedback.isCanceled():
                break
            atts = feat.attributes()
            url = f"{s3_url_prefix}/v100/2020/map/ESA_WorldCover_10m_2020_v100_{atts[0]}_Map.tif"

            out_fn = os.path.join(file_dir, f"ESA_WorldCover_10m_2020_v100_{atts[0]}_Map.tif")
            if not os.path.exists(out_fn):
                feedback.pushInfo(f"Downloading data from url {url}")         
                r = requests.get(url, allow_redirects=True)
                with open(out_fn, 'wb') as f:
                    feedback.pushInfo(f"Writing to output file {out_fn}")
                    f.write(r.content)
            raster_layers.append(out_fn)

        # add water ID to bounding box. We will use this to fill any for far offshore areas that area not covered by the raster data
        alg_params = {
            'FIELD_LENGTH': 0,
            'FIELD_NAME': 'id',
            'FIELD_PRECISION': 0,
            'FIELD_TYPE': 1,
            'FORMULA': f"80",
            'INPUT': clipping_layer,
            'NEW_FIELD': True,
            'OUTPUT': 'TEMPORARY_OUTPUT',
        }
        clipping_layer_with_id = processing.run('native:fieldcalculator', alg_params, context=context, feedback=feedback)["OUTPUT"]

        if len(raster_layers) == 0:
            feedback.pushWarning(f"The clipping layer did not contain any usuable landcover areas! Make sure you are in an area that is covered by Worldcover data (https://viewer.esa-worldcover.org/worldcover/). A single offshore landcover map with the size of your bounding box will be generated.")
            clipped_polygons_snapped = clipping_layer_with_id
        else:
            if len (raster_layers) > 1:
                feedback.pushInfo('Multiple input tiles, so merging rasters to one file before proceeding...')
                clipped_rasters = []
                for ras in raster_layers:            
                # use the input mask to clip the raster from each matching tile 
                    feedback.pushInfo(f'Clip raster {ras} by mask layer...')
                    clipped_rasters.append(
                        processing.run("gdal:cliprasterbymasklayer", {
                        'INPUT': ras,
                        'MASK': clipping_layer_reproj,
                        'OUTPUT': 'TEMPORARY_OUTPUT',
                        }, context=context, feedback=feedback)['OUTPUT']
                    )
                # combine the parts that were clipped from each input tile
                clipped_raster = processing.run("gdal:merge", {
                    'INPUT': clipped_rasters,
                    'PCT' : False,
                    'SEPERATE' : False,
                    'OUTPUT' :  'TEMPORARY_OUTPUT',
                }, context=context, feedback=feedback)['OUTPUT']
            else:
                # only a single raster found just open and clip to mask
                feedback.pushInfo('Opening input tile')
                clipped_raster = processing.run("gdal:cliprasterbymasklayer", {
                'INPUT': raster_layers[0],
                'MASK': clipping_layer_reproj,
                'OUTPUT': 'TEMPORARY_OUTPUT',
                }, context=context, feedback=feedback)['OUTPUT'] 

            # resamping to coarser resolution because WAsP will generally choke on very high resolution maps
            coarse_raster = processing.run("gdal:warpreproject", {
                'INPUT': clipped_raster, 
                'TARGET_CRS': clipping_layer.crs(),
                'RESAMPLING':6, # we use the mode for resampling, i.e. most common value within each coarser pixel
                'TARGET_RESOLUTION':resampling_res,
                'OUTPUT' :  'TEMPORARY_OUTPUT'
            })['OUTPUT']

            # convert raster to polygon vector layer
            feedback.pushInfo('Polygonizing raster...')
            polygons = processing.run("gdal:polygonize", {
            'INPUT': coarse_raster,
            'BAND': 1,
            'FIELD':'id',
            'EIGHT_CONNECTEDNESS':False,
            'OUTPUT': 'TEMPORARY_OUTPUT',
            }, context=context, feedback=feedback)['OUTPUT']

            # we slightly extended the original mask to make sure all pixels fall with the area.
            feedback.pushInfo("Clipping to bounding box...")
            clipped_polygons = processing.run(
                "native:clip",
                {
                    "INPUT": polygons,
                    "OUTPUT": "TEMPORARY_OUTPUT",
                    "OVERLAY": clipping_layer,
                },
                context=context,
                feedback=feedback,
            )["OUTPUT"]

            # snap geometries to make sure the rotation of the polygons is correct
            # and that they area actually bordering without any edges
            clipped_polygons_snapped = processing.run(
                "native:snapgeometries",
                {
                    "INPUT": clipped_polygons,
                    "REFERENCE_LAYER": clipped_polygons,
                    "OUTPUT": "TEMPORARY_OUTPUT",
                    "BEHAVIOR": 0,
                    "TOLERANCE": 1,
                },
                context=context,
                feedback=feedback,
            )["OUTPUT"]

            if clipped_polygons_snapped.featureCount() == 0:
                raise QgsProcessingException(
                    f"The clipping layer did not contain any usuable landcover areas! Make sure you are in an area that is covered by Worldcover data (https://viewer.esa-worldcover.org/worldcover/)."
                )

        path = os.path.join(script_dir,"landcovertables", "WorldCover.gpkg") # this file contains the landcover lookup table to convert landcover classes to roughness
        table = QgsVectorLayer(path, "Vector_Layer", "ogr")

        feedback.pushInfo('Joining WorldCover roughness table')
        alg_params = {
            'DISCARD_NONMATCHING': True,
            'FIELD': 'id',
            'FIELDS_TO_COPY': ['d','z0','desc'],
            'FIELD_2': 'id',
            'INPUT': clipped_polygons_snapped,
            'INPUT_2': table,
            'METHOD': 0,
            'PREFIX': '',
            'OUTPUT': 'TEMPORARY_OUTPUT'
        }        
        clipping_layer_reproj_joined = processing.run('native:joinattributestable', alg_params, context=context, feedback=feedback)['OUTPUT']
        
        (sink, self.dest_id) = self.parameterAsSink(
            parameters,
            self.OUTPUT,
            context,
            clipping_layer_reproj_joined.fields(),
            clipping_layer_reproj_joined.wkbType(),
            clipping_layer_reproj_joined.sourceCrs(),
        )
        
        if sink is None:
            raise QgsProcessingException(self.invalidSinkError(parameters, self.OUTPUT))

        # Compute the number of steps to display within the progress bar and
        # get features from source
        total = 100.0 / clipping_layer_reproj_joined.featureCount() if source.featureCount() else 0
        features = clipping_layer_reproj_joined.getFeatures()
        feedback.pushInfo("Number of roughness lines to write: " + str(total))
        for current, feature in enumerate(features):
            # Stop the algorithm if cancel button has been clicked
            if feedback.isCanceled():
                break

            # Add a feature in the sink
            sink.addFeature(feature, QgsFeatureSink.FastInsert)
            feedback.setProgress(int(current * total))
        feedback.pushInfo(f"nr segments {current}")

        # clean up files if requested
        if not save_on_disk:
            for ras in raster_layers:
                os.remove(ras)

        # set project options such that snapping is enabled to avoid errors when people modify the polygons
        proj = set_snapping_options()

        # Return the results of the algorithm. In this case our only result is
        # the feature sink which contains the processed features, but some
        # algorithms may return multiple feature sinks, calculated numeric
        # statistics, etc. These should all be included in the returned
        # dictionary, with keys matching the feature corresponding parameter
        # or output names.
        return {self.OUTPUT: self.dest_id}


    def postProcessAlgorithm(self, context, feedback):
        global renamer
        renamer = Renamer('WorldCover landcover polygons')
        context.layerToLoadOnCompletionDetails(self.dest_id).setPostProcessor(renamer)
        cmd_folder = os.path.split(inspect.getfile(inspect.currentframe()))[0]
        processed_layer = QgsProcessingUtils.mapLayerFromString(self.dest_id, context)
        processed_layer.loadNamedStyle(os.path.join(cmd_folder,"styles","z0_polygons.qml"))
        processed_layer.triggerRepaint()
        
        return {self.OUTPUT: self.dest_id}