# -*- coding: utf-8 -*-

"""
/***************************************************************************
 SciPyFilters
                                 A QGIS plugin
 Filter collection implemented with SciPy
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2024-03-03
        copyright            : (C) 2024 by Florian Neukirchen
        email                : mail@riannek.de
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""

__author__ = 'Florian Neukirchen'
__date__ = '2024-03-03'
__copyright__ = '(C) 2024 by Florian Neukirchen'

# This will get replaced with a git SHA1 when you do a git archive

__revision__ = '$Format:%H$'

import numpy as np
                      
from scipy_filters.scipy_algorithm_baseclasses import SciPyAlgorithm, Dimensions
from scipy_filters.ui.i18n import tr


class SciPyPixelMinAlgorithm(SciPyAlgorithm):
    """
    Pixel Statistics Minimum Filter

    Returns minimum of all bands for each individual pixel
    """

    # Overwrite constants of base class
    _name = 'pixel_min'
    _displayname = tr('Pixel minimum filter')
    _outputname = tr('Pixel minimum')
    _groupid = "pixel" 
    _outbands = 1

    
    # The function to be called, to be overwritten
    def get_fct(self):
        return self.myfnct
    
    def myfnct(self, a, **kwargs):
        kwargs["axis"] = 0
        dtype = kwargs.pop("output")
        return np.min(a, **kwargs).astype(dtype)
    

    def initAlgorithm(self, config):
        # Set dimensions to 3
        self._dimension = Dimensions.threeD
        super().initAlgorithm(config)
        
    def checkParameterValues(self, parameters, context):
        layer = self.parameterAsRasterLayer(parameters, self.INPUT, context)
        if layer.bandCount() == 1:
            return (False, tr("Pixel statistics only possible if layer has more than 1 band."))
        return super().checkParameterValues(parameters, context)
        
    def createInstance(self):
        return SciPyPixelMinAlgorithm()
    


class SciPyPixelMaxAlgorithm(SciPyAlgorithm):
    """
    Pixel Statistics Maximum Filter

    Returns maximum of all bands for each individual pixel
    """


    # Overwrite constants of base class
    _name = 'pixel_max'
    _displayname = tr('Pixel maximum filter')
    _outputname = tr('Pixel maximum')
    _groupid = "pixel" 
    _outbands = 1
    
    # The function to be called, to be overwritten
    def get_fct(self):
        return self.myfnct
    
    def myfnct(self, a, **kwargs):
        kwargs["axis"] = 0
        dtype = kwargs.pop("output")
        return np.max(a, **kwargs).astype(dtype)
    

    def initAlgorithm(self, config):
        # Set dimensions to 3
        self._dimension = Dimensions.threeD
        super().initAlgorithm(config)
        

    def checkParameterValues(self, parameters, context):
        layer = self.parameterAsRasterLayer(parameters, self.INPUT, context)
        if layer.bandCount() == 1:
            return (False, tr("Pixel statistics only possible if layer has more than 1 band."))
        return super().checkParameterValues(parameters, context)
    
    def createInstance(self):
        return SciPyPixelMaxAlgorithm()
    


class SciPyPixelMeanAlgorithm(SciPyAlgorithm):
    """
    Pixel Statistics Mean Filter

    Returns mean of all bands for each individual pixel
    """


    # Overwrite constants of base class
    _name = 'pixel_mean'
    _displayname = tr('Pixel mean filter')
    _outputname = tr('Pixel mean')
    _groupid = "pixel" 
    _outbands = 1
    
    # The function to be called, to be overwritten
    def get_fct(self):
        return self.myfnct
    
    def myfnct(self, a, **kwargs):
        kwargs["axis"] = 0
        dtype = kwargs.pop("output")
        return np.mean(a, **kwargs).astype(dtype)
    

    def initAlgorithm(self, config):
        # Set dimensions to 3
        self._dimension = Dimensions.threeD
        super().initAlgorithm(config)

    def checkParameterValues(self, parameters, context):
        layer = self.parameterAsRasterLayer(parameters, self.INPUT, context)
        if layer.bandCount() == 1:
            return (False, tr("Pixel statistics only possible if layer has more than 1 band."))
        return super().checkParameterValues(parameters, context)        

    def createInstance(self):
        return SciPyPixelMeanAlgorithm()
    

class SciPyPixelMedianAlgorithm(SciPyAlgorithm):
    """
    Pixel Statistics Median Filter

    Returns mean of all bands for each individual pixel
    """

    # Overwrite constants of base class
    _name = 'pixel_median'
    _displayname = tr('Pixel median filter')
    _outputname = tr('Pixel median')
    _groupid = "pixel" 
    _outbands = 1
    
    # The function to be called, to be overwritten
    def get_fct(self):
        return self.myfnct
    
    def myfnct(self, a, **kwargs):
        kwargs["axis"] = 0
        dtype = kwargs.pop("output")
        return np.median(a, **kwargs).astype(dtype)
    

    def initAlgorithm(self, config):
        # Set dimensions to 3
        self._dimension = Dimensions.threeD
        super().initAlgorithm(config)
        
    def checkParameterValues(self, parameters, context):
        layer = self.parameterAsRasterLayer(parameters, self.INPUT, context)
        if layer.bandCount() == 1:
            return (False, tr("Pixel statistics only possible if layer has more than 1 band."))
        return super().checkParameterValues(parameters, context)
    
    def createInstance(self):
        return SciPyPixelMedianAlgorithm()
    

class SciPyPixelStdAlgorithm(SciPyAlgorithm):
    """
    Pixel Statistics Standard Deviation

    Returns standard deviation of all bands for each individual pixel
    """

    # Overwrite constants of base class
    _name = 'pixel_std'
    _displayname = tr('Pixel standard deviation')
    _outputname = tr('Pixel std')
    _groupid = "pixel" 
    _outbands = 1

    
    # The function to be called, to be overwritten
    def get_fct(self):
        return self.myfnct
    
    def myfnct(self, a, **kwargs):
        kwargs["axis"] = 0
        dtype = kwargs.pop("output")
        return np.std(a, **kwargs).astype(dtype)
    

    def initAlgorithm(self, config):
        # Set dimensions to 3
        self._dimension = Dimensions.threeD
        super().initAlgorithm(config)
        
    def checkParameterValues(self, parameters, context):
        layer = self.parameterAsRasterLayer(parameters, self.INPUT, context)
        if layer.bandCount() == 1:
            return (False, tr("Pixel statistics only possible if layer has more than 1 band."))
        return super().checkParameterValues(parameters, context)
    
    def createInstance(self):
        return SciPyPixelStdAlgorithm()


class SciPyPixelVarAlgorithm(SciPyAlgorithm):
    """
    Pixel Statistics Variance
    
    Returns variance of all bands for each individual pixel
    """

    # Overwrite constants of base class
    _name = 'pixel_variance'
    _displayname = tr('Pixel variance')
    _outputname = None
    _groupid = "pixel" 
    _outbands = 1
    
    # The function to be called, to be overwritten
    def get_fct(self):
        return self.myfnct
    
    def myfnct(self, a, **kwargs):
        kwargs["axis"] = 0
        dtype = kwargs.pop("output")
        return np.var(a, **kwargs).astype(dtype)
    

    def initAlgorithm(self, config):
        # Set dimensions to 3
        self._dimension = Dimensions.threeD
        super().initAlgorithm(config)
        
    def checkParameterValues(self, parameters, context):
        layer = self.parameterAsRasterLayer(parameters, self.INPUT, context)
        if layer.bandCount() == 1:
            return (False, tr("Pixel statistics only possible if layer has more than 1 band."))
        return super().checkParameterValues(parameters, context)
    
    def createInstance(self):
        return SciPyPixelVarAlgorithm()

class SciPyPixelRangeAlgorithm(SciPyAlgorithm):
    """
    Pixel Statistics Range Filter

    Returns difference of max and min of all bands for each individual pixel
    """

    # Overwrite constants of base class
    _name = 'pixel_range'
    _displayname = tr('Pixel range filter')
    _outputname = tr('Pixel range')
    _groupid = "pixel" 
    _outbands = 1
    
    # The function to be called, to be overwritten
    def get_fct(self):
        return self.myfnct
    
    def myfnct(self, a, **kwargs):
        kwargs["axis"] = 0
        dtype = kwargs.pop("output")
        minimum = np.min(a, **kwargs)
        maximum = np.max(a, **kwargs)
        return  (maximum - minimum).astype(dtype)
    

    def initAlgorithm(self, config):
        # Set dimensions to 3
        self._dimension = Dimensions.threeD
        super().initAlgorithm(config)
        
    def checkParameterValues(self, parameters, context):
        layer = self.parameterAsRasterLayer(parameters, self.INPUT, context)
        if layer.bandCount() == 1:
            return (False, tr("Pixel statistics only possible if layer has more than 1 band."))
        return super().checkParameterValues(parameters, context)
    
    def createInstance(self):
        return SciPyPixelRangeAlgorithm()
    

    
class SciPyPixelMinMaxMeanAlgorithm(SciPyAlgorithm):
    """
    Complete pixel statistics

    Returns min, max, mean, median and std of all bands for each individual pixel
    """

    # Overwrite constants of base class
    _name = 'pixel_all'
    _displayname = tr('Complete pixel statistics')
    _outputname = None
    _groupid = "pixel" 
    _outbands = 5
    _band_desc = ["Min", "Max", "Mean", "Median", "Std"]
    
    # The function to be called, to be overwritten
    def get_fct(self):
        return self.myfnct
    
    def myfnct(self, a, **kwargs):
        kwargs["axis"] = 0
        dtype = kwargs.pop("output")

        out = np.zeros((5, a.shape[1], a.shape[2]), dtype)

        out[0] = np.min(a, **kwargs)
        out[1] = np.max(a, **kwargs)
        out[2] = np.mean(a, **kwargs)
        out[3] = np.median(a, **kwargs)
        out[4] = np.std(a, **kwargs)
        
        return  out
    

    def initAlgorithm(self, config):
        # Set dimensions to 3
        self._dimension = Dimensions.threeD
        super().initAlgorithm(config)
        
    def checkParameterValues(self, parameters, context):
        layer = self.parameterAsRasterLayer(parameters, self.INPUT, context)
        if layer.bandCount() == 1:
            return (False, tr("Pixel statistics only possible if layer has more than 1 band."))
        return super().checkParameterValues(parameters, context)
    
    def createInstance(self):
        return SciPyPixelMinMaxMeanAlgorithm()