# -*- coding: utf-8 -*-

"""
/***************************************************************************
 SciPyFilters
                                 A QGIS plugin
 Filter collection implemented with SciPy
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2024-03-03
        copyright            : (C) 2024 by Florian Neukirchen
        email                : mail@riannek.de
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""

__author__ = 'Florian Neukirchen'
__date__ = '2024-03-03'
__copyright__ = '(C) 2024 by Florian Neukirchen'

# This will get replaced with a git SHA1 when you do a git archive

__revision__ = '$Format:%H$'

import json
import numpy as np
from osgeo import gdal
from scipy import ndimage
from qgis.PyQt.QtCore import QCoreApplication
from qgis.core import (QgsProcessing,
                       QgsProcessingAlgorithm,
                       QgsProcessingParameterRasterLayer,
                       QgsProcessingParameterEnum,
                       QgsProcessingParameterNumber,
                       QgsProcessingParameterRasterDestination,
                       QgsProcessingParameterString,
                       QgsProcessingParameterBoolean,
                       QgsProcessingException,
                        )

from .scipy_algorithm_baseclasses import SciPyAlgorithm

from .ui.structure_widget import (StructureWidgetWrapper, 
                                  SciPyParameterStructure,)

from .helpers import (array_to_str, 
                      str_to_int_or_list, 
                      check_structure, 
                      str_to_array, 
                      morphostructexamples,
                      footprintexamples)

class SciPyMorphologicalBaseAlgorithm(SciPyAlgorithm):
    """
    Base class for morphological filters.
    """

    ALGORITHM = 'ALGORITHM' 
    STRUCTURE = 'STRUCTURE'

    _groupid = 'morphological'

    def getAlgs(self):
        return ['Dilation', 'Erosion', 'Closing', 'Opening']
    

    def insert_parameters(self, config):

        self.algorithms = self.getAlgs()
        
        self.addParameter(QgsProcessingParameterEnum(
            self.ALGORITHM,
            self.tr('Filter'),
            self.algorithms,
            defaultValue=0)) 
        

        struct_param = SciPyParameterStructure(
            self.STRUCTURE,
            self.tr('Structure'),
            defaultValue="[[0, 1, 0],\n[1, 1, 1],\n[0, 1, 0]]",
            examples=morphostructexamples,
            multiLine=True,
            to_int=True,
            optional=True,
            )
        
        struct_param.setMetadata({
            'widget_wrapper': {
                'class': StructureWidgetWrapper
            }
        })

        self.addParameter(struct_param)
        
        super().insert_parameters(config)

           
    def get_parameters(self, parameters, context):
        kwargs = super().get_parameters(parameters, context)

        self.alg = self.parameterAsInt(parameters, self.ALGORITHM, context)
     
        structure = self.parameterAsString(parameters, self.STRUCTURE, context)
        kwargs['structure'] = str_to_array(structure, self._ndim)

        return kwargs

    def checkParameterValues(self, parameters, context): 
        dims = self.getDimsForCheck(parameters, context)

        structure = self.parameterAsString(parameters, self.STRUCTURE, context)
        ok, s = check_structure(structure, dims)
        if not ok:
            return (ok, s)
        
        return super().checkParameterValues(parameters, context)


class SciPyBinaryMorphologicalAlgorithm(SciPyMorphologicalBaseAlgorithm):

    ITERATIONS = 'ITERATIONS'
    MASK = 'MASK'
    BORDERVALUE = 'BORDERVALUE'

    # Overwrite constants of base class
    _name = 'binary_morphology'
    _displayname = 'Binary dilation, erosion, closing, opening'
    _outputname = 'Binary morphology' # If set to None, the displayname is used 
    _help = """
            Binary morphological filters: dilation, erosion, closing, and opening. \
            Calculated with binary_dilation, \
            binary_erosion, binary_closing, binary_opening respectively from \
            <a href="https://docs.scipy.org/doc/scipy/reference/ndimage.html">scipy.ndimage</a>.

            <b>Dimension</b> Calculate for each band separately (2D) \
            or use all bands as a 3D datacube and perform filter in 3D. \
            Note: bands will be the first axis of the datacube.

            <b>Dilation</b> Set pixel to maximum value of neighborhood. Remaining shapes are larger, lines are thicker.
            <b>Erosion</b> Set pixel to minimum value of neighborhood. Remaining shapes are smaller, lines are thinner.
            <b>Closing</b> Perform dilation and then erosion. Fills small holes, large shapes are preserved.
            <b>Opening</b> Perform erosion and then dilation. Removes small shapes, large shapes are preserved.
            

            <b>Structure</b> String representation of array. \
            Must have 2 dimensions if <i>dimension</i> is set to 2D. \
            Should have 3 dimensions if <i>dimension</i> is set to 3D, \
            but a 2D array is also excepted (a new axis is added as first \
            axis and the result is the same as calculating each band \
            seperately). Examples can be loaded with the load button. \
            For convenience (i.e. when calling from a script), \
            the following shortcuts are accepted as well: \
            "square", "cross", "cross3D", "ball", "cube".

            <b>Iterations</b> Each step of filter is repeated this number of times.
            <b>Border value</b> Valute at border of output array, defaults to 0. 
            <b>Mask</b> Optional mask layer.
            """
    
    # The function to be called
    def get_fct(self):
        if self.alg == 1:
            fct = ndimage.binary_erosion
        elif self.alg == 2:
            fct = ndimage.binary_closing
        elif self.alg == 3:
            fct = ndimage.binary_opening
        else:
            fct = ndimage.binary_dilation
        
        return fct

 
    def initAlgorithm(self, config):
        super().initAlgorithm(config)

        self.addParameter(QgsProcessingParameterNumber(
            self.ITERATIONS,
            self.tr('Iterations'),
            QgsProcessingParameterNumber.Type.Integer,
            defaultValue=1, 
            optional=True, 
            minValue=1, 
            # maxValue=100
            ))    
        
        self.addParameter(QgsProcessingParameterEnum(
            self.BORDERVALUE,
            self.tr('Border value (value at border of output array)'),
            ["0","1"],
            optional=True,
            defaultValue=0))
        
        self.addParameter(
            QgsProcessingParameterRasterLayer(
                self.MASK,
                self.tr('Mask layer'),
                optional=True,
            )
        )

    def get_parameters(self, parameters, context):
        kwargs = super().get_parameters(parameters, context)

        self.masklayer = self.parameterAsRasterLayer(parameters, self.MASK, context)

        iterations = self.parameterAsInt(parameters, self.MASK, context)
        if iterations:
            kwargs['iterations'] = iterations

        bordervalue = self.parameterAsInt(parameters, self.BORDERVALUE, context)
        if bordervalue:
            kwargs['border_value'] = bordervalue

        self._outputname = 'Binary ' + self.algorithms[self.alg]

        return kwargs
    
    def createInstance(self):
        return SciPyBinaryMorphologicalAlgorithm()


class SciPyGreyMorphologicalAlgorithm(SciPyMorphologicalBaseAlgorithm):

    SIZE = 'SIZE'
    MODE = 'MODE'
    CVAL = 'CVAL'
    FOOTPRINT = 'FOOTPRINT'

    # Overwrite constants of base class
    _name = 'grey_morphology'
    _displayname = 'Grey dilation, erosion, closing, opening'
    _outputname = 'Grey morphology' # If set to None, the displayname is used 
    _help = """
            Grey morphological filters: dilation, erosion, closing, and opening. \
            Calculated for every band with grey_dilation, \
            grey_erosion, gey_closing or grey_opening, respectively from \
            <a href="https://docs.scipy.org/doc/scipy/reference/ndimage.html">scipy.ndimage</a>.

            <b>Dilation</b> Set pixel to maximum value of neighborhood. Remaining shapes are larger, lines are thicker.
            <b>Erosion</b> Set pixel to minimum value of neighborhood. Remaining shapes are smaller, lines are thinner.
            <b>Closing</b> Perform dilation and then erosion. Fills small holes, large shapes are preserved.
            <b>Opening</b> Perform erosion and then dilation. Removes small shapes, large shapes are preserved.

            <b>Structure</b> Structuring element of filter, can be cross, square or custom in 2D; \
            or cross, ball or cube in 3D. 

            <b>Custom structure</b> String representation of array, only used if "Structure" is set to "Custom". \
            Must have 2 dimensions if <i>dimension</i> is set to 2D. \
            Should have 3 dimensions if <i>dimension</i> is set to 3D, \
            but a 2D array is also excepted (a new axis is added as first \
            axis and the result is the same as calculating each band \
            seperately).

            <b>Size</b> Size of flat and full structuring element, optional if footprint or structure is provided.
            <b>Border mode</b> determines how input is extended around \
            the edges: <i>Reflect</i> (input is extended by reflecting at the edge), \
            <i>Constant</i> (fill around the edges with a <b>constant value</b>), \
            <i>Nearest</i> (extend by replicating the nearest pixel), \
            <i>Mirror</i> (extend by reflecting about the center of last pixel), \
            <i>Wrap</i> (extend by wrapping around to the opposite edge).
            <b>Footprint</b> Positions of elements of a flat structuring element used for the filter (string representation of array, only used if checkbox is checked).
            """
    
    # The function to be called
    def get_fct(self):
        if self.alg == 1:
            fct = ndimage.grey_erosion
        elif self.alg == 2:
            fct = ndimage.grey_closing
        elif self.alg == 3:
            fct = ndimage.grey_opening
        else:
            fct = ndimage.grey_dilation
        
        return fct

 
    def initAlgorithm(self, config):
        super().initAlgorithm(config)

        self.addParameter(QgsProcessingParameterNumber(
            self.SIZE,
            self.tr('Size of flat structuring element (Ignored if footprint or structure provided)'),
            QgsProcessingParameterNumber.Type.Integer,
            defaultValue=0, 
            optional=True, 
            minValue=0, 
            # maxValue=100
            ))    
        
        self.addParameter(QgsProcessingParameterEnum(
            self.MODE,
            self.tr('Border Mode'),
            [mode.capitalize() for mode in self.modes],
            defaultValue=0)) 
        
        self.addParameter(QgsProcessingParameterNumber(
            self.CVAL,
            self.tr('Constant value past edges for border mode "constant"'),
            QgsProcessingParameterNumber.Type.Double,
            defaultValue=0, 
            optional=True, 
            minValue=0, 
            # maxValue=100
            ))          
        
        struct_param = SciPyParameterStructure(
            self.FOOTPRINT,
            self.tr('Footprint array'),
            defaultValue="[[1, 1, 1],\n[1, 1, 1],\n[1, 1, 1]]",
            examples=footprintexamples,
            multiLine=True,
            to_int=True,
            optional=True,
            )
        
        struct_param.setMetadata({
            'widget_wrapper': {
                'class': StructureWidgetWrapper
            }
        })

        self.addParameter(struct_param)



    def checkParameterValues(self, parameters, context): 
        footprint = self.parameterAsString(parameters, self.FOOTPRINT, context)
        if footprint:
            dims = self.getDimsForCheck(parameters, context)

            ok, s = check_structure(footprint, dims)
            if not ok:
                return (ok, self.tr('Footprint: ' + s))
        
        return super().checkParameterValues(parameters, context)
    

    def get_parameters(self, parameters, context):
        kwargs = super().get_parameters(parameters, context)

        size = self.parameterAsInt(parameters, self.SIZE, context)
        if size:
            kwargs['size'] = size
        footprint = self.parameterAsString(parameters, self.FOOTPRINT, context)
        if footprint.strip() != "":
            kwargs['footprint'] = str_to_array(footprint, self._ndim)
        else:
            if not size:
                # Either size or footprint must be set
                kwargs['size'] = 1

        mode = self.parameterAsInt(parameters, self.MODE, context) 
        kwargs['mode'] = self.modes[mode]

        cval = self.parameterAsDouble(parameters, self.CVAL, context)
        if cval:
            kwargs['cval'] = cval

        if isinstance(self, SciPyTophatAlgorithm):
            self._outputname = self.algorithms[self.alg]
        else:
            self._outputname = 'Grey ' + self.algorithms[self.alg]

        return kwargs
    

    def createInstance(self):
        return SciPyGreyMorphologicalAlgorithm()

    
class SciPyTophatAlgorithm(SciPyGreyMorphologicalAlgorithm):

    # Overwrite constants of base class
    _name = 'tophat'
    _displayname = 'Tophat or morphological gradient/laplace'
    _outputname = 'Tophat' # If set to None, the displayname is used 
    _help = """
            Morphological filters: black/white tophat, morphological gradient/laplace. \
            Calculated with black_tophat, \
            white_tophat, morphological_radient or morphological_laplace, respectively from \
            <a href="https://docs.scipy.org/doc/scipy/reference/ndimage.html">scipy.ndimage</a>.

            <b>Dimension</b> Calculate for each band separately (2D) \
            or use all bands as a 3D datacube and perform filter in 3D. \
            Note: bands will be the first axis of the datacube.

            <b>White tophat</b> Difference between input raster and it's opening. Extracts white spots smaller than the structural element.
            <b>Black tophat</b> Difference between input raster and it's closing. Extracts black spots smaller than the structural element.
            <b>Morphological gradient</b> Difference between dilation and erosion.
            <b>Morphological laplace</b> Difference between internal and external gradient.

            <b>Structure</b> String representation of array. \
            Must have 2 dimensions if <i>dimension</i> is set to 2D. \
            Should have 3 dimensions if <i>dimension</i> is set to 3D, \
            but a 2D array is also excepted (a new axis is added as first \
            axis and the result is the same as calculating each band \
            seperately). Examples can be loaded with the load button. \
            For convenience (i.e. when calling from a script), \
            the following shortcuts are accepted as well: \
            "square", "cross", "cross3D", "ball", "cube".

            <b>Size</b> Size of flat and full structuring element, optional if footprint or structure is provided.
            <b>Border mode</b> determines how input is extended around \
            the edges: <i>Reflect</i> (input is extended by reflecting at the edge), \
            <i>Constant</i> (fill around the edges with a <b>constant value</b>), \
            <i>Nearest</i> (extend by replicating the nearest pixel), \
            <i>Mirror</i> (extend by reflecting about the center of last pixel), \
            <i>Wrap</i> (extend by wrapping around to the opposite edge).
            
            <b>Footprint</b> Positions of elements of a flat structuring element \
            used for the filter (as string representation of array). \
            Must have 2 dimensions if <i>dimension</i> is set to 2D. \
            Should have 3 dimensions if <i>dimension</i> is set to 3D, \
            but a 2D array is also excepted (a new axis is added as first \
            axis and the result is the same as calculating each band \
            seperately). Examples can be loaded with the load button. \
            For convenience (i.e. when calling from a script), \
            the following shortcuts are accepted as well: \
            "square", "cross", "cross3D", "ball", "cube".
            """
    
    def getAlgs(self):
        return ['White Tophat', 'Black Tophat', 'Morphological Gradient', 'Morphological Laplace']

    # The function to be called
    def get_fct(self):
        if self.alg == 1:
            fct = ndimage.black_tophat
        elif self.alg == 2:
            fct = ndimage.morphological_gradient
        elif self.alg == 3:
            fct = ndimage.morphological_laplace
        else:
            fct = ndimage.white_tophat
        
        return fct

    def createInstance(self):
        return SciPyTophatAlgorithm()


