# -*- coding: utf-8 -*-

"""
/***************************************************************************
 SciPyFilters
                                 A QGIS plugin
 Filter collection implemented with SciPy
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2024-03-03
        copyright            : (C) 2024 by Florian Neukirchen
        email                : mail@riannek.de
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""

__author__ = 'Florian Neukirchen'
__date__ = '2024-03-03'
__copyright__ = '(C) 2024 by Florian Neukirchen'

# This will get replaced with a git SHA1 when you do a git archive

__revision__ = '$Format:%H$'

import numpy as np
from osgeo import gdal
from scipy import ndimage, signal
from qgis.PyQt.QtCore import QCoreApplication
from qgis.core import (QgsProcessing,
                       QgsProcessingAlgorithm,
                       QgsProcessingParameterRasterLayer,
                       QgsProcessingParameterNumber,
                       QgsProcessingParameterRasterDestination,
                       QgsProcessingParameterEnum,
                       QgsProcessingParameterBand,
                        )
from .scipy_algorithm_baseclasses import (SciPyAlgorithm,
                                          SciPyAlgorithmWithMode,
                                          SciPyAlgorithmWithModeAxis,
                                          SciPyStatisticalAlgorithm)


def estimate_local_variance(raster, size):
    """
    Estimate local variance within window of size.
    
    Implementation based on parts of the source code of 
    the scipy.signal.wiener filter:
    https://github.com/scipy/scipy/blob/v1.8.0/scipy/signal/_signaltools.py#L1541-L1615
    which is (c) 1999-2002 Travis Oliphant under BSD-3-Clause license.

    Parameters
    ==========
    raster : Array of n dimensions.
    size : int or array-like
        Scalar or array/list of length n (numbers of dimensions of the input raster),
        giving a size for each axis.

    Returns
    =======
    out : ndarray
    """

    size = np.array(size)

    dtype = raster.dtype
    raster = raster.astype("float64") # avoid overflow

    if size.ndim == 0:
        size = np.repeat(size.item(), raster.ndim)

    # Estimate local mean
    local_mean = signal.correlate(raster, np.ones(size), mode='same') / np.prod(size, axis=0)

    local_variance = (signal.correlate(raster ** 2, np.ones(size), mode='same') /
                      np.prod(size, axis=0) - local_mean ** 2)

    return local_variance.astype(dtype)


class SciPyEstimateVarianceAlgorithm(SciPyAlgorithm):
    # Overwrite constants of base class
    _name = 'estimate_var'
    _displayname = 'Estimate local variance'
    _outputname = None # If set to None, the displayname is used 
    _groupid = "statistic" 
    _help = """
            Estimate local variance. Implementation based on the \
            <a href="https://github.com/scipy/scipy/blob/v1.8.0/scipy/signal/_signaltools.py#L1541-L1615">source code of scipy.signal.wiener</a> \
            using correlate from <a href="https://docs.scipy.org/doc/scipy/reference/signal.html">scipy.signal</a>.

            <b>Dimension</b> Calculate for each band separately (2D) \
            or use all bands as a 3D datacube and perform filter in 3D. \
            Note: bands will be the first axis of the datacube.

            <b>Size</b> Size of filter window.
            
            """
    
    SIZE = 'SIZE'
    
    def insert_parameters(self, config):
        
        self.addParameter(QgsProcessingParameterNumber(
            self.SIZE,
            self.tr('Size'),
            QgsProcessingParameterNumber.Type.Integer,
            defaultValue=3, 
            optional=False, 
            ))   
        
        super().insert_parameters(config)

    
    def get_parameters(self, parameters, context):
        kwargs = super().get_parameters(parameters, context)

        kwargs["size"] = self.parameterAsInt(parameters, self.SIZE, context) 

        return kwargs
    
    # The function to be called, to be overwritten
    def get_fct(self):
        return estimate_local_variance
        
    def createInstance(self):
        return SciPyEstimateVarianceAlgorithm()  



class SciPyEstimateStdAlgorithm(SciPyEstimateVarianceAlgorithm):
    # Overwrite constants of base class
    _name = 'estimate_std'
    _displayname = 'Estimate local standard deviation'
    _outputname = None # If set to None, the displayname is used 
    _groupid = "statistic" 
    _help = """
            Estimate local variance. Implementation based on the \
            <a href="https://github.com/scipy/scipy/blob/v1.8.0/scipy/signal/_signaltools.py#L1541-L1615">source code of scipy.signal.wiener</a> \
            using correlate from <a href="https://docs.scipy.org/doc/scipy/reference/signal.html">scipy.signal</a>.

            <b>Dimension</b> Calculate for each band separately (2D) \
            or use all bands as a 3D datacube and perform filter in 3D. \
            Note: bands will be the first axis of the datacube.

            <b>Size</b> Size of filter window.
            
            """
    

    
    # The function to be called, to be overwritten
    def get_fct(self):
        return self.estimate_local_std
    

    def estimate_local_std(self, raster, size):
        local_variance = estimate_local_variance(raster, size)
        return np.sqrt(local_variance)
        
        
    def createInstance(self):
        return SciPyEstimateStdAlgorithm()  
    



# Very slow calculation of Std:
    
class SciPyStdAlgorithm(SciPyStatisticalAlgorithm):

    DDOF = 'DDOF'

    # Overwrite constants of base class
    _name = 'std'
    _displayname = 'Local standard deviation (very slow)'
    _outputname = None # If set to None, the displayname is used 
    _groupid = "statistic" 
    _help = """
            Local standard deviation.\
            Calculated with generic_filter from \
            <a href="https://docs.scipy.org/doc/scipy/reference/ndimage.html">scipy.ndimage</a> \
            and numpy.std.

            Warning: Very slow! 

            <b>Dimension</b> Calculate for each band separately (2D) \
            or use all bands as a 3D datacube and perform filter in 3D. \
            Note: bands will be the first axis of the datacube.

            <b>Delta degrees of freedom</b> of the standard deviation \
            (ddof in numpy). With ddof=0, the std is calculated with \
            1/N, with ddof=1 with 1/(N-1).

            <b>Size</b> Size of filter if no footprint is given. Equivalent \
            to a footprint array of shape size × size [× size in 3D] \
            filled with ones.
            
            <b>Footprint</b> String representation of array, specifiying \
            the kernel of the filter. \
            Must have 2 dimensions if <i>dimension</i> is set to 2D. \
            Should have 3 dimensions if <i>dimension</i> is set to 3D, \
            but a 2D array is also excepted (a new axis is added as first \
            axis and the result is the same as calculating each band \
            seperately).

            <b>Border mode</b> determines how input is extended around \
            the edges: <i>Reflect</i> (input is extended by reflecting at the edge), \
            <i>Constant</i> (fill around the edges with a <b>constant value</b>), \
            <i>Nearest</i> (extend by replicating the nearest pixel), \
            <i>Mirror</i> (extend by reflecting about the center of last pixel), \
            <i>Wrap</i> (extend by wrapping around to the opposite edge).
            """
    
    def insert_parameters(self, config):
        
        self.addParameter(QgsProcessingParameterNumber(
            self.DDOF,
            self.tr('Delta degrees of freedom (ddof)'),
            QgsProcessingParameterNumber.Type.Integer,
            optional=False, 
            defaultValue=1,
            minValue=0,
            maxValue=1,
            ))   
        
        super().insert_parameters(config)

    def get_parameters(self, parameters, context):
        kwargs = super().get_parameters(parameters, context)

        ddof = self.parameterAsInt(parameters, self.DDOF, context) 

        kwargs["function"] = np.std
        kwargs["extra_keywords"] = {"ddof": ddof}

        return kwargs
    
    # The function to be called, to be overwritten
    def get_fct(self):
        # np.std is passed in get_parameters as kwarg
        return ndimage.generic_filter
    
        
    def createInstance(self):
        return SciPyStdAlgorithm()  