"""
/***************************************************************************
 SciPyFilters
                                 A QGIS plugin
 Filter collection implemented with SciPy
                              -------------------
        begin                : 2024-03-03
        copyright            : (C) 2024 by Florian Neukirchen
        email                : mail@riannek.de
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""


import json
import numpy as np
from scipy import ndimage
from qgis.core import QgsProcessingException
from collections import OrderedDict


def str_to_array(s, dims=2, to_int=True):
    s = s.strip()
    if not s:
        return None
    
    if s in ["square", "cross", "cross3D", "ball", "cube"]:
        (rank, connectivity) = generate_binary_structure_options[s]
        a = ndimage.generate_binary_structure(rank, connectivity)
        
    else:
        try:
            decoded = json.loads(s)
            if to_int:
                a = np.array(decoded, dtype="int")
            else:
                a = np.array(decoded, dtype=np.float32)
        except (json.decoder.JSONDecodeError, ValueError, TypeError):
            raise QgsProcessingException('Can not parse string to array!')

    if dims == None:
        return a

    if dims == a.ndim:
        return a
    if a.ndim == 2 and dims == 3:
        a = a[np.newaxis,:]
        return a
    raise QgsProcessingException('Array has wrong number of dimensions!')


generate_binary_structure_options = {
    "square": (2, 2), 
    "cross": (2, 1), 
    "cross3D": (3, 1), 
    "ball": (3,2), 
    "cube": (3,3), 
}

def check_structure(s, dims=2, odd=False, optional=True):
    """
    Check if structure is valid 
    
    Returns tuple: (ok: bool, message: str, shape: tuple | None)
    Shape is required for origin
    """
    s = s.strip()
    if optional and not s:
        if dims == 2:
            return (True, "", (0,0))
        else:
            return (True, "", (0,0,0))
    if s == "" and not optional:
        return (False, "Argument is not optional", None)
    if s in ["square", "cross"]:
        return (True, "", (3,3))
    if s in ["cross3D", "ball", "cube"]:
        if dims == 3:
            return(True, "", (3, 3, 3))
        else:
            return (False, f'{s} not possible in 2D', None)

    # Get it as array
    try:
        decoded = json.loads(s)
        a = np.array(decoded, dtype=np.float32)
    except (json.decoder.JSONDecodeError, ValueError, TypeError):
        return (False, 'Can not parse string to array', None)

    # Array must have same number of dims as the filter input,
    # but for 3D input and 2D structure I automatically add one axis
    if not (a.ndim == 2 or a.ndim == dims):
        return (False, 'Array has wrong number of dimensions', None)

    # Wiener filter: values must be odd
    if odd and np.any(a % 2 == 0):
        return (False, 'Every element in size must be odd.', None)
    
    if a.ndim == 2 and dims == 3:
        a = a[np.newaxis,:]

    return (True, "", a.shape)


def check_origin(s, shape):
    try:
        int_or_list = str_to_int_or_list(s)
    except ValueError:
        return (False, 'Invalid origin')
    
    if isinstance(int_or_list, int):
        if (-(min(shape) // 2) <= int_or_list <= (min(shape) -1 ) // 2):
            return (True, "")
        else:
            return (False, 'Invalid origin')
    
    if not len(shape) == len(int_or_list):
        return (False, 'Invalid origin')

    for i in range(len(shape)):
        # origin must satisfy -(weights.shape[k] // 2) <= origin[k] <= (weights.shape[k]-1) // 2
        if not (-(shape[i] // 2) <= int_or_list[i] <= (shape[i] - 1) // 2):
            return (False, 'Invalid origin')

    return (True, '')



def str_to_int_or_list(s):
    """
    Allow to have parameters for axes (one or several) or size (for all or each dimension)
    """
    if s == "0":
        return 0
    out = None
    s = s.strip()
    try:
        out = int(s)
    except ValueError:
        pass
    if out:
        return out
    if s == "":
        return None
    if not (s[0] == "[" and s[-1] == "]"):
        s = "[" + s + "]"

    try:
        decoded = json.loads(s)
        a = np.array(decoded, dtype=np.int32)
    except (json.decoder.JSONDecodeError, ValueError, TypeError):
        raise ValueError('Can not parse string to array!')
    
    if a.ndim != 1:
        raise ValueError('Wrong dimensions!')
    
    return a.tolist()



def array_to_str(a):
    s = str(a.tolist())
    s = s.replace('], [', '],\n[')
    return s




footprintexamples = OrderedDict([
    ("3 × 3 Square", np.ones((3,3))),
    ("5 × 5 Square", np.ones((5,5))),
    ("7 × 7 Square", np.ones((7,7))),
    ("Cross", ndimage.generate_binary_structure(2, 1)),
    ("sep1", "---"), 
    ("3 × 3 × 3 Cube", np.ones((3,3,3))),
    ("5 × 5 × 5 Cube", np.ones((5,5,5))),
    ("7 × 7 × 7 Cube", np.ones((7,7,7))),
])


kernelexamples = OrderedDict([
    
    ("3 × 3 Square", np.ones((3,3))),
    ("5 × 5 Square", np.ones((5,5))),
    ("7 × 7 Square", np.ones((7,7))),
    ("sep1", "---"), 
    ("3 × 3 Gaussian", "[[1, 2, 1],\n[2, 4, 2],\n[1, 2, 1]]"),
    ("5 × 5 Gaussian", "[[0,1,2,1,0],\n[1,3,5,3,1],\n[2,5,9,5,2],\n[1,3,5,3,1],\n[0,1,2,1,0]]"),
    ("5 × 5 Laplacian of Gaussian", "[[0,0,-1,0,0],\n[0,-1,-2,-1,0],\n[-1,-2,16,-2,-1],\n[0,-1,-2,-1,0],\n[0,0,-1,0,0]]"),
    ("3 × 3 Sobel horizontal edges", "[[1, 2, 1],\n[0, 0, 0],\n[-1, -2, -1]]"),
    ("3 × 3 Sobel vertical edges", "[[1, 0, -1],\n[2, 0, -2],\n[1, 0, -1]]"),
    ("sep2", "---"), 
    ("3 × 3 × 3 Cube", np.ones((3,3,3))),
    ("5 × 5 × 5 Cube", np.ones((5,5,5))),
    ("7 × 7 × 7 Cube", np.ones((7,7,7))),
    ("3 × 1 × 1 Across bands of pixel", np.ones((3,1,1)))
    

])

morphostructexamples =  OrderedDict([
    ("Cross 2D", ndimage.generate_binary_structure(2, 1)),
    ("Square 2D", ndimage.generate_binary_structure(2, 2)),
    ("sep1", "---"), 
    ("Cross 3D", ndimage.generate_binary_structure(3, 1)),
    ("Ball 3D", ndimage.generate_binary_structure(3, 2)),
    ("Cube 3D", ndimage.generate_binary_structure(3, 3)),
    ("sep1", "---"), 
    ("5 × 5 Square", np.ones((5,5))),
    ("7 × 7 Square", np.ones((7,7))),
    ("5 × 5 × 5 Cube", np.ones((5,5,5))),

])