# -*- coding: utf-8 -*-

"""
/***************************************************************************
 SciPyFilters
                                 A QGIS plugin
 Filter collection implemented with SciPy
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2024-03-03
        copyright            : (C) 2024 by Florian Neukirchen
        email                : mail@riannek.de
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""

__author__ = 'Florian Neukirchen'
__date__ = '2024-03-03'
__copyright__ = '(C) 2024 by Florian Neukirchen'

# This will get replaced with a git SHA1 when you do a git archive

__revision__ = '$Format:%H$'

import numpy as np
from osgeo import gdal
from scipy import ndimage, signal
from qgis.PyQt.QtCore import QCoreApplication
from qgis.core import (QgsProcessing,
                       QgsProcessingAlgorithm,
                       QgsProcessingParameterRasterLayer,
                       QgsProcessingParameterNumber,
                       QgsProcessingParameterRasterDestination,
                       QgsProcessingParameterEnum,
                       QgsProcessingParameterBand,)
                        
from .scipy_algorithm_baseclasses import (SciPyAlgorithm,
                                          SciPyAlgorithmWithMode,
                                          SciPyAlgorithmWithModeAxis,
                                          SciPyStatisticalAlgorithm,
                                          QgsProcessingParameterString)

from .scipy_gaussian_algorithm import SciPyAlgorithmWithSigma

from .ui.sizes_widget import (OddSizesWidgetWrapper)


from .helpers import str_to_int_or_list

class SciPyWienerAlgorithm(SciPyAlgorithm):
    """
    Wiener filter 

    """

    SIZES = 'SIZES'
    NOISE = 'NOISE'

    # Overwrite constants of base class
    _name = 'wiener'
    _displayname = 'Wiener filter'
    _outputname = None
    _groupid = "enhance" 
    _help = """
            Wiener filter (noise reduction). \
            Calculated with wiener from \
            <a href="https://docs.scipy.org/doc/scipy/reference/signal.html">scipy.signal</a>.

            <b>Dimension</b> Calculate for each band separately (2D) \
            or use all bands as a 3D datacube and perform filter in 3D. \
            Note: bands will be the first axis of the datacube.

            <b>Size</b> Size of filter in pixels. All values must \
            be odd.

            <b>Noise</b> The noise-power to use. If not set, estimate noise from \
            local variance.
            """
    
    # The function to be called, to be overwritten
    def get_fct(self):
        return signal.wiener
    
    def initAlgorithm(self, config):
        # Call the super function first
        # (otherwise input is not the first parameter in the GUI)
        super().initAlgorithm(config)

        sizes_param = QgsProcessingParameterString(
            self.SIZES,
            self.tr('Size: integer (odd) or array of odd integers with sizes for every dimension'),
            defaultValue="5", 
            optional=False, 
            )

        sizes_param.setMetadata({
            'widget_wrapper': {
                'class': OddSizesWidgetWrapper
            }
        })

        self.addParameter(sizes_param)


        self.addParameter(QgsProcessingParameterNumber(
            self.NOISE,
            self.tr('Noise'),
            QgsProcessingParameterNumber.Type.Double,
            # defaultValue=5, 
            optional=True, 
            minValue=0, 
            ))    
        
    def get_parameters(self, parameters, context):
        kwargs = super().get_parameters(parameters, context)

        sizes = self.parameterAsString(parameters, self.SIZES, context)
        sizes = str_to_int_or_list(sizes)
        
        kwargs['mysize'] = sizes
        kwargs['noise'] = self.parameterAsDouble(parameters, self.NOISE, context)

        return kwargs


    def checkParameterValues(self, parameters, context): 

        sizes = self.parameterAsString(parameters, self.SIZES, context)

        dims = 2
        if self._dimension == self.Dimensions.nD:
            dim_option = self.parameterAsInt(parameters, self.DIMENSION, context)
            if dim_option == 1:
                dims = 3

        
        try:
            sizes = str_to_int_or_list(sizes)
        except ValueError:
            return (False, self.tr("Can not parse size."))
        
        sizes = np.array(sizes)

        if not (sizes.size == 1 or sizes.size == dims):
            return (False, self.tr('Number of elements in array must match the number of dimensions'))
        
        if np.any(sizes % 2 == 0):
            return (False, self.tr('Every element in size must be odd.'))

        return super().checkParameterValues(parameters, context)


    def createInstance(self):
        return SciPyWienerAlgorithm()
    


class SciPyUnsharpMaskAlgorithm(SciPyAlgorithmWithSigma):
    """
    Unsharp mask based on scipy.ndimage.gaussian

    """

    AMOUNT = 'AMOUNT'

    # Overwrite constants of base class
    _name = 'unsharp_mask'
    _displayname = 'Unsharp mask'
    _outputname = None # If set to None, the displayname is used 
    _groupid = "enhance" 
    _help = """
            Sharpen the image with an unsharp mask filter. 

            <b>Dimension</b> Calculate for each band separately (2D) \
            or use all bands as a 3D datacube and perform filter in 3D. \
            Note: bands will be the first axis of the datacube.

            <b>Sigma</b> Radius of the filter (standard deviation of the gaussian filter).

            <b>Amount</b> Amplification factor.

            <b>Border mode</b> determines how input is extended around \
            the edges: <i>Reflect</i> (input is extended by reflecting at the edge), \
            <i>Constant</i> (fill around the edges with a <b>constant value</b>), \
            <i>Nearest</i> (extend by replicating the nearest pixel), \
            <i>Mirror</i> (extend by reflecting about the center of last pixel), \
            <i>Wrap</i> (extend by wrapping around to the opposite edge).
            """
    
    def insert_parameters(self, config):
        
        self.addParameter(QgsProcessingParameterNumber(
            self.AMOUNT,
            self.tr('Amount'),
            QgsProcessingParameterNumber.Type.Double,
            defaultValue=1.0, 
            optional=False, 
            ))   
        
        super().insert_parameters(config)


    def get_parameters(self, parameters, context):
        kwargs = super().get_parameters(parameters, context)

        kwargs['amount'] = self.parameterAsDouble(parameters, self.AMOUNT, context) 
       
        return kwargs
    


    # The function to be called, to be overwritten
    def get_fct(self):
        return self.unsharpmask
    
    def unsharpmask(self, raster, **kwargs):
        # most likely raster is of dtype uint, but we need negative values
        dtype = raster.dtype
        raster = raster.astype("float64")
        blurred = ndimage.gaussian_filter(raster, sigma=kwargs['sigma'])
        out = raster + (raster - blurred) * kwargs['amount']
        return out.astype(dtype)


    def createInstance(self):
        return SciPyUnsharpMaskAlgorithm()
