# -*- coding: utf-8 -*-

"""
/***************************************************************************
 SciPyFilters
                                 A QGIS plugin
 Filter collection implemented with SciPy
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2024-03-03
        copyright            : (C) 2024 by Florian Neukirchen
        email                : mail@riannek.de
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""

__author__ = 'Florian Neukirchen'
__date__ = '2024-03-03'
__copyright__ = '(C) 2024 by Florian Neukirchen'

# This will get replaced with a git SHA1 when you do a git archive

__revision__ = '$Format:%H$'

from osgeo import gdal
from scipy import linalg
import numpy as np
import json
from qgis.PyQt.QtCore import QCoreApplication
from qgis.core import (QgsProcessingAlgorithm,
                       QgsProcessingParameterRasterLayer,
                       QgsProcessingParameterNumber,
                       QgsProcessingParameterRasterDestination,
                       QgsProcessingLayerPostProcessorInterface,
                       QgsProcessingException,
                       QgsProcessingParameterEnum,
                       QgsProcessingParameterDefinition,
                       QgsProcessingParameterBoolean,
                        )

from ..scipy_algorithm_baseclasses import groups


class SciPyPCAAlgorithm(QgsProcessingAlgorithm):
    """
    Calculate PCA (using scipy.svd)
    """

    # Constants used to refer to parameters and outputs. They will be
    # used when calling the algorithm from another algorithm, or when
    # calling from the QGIS console.

    OUTPUT = 'OUTPUT'
    INPUT = 'INPUT'
    NCOMPONENTS = 'NCOMPONENTS'
    PERCENTVARIANCE = 'PERCENTVARIANCE'
    DTYPE = 'DTYPE'

    
    _name = 'pca'
    _displayname = 'Principal Component Analysis (PCA)'
    # Output layer name: If set to None, the displayname is used 
    # Can be changed while getting the parameters.
    _outputname = 'PCA'
    _groupid = "pca" 
    _help = """
            Principal Component Analysis (PCA), \
            calculated using Singular Value Decomposition (SVD) using svd from \
            <a href="https://docs.scipy.org/doc/scipy/reference/linalg.html">scipy.linalg</a>.

            With default parameters, all components are kept. Optionally, either the \
            <i>number of components</i> to keep or the <i>percentage of variance</i> \
            explained by the kept components can be set. 

            
            <b>Number of components</b> is only used if the value is greater than 0 and \
            smaller than the count of original bands and if percentage of variance is \
            not set.

            <b>Percentage of variance to keep</b> is only used if it is greater than 0 \
            (typical values would be in the range between 90 and 100).

            <b>Output</b> The output raster contains \
            the data projected into the principal components \
            (i.e. the PCA scores).

            <b>Output data type</b> Float32 or Float64

            The following values / vectors are avaible a) in the log tab of \
            the processing window, b) in JSON format in the "Abstract" field \
            of the metadata of the output raster layer, eventually to be used \
            by subsequent transformations, and c) in the output dict if \
            the tool has been called from the python console or a script:\n
            <ul>
            <li>Singular values (of SVD)</li>
            <li>Variance explained (Eigenvalues)</li>
            <li>Ratio of variance explained</li>
            <li>Cumulated sum of variance explained</li>
            <li>Eigenvectors (V of SVD)</li>
            <li>Loadings (eigenvectors scaled by sqrt(eigenvalues))</li>
            <li>Band Mean</li>
            </ul>

            """
    
    # Init Algorithm
    def initAlgorithm(self, config):
        """
        Here we define the inputs and output of the algorithm, along
        with some other properties.
        """

        # Add parameters
        self.addParameter(
            QgsProcessingParameterRasterLayer(
                self.INPUT,
                self.tr('Input layer'),
            )
        )


        self.addParameter(QgsProcessingParameterNumber(
            self.NCOMPONENTS,
            self.tr('Number of components to keep. Set to 0 for all components.'),
            QgsProcessingParameterNumber.Type.Integer,
            defaultValue=0, 
            optional=True, 
            minValue=0, 
            # maxValue=100
            ))      
    

        self.addParameter(QgsProcessingParameterNumber(
            self.PERCENTVARIANCE,
            self.tr('Percentage of Variance to keep (if set and > 0: overwrites number of components)'),
            QgsProcessingParameterNumber.Type.Double,
            defaultValue=0, 
            optional=True, 
            minValue=0, 
            maxValue=100
            ))      
    
        dtype_param = QgsProcessingParameterEnum(
            self.DTYPE,
            self.tr('Output data type'),
            ['Float32 (32 bit float)', 'Float64 (64 bit float)'],
            defaultValue=0,
            optional=True)
        
        # Set as advanced parameter
        dtype_param.setFlags(dtype_param.flags() | QgsProcessingParameterDefinition.Flag.FlagAdvanced)
        self.addParameter(dtype_param)


        self.addParameter(
            QgsProcessingParameterRasterDestination(
                self.OUTPUT,
            self.tr(self._outputname)))
        

    def processAlgorithm(self, parameters, context, feedback):
        """
        Here is where the processing itself takes place.
        """

        # Get Parameters
        self.inputlayer = self.parameterAsRasterLayer(parameters, self.INPUT, context)
        self.output_raster = self.parameterAsOutputLayer(parameters, self.OUTPUT,context)

        self.ncomponents = self.parameterAsInt(parameters, self.NCOMPONENTS,context)
        self.percentvariance = self.parameterAsDouble(parameters, self.PERCENTVARIANCE,context)

        self.outdtype = self.parameterAsInt(parameters, self.DTYPE, context)
        self.outdtype = self.outdtype + 6 # float32 and float64 in gdal

        # Open Raster with GDAL
        self.ds = gdal.Open(self.inputlayer.source())

        if not self.ds:
            raise Exception("Failed to open Raster Layer")
        
        self.bandcount = self.ds.RasterCount
        self.indatatype = self.ds.GetRasterBand(1).DataType


        if feedback.isCanceled():
            return {}
        
        feedback.setProgress(0)

        # Start the actual work
        if self.outdtype == 6:
            a = self.ds.ReadAsArray().astype(np.float32)
        else:
            a = self.ds.ReadAsArray().astype(np.float64)

        # shape is bands, RasterYSize, RasterXSize
        orig_shape = a.shape
        flattened = a.reshape(orig_shape[0], -1)

        flattened = flattened.T
        # Now shape is number of pixels, bands

        # substract mean

        col_mean = flattened.mean(axis=0)

        centered = flattened - col_mean[np.newaxis, :]

        n_pixels = flattened.shape[0]

        # Get loadings with SVD

        # For info on relation of SVD and PCA see:
        # https://stats.stackexchange.com/a/134283
        # https://scentellegher.github.io/machine-learning/2020/01/27/pca-loadings-sklearn.html
        # https://stats.stackexchange.com/a/141755

        # Note: U, S, VT = svd(X) followed by S = S * constant
        # is identical to U, S, VT = svd(x * constant)
        # U and VT do not change.
        # The constant used for normalization in PCA is: 1 / sqrt(n_samples) 
        # or 1 / sqrt(n_samples - 1)


        U, S, VT = linalg.svd(centered,full_matrices=False)

        loadings = VT.T @ np.diag(S) / np.sqrt(n_pixels - 1)

        # variance_explained = eigenvalues
        # and they can be calculated from the singular values (S)
        # See point 3 in https://stats.stackexchange.com/a/134283

        variance_explained = S * S / (n_pixels - 1)
        variance_ratio = variance_explained / variance_explained.sum()
        variance_explained_cumsum = variance_ratio.cumsum()


        if feedback.isCanceled():
            return {}

        # Rotate component vectors by 180° if sum of loadings is < 0
        # Otherwise dark will be bright, and vica versa

        for i in range(loadings.shape[1]):
            if loadings[:,i].sum() < 0:
                loadings[:,i] = loadings[:,i] * -1
                VT.T[:,i] = VT.T[:,i] * -1

        # Give feedback
        
        feedback.pushInfo("Singular values (of SVD):")
        feedback.pushInfo(str(S.tolist()))
        feedback.pushInfo("\nVariance explained (Eigenvalues):")
        feedback.pushInfo(str(variance_explained.tolist()))
        feedback.pushInfo("\nRatio of variance explained:")
        feedback.pushInfo(str(variance_ratio.tolist()))
        feedback.pushInfo("\nCumulated sum of variance explained:")
        feedback.pushInfo(str(variance_explained_cumsum.tolist()))
        feedback.pushInfo("\nEigenvectors (V of SVD):")
        feedback.pushInfo(str(VT.T.tolist()))
        feedback.pushInfo("\nLoadings (eigenvectors scaled by sqrt(eigenvalues)):")
        feedback.pushInfo(str(loadings.tolist()))
        feedback.pushInfo("\nBand Mean:")
        feedback.pushInfo(str(col_mean.tolist()) + "\n")

        if feedback.isCanceled():
            return {}

        # Get the scores, i.e. the data in principal components
        new_array = centered @ VT.T


        # Reshape to original shape
        new_array = new_array.T.reshape(orig_shape)

        # How many bands to keep?
        bands = self.bandcount

        if 0 < self.percentvariance < 100:
            fraction = self.percentvariance / 100
            # get index with >= fraction and add 1 (bands is not zero indexed)
            bands = np.argmax(variance_explained_cumsum >= fraction) + 1
            bands = int(bands) # np.argmax returns np.int64
        elif 0 < self.ncomponents < self.bandcount:
            bands = self.ncomponents 

        # Prepare output and write file
        # etype = gdal.GDT_Float32

        driver = gdal.GetDriverByName('GTiff')
        self.out_ds = driver.Create(self.output_raster,
                                    xsize=self.ds.RasterXSize,
                                    ysize=self.ds.RasterYSize,
                                    bands=bands,
                                    eType=self.outdtype)

        self.out_ds.SetGeoTransform(self.ds.GetGeoTransform())
        self.out_ds.SetProjection(self.ds.GetProjection())

        self.out_ds.WriteArray(new_array[0:bands,:,:])    

        # Calculate and write band statistics (min, max, mean, std)
        for b in range(1, bands + 1):
            band = self.out_ds.GetRasterBand(b)
            stats = band.GetStatistics(0,1)
            band.SetStatistics(*stats)

        # Add band description
        for i in range(bands):
            band = self.out_ds.GetRasterBand(i + 1)
            band.SetDescription(f"({100*variance_ratio[i]:.1f}%; {100*variance_explained_cumsum[i]:.1f}%)")

        # Close the dataset to write file to disk
        self.out_ds = None 

        encoded = json.dumps({
                'singular values': S.tolist(),
                'eigenvectors': VT.T.tolist(),
                'loadings': loadings.tolist(),
                'variance explained': variance_explained.tolist(),
                'variance_ratio': variance_ratio.tolist(),
                'variance explained cumsum': variance_explained_cumsum.tolist(),
                'band mean': col_mean.tolist(),
                })

        # Save loadings etc as json in the metadata abstract of the layer
        global updatemetadata
        updatemetadata = self.UpdateMetadata(encoded)
        context.layerToLoadOnCompletionDetails(self.output_raster).setPostProcessor(updatemetadata)

        return {self.OUTPUT: self.output_raster,
                'singular values': S,
                'loadings': loadings,
                'variance explained': variance_explained,
                'variance_ratio': variance_ratio,
                'variance explained cumsum': variance_explained_cumsum,
                'band mean': col_mean,
                'eigenvectors': VT.T,
                'json': encoded}


    def checkParameterValues(self, parameters, context):
        layer = self.parameterAsRasterLayer(parameters, self.INPUT, context)
        # PCA only possible with more than 1 bands
        if layer.bandCount() == 1:
            return (False, self.tr("PCA only possible if input layer has more than 1 bands"))
            
        return super().checkParameterValues(parameters, context)
    
    class UpdateMetadata(QgsProcessingLayerPostProcessorInterface):
        """
        To add metadata in the postprocessing step.
        """
        def __init__(self, abstract):
            self.abstract = abstract
            super().__init__()
            
        def postProcessLayer(self, layer, context, feedback):
            meta = layer.metadata()
            meta.setAbstract(self.abstract)
            layer.setMetadata(meta)
    

    def name(self):
        """
        Returns the algorithm name, used for identifying the algorithm. This
        string should be fixed for the algorithm, and must not be localised.
        The name should be unique within each provider. Names should contain
        lowercase alphanumeric characters only and no spaces or other
        formatting characters.
        """
        return self._name

    def displayName(self):
        """
        Returns the translated algorithm name, which should be used for any
        user-visible display of the algorithm name.
        """
        return self.tr(self._displayname)

    def group(self):
        """
        Returns the name of the group this algorithm belongs to. This string
        should be localised.
        """
        if self._groupid == "":
            return ""
        s = groups.get(self._groupid)
        if not s:
            # If group ID is not in dictionary group, return error message for debugging
            return "Displayname of group must be set in groups dictionary"
        return self.tr(s)

    def groupId(self):
        """
        Returns the unique ID of the group this algorithm belongs to. This
        string should be fixed for the algorithm, and must not be localised.
        The group id should be unique within each provider. Group id should
        contain lowercase alphanumeric characters only and no spaces or other
        formatting characters.
        """
        return self._groupid
    
    def shortHelpString(self):
        """
        Returns the help string that is shown on the right side of the 
        user interface.
        """
        return self.tr(self._help)

    def tr(self, string):
        return QCoreApplication.translate('Processing', string)

    def createInstance(self):
        return SciPyPCAAlgorithm()  
    