# -*- coding: utf-8 -*-

"""
/***************************************************************************
 SciPyFilters
                                 A QGIS plugin
 Filter collection implemented with SciPy
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2024-03-03
        copyright            : (C) 2024 by Florian Neukirchen
        email                : mail@riannek.de
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""

__author__ = 'Florian Neukirchen'
__date__ = '2024-03-03'
__copyright__ = '(C) 2024 by Florian Neukirchen'

# This will get replaced with a git SHA1 when you do a git archive

__revision__ = '$Format:%H$'

from osgeo import gdal
from scipy import ndimage
import numpy as np
import json
import enum
from qgis.PyQt.QtCore import QCoreApplication
from qgis.core import (QgsProcessing,
                       QgsProcessingAlgorithm,
                       QgsProcessingParameterRasterLayer,
                       QgsProcessingParameterNumber,
                       QgsProcessingParameterRasterDestination,
                       QgsProcessingParameterEnum,
                       QgsProcessingParameterBand,
                       QgsProcessingParameterString,
                       QgsProcessingLayerPostProcessorInterface,
                       QgsProcessingParameterBoolean,
                       QgsProcessingParameterDefinition,
                       QgsProcessingException,
                        )

from .ui.sizes_widget import (SizesWidgetWrapper)
from .ui.dim_widget import (DimsWidgetWrapper, SciPyParameterDims)
from .ui.structure_widget import (StructureWidgetWrapper, 
                                  SciPyParameterStructure,)



from .helpers import (array_to_str, 
                      str_to_int_or_list, 
                      check_structure, str_to_array, 
                      footprintexamples)

# Group IDs and group names
groups = {
    'edges': 'Edges',
    'morphological': "Morphological Filters",
    'statistic': 'Statistical Filters',
    'blur': 'Blur',
    'convolution': 'Convolution',
    'enhance': 'Enhance',
}


class SciPyAlgorithm(QgsProcessingAlgorithm):
    """
    Lowest level base class for algorithms based on SciPy.

    The function to be called is returned by the function
    get_fct(), needs to be overwritten by inheriting classes.
    Name, displayname, help are set as class variables, to be 
    overwritten by inheriting classes.

    In inheriting classes, parameters can be added by:
    1) setting constants as class variables
    2) overwriting either insert_parameters (added below input parameter)
       or initAlgorithm (added above output parameter). 
       Don't forget to call the same function on super().
    3) overwriting the function get_parameters (don't forget to call
       the same function on super())

    Inheriting classes must implement createInstance, returning an
    instance of the class, e.g.:

    def createInstance(self):
        return SciPyGaussianAlgorithm()
    """

    # Constants used to refer to parameters and outputs. They will be
    # used when calling the algorithm from another algorithm, or when
    # calling from the QGIS console.

    OUTPUT = 'OUTPUT'
    INPUT = 'INPUT'
    DIMENSION = 'DIMENSION'
    
    # The following constants are supposed to be overwritten
    _name = 'name, short, lowercase without spaces'
    _displayname = 'User-visible name'
    # Output layer name: If set to None, the displayname is used 
    # Can be changed while getting the parameters.
    _outputname = None 
    _groupid = "" 
    _help = """
            Help
            """
    

    modes = ['reflect', 'constant', 'nearest', 'mirror', 'wrap']

    # Return the function to be called, to be overwritten
    def get_fct(self):
        return ndimage.laplace

    
    # Dimensions the algorithm is working on. 
    # The numbers match the index in the list of GUI strings (below).
    # Default during init is nD (for algorithms based on scipy.ndimage),
    # giving users a choice (and calling from a script without setting 
    # DIMENSIONS defaults to 2D in this case). This allows to write filters
    # that are not working in n dimensions.
    # Calling from a script, DIMENSION must be in (0,1,2) and match the 
    # values of the enum!
    class Dimensions(enum.Enum):
        nD = 2         # users can decide between 1D, 2D, 3D
        twoD = 0       # Seperate for each band
        threeD = 1     # 3D filter in data cube

    _dimension = Dimensions.nD
    _ndim = None # to be set while getting parameters

    # Strings for the GUI
    _dimension_options = ['2D (Separate for each band)',
                          '3D (All bands as a 3D data cube)']

    # Function to insert Parameters (overwrite in inherited classes)
    def insert_parameters(self, config):
        return

    # Init Algorithm
    def initAlgorithm(self, config):
        """
        Here we define the inputs and output of the algorithm, along
        with some other properties.
        """

        # Some Algorithms will add a masklayer
        self.masklayer = None

        # Add parameters
        self.addParameter(
            QgsProcessingParameterRasterLayer(
                self.INPUT,
                self.tr('Input layer'),
            )
        )

        if self._dimension == self.Dimensions.nD:
            # Algorithms based on scipy.ndimage can have any number of dimensions
            dim_param = SciPyParameterDims(
                self.DIMENSION,
                self.tr('Dimension'),
                self._dimension_options,
                defaultValue=0,
                optional=False,)

            dim_param.setMetadata({
                'widget_wrapper': {
                    'class': DimsWidgetWrapper
                }
            })

            self.addParameter(dim_param)


        # Insert Parameters 
        self.insert_parameters(config)

        if not self._outputname:
            self._outputname = self._displayname

        self.addParameter(
            QgsProcessingParameterRasterDestination(
                self.OUTPUT,
            self.tr(self._outputname)))
        
    def get_parameters(self, parameters, context):
        """
        Factoring this out of processAlgorithm allows to add Parameters
        in classes inheriting form this base class by overwriting this
        function. But always call kwargs = super().get_parameters(...)
        first!

        Returns kwargs dictionary and sets variables self.variable for 
        non-keyword arguments.

        This is the most basic base class and kwargs is empty {}.
        """
        self.inputlayer = self.parameterAsRasterLayer(parameters, self.INPUT, context)
        self.output_raster = self.parameterAsOutputLayer(parameters, self.OUTPUT,context)

        if self._dimension == self.Dimensions.nD:
            dimension = self.parameterAsInt(parameters, self.DIMENSION, context)
            if dimension == 1:
                self._dimension = self.Dimensions.threeD
                self._ndim = 3
            else:
                # Default to 2D
                self._dimension = self.Dimensions.twoD
                self._ndim = 2

        return {}


    def processAlgorithm(self, parameters, context, feedback):
        """
        Here is where the processing itself takes place.
        """

        # Get Parameters
        kwargs = self.get_parameters(parameters, context)
        # print("kwargs\n", kwargs)

        self.fct = self.get_fct()

        # Open Raster with GDAL
        self.ds = gdal.Open(self.inputlayer.source())

        if not self.ds:
            raise Exception("Failed to open Raster Layer")
        
        self.bandcount = self.ds.RasterCount

        # Set to 2D if layer has only one band
        if self.bandcount == 1:
            _dimension = self.Dimensions.twoD

        # Eventually open mask layer 
        if self.masklayer:
            self.mask_ds = gdal.Open(self.masklayer.source())
            if not self.mask_ds:
                raise Exception(self.tr("Failed to open Mask Layer"))
            
            # Mask must have same crs etc.
            if not (self.mask_ds.GetProjection() == self.ds.GetProjection()
                    and self.mask_ds.RasterXSize == self.ds.RasterXSize
                    and self.mask_ds.RasterYSize == self.ds.RasterYSize
                    and self.mask_ds.GetGeoTransform() == self.ds.GetGeoTransform()):
                feedback.pushInfo("Mask layer does not match input layer, reprojecting mask.")

                geoTransform = self.ds.GetGeoTransform()

                kwargs_w = {"format": "GTiff", 'resampleAlg':'near'}
                kwargs_w["xRes"] = geoTransform[1]
                kwargs_w["yRes"] = abs(geoTransform[5])

                minx = geoTransform[0]
                maxy = geoTransform[3]
                maxx = minx + geoTransform[1] * self.ds.RasterXSize
                miny = maxy + geoTransform[5] * self.ds.RasterYSize

                kwargs_w["outputBounds"] = (minx, miny, maxx, maxy)

                warped_mask = gdal.Warp("/vsimem/tmpmask", self.mask_ds, **kwargs_w)
                kwargs['mask'] = warped_mask.GetRasterBand(1).ReadAsArray()
            else:
                feedback.pushInfo("Mask layer does match input layer.")
                kwargs['mask'] = self.mask_ds.GetRasterBand(1).ReadAsArray()


        # Prepare output
        driver = gdal.GetDriverByName('GTiff')
        self.out_ds = driver.CreateCopy(self.output_raster, self.ds, strict=0)

        if feedback.isCanceled():
            return {}
        
        feedback.setProgress(0)

        # Start the actual work

        if self._dimension == self.Dimensions.twoD:
            # Iterate over bands and calculate 
            for i in range(1, self.bandcount + 1):
                a = self.ds.GetRasterBand(i).ReadAsArray()

                # The actual function
                filtered = self.fct(a, **kwargs)

                self.out_ds.GetRasterBand(i).WriteArray(filtered)

                feedback.setProgress(i * 100 / self.bandcount)
                if feedback.isCanceled():
                    return {}
                
        elif self._dimension == self.Dimensions.threeD:
            a = self.ds.ReadAsArray()

            # The actual function
            filtered = self.fct(a, **kwargs)

            self.out_ds.WriteArray(filtered)            

        # Calculate and write band statistics (min, max, mean, std)
        for b in range(1, self.bandcount + 1):
            band = self.out_ds.GetRasterBand(b)
            stats = band.GetStatistics(0,1)
            band.SetStatistics(*stats)

        # Close the dataset to write file to disk
        self.out_ds = None 

        feedback.setProgress(100)

        # Optionally rename the output layer
        if self._outputname:
            global renamer
            renamer = self.Renamer(self._outputname)
            context.layerToLoadOnCompletionDetails(self.output_raster).setPostProcessor(renamer)


        return {self.OUTPUT: self.output_raster}


    def checkParameterValues(self, parameters, context):
        dim_option = self.parameterAsInt(parameters, self.DIMENSION, context)
        layer = self.parameterAsRasterLayer(parameters, self.INPUT, context)
        # 3D only possible with more than 1 bands
        if dim_option == 1 and layer.bandCount() == 1:
            return (False, self.tr("3D only possible if input layer has more than 1 bands"))
            
        return super().checkParameterValues(parameters, context)
    
    @property
    def dims(self):
        d = 2
        if self._dimension == self.Dimensions.threeD:
            d = 3
        return d

    def getDimsForCheck(self, parameters, context):
        dims = 2
        if self._dimension == self.Dimensions.nD:
            dim_option = self.parameterAsInt(parameters, self.DIMENSION, context)
            if dim_option == 1:
                dims = 3
        return dims


    class Renamer(QgsProcessingLayerPostProcessorInterface):
        """
        To rename output layer name in the postprocessing step.
        """
        def __init__(self, layer_name):
            self.name = layer_name
            super().__init__()
            
        def postProcessLayer(self, layer, context, feedback):
            layer.setName(self.name)

    def name(self):
        """
        Returns the algorithm name, used for identifying the algorithm. This
        string should be fixed for the algorithm, and must not be localised.
        The name should be unique within each provider. Names should contain
        lowercase alphanumeric characters only and no spaces or other
        formatting characters.
        """
        return self._name

    def displayName(self):
        """
        Returns the translated algorithm name, which should be used for any
        user-visible display of the algorithm name.
        """
        return self.tr(self._displayname)

    def group(self):
        """
        Returns the name of the group this algorithm belongs to. This string
        should be localised.
        """
        if self._groupid == "":
            return ""
        s = groups.get(self._groupid)
        if not s:
            # If group ID is not in dictionary group, return error message for debugging
            return "Displayname of group must be set in groups dictionary"
        return self.tr(s)

    def groupId(self):
        """
        Returns the unique ID of the group this algorithm belongs to. This
        string should be fixed for the algorithm, and must not be localised.
        The group id should be unique within each provider. Group id should
        contain lowercase alphanumeric characters only and no spaces or other
        formatting characters.
        """
        return self._groupid
    
    def shortHelpString(self):
        """
        Returns the help string that is shown on the right side of the 
        user interface.
        """
        return self.tr(self._help)

    def tr(self, string):
        return QCoreApplication.translate('Processing', string)



class SciPyAlgorithmWithMode(SciPyAlgorithm):
    """
    Base class with added mode and cval; used by laplace etc.
    """

    MODE = 'MODE'
    CVAL = 'CVAL'

    def initAlgorithm(self, config):
        # Call the super function first
        # (otherwise input is not the first parameter in the GUI)
        super().initAlgorithm(config)

        self.addParameter(QgsProcessingParameterEnum(
            self.MODE,
            self.tr('Border Mode'),
            [mode.capitalize() for mode in self.modes],
            defaultValue=0)) 
        
        self.addParameter(QgsProcessingParameterNumber(
            self.CVAL,
            self.tr('Constant value past edges for border mode "constant"'),
            QgsProcessingParameterNumber.Type.Double,
            defaultValue=0, 
            optional=True, 
            minValue=0, 
            # maxValue=100
            ))      
    
    def get_parameters(self, parameters, context):
        kwargs = super().get_parameters(parameters, context)

        mode = self.parameterAsInt(parameters, self.MODE, context) 
        kwargs['mode'] = self.modes[mode]

        cval = self.parameterAsDouble(parameters, self.CVAL, context)
        if cval:
            kwargs['cval'] = cval

        return kwargs
    

class SciPyAlgorithmWithModeAxis(SciPyAlgorithmWithMode):
    """
    Base class with added mode and cval and axis; used by sobel etc.
    """

    AXIS = 'AXIS'
    axis_modes = ['Horizontal edges', 'Vertical edges', 'Band axis edges', 'Magnitude']

    def insert_parameters(self, config):
        
        self.addParameter(QgsProcessingParameterEnum(
            self.AXIS,
            self.tr('Axis'),
            self.axis_modes,
            defaultValue=0)) 
        
        super().insert_parameters(config)
           
    def get_parameters(self, parameters, context):
        """Axis parameter must be set in inheriting class to implement magnitude"""
        kwargs = super().get_parameters(parameters, context)

        self.axis_mode = self.parameterAsInt(parameters, self.AXIS, context) 

        self.axis = -1

        if self.axis_mode == 0:
            self.axis = -2
        if self.axis_mode == 1:
            self.axis = -1
        if self.axis_mode == 2 and self._dimension == self.Dimensions.threeD:
            self.axis = -3

        return kwargs
    
    def checkParameterValues(self, parameters, context):
        dim_option = self.parameterAsInt(parameters, self.DIMENSION, context)
        if dim_option == 0: # 2D
            axis_mode = self.parameterAsInt(parameters, self.AXIS, context)
            if axis_mode == 2: # Band (not in 2D)
                return (False, self.tr("Band axis not possible in 2D case"))
            
        return super().checkParameterValues(parameters, context)
    


class SciPyStatisticalAlgorithm(SciPyAlgorithmWithMode):
    """
    Base class for median, minimum etc.
    """
    SIZE = 'SIZE'
    SIZES = 'SIZES'
    FOOTPRINT = 'FOOTPRINT'

    def initAlgorithm(self, config):
        super().initAlgorithm(config)

        size_param = QgsProcessingParameterNumber(
            self.SIZE,
            self.tr('Size of flat structuring element (either size or footprint must be given, with footprint, size is ignored)'),
            QgsProcessingParameterNumber.Type.Integer,
            defaultValue=1, 
            optional=True, 
            minValue=1, 
            maxValue=20, # Large sizes are really slow
            )
        
        size_param.setFlags(size_param.flags() | QgsProcessingParameterDefinition.Flag.FlagHidden)

        self.addParameter(size_param)


        sizes_param = QgsProcessingParameterString(
            self.SIZES,
            self.tr('Size'),
            defaultValue="", 
            optional=True, 
            )
        
        sizes_param.setMetadata({
            'widget_wrapper': {
                'class': SizesWidgetWrapper
            }
        })

        self.addParameter(sizes_param)   

        footprint_param = SciPyParameterStructure(
            self.FOOTPRINT,
            self.tr('Footprint array'),
            defaultValue="",
            examples=footprintexamples,
            multiLine=True,
            optional=True,
            to_int=True,
            )
                
        footprint_param.setMetadata({
            'widget_wrapper': {
                'class': StructureWidgetWrapper
            }
        })

        self.addParameter(footprint_param)


    def checkParameterValues(self, parameters, context): 
        footprint = self.parameterAsString(parameters, self.FOOTPRINT, context)

        dims = self.getDimsForCheck(parameters, context)


        ok, s = check_structure(footprint, dims)
        if not ok:
            s = 'Footprint: ' + s
            return (ok, self.tr(s))

        # Extra check for rank_filter: rank must be < as footprint size
        # It is easier to do it here as we already have the footprint checked
            
        from .scipy_statistical_algorithms import SciPyRankAlgorithm 

        if isinstance(self, SciPyRankAlgorithm):
            rank = self.parameterAsInt(parameters, self.RANK, context)
            if footprint:
                footprint = str_to_array(footprint, dims=dims)
                footprintsize = footprint.size
            else:
                size = self.parameterAsInt(parameters, self.SIZE, context)
                footprintsize = np.power(size, dims)
            if rank >= footprintsize:
                return (False, self.tr('Rank must be smaller than the size of the footprint'))
            
        return super().checkParameterValues(parameters, context)
    

    def get_parameters(self, parameters, context):
        kwargs = super().get_parameters(parameters, context)

        sizes = self.parameterAsString(parameters, self.SIZES, context)
        if sizes:
            size = str_to_int_or_list(sizes)
        else:
            size = self.parameterAsInt(parameters, self.SIZE, context)
        if not size:
            # Just in case it is called from python and neither size or sizes or footprint is set
            size = 3
        kwargs['size'] = size


        footprint = self.parameterAsString(parameters, self.FOOTPRINT, context)
        if footprint:
            kwargs['footprint'] = str_to_array(footprint, self._ndim)
        else:
            if not size:
                # Either size or footprint must be set
                kwargs['size'] = 1

        mode = self.parameterAsInt(parameters, self.MODE, context) 
        kwargs['mode'] = self.modes[mode]

        cval = self.parameterAsDouble(parameters, self.CVAL, context)
        if cval:
            kwargs['cval'] = cval

        return kwargs
    

