# -*- coding: utf-8 -*-

"""
/***************************************************************************
 SciPyFilters
                                 A QGIS plugin
 Filter collection implemented with SciPy
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2024-03-03
        copyright            : (C) 2024 by Florian Neukirchen
        email                : mail@riannek.de
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""

__author__ = 'Florian Neukirchen'
__date__ = '2024-03-03'
__copyright__ = '(C) 2024 by Florian Neukirchen'

# This will get replaced with a git SHA1 when you do a git archive

__revision__ = '$Format:%H$'

from scipy import ndimage
from qgis.core import (QgsProcessingParameterNumber,
                       QgsProcessingParameterEnum,)

from scipy_filters.scipy_algorithm_baseclasses import SciPyAlgorithmWithMode
from scipy_filters.ui.i18n import tr

class SciPyAlgorithmWithSigma(SciPyAlgorithmWithMode):
    """
    Base class with mode and sigma for any algorithm using gaussian.
    """

    SIGMA = 'SIGMA'

    def insert_parameters(self, config):

        self.addParameter(QgsProcessingParameterNumber(
            self.SIGMA,
            tr('Sigma'),
            QgsProcessingParameterNumber.Type.Double,
            defaultValue=5, 
            optional=False, 
            minValue=0, 
            maxValue=100
            ))
        
        super().insert_parameters(config)

    
    def get_parameters(self, parameters, context):
        kwargs = super().get_parameters(parameters, context)
        kwargs['sigma'] = self.parameterAsDouble(parameters, self.SIGMA, context)
        self.margin = int(4 * kwargs['sigma'] + 1) # truncate default = 4
        return kwargs


class SciPyGaussianLaplaceAlgorithm(SciPyAlgorithmWithSigma):
    """
    Laplace filter using Gaussian second derivatives

    Calculated with gaussian_laplace from 
    `scipy.ndimage <https://docs.scipy.org/doc/scipy/reference/ndimage.html>`_.

    .. note:: No data cells within the filter radius are filled with 0.

    **Dimension** Calculate for each band separately (2D) 
    or use all bands as a 3D datacube and perform filter in 3D. 
    
    .. note:: bands will be the first axis of the datacube.

    **Sigma** Standard deviation of the gaussian filter.
    **Border mode** determines how input is extended around 
    the edges: *Reflect* (input is extended by reflecting at the edge), 
    *Constant* (fill around the edges with a **constant value**), 
    *Nearest* (extend by replicating the nearest pixel), 
    *Mirror* (extend by reflecting about the center of last pixel), 
    *Wrap* (extend by wrapping around to the opposite edge).

    **Dtype** Data type of output. Beware of clipping 
    and potential overflow errors if min/max of output does 
    not fit. Default is Float32.
    """

    # Overwrite constants of base class
    _name = 'gaussian_laplace'
    _displayname = tr('Gaussian Laplace')
    _outputname = None # If set to None, the displayname is used 
    _groupid = "edges" 

    _default_dtype = 6 # Optionally change default output dtype (value = idx of combobox)
    
    # The function to be called, to be overwritten
    def get_fct(self):
        return ndimage.gaussian_laplace
    
    def checkAndComplain(self, feedback):
        if self._outdtype in (1,2,4):
            msg = tr(f"WARNING: Output contains negative values, but output data type is unsigned integer!")
            feedback.reportError(msg, fatalError = False)

    def createInstance(self):
        return SciPyGaussianLaplaceAlgorithm()


class SciPyGaussianAlgorithm(SciPyAlgorithmWithSigma):
    """
    Gaussian filter (blur with a gaussian kernel)

    Calculated with gaussian_filter from 
    `scipy.ndimage <https://docs.scipy.org/doc/scipy/reference/ndimage.html>`_.

    .. note:: No data cells within the filter radius are filled with 0.

    **Dimension** Calculate for each band separately (2D) 
    or use all bands as a 3D datacube and perform filter in 3D. 

    .. note:: bands will be the first axis of the datacube.

    **Sigma** Standard deviation of a gaussian kernel.

    **Border mode** determines how input is extended around 
    the edges: *Reflect* (input is extended by reflecting at the edge), 
    *Constant* (fill around the edges with a **constant value**), 
    *Nearest* (extend by replicating the nearest pixel), 
    *Mirror* (extend by reflecting about the center of last pixel), 
    *Wrap* (extend by wrapping around to the opposite edge).

    **Order** Optionally use first, second or third derivative of gaussian.
    **Truncate** Radius of kernel in standard deviations.
    """

    TRUNCATE = 'TRUNCATE'
    ORDER = 'ORDER'

    # Overwrite constants of base class
    _name = 'gaussian'
    _displayname = tr('Gaussian filter (blur)')
    _outputname = tr('Gaussian') 
    _groupid = "blur" 
    
    # The function to be called, to be overwritten
    def get_fct(self):
        return ndimage.gaussian_filter
    
    def initAlgorithm(self, config):
        # Call the super function first
        # (otherwise input is not the first parameter in the GUI)
        super().initAlgorithm(config)

        self.order_options = ["0 (Gaussian)", 
                              "1 (First derivative of Gaussian)", 
                              "2 (Second derivative of Gaussian)", 
                              "3 (Third derivative of Gaussian)"]
        
        self.addParameter(QgsProcessingParameterEnum(
            self.ORDER,
            tr('Order'),
            self.order_options,
            defaultValue=0)) 
        
        self.addParameter(QgsProcessingParameterNumber(
            self.TRUNCATE,
            tr('Truncate filter at x standard deviations'),
            QgsProcessingParameterNumber.Type.Double,
            defaultValue=4, 
            optional=True, 
            minValue=1, 
            # maxValue=100
            ))    
        
    def get_parameters(self, parameters, context):
        kwargs = super().get_parameters(parameters, context)

        kwargs['order'] = self.parameterAsInt(parameters, self.ORDER, context) 
       
        truncate = self.parameterAsDouble(parameters, self.TRUNCATE, context)
        if truncate:
            kwargs['truncate'] = truncate
            self._truncate = truncate

        # Calculate margin
        if not truncate:
            truncate = 4 # default of scipy

        # See https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.gaussian_filter.html
        self.margin = int(truncate * kwargs['sigma'] + 1)

        return kwargs

   

    def createInstance(self):
        return SciPyGaussianAlgorithm()
    


class SciPyGaussianGradientMagnitudeAlgorithm(SciPyAlgorithmWithSigma):
    """
    Gradient magnitude using Gaussian derivatives. 
    Calculated with gaussian_gradient_magnitude from 
    `scipy.ndimage <https://docs.scipy.org/doc/scipy/reference/ndimage.html>`_.

    .. note:: No data cells within the filter radius are filled with 0.

    **Dimension** Calculate for each band separately (2D) 
    or use all bands as a 3D datacube and perform filter in 3D. 

    .. note:: bands will be the first axis of the datacube.

    **Sigma** Standard deviation of the gaussian filter.
    **Border mode** determines how input is extended around 
    the edges: *Reflect* (input is extended by reflecting at the edge), 
    *Constant* (fill around the edges with a **constant value**), 
    *Nearest* (extend by replicating the nearest pixel), 
    *Mirror* (extend by reflecting about the center of last pixel), 
    *Wrap* (extend by wrapping around to the opposite edge).

    **Dtype** Data type of output. Beware of clipping 
    and potential overflow errors if min/max of output does 
    not fit. Default is Float32.
    """

    # Overwrite constants of base class
    _name = 'gaussian_gradient_magnitude'
    _displayname = tr('Gaussian gradient magnitude')
    _outputname = None # If set to None, the displayname is used 
    _groupid = "edges" 

    _default_dtype = 6 # Optionally change default output dtype (value = idx of combobox)
    
    # The function to be called, to be overwritten
    def get_fct(self):
        return ndimage.gaussian_gradient_magnitude
    
    def checkAndComplain(self, feedback):
        if self._outdtype in (1,2,4):
            msg = tr(f"WARNING: Output contains negative values, but output data type is unsigned integer!")
            feedback.reportError(msg, fatalError = False)

    def createInstance(self):
        return SciPyGaussianGradientMagnitudeAlgorithm()