# -*- coding: utf-8 -*-

"""
/***************************************************************************
 SciPyFilters
                                 A QGIS plugin
 Filter collection implemented with SciPy
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2024-03-03
        copyright            : (C) 2024 by Florian Neukirchen
        email                : mail@riannek.de
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""

__author__ = 'Florian Neukirchen'
__date__ = '2024-03-03'
__copyright__ = '(C) 2024 by Florian Neukirchen'

# This will get replaced with a git SHA1 when you do a git archive

__revision__ = '$Format:%H$'

from scipy import ndimage, fft, signal
import numpy as np

from qgis.core import (QgsProcessingParameterNumber,
                       QgsProcessingParameterString,
                       QgsProcessingParameterDefinition,
                        )

from scipy_filters.ui.sizes_widget import SizesWidgetWrapper
from scipy_filters.scipy_algorithm_baseclasses import SciPyAlgorithm, Dimensions

from scipy_filters.helpers import (check_structure, 
                       str_to_array, 
                       kernelexamples)

from scipy_filters.ui.i18n import tr

from scipy_filters.ui.structure_widget import (StructureWidgetWrapper, 
                                  SciPyParameterStructure,)

from scipy_filters.helpers import str_to_int_or_list, get_np_dtype

class SciPyFourierGaussianAlgorithm(SciPyAlgorithm):
    """
    Gaussian fourier filter

    Gaussian filter calculated by multiplication in the frequency domain. 
    This is faster with large kernels (large sigma).

    The input band is transformed with fast fourier transform (FFT) 
    using fft2 (for 2D) or fftn (for 3D) from 
    `scipy.fft <https://docs.scipy.org/doc/scipy/reference/fft.html>`_.
    The multiplication with the fourier transform of a gaussian kernel 
    is calculated with fourier_gaussian from 
    `scipy.ndimage <https://docs.scipy.org/doc/scipy/reference/ndimage.html>`_. 
    The product is transformed back with ifft2 or ifftn, respectively. 
    Only the real part of the resulting complex 
    numbers is returned.

    .. note:: No data cells within the filter radius are filled with 0.

    **Dimension** Calculate for each band separately (2D) 
    or use all bands as a 3D datacube and perform filter in 3D.

    .. note:: bands will be the first axis of the datacube.

    **Sigma** Standard deviation of the gaussian filter.
    """


    # Overwrite constants of base class
    _name = 'fourier_gaussian'
    _displayname = tr('Fourier Gaussian')
    _outputname = None # If set to None, the displayname is used 
    _groupid = "blur" 
    
    SIGMA = 'SIGMA'


    def insert_parameters(self, config):

        self.addParameter(QgsProcessingParameterNumber(
            self.SIGMA,
            tr('Sigma'),
            QgsProcessingParameterNumber.Type.Double,
            defaultValue=5, 
            optional=False, 
            minValue=0, 
            maxValue=100
            ))
        
        super().insert_parameters(config)

    
    def get_parameters(self, parameters, context):
        kwargs = super().get_parameters(parameters, context)
        kwargs['sigma'] = self.parameterAsDouble(parameters, self.SIGMA, context)
        self.margin = int(4 * kwargs['sigma'] + 1)
        return kwargs
    
    # The function to be called, to be overwritten
    def get_fct(self):
        if self._dimension == Dimensions.threeD:
            return self.my_fct_3D
        else:
            return self.my_fct_2D

    def my_fct_2D(self, input_raster, **kwargs):
        dtype = kwargs.pop("output")
        a = fft.fft2(input_raster)
        a = ndimage.fourier_gaussian(a, **kwargs)
        a = fft.ifft2(a).real

        if np.issubdtype(dtype, np.integer):
            info = np.iinfo(dtype)
            if a.min() < info.min or a.max() > info.max:
                a = np.clip(a, info.min, info.max)
        return a
    


    def my_fct_3D(self, input_raster, **kwargs):
        dtype = kwargs.pop("output")
        a = fft.fftn(input_raster)
        a = ndimage.fourier_gaussian(a, **kwargs)
        a = fft.ifftn(a).real

        if np.issubdtype(dtype, np.integer):
            info = np.iinfo(dtype)
            if a.min() < info.min or a.max() > info.max:
                a = np.clip(a, info.min, info.max)
        return a

    def createInstance(self):
        return SciPyFourierGaussianAlgorithm()
    


class SciPyFourierEllipsoidAlgorithm(SciPyAlgorithm):
    """
    Ellipsoid fourier filter

    Ellipsoidal box filter calculated by multiplication 
    with a circular or ellipsoidal kernel in the frequency domain. 

    The input band is transformed with fast fourier transform (FFT) 
    using fft2 (for 2D) or fftn (for 3D) from 
    `scipy.fft <https://docs.scipy.org/doc/scipy/reference/fft.html>`_.
    The multiplication with the fourier transform of a gaussian kernel 
    is calculated with fourier_ellipsoid from 
    `scipy.ndimage <https://docs.scipy.org/doc/scipy/reference/ndimage.html>`_.
    The product is transformed back with ifft2 or ifftn, respectively. 
    Only the real part of the resulting complex 
    numbers is returned.

    .. note:: No data cells within the filter radius are filled with 0.

    **Dimension** Calculate for each band separately (2D) 
    or use all bands as a 3D datacube and perform filter in 3D. 

    .. note:: bands will be the first axis of the datacube.

    **Size** Size of the circular or ellipsoidal box.
    """

    # Overwrite constants of base class
    _name = 'fourier_ellipsoid'
    _displayname = tr('Fourier ellipsoid')
    _outputname = None # If set to None, the displayname is used 
    _groupid = "blur" 
    
    SIZE = 'SIZE' 
    SIZES = 'SIZES'


    def insert_parameters(self, config):       

        size_param = QgsProcessingParameterNumber(
            self.SIZE,
            tr('Size'),
            QgsProcessingParameterNumber.Type.Double,
            defaultValue=5, 
            optional=True, 
            minValue=0, 
            # maxValue=100, 
            )
        
        size_param.setFlags(size_param.flags() | QgsProcessingParameterDefinition.Flag.FlagHidden)

        self.addParameter(size_param)  

        sizes_param = QgsProcessingParameterString(
            self.SIZES,
            tr('Size'),
            defaultValue="", 
            optional=True, 
            )
        
        sizes_param.setMetadata({
            'widget_wrapper': {
                'class': SizesWidgetWrapper
            }
        })

        self.addParameter(sizes_param)

        
        super().insert_parameters(config)

    
    def get_parameters(self, parameters, context):
        kwargs = super().get_parameters(parameters, context)

        sizes = self.parameterAsString(parameters, self.SIZES, context)
        if sizes:
            size = str_to_int_or_list(sizes)
            self.margin = int(max(size))
        else:
            size = self.parameterAsDouble(parameters, self.SIZE, context)
            self.margin = size
        if not size:
            # Just in case it is called from python and neither size or sizes or footprint is set
            size = 3
            self.margin = 3
        kwargs['size'] = size


        return kwargs
    
    def checkParameterValues(self, parameters, context): 
        dims = self.getDimsForCheck(parameters, context)
        
        sizes = self.parameterAsString(parameters, self.SIZES, context)
        sizes = str_to_int_or_list(sizes)
        if isinstance(sizes, list):
            if len(sizes) != dims:
                return (False, tr("Sizes does not match number of dimensions"))

        return super().checkParameterValues(parameters, context)
    

    # The function to be called, to be overwritten
    def get_fct(self):
        if self._dimension == Dimensions.threeD:
            return self.my_fct_3D
        else:
            return self.my_fct_2D


    def my_fct_2D(self, input_raster, **kwargs):
        dtype = kwargs.pop("output")
        a = fft.fft2(input_raster)
        a = ndimage.fourier_ellipsoid(a, **kwargs)
        a = fft.ifft2(a).real

        if np.issubdtype(dtype, np.integer):
            info = np.iinfo(dtype)
            if a.min() < info.min or a.max() > info.max:
                a = np.clip(a, info.min, info.max)
        return a
    

    def my_fct_3D(self, input_raster, **kwargs):
        dtype = kwargs.pop("output")
        a = fft.fftn(input_raster)
        a = ndimage.fourier_ellipsoid(a, **kwargs)
        a = fft.ifftn(a).real

        if np.issubdtype(dtype, np.integer):
            info = np.iinfo(dtype)
            if a.min() < info.min or a.max() > info.max:
                a = np.clip(a, info.min, info.max)
        return a

    def createInstance(self):
        return SciPyFourierEllipsoidAlgorithm()
    



class SciPyFourierUniformAlgorithm(SciPyAlgorithm):
    """
    Fourier uniform (i.e. mean) filter 

    Uniform filter calculated by multiplication 
    with a box kernel in the frequency domain. 

    The input band is transformed with fast fourier transform (FFT) 
    using fft2 (for 2D) or fftn (for 3D) from 
    `scipy.fft <https://docs.scipy.org/doc/scipy/reference/fft.html>`_.
    The multiplication with the fourier transform of a gaussian kernel 
    is calculated with fourier_uniform from 
    `scipy.ndimage <https://docs.scipy.org/doc/scipy/reference/ndimage.html>`_.
    The product is transformed back with ifft2 or ifftn, respectively. 
    Only the real part of the resulting complex 
    numbers is returned.

    .. note:: No data cells within the filter radius are filled with 0.

    **Dimension** Calculate for each band separately (2D) 
    or use all bands as a 3D datacube and perform filter in 3D. 

    .. note:: bands will be the first axis of the datacube.

    **Size** Size of the box.
    """

    # Overwrite constants of base class
    _name = 'fourier_uniform'
    _displayname = tr('Fourier uniform (box filter)')
    _outputname = tr('Fourier uniform') # If set to None, the displayname is used 
    _groupid = "blur" 
    
    SIZE = 'SIZE' 
    SIZES = 'SIZES'


    def insert_parameters(self, config):


        size_param = QgsProcessingParameterNumber(
            self.SIZE,
            tr('Size'),
            QgsProcessingParameterNumber.Type.Double,
            defaultValue=5, 
            optional=True, 
            minValue=0, 
            # maxValue=100, 
            )
        
        size_param.setFlags(size_param.flags() | QgsProcessingParameterDefinition.Flag.FlagHidden)

        self.addParameter(size_param)  

        sizes_param = QgsProcessingParameterString(
            self.SIZES,
            tr('Size'),
            defaultValue="", 
            optional=True, 
            )
        
        sizes_param.setMetadata({
            'widget_wrapper': {
                'class': SizesWidgetWrapper
            }
        })

        self.addParameter(sizes_param)
        
        super().insert_parameters(config)

    
    def get_parameters(self, parameters, context):
        kwargs = super().get_parameters(parameters, context)

        sizes = self.parameterAsString(parameters, self.SIZES, context)
        if sizes:
            size = str_to_int_or_list(sizes)
            self.margin = int(max(size))
        else:
            size = self.parameterAsDouble(parameters, self.SIZE, context)
            self.margin = size
        if not size:
            # Just in case it is called from python and neither size or sizes or footprint is set
            size = 3
            self.margin = 3
        kwargs['size'] = size


        return kwargs
    
    def checkParameterValues(self, parameters, context): 
        dims = self.getDimsForCheck(parameters, context)
        
        sizes = self.parameterAsString(parameters, self.SIZES, context)
        sizes = str_to_int_or_list(sizes)
        if isinstance(sizes, list):
            if len(sizes) != dims:
                return (False, tr("Sizes does not match number of dimensions"))

        return super().checkParameterValues(parameters, context)
    
    
    # The function to be called, to be overwritten
    def get_fct(self):
        if self._dimension == Dimensions.threeD:
            return self.my_fct_3D
        else:
            return self.my_fct_2D


    def my_fct_2D(self, input_raster, **kwargs):
        dtype = kwargs.pop("output")
        a = fft.fft2(input_raster)
        a = ndimage.fourier_uniform(a, **kwargs)
        a = fft.ifft2(a).real

        if np.issubdtype(dtype, np.integer):
            info = np.iinfo(dtype)
            if a.min() < info.min or a.max() > info.max:
                a = np.clip(a, info.min, info.max)
        return a
    

    def my_fct_3D(self, input_raster, **kwargs):
        dtype = kwargs.pop("output")
        a = fft.fftn(input_raster)
        a = ndimage.fourier_uniform(a, **kwargs)
        a = fft.ifftn(a).real

        if np.issubdtype(dtype, np.integer):
            info = np.iinfo(dtype)
            if a.min() < info.min or a.max() > info.max:
                a = np.clip(a, info.min, info.max)
        return a

    def createInstance(self):
        return SciPyFourierUniformAlgorithm()



class SciPyFFTConvolveAlgorithm(SciPyAlgorithm):
    """
    Convolve raster band(s) with custom kernel using FFT. 
    
    This is faster for large kernels. 
    Both, raster band(s) and kernel are transformed into the frequency domain 
    with fast fourier transform (FFT), the results are multiplied and the product 
    is converted back using FFT.

    Calculated using fftconvolve from 
    `scipy.signal <https://docs.scipy.org/doc/scipy/reference/signal.html>`_.

    .. note:: No data cells within the filter radius are filled with 0.

    **Kernel** String representation of array. 
    Must have 2 dimensions if *dimension* is set to 2D. 
    Should have 3 dimensions if *dimension* is set to 3D, 
    but a 2D array is also excepted (a new axis is added as first 
    axis and the result is the same as calculating each band 
    seperately).
    **Normalization** Normalize the kernel by dividing through 
    given value; set to 0 to devide through the sum of the absolute 
    values of the kernel.

    **Dtype** Data type of output. Beware of clipping 
    and potential overflow errors if min/max of output does 
    not fit. Default is Float32.
    """

    KERNEL = 'KERNEL'
    NORMALIZATION = 'NORMALIZATION'

    # Overwrite constants of base class
    _name = 'fft_convolve'
    _displayname = tr('FFT Convolve')
    _outputname = None # If set to None, the displayname is used 
    _groupid = "convolution" 

    _default_dtype = 6 # Optionally change default output dtype (value = idx of combobox)

    
    def initAlgorithm(self, config):
        # Set dimensions to 2
        self._dimension = Dimensions.twoD

        # Used for feedback
        self.inmax = []
        self.inmin = []

        super().initAlgorithm(config)


    def insert_parameters(self, config):

        default_kernel = "[[1, 2, 1],\n[2, 4, 2],\n[1, 2, 1]]"

        kernel_param = SciPyParameterStructure(
            self.KERNEL,
            tr('Kernel'),
            defaultValue=default_kernel,
            examples=kernelexamples,
            multiLine=True,
            to_int=False,
            optional=False
            )
        
        kernel_param.setMetadata({
            'widget_wrapper': {
                'class': StructureWidgetWrapper
            }
        })

        self.addParameter(kernel_param)
        
        self.addParameter(QgsProcessingParameterNumber(
            self.NORMALIZATION,
            tr('Normalization (devide kernel values by given number). Set to 0 to devide by sum of absolute values of the kernel.'),
            QgsProcessingParameterNumber.Type.Double,
            defaultValue=0, 
            optional=True, 
            minValue=0, 
            # maxValue=100
            )) 

        
        super().insert_parameters(config)

    
    def get_parameters(self, parameters, context):
        kwargs = super().get_parameters(parameters, context)

        kernel = self.parameterAsString(parameters, self.KERNEL, context)
        kernel = str_to_array(kernel, 2)

        normalization = self.parameterAsDouble(parameters, self.NORMALIZATION, context)

        if normalization == 0:
            kernel = kernel / np.abs(kernel).sum()
        else:
            kernel = kernel / normalization

        kwargs['in2'] = kernel

        self.kernel = kernel # For feedback

        kwargs['mode'] = 'same' # size must be the same as input raster

        self.margin = int(np.ceil(max(kernel.shape) / 2).max())

        return kwargs
    
    def checkParameterValues(self, parameters, context): 

        structure = self.parameterAsString(parameters, self.KERNEL, context)

        ok, s, shape = check_structure(structure, 2)
        if not ok:
            return (ok, s)
        
        return super().checkParameterValues(parameters, context)
    
    def my_fct(self, a, **kwargs):
        dtype = kwargs.pop("output")

        # Used for feedback
        self.inmin.append(a.min())
        self.inmax.append(a.max())

        a = signal.fftconvolve(a, **kwargs)

        if np.issubdtype(dtype, np.integer):
            info = np.iinfo(dtype)
            if a.min() < info.min or a.max() > info.max:
                self.error = (tr("Values ({}...{}) are out of bounds of new dtype, clipping to {}...{}").format(a.min().round(1), a.max().round(1), info.min, info.max), False)
                a = np.clip(a, info.min, info.max)
        return a

    def checkAndComplain(self, feedback):

        inmin = min(self.inmin)
        inmax = max(self.inmax)

        msg = tr("Input values are in the range {}...{}").format(inmin, inmax)
        feedback.pushInfo(msg)

        # Calculate the possible range after applying the kernel
        outmax = ((np.where(self.kernel < 0, 0, self.kernel)    # positive part of kernel
                   * max(0, inmax)).sum()                       # multiplied with positive input
                  + (np.where(self.kernel > 0, 0, self.kernel)  # negative part of kernel
                     * min(0, inmin)).sum()).astype("int")      # multiplied with negative input

        outmin = ((np.where(self.kernel > 0, 0, self.kernel)    # negative part of kernel
                   * max(0, inmax)).sum()                       # multiplied with positive input
                  + (np.where(self.kernel < 0, 0, self.kernel)  # positive part of kernel
                     * min(0, inmin)).sum()).astype("int")      # multiplied with negative input
        
        msg = tr("Expected output range is {}...{}").format(outmin, outmax)
        feedback.pushInfo(msg)
        
        if self._outdtype in (1,2,4) and np.any(self.kernel < 0):
            msg = tr("WARNING: With a kernel containing negative values, output values can be negative. But output data type is unsigned integer!")
            feedback.reportError(msg, fatalError = False)

        if 1 <= self._outdtype <= 5: # integer types
            info_out = np.iinfo(get_np_dtype(self._outdtype))
            if outmin < info_out.min or outmax > info_out.max:
                msg = tr("WARNING: The possible range of output values is not in the range of the output datatype. Clipping is likely.")
                feedback.reportError(msg, fatalError=False)

    # The function to be called, to be overwritten
    def get_fct(self):
        return self.my_fct


    def createInstance(self):
        return SciPyFFTConvolveAlgorithm()
    





class SciPyFFTCorrelateAlgorithm(SciPyFFTConvolveAlgorithm):
    """
    Correlate raster band(s) with custom kernel using FFT
    
    This is faster for large kernels. 
    Both, raster band(s) and kernel are transformed into the frequency domain 
    with fast fourier transform (FFT), the results are multiplied and the product 
    is converted back using FFT.

    Calculated using correlate from 
    `scipy.signal <https://docs.scipy.org/doc/scipy/reference/signal.html>`_.
    using method "fft".

    .. note:: No data cells within the filter radius are filled with 0.

    **Kernel** String representation of array. 
    Must have 2 dimensions if *dimension* is set to 2D. 
    Should have 3 dimensions if *dimension* is set to 3D, 
    but a 2D array is also excepted (a new axis is added as first 
    axis and the result is the same as calculating each band 
    seperately).
    
    **Normalization** Normalize the kernel by dividing through 
    given value; set to 0 to devide through the sum of the absolute 
    values of the kernel.

    **Dtype** Data type of output. Beware of clipping 
    and potential overflow errors if min/max of output does 
    not fit. Default is Float32.
    """

    # Overwrite constants of base class
    _name = 'fft_correlate'
    _displayname = tr('FFT Correlate')
    _outputname = None # If set to None, the displayname is used 
    _groupid = "convolution" 

    _default_dtype = 6 # Optionally change default output dtype (value = idx of combobox)

    
    def my_fct(self, a, **kwargs):
        dtype = kwargs.pop("output")

        # Used for feedback
        self.inmin.append(a.min())
        self.inmax.append(a.max())

        kwargs["method"] = "fft"
        
        a = signal.correlate(a, **kwargs)

        if np.issubdtype(dtype, np.integer):
            info = np.iinfo(dtype)
            if a.min() < info.min or a.max() > info.max:
                self.error = (tr("Values ({}...{}) are out of bounds of new dtype, clipping to {}...{}").format(a.min().round(1), a.max().round(1), info.min, info.max), False)
                a = np.clip(a, info.min, info.max)
        return a

    def createInstance(self):
        return SciPyFFTCorrelateAlgorithm()