# -*- coding: utf-8 -*-

"""
/***************************************************************************
 SciPyFilters
                                 A QGIS plugin
 Filter collection implemented with SciPy
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2024-03-03
        copyright            : (C) 2024 by Florian Neukirchen
        email                : mail@riannek.de
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""

__author__ = 'Florian Neukirchen'
__date__ = '2024-03-03'
__copyright__ = '(C) 2024 by Florian Neukirchen'

# This will get replaced with a git SHA1 when you do a git archive

__revision__ = '$Format:%H$'

from osgeo import gdal
import numpy as np
import json
from qgis.PyQt.QtCore import QCoreApplication
from qgis.core import (QgsProcessingAlgorithm,
                       QgsProcessingParameterString,
                       QgsProcessingParameterDefinition,
                       QgsProcessingParameterNumber,
                       QgsProcessingException,
                       QgsProcessingLayerPostProcessorInterface,
                       QgsProcessingParameterBoolean,
                       QgsProcessingParameterRasterLayer,
                       QgsProcessingParameterEnum,
                       QgsProcessingParameterRasterDestination)



from ..helpers import (str_to_int_or_list, 
                      check_structure, 
                      str_to_array, 
                      kernelexamples, 
                      get_np_dtype)

from ..scipy_algorithm_baseclasses import groups


class SciPyTransformPcBaseclass(QgsProcessingAlgorithm):
    """
    Transform to principal components

    """

    EIGENVECTORS = 'EIGENVECTORS'
    OUTPUT = 'OUTPUT'
    INPUT = 'INPUT'
    BANDMEAN = 'BANDMEAN'
    DTYPE = 'DTYPE'

    # Overwrite constants of base class

    _groupid = "pca" 
    _name = ''
    _displayname = ''
    _outputname = ""

    # _outbands = 1
    _help = """
            Baseclass to transform to/from principal components using matrix of eigenvectors
            """

    _inverse = False
    _keepbands = 0
    falsemean = False

    _bandmean = None
    V = None
    abstract = ""

    def initAlgorithm(self, config):

        self.addParameter(
            QgsProcessingParameterRasterLayer(
                self.INPUT,
                self.tr('Input layer'),
            )
        )

        eig_param = QgsProcessingParameterString(
            self.EIGENVECTORS,
            self.tr('Eigenvectors'),
            defaultValue="",
            multiLine=True,
            optional=True,
            )
        
        if self._inverse:
            eig_param.setFlags(eig_param.flags() | QgsProcessingParameterDefinition.Flag.FlagAdvanced)
        
        self.addParameter(eig_param)

        if self._inverse:
            desc = self.tr('Mean of original bands')
        else:
            desc = self.tr('False mean for each band')

        mean_param = QgsProcessingParameterString(
            self.BANDMEAN,
            desc,
            defaultValue="",
            multiLine=False,
            optional=True,
            )
        
        mean_param.setFlags(mean_param.flags() | QgsProcessingParameterDefinition.Flag.FlagAdvanced)
      
        self.addParameter(mean_param)

        dtype_param = QgsProcessingParameterEnum(
            self.DTYPE,
            self.tr('Output data type'),
            ['Float32 (32 bit float)', 'Float64 (64 bit float)'],
            defaultValue=0,
            optional=True)
        
        # Set as advanced parameter
        dtype_param.setFlags(dtype_param.flags() | QgsProcessingParameterDefinition.Flag.FlagAdvanced)
        self.addParameter(dtype_param)


        self.addParameter(
            QgsProcessingParameterRasterDestination(
                self.OUTPUT,
            self.tr(self._outputname)))
        
    
    def get_parameters(self, parameters, context):
        
        self.inputlayer = self.parameterAsRasterLayer(parameters, self.INPUT, context)
        self.output_raster = self.parameterAsOutputLayer(parameters, self.OUTPUT,context)

        # eigenvectors from text field
        V = self.parameterAsString(parameters, self.EIGENVECTORS, context)
        self.V = str_to_array(V, dims=None, to_int=False)

        # Parameters from metadata abstract
        if self._inverse:
            self.abstract = self.inputlayer.metadata().abstract()
            # The other case is handled in the inheriting class

        eigenvectors, means = self.json_to_parameters(self.abstract)


        if self.V is None:
            self.V = eigenvectors

        # Get the mean, start with metadata of layer and eventually overwrite it
        if self._inverse or self.falsemean:    
            if means is None:
                means = 0
            if isinstance(means, np.ndarray) and means.ndim == 1:
                means = means[np.newaxis, :]

            self._bandmean = means

            # Mean from text field
            bandmean = self.parameterAsString(parameters, self.BANDMEAN, context)

            bandmean = bandmean.strip()

            if bandmean != "":
                if not (bandmean[0] == "[" and bandmean[-1] == "]"):
                    bandmean = "[" + bandmean + "]"
                try:
                    decoded = json.loads(bandmean)
                    a = np.array(decoded)
                except (json.decoder.JSONDecodeError, ValueError, TypeError):
                    a = None

                if not a is None:
                    self._bandmean = a[np.newaxis, :]
        
        self.outdtype = self.parameterAsInt(parameters, self.DTYPE, context)
        self.outdtype = self.outdtype + 6 # float32 and float64 in gdal

    def checkParameterValues(self, parameters, context):

        inputlayer = self.parameterAsRasterLayer(parameters, self.INPUT, context)
        bands = inputlayer.bandCount()

        V = self.parameterAsString(parameters, self.EIGENVECTORS, context)
        V = V.strip()

        # Check eigenvectors (text field)
        try:
            V = str_to_array(V, dims=None, to_int=False)
        except QgsProcessingException:
            return False, self.tr("Can not parse eigenvectors")

        # Get parameters from metadata and check eigenvectors
        if self._inverse:
            abstract = inputlayer.metadata().abstract()
            checkmean = True
        else:
            paramlayer = self.parameterAsRasterLayer(parameters, self.PARAMETERLAYER, context)
            if paramlayer:
                abstract = paramlayer.metadata().abstract()
            else:
                abstract = ""
            checkmean = self.parameterAsBool(parameters, self.FALSEMEAN, context)

        eigenvectors, layermeans = None, ""

        abstract = abstract.strip()
        if not abstract == "":
            try:
                decoded = json.loads(abstract)
            except (json.decoder.JSONDecodeError, ValueError, TypeError):
                return False, self.tr("Could not decode metadata abstract")
            eigenvectors = decoded.get("eigenvectors", None)
            layermeans = decoded.get("band mean", "")

            if not eigenvectors is None:
                try:
                    eigenvectors = np.array(eigenvectors)
                except (ValueError, TypeError):
                    return False, self.tr("Could not decode metadata abstract")
            
        # Check if eigenvectors are provided one or the other way
        if V is None:
            V = eigenvectors
        if V is None:
            return False, self.tr("The layer does not contain valid eigenvactors and no eigenvectors where provided")

        # Check dimensions and shape of eigenvectors
        if V.ndim != 2 or V.shape[0] != V.shape[1]:
            return False, self.tr("Matrix of eigenvectors must be square (2D)")

        if (self._inverse and V.shape[0] < bands) or ((not self._inverse) and V.shape[0] != bands):
            return False, self.tr("Shape of matrix of eigenvectors does not match number of bands")


        # Check provided means
        if checkmean:
            if isinstance(layermeans, list) and len(layermeans) > 1:
                layermeans = np.array(layermeans)

            # Start with mean from text field
            bandmean = self.parameterAsString(parameters, self.BANDMEAN, context)

            bandmean = bandmean.strip()
            if bandmean != "":
                if not (bandmean[0] == "[" and bandmean[-1] == "]"):
                    bandmean = "[" + bandmean + "]"
            
            if not bandmean == "":
                try:
                    decoded = json.loads(bandmean)
                    bandmean = np.array(decoded)
                except (json.decoder.JSONDecodeError, ValueError, TypeError):
                    return False, self.tr("Could not parse list of means")
                # If mean is given in text field, do not use the one from metadata
                layermeans = bandmean

            # Check dimensions (both cases) 
            if layermeans.ndim != 1:
                return False, self.tr("False shape of means list")
            if layermeans.shape[0] != V.shape[0]:
                return False, self.tr("False shape of means list")

        return super().checkParameterValues(parameters, context)


    def processAlgorithm(self, parameters, context, feedback):
        """
        Here is where the processing itself takes place.
        """
        self.get_parameters(parameters, context)


        self.ds = gdal.Open(self.inputlayer.source())

        if not self.ds:
            raise Exception("Failed to open Raster Layer")
        
        self.bandcount = self.ds.RasterCount
        bands = self._keepbands        
        if bands == 0 or bands > self.bandcount:
            bands = self.bandcount

        
        if feedback.isCanceled():
            return {}
        
        feedback.setProgress(0)

        # Start the actual work
        if self.outdtype == 6:
            a = self.ds.ReadAsArray().astype(np.float32)
        else:
            a = self.ds.ReadAsArray().astype(np.float64)
        if a.ndim == 2: # Layer with only 1 band
            a = a[np.newaxis, :]
        
        orig_shape = a.shape

        # Flatten 
        a = a.reshape(orig_shape[0], -1)
        a = a.T


        # substract mean
        if not self._inverse:
            if not self.falsemean:
                self._bandmean = a.mean(axis=0)
                self._bandmean = self._bandmean[np.newaxis, :]
                feedback.pushInfo(self.tr("\nBand Mean:"))
            else:
                feedback.pushInfo(self.tr("\nFalse (given) band mean:"))
            feedback.pushInfo(str(self._bandmean[0].tolist()) + "\n")

            a = a - self._bandmean


        # Transform to PC
        if self._inverse:
            components = self.V.T
            # If not all PC bands were kept
            if self.bandcount < components.shape[0]:
                bands = components.shape[0]
                orig_shape = (bands, orig_shape[1], orig_shape[2])
                
                components = components[0:self.bandcount,:]

        else:
            components = self.V

        new_array = a @ components

        if self._inverse:
            new_array = new_array + self._bandmean

        new_array = new_array.T.reshape(orig_shape)

        if feedback.isCanceled():
            return {}
        
        # Prepare output and write file

        driver = gdal.GetDriverByName('GTiff')
        self.out_ds = driver.Create(self.output_raster,
                                    xsize=self.ds.RasterXSize,
                                    ysize=self.ds.RasterYSize,
                                    bands=bands,
                                    eType=self.outdtype)

        self.out_ds.SetGeoTransform(self.ds.GetGeoTransform())
        self.out_ds.SetProjection(self.ds.GetProjection())

        self.out_ds.WriteArray(new_array[0:bands,:,:])    

        # Calculate and write band statistics (min, max, mean, std)
        for b in range(1, bands + 1):
            band = self.out_ds.GetRasterBand(b)
            stats = band.GetStatistics(0,1)
            band.SetStatistics(*stats)

        # Close the dataset to write file to disk
        self.out_ds = None 
    
       
        if self._inverse:
            return {
                self.OUTPUT: self.output_raster,
                'eigenvectors': self.V,
                }
        else:
            encoded = json.dumps({
                'eigenvectors': self.V.tolist(),
                'band mean': self._bandmean[0].tolist(),
            })

            global updatemetadata
            updatemetadata = self.UpdateMetadata(encoded)
            context.layerToLoadOnCompletionDetails(self.output_raster).setPostProcessor(updatemetadata)


            return {
                self.OUTPUT: self.output_raster,
                'band mean': self._bandmean[0],
                'eigenvectors': self.V,
                }
                

    def json_to_parameters(self, s):
        s = s.strip()
        if s == "":
            return None, None
        try:
            decoded = json.loads(s)
        except (json.decoder.JSONDecodeError, ValueError, TypeError):
            return None, None

        eigenvectors = decoded.get("eigenvectors", None)

        try:
            eigenvectors = np.array(eigenvectors)
        except (ValueError, TypeError):
            eigenvectors = None

        means = decoded.get("band mean", 0)
        try:
            means = np.array(means)
        except (ValueError, TypeError):
            means = None
        return eigenvectors, means


    class UpdateMetadata(QgsProcessingLayerPostProcessorInterface):
        """
        To add metadata in the postprocessing step.
        """
        def __init__(self, abstract):
            self.abstract = abstract
            super().__init__()
            
        def postProcessLayer(self, layer, context, feedback):
            meta = layer.metadata()
            meta.setAbstract(self.abstract)
            layer.setMetadata(meta)

    def name(self):
        return self._name

    def displayName(self):
        return self.tr(self._displayname)

    def group(self):
        if self._groupid == "":
            return ""
        s = groups.get(self._groupid)
        if not s:
            # If group ID is not in dictionary group, return error message for debugging
            return "Displayname of group must be set in groups dictionary"
        return self.tr(s)

    def groupId(self):
        return self._groupid
    
    def shortHelpString(self):
        return self.tr(self._help)

    def tr(self, string):
        return QCoreApplication.translate('Processing', string)


    
class SciPyTransformToPCAlgorithm(SciPyTransformPcBaseclass):
    """
    Transform to principal components

    """
    _name = 'transform_to_PC'
    _displayname = 'Transform to principal components'
    _outputname = _displayname

    _help = """
        Transform data into given principal components \
        with a matrix of eigenvectors by taking the \
        dot product with a matrix of weights (after centering the data). \


        The eigenvectors can also be read from the metadata of an \
        existing PCA layer. 

        <b>Eigenvectors</b> Matrix of eigenvectors (as string). \
        Optional if the next parameter is set. \
        

        <b>Read eigenvectors from PCA layer metadata</b> \
        Reads the weights for the transformation from the metadata \
        of a layer that was generated using the PCA algorithm of this plugin. \
        Ignored if the parameter <i>eigenvectors</i> is used. 

        <b>Number of components</b> is only used if the value is greater than 0 and \
        smaller than the count of original bands.

        <b>False mean for each band</b> As first step of PCA, the data of each \
        band is centered by subtracting the means. If false means are provided, \
        these are substracted instead of the real means of the input layer. \
        This allows to transform another raster image into the same space \
        as the principal components of another layer. The result is usefull \
        for comparation of several rasters, but should not be considered to be \
        proper principal components. Only used if "Used false mean" is checked.

        <b>Use false mean</b> See also <i>false mean of each band</i>. The \
        false mean to be used can also be read from the metadata of a PCA layer. \
        
        <b>Output data type</b> Float32 or Float64.
        """

    PARAMETERLAYER = "PARAMETERLAYER"
    NCOMPONENTS = 'NCOMPONENTS'
    FALSEMEAN = 'FALSEMEAN'

    def createInstance(self):
        return SciPyTransformToPCAlgorithm()  
    
    def initAlgorithm(self, config):
        super().initAlgorithm(config)

        self.addParameter(
            QgsProcessingParameterRasterLayer(
                self.PARAMETERLAYER,
                self.tr('Read eigenvectors from PCA layer metadata'),
                optional=True,
            )
        )

        self.addParameter(QgsProcessingParameterNumber(
            self.NCOMPONENTS,
            self.tr('Number of components to keep. Set to 0 for all components.'),
            QgsProcessingParameterNumber.Type.Integer,
            defaultValue=0, 
            optional=True, 
            minValue=0, 
            # maxValue=100
            ))
        
        means_b = QgsProcessingParameterBoolean(
            self.FALSEMEAN,
            self.tr('Use false mean (provided as parameter) to center data'),
            optional=True,
            defaultValue=False,
        )

                
        means_b.setFlags(means_b.flags() | QgsProcessingParameterDefinition.Flag.FlagAdvanced)
      
        self.addParameter(means_b)

    def get_parameters(self, parameters, context):
        paramlayer = self.parameterAsRasterLayer(parameters, self.PARAMETERLAYER, context)
        if paramlayer:
            self.abstract = paramlayer.metadata().abstract()
        
        self._keepbands = self.parameterAsInt(parameters, self.NCOMPONENTS, context)
        self.falsemean = self.parameterAsBool(parameters, self.FALSEMEAN, context)

        super().get_parameters(parameters, context)



class SciPyTransformFromPCAlgorithm(SciPyTransformPcBaseclass):
    """
    Transform from principal components

    """

    _name = 'transform_from_PC'
    _displayname = 'Transform from principal components'
    _outputname = _displayname

    _help = """
        Transform data from principal components (i.e. the PCA scores) \
        back into the original feature space \
        using a matrix of eigenvectors by taking the \
        dot product of the scores the with the transpose of the matrix of eigenvectors \
        and adding the original means to the result.

        The eigenvectors can also be read from the metadata \
        of the input layer, as long as they exist and are complete. \

        <b>Eigenvectors</b> Matrix of eigenvectors (as string). \
        Optional if the next parameter is set. \
        The matrix can be taken from the output of the PCA algorith of this plugin. \

        <b>Mean of original bands</b> As first step of PCA, the data of each \
        band is centered by subtracting the means. These must be added \
        after rotating back into the original feature space. \
        Optional if the meta data of the input layer is complete. \
        (Use false means if they were used for the forward transformation.)
                
        <b>Output data type</b> Float32 or Float64.
        """

    _inverse = True


    def createInstance(self):
        return SciPyTransformFromPCAlgorithm()  