# -*- coding: utf-8 -*-

"""
/***************************************************************************
 SciPyFilters
                                 A QGIS plugin
 Filter collection implemented with SciPy
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2024-03-03
        copyright            : (C) 2024 by Florian Neukirchen
        email                : mail@riannek.de
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""

__author__ = 'Florian Neukirchen'
__date__ = '2024-03-03'
__copyright__ = '(C) 2024 by Florian Neukirchen'

# This will get replaced with a git SHA1 when you do a git archive

__revision__ = '$Format:%H$'

import numpy as np

from qgis.core import (QgsProcessingParameterBoolean,
                QgsProcessingParameterEnum,)
               
from scipy_filters.scipy_algorithm_baseclasses import SciPyAlgorithm, Dimensions
from scipy_filters.helpers import bandmean
from scipy_filters.ui.i18n import tr

class SciPyGradientAlgorithm(SciPyAlgorithm):
    """
    Gradient filter. Returns gradient along x-axis, y-axis or the maximum gradient. Calculated with 
    `numpy.gradient <https://numpy.org/doc/stable/reference/generated/numpy.gradient.html>`_.

    See also Pixel Gradient in Pixel Based Filters.

    .. note:: No data cells within the filter radius are filled with 0.

    **Axis** Calculate along x-axis, y-axis or both (returning the hypotenuse of both vectors.)

    **Return absolute values** Gradient is calculated left to right along x-axis 
    or top to bottom of the raster, eventually returning negative values. If checked, 
    the absolute values are returned.
    """

    AXIS = "AXIS"
    MAPUNITS = "MAPUNITS"
    ABSOLUTE = "ABSOLUTE"


    # Overwrite constants of base class
    _name = 'gradient'
    _displayname = tr('Gradient filter')
    _outputname = tr('Gradient')
    _groupid = "edges" 
    _default_dtype = 6 # Optionally change default output dtype (value = idx of combobox)

   
    # The function to be called, to be overwritten
    def get_fct(self):
        return self.myfnct
    

    def getArgs(self, mode):
        if self.mapunits:
            # https://gdal.org/tutorials/geotransforms_tut.html
            gt = self.ds.GetGeoTransform()
        else:
            # We need to return 1 in this case
            gt = np.ones(6)

        if mode == 0: # rows
            return -1, gt[1]
        if mode == 1: # cols
            return -2, gt[5]
        
        return None, None

    def myfnct(self, a, **kwargs):
        dtype = kwargs.pop("output")
        a = a.astype(dtype)

        if self.mode in (0,1):
            axis, spacing = self.getArgs(self.mode)
            a = np.gradient(a, spacing, axis=axis)

        else:
            axis, spacing = self.getArgs(0)
            dx = np.gradient(a, spacing, axis=axis)
            axis, spacing = self.getArgs(1)
            dy = np.gradient(a, spacing, axis=axis)
            a = np.hypot(dx, dy)

        if self.absolute:
            a = np.abs(a)
        
        return a

   

    def initAlgorithm(self, config):
        # Set dimensions to 3, even if we calculate along axis
        self._dimension = Dimensions.threeD
        self.margin = 3
        super().initAlgorithm(config)

    def getModes(self):
        return [tr('x axis'), tr('y axis'), tr('Both')]


    def insert_parameters(self, config):

        self.modes = self.getModes()
        
        self.addParameter(QgsProcessingParameterEnum(
            self.AXIS,
            tr('Axis'),
            self.modes,
            defaultValue=2)) 
    
        self.addParameter(QgsProcessingParameterBoolean(
            self.MAPUNITS,
            tr('Use map units (if false: 1 = pixel size)'),
            optional=True,
            defaultValue=True,
        ))

        self.addParameter(QgsProcessingParameterBoolean(
            self.ABSOLUTE,
            tr('Return absolute values'),
            optional=True,
            defaultValue=False,
        ))
            

    def get_parameters(self, parameters, context):
        kwargs = super().get_parameters(parameters, context)

        self.mode = self.parameterAsInt(parameters, self.AXIS, context)
        self.mapunits = self.parameterAsBool(parameters, self.MAPUNITS, context)
        self.absolute = self.parameterAsBool(parameters, self.ABSOLUTE, context)

        return kwargs 

    def checkAndComplain(self, feedback):
        if self._outdtype in (1,2,4) and not self.absolute:
            msg = tr(f"WARNING: Output contains negative values, but output data type is unsigned integer!")
            feedback.reportError(msg, fatalError = False)

    def createInstance(self):
        return SciPyGradientAlgorithm()


class SciPyPixelGradientAlgorithm(SciPyAlgorithm):
    """
    Pixel gradient filter

    Returns band to band gradient for each pixel, calculated with 
    `numpy.gradient <https://numpy.org/doc/stable/reference/generated/numpy.gradient.html>`_.
    
    **Return absolute values** Gradient is calculated band to band, starting with band 1.
    The result contains also negative values, optionally the absolute values are returned.
    """

    ABSOLUTE = "ABSOLUTE"

    # Overwrite constants of base class
    _name = 'pixel_gradient'
    _displayname = tr('Pixel gradient filter')
    _outputname = tr('Pixel gradient')
    _groupid = "pixel" 
    _default_dtype = 6 # Optionally change default output dtype (value = idx of combobox)

    
    # The function to be called, to be overwritten
    def get_fct(self):
        return self.myfnct
    
    def myfnct(self, a, **kwargs):
        dtype = kwargs.pop("output")
        a = a.astype(dtype)

        a = np.gradient(a, axis=0)

        if self.absolute:
            a = np.abs(a)
        
        return a
    

    def initAlgorithm(self, config):
        # Set dimensions to 3
        self._dimension = Dimensions.threeD

        super().initAlgorithm(config)

        self.addParameter(QgsProcessingParameterBoolean(
            self.ABSOLUTE,
            tr('Return absolute values'),
            optional=True,
            defaultValue=False,
        ))
            
        
    def get_parameters(self, parameters, context):
        kwargs = super().get_parameters(parameters, context)
        self.absolute = self.parameterAsBool(parameters, self.ABSOLUTE, context)
        return kwargs     

    def checkParameterValues(self, parameters, context):
        layer = self.parameterAsRasterLayer(parameters, self.INPUT, context)
        if layer.bandCount() == 1:
            return (False, tr("Only possible if layer has more than 1 band."))
        return super().checkParameterValues(parameters, context)

    def checkAndComplain(self, feedback):
        if self._outdtype in (1,2,4) and not self.absolute:
            msg = tr(f"WARNING: Output contains negative values, but output data type is unsigned integer!")
            feedback.reportError(msg, fatalError = False)

    def createInstance(self):
        return SciPyPixelGradientAlgorithm()
    

class SciPyPixelDiffAlgorithm(SciPyAlgorithm):
    """
    Difference band to band

    Returns band to band difference for each pixel, calculated with 
    `numpy.diff <https://numpy.org/doc/stable/reference/generated/numpy.diff.html>`_.

    The number of bands in the output is the number of input bands minus one. 

    **Return absolute values** Difference is calculated band to band, starting with band 1.
    The result contains also negative values, optionally the absolute values are returned.
    """

    ABSOLUTE = "ABSOLUTE"

    # Overwrite constants of base class
    _name = 'pixel_difference'
    _displayname = tr('Difference band to band')
    _outputname = tr('Pixel difference')
    _groupid = "pixel" 
    _default_dtype = 6 # Optionally change default output dtype (value = idx of combobox)

    
    # The function to be called, to be overwritten
    def get_fct(self):
        return self.myfnct
    
    def myfnct(self, a, **kwargs):
        dtype = kwargs.pop("output")
        a = a.astype(dtype)

        a = np.diff(a, axis=0)

        if self.absolute:
            a = np.abs(a)
        
        return a
    

    def initAlgorithm(self, config):
        # Set dimensions to 3
        self._dimension = Dimensions.threeD

        super().initAlgorithm(config)

        self.addParameter(QgsProcessingParameterBoolean(
            self.ABSOLUTE,
            tr('Return absolute values'),
            optional=True,
            defaultValue=False,
        ))
            
        
    def get_parameters(self, parameters, context):
        kwargs = super().get_parameters(parameters, context)
        self._outbands = self.inputlayer.bandCount() - 1

        self.absolute = self.parameterAsBool(parameters, self.ABSOLUTE, context)
        return kwargs     

    def checkParameterValues(self, parameters, context):
        layer = self.parameterAsRasterLayer(parameters, self.INPUT, context)
        if layer.bandCount() == 1:
            return (False, tr("Only possible if layer has more than 1 band."))
        return super().checkParameterValues(parameters, context)

    def checkAndComplain(self, feedback):
        if self._outdtype in (1,2,4) and not self.absolute:
            msg = tr(f"WARNING: Output contains negative values, but output data type is unsigned integer!")
            feedback.reportError(msg, fatalError = False)

    def createInstance(self):
        return SciPyPixelDiffAlgorithm()
    
