# -*- coding: utf-8 -*-

"""
/***************************************************************************
 SciPyFilters
                                 A QGIS plugin
 Filter collection implemented with SciPy
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2024-03-03
        copyright            : (C) 2024 by Florian Neukirchen
        email                : mail@riannek.de
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""

__author__ = 'Florian Neukirchen'
__date__ = '2024-03-03'
__copyright__ = '(C) 2024 by Florian Neukirchen'

# This will get replaced with a git SHA1 when you do a git archive

__revision__ = '$Format:%H$'

from typing import Dict, Tuple
from scipy import ndimage
import numpy as np
from qgis.PyQt.QtCore import QCoreApplication
from qgis.core import (QgsProcessingParameterNumber,
                       QgsProcessingParameterDefinition,
                       QgsProcessingException,)

from scipy_filters.scipy_algorithm_baseclasses import SciPyAlgorithmWithMode

from scipy_filters.ui.structure_widget import (StructureWidgetWrapper, 
                                  SciPyParameterStructure,)

from scipy_filters.ui.origin_widget import (OriginWidgetWrapper, 
                               SciPyParameterOrigin,)

from scipy_filters.helpers import (str_to_int_or_list, 
                      check_structure, 
                      str_to_array, 
                      kernelexamples, 
                      get_np_dtype)

from scipy_filters.ui.i18n import tr

class SciPyConvolveAlgorithm(SciPyAlgorithmWithMode):
    """
    Convolve raster with given kernel. 
    Calculated with convolve from 
    `scipy.ndimage <https://docs.scipy.org/doc/scipy/reference/ndimage.html>`_.

    .. note:: No data cells within the filter radius are filled with 0.

    **Dimension** Calculate for each band separately (2D) 
    or use all bands as a 3D datacube and perform filter in 3D. 
    
    .. note:: bands will be the first axis of the datacube.

    **Kernel** String representation of array. 
    Must have 2 dimensions if *dimension* is set to 2D. 
    Should have 3 dimensions if *dimension* is set to 3D, 
    but a 2D array is also excepted (a new axis is added as first 
    axis and the result is the same as calculating each band 
    seperately).

    **Normalization** Normalize the kernel by dividing through 
    given value; set to 0 to devide through the sum of the absolute 
    values of the kernel.

    **Origin** Shift the origin (hotspot) of the kernel.

    **Border mode** determines how input is extended around 
    the edges: *Reflect* (input is extended by reflecting at the edge), 
    *Constant* (fill around the edges with a **constant value**), 
    *Nearest* (extend by replicating the nearest pixel), 
    *Mirror* (extend by reflecting about the center of last pixel), 
    *Wrap* (extend by wrapping around to the opposite edge).

    **Dtype** Data type of output. Beware of clipping 
    and potential overflow errors if min/max of output does 
    not fit. Default is Float32.
    """

    KERNEL = 'KERNEL'
    NORMALIZATION = 'NORMALIZATION'
    ORIGIN = 'ORIGIN'

    # Overwrite constants of base class
    _name = 'convolve'
    _displayname = tr('Convolve')
    _outputname = None # If set to None, the displayname is used 
    _groupid = 'convolution'

    _default_dtype = 6 # Optionally change default output dtype (value = idx of combobox)
    
    
    # The function to be called
    def get_fct(self):
        return self.my_fct 
    
    def my_fct(self, a, **kwargs):

        # Used for feedback
        self.inmin.append(a.min())
        self.inmax.append(a.max())

        return ndimage.convolve(a, **kwargs)
 

    def insert_parameters(self, config):
        super().insert_parameters(config)

        default_kernel = "[[1, 2, 1],\n[2, 4, 2],\n[1, 2, 1]]"

        kernel_param = SciPyParameterStructure(
            self.KERNEL,
            tr('Kernel'),
            defaultValue=default_kernel,
            examples=kernelexamples,
            multiLine=True,
            to_int=False,
            optional=False,
            )
        
        kernel_param.setMetadata({
            'widget_wrapper': {
                'class': StructureWidgetWrapper
            }
        })

        self.addParameter(kernel_param)

        
        self.addParameter(QgsProcessingParameterNumber(
            self.NORMALIZATION,
            tr('Normalization (devide kernel values by given number). Set to 0 to devide by sum of absolute values of the kernel.'),
            QgsProcessingParameterNumber.Type.Double,
            defaultValue=0, 
            optional=True, 
            minValue=0, 
            # maxValue=100
            )) 

        
        origin_param = SciPyParameterOrigin(
            self.ORIGIN,
            tr('Origin'),
            defaultValue="0",
            optional=False,
            watch="KERNEL"
            )
        
        origin_param.setMetadata({
            'widget_wrapper': {
                'class': OriginWidgetWrapper
            }
        })

        origin_param.setFlags(origin_param.flags() | QgsProcessingParameterDefinition.Flag.FlagAdvanced)


        self.addParameter(origin_param)

        # Used for feedback
        self.inmax = []
        self.inmin = []


    def get_parameters(self, parameters, context):
        kwargs = super().get_parameters(parameters, context)

        weights = self.parameterAsString(parameters, self.KERNEL, context)
        weights = str_to_array(weights, self._ndim)

        normalization = self.parameterAsDouble(parameters, self.NORMALIZATION, context)


        if normalization == 0:
            weights = weights / np.abs(weights).sum()
        else:
            weights = weights / normalization
            
        kwargs['weights'] = weights
 
        self.kernel = weights # For feedback

        origin = self.parameterAsString(parameters, self.ORIGIN, context)
        kwargs['origin'] = str_to_int_or_list(origin)

        self.margin = int(np.ceil(max(weights.shape) / 2) + np.abs(kwargs['origin']).max())

        return kwargs
    
    
    def checkParameterValues(self, parameters, context): 

        structure = self.parameterAsString(parameters, self.KERNEL, context)

        dims = self.getDimsForCheck(parameters, context)

        ok, s, shape = check_structure(structure, dims, optional=False)
        if not ok:
            return (ok, s)
        
        origin = self.parameterAsString(parameters, self.ORIGIN, context)
        origin = str_to_int_or_list(origin)

        if isinstance(origin, list):          
            if len(origin) != dims:
                return (False, tr("Origin does not match number of dimensions"))
            
            for i in range(dims):
                if not (-(shape[i] // 2) <= origin[i] <= (shape[i]-1) // 2):
                    return (False, tr("Origin out of bounds of structure"))


        return super().checkParameterValues(parameters, context)
    

    def checkAndComplain(self, feedback):

        inmin = min(self.inmin)
        inmax = max(self.inmax)

        msg = tr("Input values are in the range {}...{}").format(inmin, inmax)
        feedback.pushInfo(msg)

        # Calculate the possible range after applying the kernel
        outmax = ((np.where(self.kernel < 0, 0, self.kernel)    # positive part of kernel
                   * max(0, inmax)).sum()                       # multiplied with positive input
                  + (np.where(self.kernel > 0, 0, self.kernel)  # negative part of kernel
                     * min(0, inmin)).sum()).astype("int")      # multiplied with negative input

        outmin = ((np.where(self.kernel > 0, 0, self.kernel)    # negative part of kernel
                   * max(0, inmax)).sum()                       # multiplied with positive input
                  + (np.where(self.kernel < 0, 0, self.kernel)  # positive part of kernel
                     * min(0, inmin)).sum()).astype("int")      # multiplied with negative input
        
        msg = tr("Expected output range is {}...{}").format(outmin, outmax)
        feedback.pushInfo(msg)
        
        if self._outdtype in (1,2,4) and np.any(self.kernel < 0):
            msg = tr("WARNING: With a kernel containing negative values, output values can be negative. But output data type is unsigned integer!")
            feedback.reportError(msg, fatalError = False)

        if 1 <= self._outdtype <= 5: # integer types
            info_out = np.iinfo(get_np_dtype(self._outdtype))
            if outmin < info_out.min or outmax > info_out.max:
                msg = tr("WARNING: The possible range of output values is not in the range of the output datatype. Clipping is likely.")
                feedback.reportError(msg, fatalError=False)


    def createInstance(self):
        return SciPyConvolveAlgorithm()



class SciPyCorrelateAlgorithm(SciPyConvolveAlgorithm):
    """
    Correlate raster with given kernel. 
    Calculated with correlate from 
    `scipy.ndimage <https://docs.scipy.org/doc/scipy/reference/ndimage.html>`_.

    .. note:: No data cells within the filter radius are filled with 0.

    **Dimension** Calculate for each band separately (2D) 
    or use all bands as a 3D datacube and perform filter in 3D. 

    .. note:: bands will be the first axis of the datacube.

    **Kernel** String representation of array. 
    Must have 2 dimensions if *dimension* is set to 2D. 
    Should have 3 dimensions if *dimension* is set to 3D, 
    but a 2D array is also excepted (a new axis is added as first 
    axis and the result is the same as calculating each band 
    seperately).

    **Normalization** Normalize the kernel by dividing through 
    given value; set to 0 to devide through the sum of the absolute 
    values of the kernel.

    **Origin** Shift the origin (hotspot) of the kernel.

    **Border mode** determines how input is extended around 
    the edges: *Reflect* (input is extended by reflecting at the edge), 
    *Constant* (fill around the edges with a **constant value**), 
    *Nearest* (extend by replicating the nearest pixel), 
    *Mirror* (extend by reflecting about the center of last pixel), 
    *Wrap* (extend by wrapping around to the opposite edge).

    **Dtype** Data type of output. Beware of clipping 
    and potential overflow errors if min/max of output does 
    not fit. Default is Float32.
    """


    # Overwrite constants of base class
    _name = 'correlate'
    _displayname = tr('Correlate')
    _outputname = None # If set to None, the displayname is used 
    _groupid = 'convolution'

    _default_dtype = 6 # Optionally change default output dtype (value = idx of combobox)
    
    
    def my_fct(self, a, **kwargs):

        # Used for feedback
        self.inmin.append(a.min())
        self.inmax.append(a.max())

        return ndimage.correlate(a, **kwargs)
 
    def createInstance(self):
        return SciPyCorrelateAlgorithm()
