# -*- coding: utf-8 -*-

"""
/***************************************************************************
 SciPyFilters
                                 A QGIS plugin
 Filter collection implemented with SciPy
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2024-03-03
        copyright            : (C) 2024 by Florian Neukirchen
        email                : mail@riannek.de
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""

__author__ = 'Florian Neukirchen'
__date__ = '2024-03-03'
__copyright__ = '(C) 2024 by Florian Neukirchen'

# This will get replaced with a git SHA1 when you do a git archive

__revision__ = '$Format:%H$'

from osgeo import gdal
from scipy import ndimage, fft, signal
from qgis.PyQt.QtCore import QCoreApplication
from qgis.core import (QgsProcessing,
                       QgsProcessingAlgorithm,
                       QgsProcessingParameterRasterLayer,
                       QgsProcessingParameterNumber,
                       QgsProcessingParameterRasterDestination,
                       QgsProcessingParameterEnum,
                       QgsProcessingParameterBand,
                       QgsProcessingParameterString
                        )

from .scipy_algorithm_baseclasses import SciPyAlgorithm
from .helpers import check_structure, str_to_array, kernelexamples
from .ui.structure_widget import (StructureWidgetWrapper, 
                                  SciPyParameterStructure,)

class SciPyFourierGaussianAlgorithm(SciPyAlgorithm):
    """
    Gaussian fourier filter 


    """

    # Overwrite constants of base class
    _name = 'fourier_gaussian'
    _displayname = 'Fourier Gaussian'
    _outputname = None # If set to None, the displayname is used 
    _groupid = "blur" 
    _help = """
            Gaussian filter calculated by multiplication in the frequency domain. \
            This is faster with large kernels (large sigma).

            The input band is transformed with fast fourier transform (FFT) \
            using fft2 (for 2D) or fftn (for 3D) from \
            <a href="https://docs.scipy.org/doc/scipy/reference/fft.html">scipy.fft</a>.
            The multiplication with the fourier transform of a gaussian kernel \
            is calculated with fourier_gaussian from \
            <a href="https://docs.scipy.org/doc/scipy/reference/ndimage.html">scipy.ndimage</a>. \
            The product is transformed back with ifft2 or ifftn, respectively. \
            Only the real part of the resulting complex \
            numbers is returned.

            <b>Dimension</b> Calculate for each band separately (2D) \
            or use all bands as a 3D datacube and perform filter in 3D. \
            Note: bands will be the first axis of the datacube.
        
            <b>Sigma</b> Standard deviation of the gaussian filter.
            """
    
    SIGMA = 'SIGMA'


    def insert_parameters(self, config):

        self.addParameter(QgsProcessingParameterNumber(
            self.SIGMA,
            self.tr('Sigma'),
            QgsProcessingParameterNumber.Type.Double,
            defaultValue=5, 
            optional=False, 
            minValue=0, 
            maxValue=100
            ))
        
        super().insert_parameters(config)

    
    def get_parameters(self, parameters, context):
        kwargs = super().get_parameters(parameters, context)
        kwargs['sigma'] = self.parameterAsDouble(parameters, self.SIGMA, context)
        return kwargs
    
    # The function to be called, to be overwritten
    def get_fct(self):
        if self._dimension == self.Dimensions.threeD:
            return self.my_fct_3D
        else:
            return self.my_fct_2D

    def my_fct_2D(self, input_raster, **kwargs):
        input_fft = fft.fft2(input_raster)
        result = ndimage.fourier_gaussian(input_fft, **kwargs)
        result = fft.ifft2(result)
        return result.real

    def my_fct_3D(self, input_raster, **kwargs):
        input_fft = fft.fftn(input_raster)
        result = ndimage.fourier_gaussian(input_fft, **kwargs)
        result = fft.ifftn(result)
        return result.real

    def createInstance(self):
        return SciPyFourierGaussianAlgorithm()
    




class SciPyFourierEllipsoidAlgorithm(SciPyAlgorithm):
    """
    Ellipsoid fourier filter 


    """

    # Overwrite constants of base class
    _name = 'fourier_ellipsoid'
    _displayname = 'Fourier ellipsoid'
    _outputname = 'Gaussian ellipsoid' # If set to None, the displayname is used 
    _groupid = "blur" 
    _help = """
            Ellipsoidal box filter calculated by multiplication \
            with a circular or ellipsoidal kernel in the frequency domain. \

            The input band is transformed with fast fourier transform (FFT) \
            using fft2 (for 2D) or fftn (for 3D) from \
            <a href="https://docs.scipy.org/doc/scipy/reference/fft.html">scipy.fft</a>.
            The multiplication with the fourier transform of a gaussian kernel \
            is calculated with fourier_ellipsoid from \
            <a href="https://docs.scipy.org/doc/scipy/reference/ndimage.html">scipy.ndimage</a>. \
            The product is transformed back with ifft2 or ifftn, respectively. \
            Only the real part of the resulting complex \
            numbers is returned.

            <b>Dimension</b> Calculate for each band separately (2D) \
            or use all bands as a 3D datacube and perform filter in 3D. \
            Note: bands will be the first axis of the datacube.
        
            <b>Size</b> Size of the box (for now only circular size).
            """
    
    SIZE = 'SIZE' # can be float or int or tuple of int but not tuple of float


    def insert_parameters(self, config):

        self.addParameter(QgsProcessingParameterNumber(
            self.SIZE,
            self.tr('Size'),
            QgsProcessingParameterNumber.Type.Double,
            defaultValue=5, 
            optional=False, 
            minValue=0, 
            # maxValue=100
            ))
        
        super().insert_parameters(config)

    
    def get_parameters(self, parameters, context):
        kwargs = super().get_parameters(parameters, context)
        kwargs['size'] = self.parameterAsDouble(parameters, self.SIZE, context)
        return kwargs
    
    # The function to be called, to be overwritten
    def get_fct(self):
        if self._dimension == self.Dimensions.threeD:
            return self.my_fct_3D
        else:
            return self.my_fct_2D

    def my_fct_2D(self, input_raster, **kwargs):
        input_fft = fft.fft2(input_raster)
        result = ndimage.fourier_ellipsoid(input_fft, **kwargs)
        result = fft.ifft2(result)
        return result.real

    def my_fct_3D(self, input_raster, **kwargs):
        input_fft = fft.fftn(input_raster)
        result = ndimage.fourier_ellipsoid(input_fft, **kwargs)
        result = fft.ifftn(result)
        return result.real

    def createInstance(self):
        return SciPyFourierEllipsoidAlgorithm()
    



class SciPyFourierUniformAlgorithm(SciPyAlgorithm):
    """
    Ellipsoid uniform (i.e. mean) filter 


    """

    # Overwrite constants of base class
    _name = 'fourier_uniform'
    _displayname = 'Fourier uniform (box filter)'
    _outputname = 'Fourier uniform' # If set to None, the displayname is used 
    _groupid = "blur" 
    _help = """
            Uniform filter calculated by multiplication \
            with a box kernel in the frequency domain. \

            The input band is transformed with fast fourier transform (FFT) \
            using fft2 (for 2D) or fftn (for 3D) from \
            <a href="https://docs.scipy.org/doc/scipy/reference/fft.html">scipy.fft</a>.
            The multiplication with the fourier transform of a gaussian kernel \
            is calculated with fourier_uniform from \
            <a href="https://docs.scipy.org/doc/scipy/reference/ndimage.html">scipy.ndimage</a>. \
            The product is transformed back with ifft2 or ifftn, respectively. \
            Only the real part of the resulting complex \
            numbers is returned.

            <b>Dimension</b> Calculate for each band separately (2D) \
            or use all bands as a 3D datacube and perform filter in 3D. \
            Note: bands will be the first axis of the datacube.
        
            <b>Size</b> Size of the box.
            """
    
    SIZE = 'SIZE' # can be float or int or tuple of int but not tuple of float


    def insert_parameters(self, config):

        self.addParameter(QgsProcessingParameterNumber(
            self.SIZE,
            self.tr('Size'),
            QgsProcessingParameterNumber.Type.Double,
            defaultValue=5, 
            optional=False, 
            minValue=0, 
            # maxValue=100
            ))
        
        super().insert_parameters(config)

    
    def get_parameters(self, parameters, context):
        kwargs = super().get_parameters(parameters, context)
        kwargs['size'] = self.parameterAsDouble(parameters, self.SIZE, context)
        return kwargs
    
    # The function to be called, to be overwritten
    def get_fct(self):
        if self._dimension == self.Dimensions.threeD:
            return self.my_fct_3D
        else:
            return self.my_fct_2D

    def my_fct_2D(self, input_raster, **kwargs):
        input_fft = fft.fft2(input_raster)
        result = ndimage.fourier_uniform(input_fft, **kwargs)
        result = fft.ifft2(result)
        return result.real

    def my_fct_3D(self, input_raster, **kwargs):
        input_fft = fft.fftn(input_raster)
        result = ndimage.fourier_uniform(input_fft, **kwargs)
        result = fft.ifftn(result)
        return result.real

    def createInstance(self):
        return SciPyFourierUniformAlgorithm()



class SciPyFFTConvolveAlgorithm(SciPyAlgorithm):
    """
    Convolve raster band(s) with custom kernel using FFT

    
    """

    KERNEL = 'KERNEL'
    NORMALIZATION = 'NORMALIZATION'

    # Overwrite constants of base class
    _name = 'fft_convolve'
    _displayname = 'FFT Convolve'
    _outputname = None # If set to None, the displayname is used 
    _groupid = "convolution" 
    _help = """
            Convolve raster band(s) with custom kernel using FFT. This is faster for large kernels. \
            Both, raster band(s) and kernel are transformed into the frequency domain \
            with fast fourier transform (FFT), the results are multiplied and the product \
            is converted back using FFT.

            Calculated using fftconvolve from \
            <a href="https://docs.scipy.org/doc/scipy/reference/signal.html">scipy.signal</a>.

            <b>Kernel</b> String representation of array. \
            Must have 2 dimensions if <i>dimension</i> is set to 2D. \
            Should have 3 dimensions if <i>dimension</i> is set to 3D, \
            but a 2D array is also excepted (a new axis is added as first \
            axis and the result is the same as calculating each band \
            seperately).
            <b>Normalization</b> Normalize the kernel by dividing through given value; set to 0 to devide through the sum of kernel values.
            """
    

    def initAlgorithm(self, config):
        # Set dimensions to 2
        self._dimension = self.Dimensions.twoD

        # Set modes 
        self.modes = ['full', 'valid', 'same']

        super().initAlgorithm(config)


    def insert_parameters(self, config):

        default_kernel = "[[1, 2, 1],\n[2, 4, 2],\n[1, 2, 1]]"

        kernel_param = SciPyParameterStructure(
            self.KERNEL,
            self.tr('Kernel'),
            defaultValue=default_kernel,
            examples=kernelexamples,
            multiLine=True,
            to_int=False,
            optional=False
            )
        
        kernel_param.setMetadata({
            'widget_wrapper': {
                'class': StructureWidgetWrapper
            }
        })

        self.addParameter(kernel_param)
        
        self.addParameter(QgsProcessingParameterNumber(
            self.NORMALIZATION,
            self.tr('Normalization (devide kernel values by number). Set to 0 to devide by sum of kernel values.'),
            QgsProcessingParameterNumber.Type.Double,
            defaultValue=0, 
            optional=True, 
            minValue=0, 
            # maxValue=100
            )) 

        
        super().insert_parameters(config)

    
    def get_parameters(self, parameters, context):
        kwargs = super().get_parameters(parameters, context)

        kernel = self.parameterAsString(parameters, self.KERNEL, context)
        kernel = str_to_array(kernel, 2)

        normalization = self.parameterAsDouble(parameters, self.NORMALIZATION, context)

        if normalization == 0:
            kernel = kernel / kernel.sum()
        else:
            kernel = kernel / normalization

        kwargs['in2'] = kernel
        kwargs['mode'] = 'same' # size must be the same as input raster

        return kwargs
    
    def checkParameterValues(self, parameters, context): 

        structure = self.parameterAsString(parameters, self.KERNEL, context)

        ok, s = check_structure(structure, 2)
        if not ok:
            return (ok, s)
        
        return super().checkParameterValues(parameters, context)
    

    # The function to be called, to be overwritten
    def get_fct(self):
        return signal.fftconvolve


    def createInstance(self):
        return SciPyFFTConvolveAlgorithm()