# -*- coding: utf-8 -*-

"""
/***************************************************************************
 SciPyFilters
                                 A QGIS plugin
 Filter collection implemented with SciPy
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2024-03-03
        copyright            : (C) 2024 by Florian Neukirchen
        email                : mail@riannek.de
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""

__author__ = 'Florian Neukirchen'
__date__ = '2024-03-03'
__copyright__ = '(C) 2024 by Florian Neukirchen'

# This will get replaced with a git SHA1 when you do a git archive

__revision__ = '$Format:%H$'

from typing import Dict, Tuple
from scipy import ndimage
import numpy as np
from qgis.PyQt.QtCore import QCoreApplication
from qgis.core import (QgsProcessingParameterNumber,
                       QgsProcessingParameterDefinition,
                       QgsProcessingException,)

from ..scipy_algorithm_baseclasses import SciPyAlgorithmWithMode

from ..ui.structure_widget import (StructureWidgetWrapper, 
                                  SciPyParameterStructure,)

from ..ui.origin_widget import (OriginWidgetWrapper, 
                               SciPyParameterOrigin,)

from ..helpers import (str_to_int_or_list, 
                      check_structure, 
                      str_to_array, 
                      kernelexamples, 
                      get_np_dtype)

class SciPyConvolveAlgorithm(SciPyAlgorithmWithMode):

    KERNEL = 'KERNEL'
    NORMALIZATION = 'NORMALIZATION'
    ORIGIN = 'ORIGIN'

    # Overwrite constants of base class
    _name = 'convolve'
    _displayname = 'Convolve'
    _outputname = None # If set to None, the displayname is used 
    _groupid = 'convolution'

    _default_dtype = 6 # Optionally change default output dtype (value = idx of combobox)
    
    _help = """
            Convolve raster with given kernel. \
            Calculated with convolve from \
            <a href="https://docs.scipy.org/doc/scipy/reference/ndimage.html">scipy.ndimage</a>.

            <b>Dimension</b> Calculate for each band separately (2D) \
            or use all bands as a 3D datacube and perform filter in 3D. \
            Note: bands will be the first axis of the datacube.

            <b>Kernel</b> String representation of array. \
            Must have 2 dimensions if <i>dimension</i> is set to 2D. \
            Should have 3 dimensions if <i>dimension</i> is set to 3D, \
            but a 2D array is also excepted (a new axis is added as first \
            axis and the result is the same as calculating each band \
            seperately).
            <b>Normalization</b> Normalize the kernel by dividing through \
            given value; set to 0 to devide through the sum of the absolute \
            values of the kernel.

            <b>Origin</b> Shift the origin (hotspot) of the kernel.

            <b>Border mode</b> determines how input is extended around \
            the edges: <i>Reflect</i> (input is extended by reflecting at the edge), \
            <i>Constant</i> (fill around the edges with a <b>constant value</b>), \
            <i>Nearest</i> (extend by replicating the nearest pixel), \
            <i>Mirror</i> (extend by reflecting about the center of last pixel), \
            <i>Wrap</i> (extend by wrapping around to the opposite edge).

            <b>Dtype</b> Data type of output. Beware of clipping \
            and potential overflow errors if min/max of output does \
            not fit. Default is Float32.
            """
    
    # The function to be called
    def get_fct(self):
        return self.my_fct 
    
    def my_fct(self, a, **kwargs):

        # Used for feedback
        self.inmin.append(a.min())
        self.inmax.append(a.max())

        return ndimage.convolve(a, **kwargs)
 

    def insert_parameters(self, config):
        super().insert_parameters(config)

        default_kernel = "[[1, 2, 1],\n[2, 4, 2],\n[1, 2, 1]]"

        kernel_param = SciPyParameterStructure(
            self.KERNEL,
            self.tr('Kernel'),
            defaultValue=default_kernel,
            examples=kernelexamples,
            multiLine=True,
            to_int=False,
            optional=False,
            )
        
        kernel_param.setMetadata({
            'widget_wrapper': {
                'class': StructureWidgetWrapper
            }
        })

        self.addParameter(kernel_param)

        
        self.addParameter(QgsProcessingParameterNumber(
            self.NORMALIZATION,
            self.tr('Normalization (devide kernel values by number). Set to 0 to devide by sum of kernel values.'),
            QgsProcessingParameterNumber.Type.Double,
            defaultValue=0, 
            optional=True, 
            minValue=0, 
            # maxValue=100
            )) 

        
        origin_param = SciPyParameterOrigin(
            self.ORIGIN,
            self.tr('Origin'),
            defaultValue="0",
            optional=False,
            watch="KERNEL"
            )
        
        origin_param.setMetadata({
            'widget_wrapper': {
                'class': OriginWidgetWrapper
            }
        })

        origin_param.setFlags(origin_param.flags() | QgsProcessingParameterDefinition.Flag.FlagAdvanced)


        self.addParameter(origin_param)

        # Used for feedback
        self.inmax = []
        self.inmin = []


    def get_parameters(self, parameters, context):
        kwargs = super().get_parameters(parameters, context)

        weights = self.parameterAsString(parameters, self.KERNEL, context)
        weights = str_to_array(weights, self._ndim)

        normalization = self.parameterAsDouble(parameters, self.NORMALIZATION, context)


        if normalization == 0:
            weights = weights / np.abs(weights).sum()
        else:
            weights = weights / normalization
            
        kwargs['weights'] = weights
 
        self.kernel = weights # For feedback

        origin = self.parameterAsString(parameters, self.ORIGIN, context)
        kwargs['origin'] = str_to_int_or_list(origin)

        return kwargs
    
    
    def checkParameterValues(self, parameters, context): 

        structure = self.parameterAsString(parameters, self.KERNEL, context)

        dims = self.getDimsForCheck(parameters, context)

        ok, s, shape = check_structure(structure, dims, optional=False)
        if not ok:
            return (ok, s)
        
        origin = self.parameterAsString(parameters, self.ORIGIN, context)
        origin = str_to_int_or_list(origin)

        if isinstance(origin, list):          
            if len(origin) != dims:
                return (False, self.tr("Origin does not match number of dimensions"))
            
            for i in range(dims):
                if not (-(shape[i] // 2) <= origin[i] <= (shape[i]-1) // 2):
                    return (False, self.tr("Origin out of bounds of structure"))


        return super().checkParameterValues(parameters, context)
    

    def checkAndComplain(self, feedback):

        inmin = min(self.inmin)
        inmax = max(self.inmax)

        msg = self.tr(f"Input values are in the range {inmin}...{inmax}")
        feedback.pushInfo(msg)

        # Calculate the possible range after applying the kernel
        outmax = ((np.where(self.kernel < 0, 0, self.kernel)    # positive part of kernel
                   * max(0, inmax)).sum()                       # multiplied with positive input
                  + (np.where(self.kernel > 0, 0, self.kernel)  # negative part of kernel
                     * min(0, inmin)).sum()).astype("int")      # multiplied with negative input

        outmin = ((np.where(self.kernel > 0, 0, self.kernel)    # negative part of kernel
                   * max(0, inmax)).sum()                       # multiplied with positive input
                  + (np.where(self.kernel < 0, 0, self.kernel)  # positive part of kernel
                     * min(0, inmin)).sum()).astype("int")      # multiplied with negative input
        
        msg = self.tr(f"Expected output range is {outmin}...{outmax}")
        feedback.pushInfo(msg)
        
        if self._outdtype in (1,2,4) and np.any(self.kernel < 0):
            msg = self.tr(f"WARNING: With a kernel containing negative values, output values can be negative. But output data type is unsigned integer!")
            feedback.reportError(msg, fatalError = False)

        if 1 <= self._outdtype <= 5: # integer types
            info_out = np.iinfo(get_np_dtype(self._outdtype))
            if outmin < info_out.min or outmax > info_out.max:
                msg = self.tr("WARNING: The possible range of output values is not in the range of the output datatype. Clipping is likely.")
                feedback.reportError(msg, fatalError=False)


    def createInstance(self):
        return SciPyConvolveAlgorithm()

