# -*- coding: utf-8 -*-

"""
/***************************************************************************
 SciPyFilters
                                 A QGIS plugin
 Filter collection implemented with SciPy
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2024-03-03
        copyright            : (C) 2024 by Florian Neukirchen
        email                : mail@riannek.de
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""

__author__ = 'Florian Neukirchen'
__date__ = '2024-03-03'
__copyright__ = '(C) 2024 by Florian Neukirchen'

# This will get replaced with a git SHA1 when you do a git archive

__revision__ = '$Format:%H$'

import numpy as np
from scipy import ndimage, signal
from qgis.core import (QgsProcessingParameterNumber,
                       QgsProcessingParameterString,
                       QgsProcessingParameterDefinition,
                        )

from scipy_filters.helpers import str_to_int_or_list
from scipy_filters.ui.i18n import tr

from scipy_filters.scipy_algorithm_baseclasses import (SciPyAlgorithm,
                                          SciPyStatisticalAlgorithm)

def estimate_local_variance(raster, size, output=None):
    """
    Estimate local variance within window of size.
    
    Implementation based on parts of the source code of 
    the scipy.signal.wiener filter:
    `https://github.com/scipy/scipy/blob/v1.8.0/scipy/signal/_signaltools.py#L1541-L1615 <https://github.com/scipy/scipy/blob/v1.8.0/scipy/signal/_signaltools.py#L1541-L1615>`_.
    
    which is (c) 1999-2002 Travis Oliphant under BSD-3-Clause license.

    Parameters
    ==========
    raster : Array of n dimensions.
    size : int or array-like
        Scalar or array/list of length n (numbers of dimensions of the input raster),
        giving a size for each axis.
    output : dtype of output

    Returns
    =======
    out : ndarray
    """

    size = np.array(size)

    raster = raster.astype("float64") # avoid overflow

    if size.ndim == 0:
        size = np.repeat(size.item(), raster.ndim)

    # Estimate local mean
    local_mean = signal.correlate(raster, np.ones(size), mode='same') / np.prod(size, axis=0)

    local_variance = (signal.correlate(raster ** 2, np.ones(size), mode='same') /
                      np.prod(size, axis=0) - local_mean ** 2)

    if output:
        return local_variance.astype(output)
    else:
        return local_variance


class SciPyEstimateVarianceAlgorithm(SciPyAlgorithm):
    """
    Estimate local variance 
    
    Implementation based on the 
    `source code of scipy.signal.wiener <https://github.com/scipy/scipy/blob/v1.8.0/scipy/signal/_signaltools.py#L1541-L1615>`_.
    using correlate from `scipy.signal <https://docs.scipy.org/doc/scipy/reference/signal.html>`_.

    .. note:: No data cells within the filter radius are filled with 0.

    **Dimension** Calculate for each band separately (2D) 
    or use all bands as a 3D datacube and perform filter in 3D. 
    
    .. note:: bands will be the first axis of the datacube.

    **Size** Size of filter window.
    """
    # Overwrite constants of base class
    _name = 'estimate_var'
    _displayname = tr('Estimate local variance')
    _outputname = None # If set to None, the displayname is used 
    _groupid = "statistic" 
    
    SIZE = 'SIZE'
    SIZES = 'SIZES'
    
    def insert_parameters(self, config):
        
       
        size_param = QgsProcessingParameterNumber(
            self.SIZE,
            tr('Size of filter'),
            QgsProcessingParameterNumber.Type.Integer,
            defaultValue=3, 
            optional=True, 
            minValue=1, 
            maxValue=20, # Large sizes are really slow
            )
        
        size_param.setFlags(size_param.flags() | QgsProcessingParameterDefinition.Flag.FlagHidden)

        self.addParameter(size_param)  

        sizes_param = QgsProcessingParameterString(
            self.SIZES,
            tr('Size'),
            defaultValue="", 
            optional=True, 
            )
        
        sizes_param.setMetadata({
            'my_flags': {
                'positive': True # greater than zero
            }
        })

        self.addParameter(sizes_param)  
        
        super().insert_parameters(config)
    
    def checkParameterValues(self, parameters, context): 
        dims = self.getDimsForCheck(parameters, context)
        
        sizes = self.parameterAsString(parameters, self.SIZES, context)
        sizes = str_to_int_or_list(sizes)
        if isinstance(sizes, list):
            if len(sizes) != dims:
                return (False, tr("Sizes does not match number of dimensions"))

        return super().checkParameterValues(parameters, context)
    
    def get_parameters(self, parameters, context):
        kwargs = super().get_parameters(parameters, context)

        sizes = self.parameterAsString(parameters, self.SIZES, context)
        if sizes:
            size = str_to_int_or_list(sizes)
            self.margin = max(size)
            self.margin = size
        else:
            size = self.parameterAsInt(parameters, self.SIZE, context)
        if not size:
            # Just in case it is called from python and neither size or sizes or footprint is set
            size = 3
            self.margin = 3

        kwargs['size'] = size

        return kwargs
    
    # The function to be called, to be overwritten
    def get_fct(self):
        return estimate_local_variance
        
    def createInstance(self):
        return SciPyEstimateVarianceAlgorithm()  



class SciPyEstimateStdAlgorithm(SciPyEstimateVarianceAlgorithm):
    """
    Estimate local variance. Implementation based on the 
    `source code of scipy.signal.wiener <https://github.com/scipy/scipy/blob/v1.8.0/scipy/signal/_signaltools.py#L1541-L1615>`_.
    using correlate from `scipy.signal <https://docs.scipy.org/doc/scipy/reference/signal.html>`_.

    .. note:: No data cells within the filter radius are filled with 0.

    **Dimension** Calculate for each band separately (2D) 
    or use all bands as a 3D datacube and perform filter in 3D. 

    .. note:: bands will be the first axis of the datacube.

    **Size** Size of filter window.
    
    """
    # Overwrite constants of base class
    _name = 'estimate_std'
    _displayname = tr('Estimate local standard deviation')
    _outputname = None # If set to None, the displayname is used 
    _groupid = "statistic" 

    

    
    # The function to be called, to be overwritten
    def get_fct(self):
        return self.estimate_local_std
    

    def estimate_local_std(self, raster, size, output):
        local_variance = estimate_local_variance(raster, size, None)
        return np.sqrt(local_variance).astype(output)
        
        
    def createInstance(self):
        return SciPyEstimateStdAlgorithm()  
    



# Very slow calculation of Std, disabled
    
class SciPyStdAlgorithm(SciPyStatisticalAlgorithm):
    """
    Local standard deviation.
    Calculated with generic_filter from 
    <a href="https://docs.scipy.org/doc/scipy/reference/ndimage.html">scipy.ndimage</a> 
    and numpy.std.

    **Warning: Very slow!**

    .. note:: No data cells within the filter radius are filled with 0.

    **Dimension** Calculate for each band separately (2D) 
    or use all bands as a 3D datacube and perform filter in 3D. 

    .. note:: bands will be the first axis of the datacube.

    **Delta degrees of freedom** of the standard deviation 
    (ddof in numpy). With ddof=0, the std is calculated with 
    1/N, with ddof=1 with 1/(N-1).

    **Size** Size of filter if no footprint is given. Equivalent 
    to a footprint array of shape size × size [× size in 3D] 
    filled with ones.
    
    **Footprint** String representation of array, specifiying 
    the kernel of the filter. 
    Must have 2 dimensions if *dimension* is set to 2D. 
    Should have 3 dimensions if *dimension* is set to 3D, 
    but a 2D array is also excepted (a new axis is added as first 
    axis and the result is the same as calculating each band 
    seperately).

    **Border mode** determines how input is extended around 
    the edges: *Reflect* (input is extended by reflecting at the edge), 
    *Constant* (fill around the edges with a **constant value**), 
    *Nearest* (extend by replicating the nearest pixel), 
    *Mirror* (extend by reflecting about the center of last pixel), 
    *Wrap* (extend by wrapping around to the opposite edge).
    """

    DDOF = 'DDOF'

    # Overwrite constants of base class
    _name = 'std'
    _displayname = 'Local standard deviation (very slow)'
    _outputname = None # If set to None, the displayname is used 
    _groupid = "statistic" 

    
    def insert_parameters(self, config):
        
        self.addParameter(QgsProcessingParameterNumber(
            self.DDOF,
            tr('Delta degrees of freedom (ddof)'),
            QgsProcessingParameterNumber.Type.Integer,
            optional=False, 
            defaultValue=1,
            minValue=0,
            maxValue=1,
            ))   
        
        super().insert_parameters(config)

    def get_parameters(self, parameters, context):
        kwargs = super().get_parameters(parameters, context)

        ddof = self.parameterAsInt(parameters, self.DDOF, context) 

        kwargs["function"] = np.std
        kwargs["extra_keywords"] = {"ddof": ddof}

        return kwargs
    
    # The function to be called, to be overwritten
    def get_fct(self):
        # np.std is passed in get_parameters as kwarg
        return ndimage.generic_filter
    
        
    def createInstance(self):
        return SciPyStdAlgorithm()  