# -*- coding: utf-8 -*-

"""
/***************************************************************************
 SciPyFilters
                                 A QGIS plugin
 Filter collection implemented with SciPy
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2024-03-03
        copyright            : (C) 2024 by Florian Neukirchen
        email                : mail@riannek.de
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""

__author__ = 'Florian Neukirchen'
__date__ = '2024-03-03'
__copyright__ = '(C) 2024 by Florian Neukirchen'

# This will get replaced with a git SHA1 when you do a git archive

__revision__ = '$Format:%H$'


from scipy import ndimage

from qgis.core import (QgsProcessingParameterNumber,
                       QgsProcessingParameterString,
                       QgsProcessingParameterDefinition)

from scipy_filters.scipy_algorithm_baseclasses import (SciPyAlgorithmWithMode,
                                          SciPyStatisticalAlgorithm)

from scipy_filters.helpers import (str_to_int_or_list, 
                       minimumvalue,
                       maximumvalue,
                       bandmean)

from scipy_filters.ui.i18n import tr


class SciPyMedianAlgorithm(SciPyStatisticalAlgorithm):
    """
    Median filter

    Calculated with median_filter from 
    `scipy.ndimage <https://docs.scipy.org/doc/scipy/reference/ndimage.html>`_.

    .. note:: No data cells within the filter radius are filled with the band mean.

    **Dimension** Calculate for each band separately (2D) 
    or use all bands as a 3D datacube and perform filter in 3D. 

    .. note:: bands will be the first axis of the datacube.

    **Size** Size of filter in pixels if no footprint is given. Equivalent 
    to a footprint array of shape size_rows × size_cols (in 2D) or 
    size_bands × size_rows × size_cols (in 3D) 
    filled with ones.
    
    **Footprint** Positions of elements of a flat structuring element 
    used for the filter (as string representation of array). 
    Must have 2 dimensions if *dimension* is set to 2D. 
    Should have 3 dimensions if *dimension* is set to 3D, 
    but a 2D array is also excepted (a new axis is added as first 
    axis and the result is the same as calculating each band 
    seperately). Examples can be loaded with the load button. 
    For convenience (i.e. when calling from a script), 
    the following shortcuts are accepted as well: 
    "square", "cross", "cross3D", "ball", "cube".


    **Origin** Shift the origin (hotspot) of the filter.

    **Border mode** determines how input is extended around 
    the edges: *Reflect* (input is extended by reflecting at the edge), 
    *Constant* (fill around the edges with a **constant value**), 
    *Nearest* (extend by replicating the nearest pixel), 
    *Mirror* (extend by reflecting about the center of last pixel), 
    *Wrap* (extend by wrapping around to the opposite edge).
    """

    # Overwrite constants of base class
    _name = 'median'
    _displayname = tr('Median filter')
    _outputname = None # If set to None, the displayname is used 
    _groupid = "statistic" 
    
    # The function to be called, to be overwritten
    def get_fct(self):
        return ndimage.median_filter
    
    def fill_nodata(self, array, nodata, band):
        array[array == nodata] = bandmean(self.ds, band)
        
    def createInstance(self):
        return SciPyMedianAlgorithm()  


class SciPyMinimumAlgorithm(SciPyStatisticalAlgorithm):
    """
    Minimum filter

    Calculated with minimum_filter from 
    `scipy.ndimage <https://docs.scipy.org/doc/scipy/reference/ndimage.html>`_.

    .. note:: No data cells within the filter radius are filled with the maximum of the data type.

    **Dimension** Calculate for each band separately (2D) 
    or use all bands as a 3D datacube and perform filter in 3D. 

    .. note:: bands will be the first axis of the datacube.

    **Size** Size of filter in pixels if no footprint is given. Equivalent 
    to a footprint array of shape size_rows × size_cols (in 2D) or 
    size_bands × size_rows × size_cols (in 3D) 
    filled with ones.

    **Footprint** Positions of elements of a flat structuring element 
    used for the filter (as string representation of array). 
    Must have 2 dimensions if *dimension* is set to 2D. 
    Should have 3 dimensions if *dimension* is set to 3D, 
    but a 2D array is also excepted (a new axis is added as first 
    axis and the result is the same as calculating each band 
    seperately). Examples can be loaded with the load button. 
    For convenience (i.e. when calling from a script), 
    the following shortcuts are accepted as well: 
    "square", "cross", "cross3D", "ball", "cube".

    **Origin** Shift the origin (hotspot) of the filter.

    **Border mode** determines how input is extended around 
    the edges: *Reflect* (input is extended by reflecting at the edge), 
    *Constant* (fill around the edges with a **constant value**), 
    *Nearest* (extend by replicating the nearest pixel), 
    *Mirror* (extend by reflecting about the center of last pixel), 
    *Wrap* (extend by wrapping around to the opposite edge).
    """
    
    # Overwrite constants of base class
    _name = 'minimum'
    _displayname = tr('Minimum filter')
    _outputname = None # If set to None, the displayname is used 
    _groupid = "statistic" 

    
    # The function to be called, to be overwritten
    def get_fct(self):
        return ndimage.minimum_filter
    
    def fill_nodata(self, array, nodata, band=None):
        fillvalue = maximumvalue(array.dtype)
        array[array == nodata] = fillvalue
        
    def createInstance(self):
        return SciPyMinimumAlgorithm()      
    

class SciPyMaximumAlgorithm(SciPyStatisticalAlgorithm):
    """
    Maximum filter.
    Calculated with maximum_filter from 
    `scipy.ndimage <https://docs.scipy.org/doc/scipy/reference/ndimage.html>`_.

    .. note:: No data cells within the filter radius are filled with the minimum of the data type.

    **Dimension** Calculate for each band separately (2D) 
    or use all bands as a 3D datacube and perform filter in 3D. 

    .. note:: bands will be the first axis of the datacube.

    **Size** Size of filter in pixels if no footprint is given. Equivalent 
    to a footprint array of shape size_rows × size_cols (in 2D) or 
    size_bands × size_rows × size_cols (in 3D) 
    filled with ones.

    **Footprint** Positions of elements of a flat structuring element 
    used for the filter (as string representation of array). 
    Must have 2 dimensions if *dimension* is set to 2D. 
    Should have 3 dimensions if *dimension* is set to 3D, 
    but a 2D array is also excepted (a new axis is added as first 
    axis and the result is the same as calculating each band 
    seperately). Examples can be loaded with the load button. 
    For convenience (i.e. when calling from a script), 
    the following shortcuts are accepted as well: 
    "square", "cross", "cross3D", "ball", "cube".

    **Origin** Shift the origin (hotspot) of the filter.

    **Border mode** determines how input is extended around 
    the edges: *Reflect* (input is extended by reflecting at the edge), 
    *Constant* (fill around the edges with a **constant value**), 
    *Nearest* (extend by replicating the nearest pixel), 
    *Mirror* (extend by reflecting about the center of last pixel), 
    *Wrap* (extend by wrapping around to the opposite edge).
    """
    
    # Overwrite constants of base class
    _name = 'maximum'
    _displayname = tr('Maximum filter')
    _outputname = None # If set to None, the displayname is used 
    _groupid = "statistic" 

    
    # The function to be called, to be overwritten
    def get_fct(self):
        return ndimage.maximum_filter
    
    def fill_nodata(self, array, nodata, band=None):
        fillvalue = minimumvalue(array.dtype)
        array[array == nodata] = fillvalue
        
    def createInstance(self):
        return SciPyMaximumAlgorithm() 



class SciPyRangeAlgorithm(SciPyStatisticalAlgorithm):
    """
    Range filter, returns the difference of min and max within the neighborhood

    Calculated with minimum_filter and maximum_filter from 
    `scipy.ndimage <https://docs.scipy.org/doc/scipy/reference/ndimage.html>`_.

    .. note:: No data cells within the filter radius are filled with the band mean.

    **Dimension** Calculate for each band separately (2D) 
    or use all bands as a 3D datacube and perform filter in 3D. 

    .. note:: bands will be the first axis of the datacube.

    **Size** Size of filter in pixels if no footprint is given. Equivalent 
    to a footprint array of shape size_rows × size_cols (in 2D) or 
    size_bands × size_rows × size_cols (in 3D) 
    filled with ones.

    **Footprint** Positions of elements of a flat structuring element 
    used for the filter (as string representation of array). 
    Must have 2 dimensions if *dimension* is set to 2D. 
    Should have 3 dimensions if *dimension* is set to 3D, 
    but a 2D array is also excepted (a new axis is added as first 
    axis and the result is the same as calculating each band 
    seperately). Examples can be loaded with the load button. 
    For convenience (i.e. when calling from a script), 
    the following shortcuts are accepted as well: 
    "square", "cross", "cross3D", "ball", "cube".

    **Origin** Shift the origin (hotspot) of the filter.

    **Border mode** determines how input is extended around 
    the edges: *Reflect* (input is extended by reflecting at the edge), 
    *Constant* (fill around the edges with a **constant value**), 
    *Nearest* (extend by replicating the nearest pixel), 
    *Mirror* (extend by reflecting about the center of last pixel), 
    *Wrap* (extend by wrapping around to the opposite edge).
    """
    
    # Overwrite constants of base class
    _name = 'range'
    _displayname = tr('Local range (difference of min and max)')
    _outputname = tr('Range')
    _groupid = "statistic" 
    
    # The function to be called, to be overwritten
    def get_fct(self):
        return self.rangefilter
    
    def rangefilter(self, raster, **kwargs):
        maximum = ndimage.maximum_filter(raster, **kwargs)
        minimum = ndimage.minimum_filter(raster, **kwargs)
        return maximum - minimum
    
    def fill_nodata(self, array, nodata, band):
        array[array == nodata] = bandmean(self.ds, band)

    def createInstance(self):
        return SciPyRangeAlgorithm() 




class SciPyPercentileAlgorithm(SciPyStatisticalAlgorithm):
    """
    Percentile filter

    Calculated with percentile_filter from 
    `scipy.ndimage <https://docs.scipy.org/doc/scipy/reference/ndimage.html>`_.

    .. note:: No data cells within the filter radius are filled with the band mean.

    **Dimension** Calculate for each band separately (2D) 
    or use all bands as a 3D datacube and perform filter in 3D. 

    .. note:: bands will be the first axis of the datacube.

    **Percentile** Percentile from 0 to 100. Negative values: 
    use 100 - given value as percentile.

    **Size** Size of filter in pixels if no footprint is given. Equivalent 
    to a footprint array of shape size_rows × size_cols (in 2D) or 
    size_bands × size_rows × size_cols (in 3D) 
    filled with ones.

    **Footprint** Positions of elements of a flat structuring element 
    used for the filter (as string representation of array). 
    Must have 2 dimensions if *dimension* is set to 2D. 
    Should have 3 dimensions if *dimension* is set to 3D, 
    but a 2D array is also excepted (a new axis is added as first 
    axis and the result is the same as calculating each band 
    seperately). Examples can be loaded with the load button. 
    For convenience (i.e. when calling from a script), 
    the following shortcuts are accepted as well: 
    "square", "cross", "cross3D", "ball", "cube".

    **Origin** Shift the origin (hotspot) of the filter.

    **Border mode** determines how input is extended around 
    the edges: *Reflect* (input is extended by reflecting at the edge), 
    *Constant* (fill around the edges with a **constant value**), 
    *Nearest* (extend by replicating the nearest pixel), 
    *Mirror* (extend by reflecting about the center of last pixel), 
    *Wrap* (extend by wrapping around to the opposite edge).
    """

    PERCENTILE = 'PERCENTILE'

    # Overwrite constants of base class
    _name = 'percentile'
    _displayname = tr('Percentile filter')
    _outputname = None # If set to None, the displayname is used 
    _groupid = "statistic" 
    
    # The function to be called, to be overwritten
    def get_fct(self):
        return ndimage.percentile_filter
    
    def fill_nodata(self, array, nodata, band):
        array[array == nodata] = bandmean(self.ds, band)

    def insert_parameters(self, config):
        
        self.addParameter(QgsProcessingParameterNumber(
            self.PERCENTILE,
            tr('Percentile'),
            QgsProcessingParameterNumber.Type.Integer,
            optional=False, 
            minValue=-100,
            maxValue=100,
            defaultValue=80
            ))   
        
        super().insert_parameters(config)

    def get_parameters(self, parameters, context):
        kwargs = super().get_parameters(parameters, context)

        kwargs['percentile'] = self.parameterAsInt(parameters, self.PERCENTILE, context) 

        return kwargs
        
    def createInstance(self):
        return SciPyPercentileAlgorithm() 
    

# Disabled, needs checks. E.g. rank must be < size or footprint

class SciPyRankAlgorithm(SciPyStatisticalAlgorithm):
    """
    Rank filter

    Calculated with rank_filter from 
    `scipy.ndimage <https://docs.scipy.org/doc/scipy/reference/ndimage.html>`_.

    .. note:: No data cells within the filter radius are filled with the band mean.

    The filter calculates a histogram for the neighborhood 
    (specified by footprint or size) and returns the value 
    at the position of *rank*. 
    
    .. note:: a median filter is a special case with rank = 0.5 * size of the footprint.

    **Dimension** Calculate for each band separately (2D) 
    or use all bands as a 3D datacube and perform filter in 3D. 
    .. note:: bands will be the first axis of the datacube.

    **Rank** Index of the element in the array of the local histogram 
    to be returned. Ranges from 0 (smallest element) to 
    the size of the footprint minus one. 
    If using size instead of footprint, the resulting footprint size 
    is size raised by the power of the number of dimensions. 
    The rank parameter may be less than zero: 
    rank = -1 indicates the largest element, etc.

    **Size** Size of filter in pixels if no footprint is given. Equivalent 
    to a footprint array of shape size_rows × size_cols (in 2D) or 
    size_bands × size_rows × size_cols (in 3D) 
    filled with ones.

    **Footprint** Positions of elements of a flat structuring element 
    used for the filter (as string representation of array). 
    Must have 2 dimensions if *dimension* is set to 2D. 
    Should have 3 dimensions if *dimension* is set to 3D, 
    but a 2D array is also excepted (a new axis is added as first 
    axis and the result is the same as calculating each band 
    seperately). Examples can be loaded with the load button. 
    For convenience (i.e. when calling from a script), 
    the following shortcuts are accepted as well: 
    "square", "cross", "cross3D", "ball", "cube".

    **Origin** Shift the origin (hotspot) of the filter.

    **Border mode** determines how input is extended around 
    the edges: *Reflect* (input is extended by reflecting at the edge), 
    *Constant* (fill around the edges with a **constant value**), 
    *Nearest* (extend by replicating the nearest pixel), 
    *Mirror* (extend by reflecting about the center of last pixel), 
    *Wrap* (extend by wrapping around to the opposite edge).
    """

    RANK = 'RANK'

    # Overwrite constants of base class
    _name = 'rank'
    _displayname = tr('Rank filter')
    _outputname = None # If set to None, the displayname is used 
    _groupid = "statistic" 
    
    # The function to be called, to be overwritten
    def get_fct(self):
        return ndimage.rank_filter
    
    def insert_parameters(self, config):
        
        self.addParameter(QgsProcessingParameterNumber(
            self.RANK,
            tr('Rank'),
            QgsProcessingParameterNumber.Type.Integer,
            optional=False, 
            ))   
        
        super().insert_parameters(config)

    def get_parameters(self, parameters, context):
        kwargs = super().get_parameters(parameters, context)

        kwargs['rank'] = self.parameterAsInt(parameters, self.RANK, context) 

        return kwargs
    
    def fill_nodata(self, array, nodata, band):
        array[array == nodata] = bandmean(self.ds, band)

    def createInstance(self):
        return SciPyRankAlgorithm() 
    



class SciPyUniformAlgorithm(SciPyAlgorithmWithMode):
    """
    Uniform filter (a.k.a. box filter or mean filter)
    
    Calculated with uniform_filter from 
    `scipy.ndimage <https://docs.scipy.org/doc/scipy/reference/ndimage.html>`_.

    .. note:: No data cells within the filter radius are filled with the band mean.

    **Dimension** Calculate for each band separately (2D) 
    or use all bands as a 3D datacube and perform filter in 3D. 

    .. note:: bands will be the first axis of the datacube.

    **Size** Size of filter in pixels if no footprint is given. Equivalent 
    to a footprint array of shape size_rows × size_cols (in 2D) or 
    size_bands × size_rows × size_cols (in 3D) 
    filled with ones.

    **Border mode** determines how input is extended around 
    the edges: *Reflect* (input is extended by reflecting at the edge), 
    *Constant* (fill around the edges with a **constant value**), 
    *Nearest* (extend by replicating the nearest pixel), 
    *Mirror* (extend by reflecting about the center of last pixel), 
    *Wrap* (extend by wrapping around to the opposite edge).
    """

    # Note: Does not have footprint

    SIZE = 'SIZE'
    SIZES = 'SIZES'

    # Overwrite constants of base class
    _name = 'uniform'
    _displayname = tr('Uniform filter (box filter)')
    _outputname = tr('Uniform filter') 
    _groupid = "blur" 

    
    # The function to be called, to be overwritten
    def get_fct(self):
        return ndimage.uniform_filter
    
    def initAlgorithm(self, config):
        # Call the super function first
        # (otherwise input is not the first parameter in the GUI)
        super().initAlgorithm(config)

  
        size_param = QgsProcessingParameterNumber(
            self.SIZE,
            tr('Size of flat structuring element (either size or footprint must be given, with footprint, size is ignored)'),
            QgsProcessingParameterNumber.Type.Integer,
            defaultValue=1, 
            optional=True, 
            minValue=1, 
            maxValue=20, # Large sizes are really slow
            )
        
        size_param.setFlags(size_param.flags() | QgsProcessingParameterDefinition.Flag.FlagHidden)

        self.addParameter(size_param)


        sizes_param = QgsProcessingParameterString(
            self.SIZES,
            tr('Size'),
            defaultValue="", 
            optional=True, 
            )

        self.addParameter(sizes_param)   


    def fill_nodata(self, array, nodata, band):
        array[array == nodata] = bandmean(self.ds, band)

    def checkParameterValues(self, parameters, context): 
        dims = self.getDimsForCheck(parameters, context)
        
        sizes = self.parameterAsString(parameters, self.SIZES, context)
        sizes = str_to_int_or_list(sizes)
        if isinstance(sizes, list):
            if len(sizes) != dims:
                return (False, tr("Sizes does not match number of dimensions"))

        return super().checkParameterValues(parameters, context)
    
        
    def get_parameters(self, parameters, context):
        kwargs = super().get_parameters(parameters, context)


        sizes = self.parameterAsString(parameters, self.SIZES, context)
        if sizes:
            size = str_to_int_or_list(sizes)
            self.margin = max(size)
        else:
            size = self.parameterAsInt(parameters, self.SIZE, context)
            self.margin = size
        if not size:
            # Just in case it is called from python and neither size or sizes or footprint is set
            size = 3
            self.margin = 3
        kwargs['size'] = size

        return kwargs

   
    def createInstance(self):
        return SciPyUniformAlgorithm()
    

