# -*- coding: utf-8 -*-

"""
/***************************************************************************
 SciPyFilters
                                 A QGIS plugin
 Filter collection implemented with SciPy
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2024-03-03
        copyright            : (C) 2024 by Florian Neukirchen
        email                : mail@riannek.de
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""

__author__ = 'Florian Neukirchen'
__date__ = '2024-03-03'
__copyright__ = '(C) 2024 by Florian Neukirchen'

# This will get replaced with a git SHA1 when you do a git archive

__revision__ = '$Format:%H$'


from scipy import ndimage
from qgis.core import (QgsProcessingParameterRasterLayer,
                       QgsProcessingParameterEnum,
                       QgsProcessingParameterNumber,
                       QgsProcessingParameterString,
                       QgsProcessingParameterDefinition,
                       QgsProcessingException,
                        )

from scipy_filters.scipy_algorithm_baseclasses import SciPyAlgorithm

from scipy_filters.ui.structure_widget import SciPyParameterStructure

from scipy_filters.ui.origin_widget import SciPyParameterOrigin

from scipy_filters.helpers import (str_to_int_or_list, 
                      check_structure, 
                      str_to_array, 
                      morphostructexamples,
                      footprintexamples)

from scipy_filters.ui.i18n import tr


class SciPyMorphologicalBaseAlgorithm(SciPyAlgorithm):
    """
    Base class for morphological filters.
    """

    ALGORITHM = 'ALGORITHM' 
    STRUCTURE = 'STRUCTURE'
    ORIGIN = 'ORIGIN'

    _groupid = 'morphological'

    def getAlgs(self):
        return [tr('Dilation'), tr('Erosion'), tr('Closing'), tr('Opening')]
    

    def insert_parameters(self, config):

        self.algorithms = self.getAlgs()
        
        self.addParameter(QgsProcessingParameterEnum(
            self.ALGORITHM,
            tr('Filter'),
            self.algorithms,
            defaultValue=0)) 
        

        struct_param = SciPyParameterStructure(
            self.STRUCTURE,
            tr('Structure'),
            defaultValue="[[0, 1, 0],\n[1, 1, 1],\n[0, 1, 0]]",
            examples=morphostructexamples,
            multiLine=True,
            to_int=True,
            optional=True,
            )

        self.addParameter(struct_param)
        
        origin_param = SciPyParameterOrigin(
            self.ORIGIN,
            tr('Origin'),
            defaultValue="0",
            optional=False,
            watch="STRUCTURE"
            )

        origin_param.setFlags(origin_param.flags() | QgsProcessingParameterDefinition.Flag.FlagAdvanced)
        
        self.addParameter(origin_param)
        
        super().insert_parameters(config)

           
    def get_parameters(self, parameters, context):
        kwargs = super().get_parameters(parameters, context)

        self.alg = self.parameterAsInt(parameters, self.ALGORITHM, context)
     
        structure = self.parameterAsString(parameters, self.STRUCTURE, context)
        kwargs['structure'] = str_to_array(structure, self._ndim)

        origin = self.parameterAsString(parameters, self.ORIGIN, context)
        kwargs['origin'] = str_to_int_or_list(origin)

        return kwargs

    def checkParameterValues(self, parameters, context): 
        dims = self.getDimsForCheck(parameters, context)

        structure = self.parameterAsString(parameters, self.STRUCTURE, context)
        ok, s, shape = check_structure(structure, dims)
        if not ok:
            return (ok, s)
        
        origin = self.parameterAsString(parameters, self.ORIGIN, context)
        origin = str_to_int_or_list(origin)

        if isinstance(origin, list):          
            if len(origin) != dims:
                return (False, tr("Origin does not match number of dimensions"))
            for i in range(dims):
                if shape[i] != 0 and not (-(shape[i] // 2) <= origin[i] <= (shape[i]-1) // 2):
                    return (False, tr("Origin out of bounds of structure"))

        return super().checkParameterValues(parameters, context)


class SciPyBinaryMorphologicalAlgorithm(SciPyMorphologicalBaseAlgorithm):
    """
    Binary morphological filters: dilation, erosion, closing, and opening. 
    Calculated with binary_dilation, 
    binary_erosion, binary_closing, binary_opening respectively from 
    `scipy.ndimage <https://docs.scipy.org/doc/scipy/reference/ndimage.html>`_.

    .. note:: No data cells within the filter radius are filled with 0.

    **Dimension** Calculate for each band separately (2D) 
    or use all bands as a 3D datacube and perform filter in 3D. 

    .. note:: bands will be the first axis of the datacube.

    **Dilation** Set pixel to maximum value of neighborhood. Remaining shapes are larger, lines are thicker.
    **Erosion** Set pixel to minimum value of neighborhood. Remaining shapes are smaller, lines are thinner.
    **Closing** Perform dilation and then erosion. Fills small holes, large shapes are preserved.
    **Opening** Perform erosion and then dilation. Removes small shapes, large shapes are preserved.

    **Structure** String representation of array. 
    Must have 2 dimensions if *dimension* is set to 2D. 
    Should have 3 dimensions if *dimension* is set to 3D, 
    but a 2D array is also excepted (a new axis is added as first 
    axis and the result is the same as calculating each band 
    seperately). Examples can be loaded with the load button. 
    For convenience (i.e. when calling from a script), 
    the following shortcuts are accepted as well: 
    "square", "cross", "cross3D", "ball", "cube".

    **Origin** Shift the origin (hotspot) of the filter.

    **Iterations** Each step of filter is repeated this number of times.
    **Border value** Valute at border of output array, defaults to 0. 
    **Mask** Optional mask layer.
    """

    ITERATIONS = 'ITERATIONS'
    MASK = 'MASK'
    BORDERVALUE = 'BORDERVALUE'

    # Overwrite constants of base class
    _name = 'binary_morphology'
    _displayname = tr('Binary dilation, erosion, closing, opening')
    _outputname = tr('Binary morphology') # If set to None, the displayname is used 

    
    # The function to be called
    def get_fct(self):
        if self.alg == 1:
            fct = ndimage.binary_erosion
        elif self.alg == 2:
            fct = ndimage.binary_closing
        elif self.alg == 3:
            fct = ndimage.binary_opening
        else:
            fct = ndimage.binary_dilation
        
        return fct

 
    def initAlgorithm(self, config):
        super().initAlgorithm(config)

        self.addParameter(QgsProcessingParameterNumber(
            self.ITERATIONS,
            tr('Iterations'),
            QgsProcessingParameterNumber.Type.Integer,
            defaultValue=1, 
            optional=True, 
            minValue=1, 
            # maxValue=100
            ))    
        
        self.addParameter(QgsProcessingParameterEnum(
            self.BORDERVALUE,
            tr('Border value (value at border of output array)'),
            ["0","1"],
            optional=True,
            defaultValue=0))
        
        self.addParameter(
            QgsProcessingParameterRasterLayer(
                self.MASK,
                tr('Mask layer'),
                optional=True,
            )
        )

    # TODO check parameters: agg only if agg_column is set, etc.
    def get_parameters(self, parameters, context):
        kwargs = super().get_parameters(parameters, context)

        self.masklayer = self.parameterAsRasterLayer(parameters, self.MASK, context)

        iterations = self.parameterAsInt(parameters, self.ITERATIONS, context)
        if iterations:
            kwargs['iterations'] = iterations

        bordervalue = self.parameterAsInt(parameters, self.BORDERVALUE, context)
        if bordervalue:
            kwargs['border_value'] = bordervalue

        self._outputname = tr('Binary ') + self.algorithms[self.alg]

        self.margin = max(kwargs['structure'].shape)
        return kwargs
    
    def createInstance(self):
        return SciPyBinaryMorphologicalAlgorithm()


class SciPyGreyMorphologicalAlgorithm(SciPyMorphologicalBaseAlgorithm):
    """
    Grey morphological filters: dilation, erosion, closing, and opening. 
    Calculated for every band with grey_dilation, 
    grey_erosion, gey_closing or grey_opening, respectively from 
    `scipy.ndimage <https://docs.scipy.org/doc/scipy/reference/ndimage.html>`_.

    .. note:: No data cells within the filter radius are filled with 0.

    **Dilation** Set pixel to maximum value of neighborhood. Remaining shapes are larger, lines are thicker.
    **Erosion** Set pixel to minimum value of neighborhood. Remaining shapes are smaller, lines are thinner.
    **Closing** Perform dilation and then erosion. Fills small holes, large shapes are preserved.
    **Opening** Perform erosion and then dilation. Removes small shapes, large shapes are preserved.

    **Structure** String representation of array. 
    Must have 2 dimensions if *dimension* is set to 2D. 
    Should have 3 dimensions if *dimension* is set to 3D, 
    but a 2D array is also excepted (a new axis is added as first 
    axis and the result is the same as calculating each band 
    seperately). Examples can be loaded with the load button. 
    For convenience (i.e. when calling from a script), 
    the following shortcuts are accepted as well: 
    "square", "cross", "cross3D", "ball", "cube".

    **Origin** Shift the origin (hotspot) of the filter.

    **Size** Size of flat and full structuring element, optional if footprint or structure is provided.
    **Border mode** determines how input is extended around 
    the edges: *Reflect* (input is extended by reflecting at the edge), 
    *Constant* (fill around the edges with a **constant value**), 
    *Nearest* (extend by replicating the nearest pixel), 
    *Mirror* (extend by reflecting about the center of last pixel), 
    *Wrap* (extend by wrapping around to the opposite edge).
    **Footprint** Positions of elements of a flat structuring element used for the filter (string representation of array, only used if checkbox is checked).
    """

    SIZE = 'SIZE'
    SIZES = 'SIZES'
    MODE = 'MODE'
    CVAL = 'CVAL'
    FOOTPRINT = 'FOOTPRINT'

    # Overwrite constants of base class
    _name = 'grey_morphology'
    _displayname = tr('Grey dilation, erosion, closing, opening')
    _outputname = tr('Grey morphology') # If set to None, the displayname is used 
    
    # The function to be called
    def get_fct(self):
        if self.alg == 1:
            fct = ndimage.grey_erosion
        elif self.alg == 2:
            fct = ndimage.grey_closing
        elif self.alg == 3:
            fct = ndimage.grey_opening
        else:
            fct = ndimage.grey_dilation
        
        return fct

 
    def initAlgorithm(self, config):
        super().initAlgorithm(config)

        size_param = QgsProcessingParameterNumber(
            self.SIZE,
            tr('Size of flat structuring element (Ignored if footprint or structure provided)'),
            QgsProcessingParameterNumber.Type.Integer,
            defaultValue=1, 
            optional=True, 
            minValue=1, 
            maxValue=20, # Large sizes are really slow
            )
        
        size_param.setFlags(size_param.flags() | QgsProcessingParameterDefinition.Flag.FlagHidden)

        self.addParameter(size_param)  

        sizes_param = QgsProcessingParameterString(
            self.SIZES,
            tr('Size'),
            defaultValue="", 
            optional=True, 
            )

        self.addParameter(sizes_param)  
        
        self.addParameter(QgsProcessingParameterEnum(
            self.MODE,
            tr('Border Mode'),
            self.mode_labels,
            defaultValue=0)) 
        
        self.addParameter(QgsProcessingParameterNumber(
            self.CVAL,
            tr('Constant value past edges for border mode "constant"'),
            QgsProcessingParameterNumber.Type.Double,
            defaultValue=0, 
            optional=True, 
            minValue=0, 
            # maxValue=100
            ))          
        
        struct_param = SciPyParameterStructure(
            self.FOOTPRINT,
            tr('Footprint array'),
            defaultValue="",
            examples=footprintexamples,
            multiLine=True,
            to_int=True,
            optional=True,
            )
        
        self.addParameter(struct_param)



    def checkParameterValues(self, parameters, context): 
        dims = self.getDimsForCheck(parameters, context)
        footprint = self.parameterAsString(parameters, self.FOOTPRINT, context)
        if footprint:
            ok, s, shape = check_structure(footprint, dims)
            if not ok:
                return (ok, tr('Footprint: ') + s)
        
        sizes = self.parameterAsString(parameters, self.SIZES, context)
        sizes = str_to_int_or_list(sizes)
        if isinstance(sizes, list):
            if len(sizes) != dims:
                return (False, tr("Sizes does not match number of dimensions"))

        return super().checkParameterValues(parameters, context)
    

    def get_parameters(self, parameters, context):
        kwargs = super().get_parameters(parameters, context)

        sizelist = list(kwargs['structure'].shape)
     
        size = self.parameterAsString(parameters, self.SIZES, context)
        size = str_to_int_or_list(size)
        if not size:
            size = self.parameterAsInt(parameters, self.SIZE, context)
            sizelist.append(size)
        else:
            sizelist.extend(size)
        kwargs['size'] = size

        footprint = self.parameterAsString(parameters, self.FOOTPRINT, context)
        if footprint.strip() != "":
            kwargs['footprint'] = str_to_array(footprint, self._ndim)
            sizelist.extend(kwargs['footprint'].shape)
        else:
            if not size:
                # Either size or footprint must be set
                kwargs['size'] = 1

        self.margin = max(sizelist)

        mode = self.parameterAsInt(parameters, self.MODE, context) 
        kwargs['mode'] = self.modes[mode]
        if kwargs['mode'] == 'wrap':
            self.wrapping = True

        cval = self.parameterAsDouble(parameters, self.CVAL, context)
        if cval:
            kwargs['cval'] = cval

        if isinstance(self, SciPyTophatAlgorithm):
            self._outputname = self.algorithms[self.alg]
        else:
            self._outputname = tr('Grey ') + self.algorithms[self.alg]

        return kwargs
    

    def createInstance(self):
        return SciPyGreyMorphologicalAlgorithm()

    
class SciPyTophatAlgorithm(SciPyGreyMorphologicalAlgorithm):
    """
    Morphological filters: black/white tophat, morphological gradient/laplace. 

    Calculated with black_tophat, 
    white_tophat, morphological_radient or morphological_laplace, respectively from 
    `scipy.ndimage <https://docs.scipy.org/doc/scipy/reference/ndimage.html>`_.

    .. note:: No data cells within the filter radius are filled with 0.

    **Dimension** Calculate for each band separately (2D) 
    or use all bands as a 3D datacube and perform filter in 3D. 
    
    .. note:: bands will be the first axis of the datacube.

    **White tophat** Difference between input raster and it's opening. Extracts white spots smaller than the structural element.
    **Black tophat** Difference between input raster and it's closing. Extracts black spots smaller than the structural element.
    **Morphological gradient** Difference between dilation and erosion.
    **Morphological laplace** Difference between internal and external gradient.

    **Structure** String representation of array. 
    Must have 2 dimensions if *dimension* is set to 2D.
    Should have 3 dimensions if *dimension* is set to 3D, 
    but a 2D array is also excepted (a new axis is added as first 
    axis and the result is the same as calculating each band 
    seperately). Examples can be loaded with the load button. 
    For convenience (i.e. when calling from a script), 
    the following shortcuts are accepted as well: 
    "square", "cross", "cross3D", "ball", "cube".

    **Origin** Shift the origin (hotspot) of the filter.

    **Size** Size of flat and full structuring element, optional if footprint or structure is provided.
    **Border mode** determines how input is extended around 
    the edges: *Reflect* (input is extended by reflecting at the edge), 
    *Constant* (fill around the edges with a **constant value**), 
    *Nearest* (extend by replicating the nearest pixel), 
    *Mirror* (extend by reflecting about the center of last pixel), 
    *Wrap* (extend by wrapping around to the opposite edge).
    
    **Footprint** Positions of elements of a flat structuring element 
    used for the filter (as string representation of array). 
    Must have 2 dimensions if *dimension* is set to 2D. 
    Should have 3 dimensions if *dimension* is set to 3D, 
    but a 2D array is also excepted (a new axis is added as first 
    axis and the result is the same as calculating each band 
    seperately). Examples can be loaded with the load button. 
    For convenience (i.e. when calling from a script), 
    the following shortcuts are accepted as well: 
    "square", "cross", "cross3D", "ball", "cube".
    """

    # Overwrite constants of base class
    _name = 'tophat'
    _displayname = tr('Tophat or morphological gradient/laplace')
    _outputname = tr('Tophat') # If set to None, the displayname is used 

    
    def getAlgs(self):
        return [tr('White Tophat'), tr('Black Tophat'), tr('Morphological Gradient'), tr('Morphological Laplace')]

    # The function to be called
    def get_fct(self):
        if self.alg == 1:
            fct = ndimage.black_tophat
        elif self.alg == 2:
            fct = ndimage.morphological_gradient
        elif self.alg == 3:
            fct = ndimage.morphological_laplace
        else:
            fct = ndimage.white_tophat
        
        return fct

    def createInstance(self):
        return SciPyTophatAlgorithm()


