# -*- coding: utf-8 -*-

"""
***************************************************************************
*                                                                         *
*   This program is free software; you can redistribute it and/or modify  *
*   it under the terms of the GNU General Public License as published by  *
*   the Free Software Foundation; either version 2 of the License, or     *
*   (at your option) any later version.                                   *
*                                                                         *
***************************************************************************
"""
from qgis.PyQt.QtCore import QCoreApplication
from qgis.PyQt.QtWidgets import QMessageBox
from qgis.core import (
    QgsProcessingLayerPostProcessorInterface,
    QgsProcessing,
    QgsProcessingUtils,
    QgsFeatureSink,
    QgsRasterLayer,
    QgsFeature,
    QgsProcessingException,
    QgsProcessingAlgorithm,
    QgsProcessingParameterVectorDestination,
    QgsProcessingParameterFeatureSource,
    QgsProcessingParameterBoolean,
    QgsProcessingParameterNumber,
)

from qgis import processing
import os, inspect
import numpy as np

class DownloadGWAElevationCOG(QgsProcessingAlgorithm):
    """
    This scripts extracts a roughness map from polygons
    and saves it as a WAsP .map file
    """

    # Constants used to refer to parameters and outputs. They will be
    # used when calling the algorithm from another algorithm, or when
    # calling from the QGIS console.
    dest_id = None  # Save a reference to the output layer id
    INPUT = "INPUT"
    INTERVAL = "contouring_interval"
    OFFSET = "contouring_offset"
    OUTPUT = "OUTPUT"

    def tr(self, string):
        """
        Returns a translatable string with the self.tr() function.
        """
        return QCoreApplication.translate("Processing", string)

    def createInstance(self):
        return DownloadGWAElevationCOG()

    def name(self):
        """
        Returns the algorithm name, used for identifying the algorithm. This
        string should be fixed for the algorithm, and must not be localised.
        The name should be unique within each provider. Names should contain
        lowercase alphanumeric characters only and no spaces or other
        formatting characters.
        """
        return "download_gwa_elev_cog"

    def displayName(self):
        """
        Returns the translated algorithm name, which should be used for any
        user-visible display of the algorithm name.
        """
        return self.tr("Get GWA elevation lines")

    def group(self):
        """
        Returns the name of the group this algorithm belongs to. This string
        should be localised.
        """
        return self.tr("WAsP scripts")

    def groupId(self):
        """
        Returns the unique ID of the group this algorithm belongs to. This
        string should be fixed for the algorithm, and must not be localised.
        The group id should be unique within each provider. Group id should
        contain lowercase alphanumeric characters only and no spaces or other
        formatting characters.
        """
        return "wasp_scripts"

    def shortHelpString(self):
        """
        Returns a localised short helper string for the algorithm. This string
        should provide a basic description about what the algorithm does and the
        parameters and outputs associated with it..
        """
        return self.tr(
            """Extracts elevation lines from viewfinder/SRTM

            The input is a 'Bounding box' layer with a single Polygon where we want to retrieve the data. Based on this layer, the tiles that cover this area are downloaded and the data is clipped out of these tiles. Note that the projection of the clipping layer determines the projection of the output.

            The lines are then converted using an interval and offset. The interval denotes the interval between which the contour lines are drawn and the offset denotes the starting point from where the contouring is started from.

            The returned vector layer contains an attribute "ELEV" which can be used later on when saving it as a WAsP .map file.

            The elevation data are obtained from <a href='http://viewfinderpanoramas.org'>viewfinderpanoramas</a> and we acknowledge all contributors mentioned there. Most of the data is based on the 3'' SRTM data, but other sources are used as well as described <a href='http://viewfinderpanoramas.org/technical.htm'>here</a>.

            <i>de Ferranti, Jonathan. Viewfinder Mountain Top Horizon Maps, viewfinderpanoramas.org/. Accessed 4 July 2023.</i>
            """
        )

    def initAlgorithm(self, config=None):
        """
        Here we define the inputs and output of the algorithm, along
        with some other properties.
        """

        self.addParameter(
            QgsProcessingParameterFeatureSource(
                self.INPUT,
                self.tr("Clipping layer"),
                [QgsProcessing.TypeVectorPolygon],
            )
        )


        self.addParameter(
            QgsProcessingParameterVectorDestination(
                self.OUTPUT, self.tr("Output layer")
            )
        )

        self.addParameter(
            QgsProcessingParameterNumber(
                name=self.INTERVAL,
                description=self.tr("Contouring interval (m)"),
                type=QgsProcessingParameterNumber.Double,
                defaultValue=10,
                optional=False,
                minValue=1,
                maxValue=10000,
            )
        )


        self.addParameter(
            QgsProcessingParameterNumber(
                name=self.OFFSET,
                description=self.tr("Contouring offset (m)"),
                type=QgsProcessingParameterNumber.Double,
                defaultValue=0,
                optional=False
            )
        )


    def processAlgorithm(self, parameters, context, feedback):
        """
        Here is where the processing itself takes place.
        """

        # Retrieve the feature source and sink. The 'dest_id' variable is used
        # to uniquely identify the feature sink, and must be included in the
        # dictionary returned by the processAlgorithm function.
        source = self.parameterAsSource(parameters, self.INPUT, context)

        # If source was not found, throw an exception to indicate that the algorithm
        # encountered a fatal error. The exception text can be any string, but in this
        # case we use the pre-built invalidSourceError method to return a standard
        # helper text for when a source cannot be evaluated
        if source is None:
            raise QgsProcessingException(
                self.invalidSourceError(parameters, self.INPUT)
            )

        # Send some information to the user
        feedback.pushInfo("CRS is {}".format(source.sourceCrs().authid()))
        clipping_layer = self.parameterAsVectorLayer(parameters, self.INPUT, context)

        # read the resampling resolution
        interval = self.parameterAsDouble(parameters, self.INTERVAL, context)
        offset = self.parameterAsDouble(parameters, self.OFFSET, context)
    
        if clipping_layer.crs().isGeographic():
            raise QgsProcessingException("WAsP files cannot be saved in geographic coordinates! Please choose a projection with a metric coordinate system!")

        nlines = clipping_layer.featureCount()
        if nlines > 1:
            raise QgsProcessingException(
                f"The clipping layer must consist of a single feature only to avoid downloading too many files! Yours contains {nlines}"
            )

        clipping_layer_features = clipping_layer.getFeatures()
        for f in clipping_layer_features:
            geom = f.geometry()
            size = geom.area()
            if size > 3.6E9:
                raise QgsProcessingException(
                    "The clipping layer must be smaller than 60x60 km!"
                )

        # we slightly extend the mask layer so we have enough sample even in corner pixels
        # when we resample to a more coarse resolution
        buffered_layer = processing.run(
        "native:buffer",
        {
            "DISSOLVE": False,
            "DISTANCE": 90,
            "END_CAP_STYLE": 2,
            "INPUT": clipping_layer,
            "JOIN_STYLE": 0,
            "MITER_LIMIT": 2,
            "OUTPUT": "TEMPORARY_OUTPUT",
            "SEGMENTS": 5,
        },
        context=context,
        feedback=feedback,
        )["OUTPUT"]

        # load SRTM tiles
        feedback.pushInfo('Opening SRTM tiles...')
        tiles = QgsRasterLayer("/vsicurl/https://api.globalwindatlas.info/cogs/srtmGL3003_plus_viewfinder_corrected_cog.tif")

        feedback.pushInfo('Reprojecting mask layer to Worldcover projection...')
        clipping_layer_reproj = processing.run("native:reprojectlayer", {
        'INPUT': buffered_layer,
        'OUTPUT': 'TEMPORARY_OUTPUT',
        'TARGET_CRS' : tiles.crs()
        }, context=context, feedback=feedback)['OUTPUT']

        # only a single raster found just open and clip to mask
        feedback.pushInfo('Clipping raster from tiles')
        clipped_raster = processing.run("gdal:cliprasterbymasklayer", {
        'INPUT': tiles,
        'MASK': clipping_layer_reproj,
        'OUTPUT': 'TEMPORARY_OUTPUT',
        }, context=context, feedback=feedback)['OUTPUT']

        # convert 0 to slightly negative 0 so that the 0 contour line is also drawn
        alg_params = { 'BAND_A' : 1, 'FORMULA' : '(A == 0) * -0.0001 + (A != 0) * A', 'INPUT_A' : clipped_raster, 'OUTPUT' : 'TEMPORARY_OUTPUT'}
        clipped_raster_negative_zero = processing.run('gdal:rastercalculator', alg_params, context=context, feedback=feedback)['OUTPUT']

        ters = QgsRasterLayer(clipped_raster_negative_zero, "gdal")
        stats = ters.dataProvider().bandStatistics(1)
        feedback.pushInfo(f"Data range: {stats.range}")
        feedback.pushInfo(f"Data mean: {stats.mean}")
        feedback.pushInfo(f"Requested interval: {interval}")

        if stats.range > interval or (offset > stats.minimumValue and offset < stats.maximumValue):
            # convert raster to polygon vector layer
            feedback.pushInfo('Contouring raster...')
            clipped_polygons = processing.run("gdal:contour", {
            'FIELD_NAME' : 'ELEV',
            'INTERVAL' : interval,
            'OFFSET' : offset,
            'INPUT': clipped_raster_negative_zero,
            'BAND': 1,
            'OUTPUT': 'TEMPORARY_OUTPUT',
            }, context=context, feedback=feedback)['OUTPUT']

                
            feedback.pushInfo('Reprojecting mask layer to mask projection...')
            clipping_layer_reproj = processing.run("native:reprojectlayer", {
            'INPUT': clipped_polygons,
            'OUTPUT': 'TEMPORARY_OUTPUT',
            'TARGET_CRS' : clipping_layer.crs()
            }, context=context, feedback=feedback)['OUTPUT']

            
            # we slightly extended the original mask to make sure all pixels fall with the area.
            feedback.pushInfo("Clipping to bounding box...")
            clipped_buffer_removed = processing.run(
                "native:clip",
                {
                    "INPUT": clipping_layer_reproj,
                    "OUTPUT": "TEMPORARY_OUTPUT",
                    "OVERLAY": clipping_layer,
                },
                context=context,
                feedback=feedback,
            )["OUTPUT"]
        else:
            mean_elev = (0 if np.isnan(stats.mean) else stats.mean)
            feedback.pushInfo(f"The terrain variation ({stats.range} m) in the map does not exceed the contouring interval ({interval} m). A single line around the chosen area with the mean {stats.mean} of the area is chosen.")   
            clipping_line_elev = processing.run(
                "native:polygonstolines",
                {
                    "INPUT": clipping_layer,
                    "OUTPUT": "TEMPORARY_OUTPUT"
                },
                context=context,
                feedback=feedback,
            )["OUTPUT"]
            alg_params = {
                'FIELD_LENGTH': 0,
                'FIELD_NAME': 'ELEV',
                'FIELD_PRECISION': 0,
                'FIELD_TYPE': 0,
                'FORMULA': f"{mean_elev}",
                'INPUT': clipping_line_elev,
                'NEW_FIELD': True,
                'OUTPUT': 'TEMPORARY_OUTPUT',
            }
            clipped_buffer_removed = processing.run('qgis:fieldcalculator', alg_params, context=context, feedback=feedback)["OUTPUT"]

        # convert bbox polygon to lines and add as no attribute layer to elevation lines
        # this makes sure that the offshore area is also included in the bounding box
        width_map = round(clipped_buffer_removed.extent().width()-clipping_layer.extent().width(),1)
        height_map = round(clipped_buffer_removed.extent().height()-clipping_layer.extent().height(),1)
        if width_map != 0 or height_map != 0:
            feedback.pushInfo("Adding no-attribute line to make sure the map includes the offshore area.")
            clipped_line = processing.run("native:polygonstolines", {
                'INPUT':clipping_layer,
                'OUTPUT':'TEMPORARY_OUTPUT'})["OUTPUT"]
            clipped_buffer_removed.startEditing()
            prov = clipped_buffer_removed.dataProvider()        
            maxfid=clipped_buffer_removed.maximumValue(clipped_buffer_removed.fields().indexFromName('fid'))
            attributes_bb = [maxfid + 1, None]
            feats = []
            for i,f in enumerate(clipped_line.getFeatures()):
                feat = QgsFeature()
                bb = f.geometry()
                feat.setGeometry(bb)
                feat.setAttributes(attributes_bb)
                feats.append(feat)
            prov.addFeatures(feats)
            clipped_buffer_removed.updateExtents()
            clipped_buffer_removed.commitChanges()


        (sink, self.dest_id) = self.parameterAsSink(
            parameters,
            self.OUTPUT,
            context,
            clipped_buffer_removed.fields(),
            clipped_buffer_removed.wkbType(),
            clipped_buffer_removed.sourceCrs(),
        )

        if sink is None:
            raise QgsProcessingException(self.invalidSinkError(parameters, self.OUTPUT))

        # Compute the number of steps to display within the progress bar and
        # get features from source
        total = 100.0 / clipped_buffer_removed.featureCount() if source.featureCount() else 0
        features = clipped_buffer_removed.getFeatures()
        feedback.pushInfo("Number of roughness lines to write: " + str(total))
        for current, feature in enumerate(features):
            # Stop the algorithm if cancel button has been clicked
            if feedback.isCanceled():
                break

            # Add a feature in the sink
            sink.addFeature(feature, QgsFeatureSink.FastInsert)
            feedback.setProgress(int(current * total))
        feedback.pushInfo(f"nr segments {current}")

        global renamer
        renamer = Renamer('Viewfinder elevation lines')
        context.layerToLoadOnCompletionDetails(self.dest_id).setPostProcessor(renamer)

        # Return the results of the algorithm. In this case our only result is
        # the feature sink which contains the processed features, but some
        # algorithms may return multiple feature sinks, calculated numeric
        # statistics, etc. These should all be included in the returned
        # dictionary, with keys matching the feature corresponding parameter
        # or output names.
        return {self.OUTPUT: self.dest_id}


    def postProcessAlgorithm(self, context, feedback):
        processed_layer = QgsProcessingUtils.mapLayerFromString(self.dest_id, context)
        cmd_folder = os.path.split(inspect.getfile(inspect.currentframe()))[0]
        processed_layer.loadNamedStyle(os.path.join(cmd_folder,"styles","elev_lines.qml"))
        processed_layer.triggerRepaint()
        return {self.OUTPUT: self.dest_id}


class Renamer (QgsProcessingLayerPostProcessorInterface):
    def __init__(self, layer_name):
        self.name = layer_name
        super().__init__()

    def postProcessLayer(self, layer, context, feedback):
        layer.setName(self.name)