# -*- coding: utf-8 -*-

"""
/***************************************************************************
 SciPyFilters
                                 A QGIS plugin
 Filter collection implemented with SciPy
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2024-03-03
        copyright            : (C) 2024 by Florian Neukirchen
        email                : mail@riannek.de
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""

__author__ = 'Florian Neukirchen'
__date__ = '2024-03-03'
__copyright__ = '(C) 2024 by Florian Neukirchen'

# This will get replaced with a git SHA1 when you do a git archive

__revision__ = '$Format:%H$'

import numpy as np
from scipy import ndimage, signal

from qgis.core import QgsProcessingParameterNumber
                        
from scipy_filters.scipy_algorithm_baseclasses import (SciPyAlgorithm,
                                          QgsProcessingParameterString,
                                          Dimensions)

from scipy_filters.algs.scipy_gaussian_algorithm import SciPyAlgorithmWithSigma

from scipy_filters.ui.sizes_widget import (OddSizesWidgetWrapper)


from scipy_filters.helpers import str_to_int_or_list 
from scipy_filters.ui.i18n import tr


class SciPyWienerAlgorithm(SciPyAlgorithm):
    """
    Wiener filter (noise reduction). 
    Calculated with wiener from 
    `scipy.signal <https://docs.scipy.org/doc/scipy/reference/signal.html>`_.

    **Dimension** Calculate for each band separately (2D) 
    or use all bands as a 3D datacube and perform filter in 3D. 

    .. note:: bands will be the first axis of the datacube.

    **Size** Size of filter in pixels. All values must 
    be odd.

    **Noise** The noise-power to use. If not set, estimate noise from 
    local variance.
    """

    SIZES = 'SIZES'
    NOISE = 'NOISE'

    # Overwrite constants of base class
    _name = 'wiener'
    _displayname = tr('Wiener filter')
    _outputname = None
    _groupid = "enhance" 
    
    # The function to be called, to be overwritten
    def get_fct(self):
        return self.my_fct
    
    def my_fct(self, a, **kwargs):
        dtype = kwargs.pop("output")

        a = signal.wiener(a, **kwargs)

        if np.issubdtype(dtype, np.integer):
            info = np.iinfo(dtype)
            if a.min() < info.min or a.max() > info.max:
                self.error = (f"Values ({a.min().round(1)}...{a.max().round(1)}) are out of bounds of new dtype, clipping to {info.min}...{info.max}", False)
                a = np.clip(a, info.min, info.max)
        return a
    
    def initAlgorithm(self, config):
        # Call the super function first
        # (otherwise input is not the first parameter in the GUI)
        super().initAlgorithm(config)

        sizes_param = QgsProcessingParameterString(
            self.SIZES,
            tr('Size: integer (odd) or array of odd integers with sizes for every dimension'),
            defaultValue="5", 
            optional=False, 
            )

        sizes_param.setMetadata({
            'widget_wrapper': {
                'class': OddSizesWidgetWrapper
            }
        })

        self.addParameter(sizes_param)


        self.addParameter(QgsProcessingParameterNumber(
            self.NOISE,
            tr('Noise'),
            QgsProcessingParameterNumber.Type.Double,
            # defaultValue=5, 
            optional=True, 
            minValue=0, 
            ))    
        
    def get_parameters(self, parameters, context):
        kwargs = super().get_parameters(parameters, context)

        sizes = self.parameterAsString(parameters, self.SIZES, context)
        sizes = str_to_int_or_list(sizes)
        
        kwargs['mysize'] = sizes
        kwargs['noise'] = self.parameterAsDouble(parameters, self.NOISE, context)

        self.margin = int(np.ceil(max(sizes) / 2))

        return kwargs


    def checkParameterValues(self, parameters, context): 

        sizes = self.parameterAsString(parameters, self.SIZES, context)

        dims = 2
        if self._dimension == Dimensions.nD:
            dim_option = self.parameterAsInt(parameters, self.DIMENSION, context)
            if dim_option == 1:
                dims = 3

        
        try:
            sizes = str_to_int_or_list(sizes)
        except ValueError:
            return (False, tr("Can not parse size."))
        
        sizes = np.array(sizes)

        if not (sizes.size == 1 or sizes.size == dims):
            return (False, tr('Number of elements in array must match the number of dimensions'))
        
        if np.any(sizes % 2 == 0):
            return (False, tr('Every element in size must be odd.'))

        return super().checkParameterValues(parameters, context)


    def createInstance(self):
        return SciPyWienerAlgorithm()
    


class SciPyUnsharpMaskAlgorithm(SciPyAlgorithmWithSigma):
    """
    Sharpen the image with an unsharp mask filter

    **Dimension** Calculate for each band separately (2D) 
    or use all bands as a 3D datacube and perform filter in 3D.

    .. note:: bands will be the first axis of the datacube.

    **Sigma** Standard deviation of the gaussian filter. 
    The size of the gaussian kernel is 4 * sigma.

    **Amount** Amplification factor.

    **Border mode** determines how input is extended around 
    the edges: *Reflect* (input is extended by reflecting at the edge), 
    *Constant* (fill around the edges with a **constant value**), 
    *Nearest* (extend by replicating the nearest pixel), 
    *Mirror* (extend by reflecting about the center of last pixel), 
    *Wrap* (extend by wrapping around to the opposite edge).
    """

    AMOUNT = 'AMOUNT'

    # Overwrite constants of base class
    _name = 'unsharp_mask'
    _displayname = tr('Unsharp mask')
    _outputname = None # If set to None, the displayname is used 
    _groupid = "enhance" 

    
    def insert_parameters(self, config):
        
        self.addParameter(QgsProcessingParameterNumber(
            self.AMOUNT,
            tr('Amount'),
            QgsProcessingParameterNumber.Type.Double,
            defaultValue=1.0, 
            optional=False, 
            minValue=0,
            maxValue=5,
            ))   
        
        super().insert_parameters(config)


    def get_parameters(self, parameters, context):
        kwargs = super().get_parameters(parameters, context)

        kwargs['amount'] = self.parameterAsDouble(parameters, self.AMOUNT, context) 
       
        return kwargs
    


    # The function to be called, to be overwritten
    def get_fct(self):
        return self.unsharpmask
    
    def unsharpmask(self, raster, **kwargs):
        # most likely raster is of dtype uint, but we need negative values
        dtype = kwargs.pop("output")
        raster = raster.astype("float64")

        blurred = ndimage.gaussian_filter(raster, sigma=kwargs['sigma'])

        a = raster + (raster - blurred) * kwargs['amount']

        # Clip according to datatype
        if np.issubdtype(dtype, np.integer):
            info = np.iinfo(dtype)
            if a.min() < info.min or a.max() > info.max:
                a = np.clip(a, info.min, info.max)

        return a.astype(dtype)


    def createInstance(self):
        return SciPyUnsharpMaskAlgorithm()
