# -*- coding: utf-8 -*-

"""
/***************************************************************************
 SciPyFilters
                                 A QGIS plugin
 Filter collection implemented with SciPy
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2024-03-03
        copyright            : (C) 2024 by Florian Neukirchen
        email                : mail@riannek.de
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""

__author__ = 'Florian Neukirchen'
__date__ = '2024-03-03'
__copyright__ = '(C) 2024 by Florian Neukirchen'

# This will get replaced with a git SHA1 when you do a git archive

__revision__ = '$Format:%H$'

from osgeo import gdal
from scipy import linalg
import numpy as np
import json
from qgis.PyQt.QtCore import QCoreApplication
from qgis.core import (QgsProcessingAlgorithm,
                       QgsProcessingParameterRasterLayer,
                       QgsProcessingParameterNumber,
                       QgsProcessingParameterRasterDestination,
                       QgsProcessingLayerPostProcessorInterface,
                       QgsProcessingException,
                       QgsProcessingParameterEnum,
                       QgsProcessingParameterDefinition,
                       QgsProcessingParameterBoolean,
                       QgsProcessingParameterFileDestination
                        )

from processing.core.ProcessingConfig import ProcessingConfig

from scipy_filters.scipy_algorithm_baseclasses import groups
from scipy_filters.helpers import convert_docstring_to_html, MAXSIZE

from scipy_filters.ui.i18n import tr


try:
    import plotly.graph_objects as go
    withplotly = True
except ImportError:
    withplotly = False

class SciPyPCAAlgorithm(QgsProcessingAlgorithm):
    """
    Principal Component Analysis (PCA) 

    calculated using Singular Value Decomposition (SVD) using svd from 
    `scipy.linalg <https://docs.scipy.org/doc/scipy/reference/linalg.html>`_.

    With default parameters, all components are kept. Optionally, either the 
    *number of components* to keep or the *percentage of variance* 
    explained by the kept components can be set. 

    **Standard Scaler** Optionally scale each band to unit variance (std of 1) before performing PCA (std of each band is reported in the output). Otherwise, the absolute values of the scores will be in a similar range as the original data. (Added in version 1.5)

    **Number of components** to keep. 0 for all components. If negative: 
    number of components to remove. 
    Ignored if percentage of variance is set.

    **Percentage of variance to keep** is only used if it is greater than 0 
    (typical values would be in the range between 90 and 100).

    **Plot** If plotly is available, a plot of the variance explained by
    the principal components is created and saved as HTML file.

    **Output** The output raster contains 
    the data projected into the principal components 
    (i.e. the PCA scores).

    **Output data type** Float32 or Float64

    The following values / vectors are avaible a) in the log tab of 
    the processing window, b) in JSON format in the "Abstract" field 
    of the metadata of the output raster layer, eventually to be used 
    by subsequent transformations, and c) in the output dict if 
    the tool has been called from the python console or a script:
    Singular values (of SVD),
    Variance explained (Eigenvalues),
    Ratio of variance explained,
    Cumulated sum of variance explained,
    Eigenvectors (V of SVD),
    Loadings (eigenvectors scaled by sqrt(eigenvalues)),
    Band Mean.

    The plugin should give the same results as sklearn.decomposition.PCA
    from `scikit-learn <https://scikit-learn.org/>`_: 
    'singular values' is `pca.singular_values_`
    'eigenvectors' is `pca.components_`
    'variance explained' is `pca.explained_variance_` in sklearn.
    """

    # Constants used to refer to parameters and outputs. They will be
    # used when calling the algorithm from another algorithm, or when
    # calling from the QGIS console.

    OUTPUT = 'OUTPUT'
    PLOT = 'PLOT'
    INPUT = 'INPUT'
    STANDARDSCALER = 'STANDARDSCALER'
    NCOMPONENTS = 'NCOMPONENTS'
    PERCENTVARIANCE = 'PERCENTVARIANCE'
    DTYPE = 'DTYPE'
    BANDSTATS = 'BANDSTATS'

    NODATA = -9999
    
    _name = 'pca'
    _outputname = tr('PCA')
    
    # Init Algorithm
    def initAlgorithm(self, config):
        """
        Here we define the inputs and output of the algorithm, along
        with some other properties.
        """
        try:
            self.maxsize = int(ProcessingConfig.getSetting('MAXSIZE'))
        except TypeError:
            self.maxsize = MAXSIZE

        # Add parameters
        self.addParameter(
            QgsProcessingParameterRasterLayer(
                self.INPUT,
                tr('Input layer'),
            )
        )

        self.addParameter(QgsProcessingParameterBoolean(
            self.STANDARDSCALER,
            tr('Standard Scaler: Scale each band to unit variance (std of 1) before PCA'),
            optional=True,
            defaultValue=False,
        ))


        self.addParameter(QgsProcessingParameterNumber(
            self.NCOMPONENTS,
            tr('Number of components to keep. Set to 0 for all components; negative for number of components to remove.'),
            QgsProcessingParameterNumber.Type.Integer,
            defaultValue=0, 
            optional=True, 
            # minValue=0, 
            # maxValue=100
            ))      
    

        self.addParameter(QgsProcessingParameterNumber(
            self.PERCENTVARIANCE,
            tr('Percentage of Variance to keep (if set and > 0: overwrites number of components)'),
            QgsProcessingParameterNumber.Type.Double,
            defaultValue=0, 
            optional=True, 
            minValue=0, 
            maxValue=100
            ))      
        
        stats_param = QgsProcessingParameterBoolean(
            self.BANDSTATS,
            tr('Calculate band statistics'),
            optional=True,
            defaultValue=True,
        )

                
        stats_param.setFlags(stats_param.flags() | QgsProcessingParameterDefinition.Flag.FlagAdvanced)
      
        self.addParameter(stats_param)
    
        dtype_param = QgsProcessingParameterEnum(
            self.DTYPE,
            tr('Output data type'),
            ['Float32 (32 bit float)', 'Float64 (64 bit float)'],
            defaultValue=0,
            optional=True)
        
        # Set as advanced parameter
        dtype_param.setFlags(dtype_param.flags() | QgsProcessingParameterDefinition.Flag.FlagAdvanced)
        self.addParameter(dtype_param)

        if withplotly:
            self.addParameter(QgsProcessingParameterFileDestination(
                self.PLOT,
                tr('Plot of Variance Explained (PCA)'),
                tr('HTML files (*.html)'),
            ))


        self.addParameter(
            QgsProcessingParameterRasterDestination(
                self.OUTPUT,
            tr(self._outputname)))
        

    def processAlgorithm(self, parameters, context, feedback):
        """
        Here is where the processing itself takes place.
        """
        # Get Parameters
        self.inputlayer = self.parameterAsRasterLayer(parameters, self.INPUT, context)
        self.output_raster = self.parameterAsOutputLayer(parameters, self.OUTPUT,context)

        self.standardscaler = self.parameterAsBool(parameters, self.STANDARDSCALER, context)
        self.ncomponents = self.parameterAsInt(parameters, self.NCOMPONENTS,context)
        self.percentvariance = self.parameterAsDouble(parameters, self.PERCENTVARIANCE,context)

        self.outdtype = self.parameterAsInt(parameters, self.DTYPE, context)
        self.outdtype = self.outdtype + 6 # float32 and float64 in gdal

        self.bandstats = self.parameterAsBool(parameters, self.BANDSTATS, context)

        # Open Raster with GDAL
        self.ds = gdal.Open(self.inputlayer.source())

        if not self.ds:
            raise Exception("Failed to open Raster Layer")
        
        self.bandcount = self.ds.RasterCount
        self.indatatype = self.ds.GetRasterBand(1).DataType

        # Get no data value from band 1. 
        # Geotiff only has one no data value, other formats could have different ones 
        # per band, so this is not optimal
        nodatavalue = self.ds.GetRasterBand(1).GetNoDataValue()

        if feedback.isCanceled():
            return {}
        feedback.setProgress(0)

        # Start the actual work
        if self.outdtype == 6:
            a = self.ds.ReadAsArray().astype(np.float32)
        else:
            a = self.ds.ReadAsArray().astype(np.float64)

        # shape is bands, RasterYSize, RasterXSize
        orig_shape = a.shape
        flattened = a.reshape(orig_shape[0], -1)

        a = None # Save memory

        flattened = flattened.T

        nodata_mask = np.any(flattened == nodatavalue, axis=1)

        n_pixels = flattened[~nodata_mask].shape[0]

        # substract mean

        col_mean = flattened[~nodata_mask].mean(axis=0)
        centered = flattened - col_mean[np.newaxis, :]

        flattened = None

        # Eventually scale each band to unit variance (std of 1) 
        if self.standardscaler:
            col_std = centered[~nodata_mask].std(axis=0)
            centered = centered / col_std[np.newaxis, :]

        feedback.setProgress(5)
        if feedback.isCanceled():
            return {}

        # Get eigenvectors, eigenvalues, loadings with SVD

        # For info on relation of SVD and PCA see:
        # https://stats.stackexchange.com/a/134283
        # https://scentellegher.github.io/machine-learning/2020/01/27/pca-loadings-sklearn.html
        # https://stats.stackexchange.com/a/141755

        # Note: U, S, VT = svd(X) followed by S = S * constant
        # is identical to U, S, VT = svd(x * constant)
        # U and VT do not change.
        # The factor usually used for normalization in PCA is: 1 / sqrt(n_samples - 1)

        U, S, VT = linalg.svd(centered[~nodata_mask],full_matrices=False)

        U = None # Save memory, not needed anymore
        feedback.setProgress(15)

        # loadings = eigenvectors scaled by sqrt(eigenvalues)
        loadings = VT.T @ np.diag(S) / np.sqrt(n_pixels - 1)

        # variance_explained = eigenvalues
        # and they can be calculated from the singular values (S)
        # See point 3 in https://stats.stackexchange.com/a/134283

        variance_explained = S * S / (n_pixels - 1)
        variance_ratio = variance_explained / variance_explained.sum()
        variance_explained_cumsum = variance_ratio.cumsum()


        if feedback.isCanceled():
            return {}

        # Rotate component vectors by 180° if sum of loadings is < 0
        # Otherwise dark will be bright, and vica versa

        for i in range(loadings.shape[1]):
            if loadings[:,i].sum() < 0:
                loadings[:,i] = loadings[:,i] * -1
                VT.T[:,i] = VT.T[:,i] * -1

        # Give feedback
        
        feedback.pushInfo("Singular values (of SVD):")
        feedback.pushInfo(str(S.tolist()))
        feedback.pushInfo("\nVariance explained (Eigenvalues):")
        feedback.pushInfo(str(variance_explained.tolist()))
        feedback.pushInfo("\nRatio of variance explained:")
        feedback.pushInfo(str(variance_ratio.tolist()))
        feedback.pushInfo("\nCumulated sum of variance explained:")
        feedback.pushInfo(str(variance_explained_cumsum.tolist()))
        feedback.pushInfo("\nEigenvectors (V of SVD):")
        feedback.pushInfo(str(VT.T.tolist()))
        feedback.pushInfo("\nLoadings (eigenvectors scaled by sqrt(eigenvalues)):")
        feedback.pushInfo(str(loadings.tolist()))
        feedback.pushInfo("\nBand Mean:")
        feedback.pushInfo(str(col_mean.tolist()) + "\n")
        if self.standardscaler:
            feedback.pushInfo("Standard deviation of each band:")
            feedback.pushInfo(str(col_std.tolist()) + "\n")

        feedback.setProgress(30)
        if feedback.isCanceled():
            return {}

        # Get the scores, i.e. the data in principal components
        new_array = centered @ VT.T
        centered = None

        # Set no data value
        new_array[nodata_mask] = self.NODATA

        # Reshape to original shape
        new_array = new_array.T.reshape(orig_shape)

        # How many bands to keep?
        bands = self.bandcount

        if self.ncomponents <= 0:
            self.ncomponents = self.bandcount + self.ncomponents

        if 0 < self.percentvariance < 100:
            fraction = self.percentvariance / 100
            # get index with >= fraction and add 1 (bands is not zero indexed)
            bands = np.argmax(variance_explained_cumsum >= fraction) + 1
            bands = int(bands) # np.argmax returns np.int64
        elif 0 < self.ncomponents < self.bandcount:
            bands = self.ncomponents 

        # Prepare output and write file
        # etype = gdal.GDT_Float32

        driver = gdal.GetDriverByName('GTiff')
        self.out_ds = driver.Create(self.output_raster,
                                    xsize=self.ds.RasterXSize,
                                    ysize=self.ds.RasterYSize,
                                    bands=bands,
                                    eType=self.outdtype)

        self.out_ds.SetGeoTransform(self.ds.GetGeoTransform())
        self.out_ds.SetProjection(self.ds.GetProjection())

        self.out_ds.WriteArray(new_array[0:bands,:,:])   

        # Set no data value
        if nodatavalue is not None:
            for b in range(1, bands + 1):
                self.out_ds.GetRasterBand(b).SetNoDataValue(self.NODATA)


        feedback.setProgress(80)
        # Calculate and write band statistics (min, max, mean, std)
        if self.bandstats:
            for b in range(1, bands + 1):
                band = self.out_ds.GetRasterBand(b)
                stats = band.GetStatistics(0,1)
                band.SetStatistics(*stats)

        # Add band description
        for i in range(bands):
            band = self.out_ds.GetRasterBand(i + 1)
            band.SetDescription(f"({100*variance_ratio[i]:.1f}%; {100*variance_explained_cumsum[i]:.1f}%)")

        # Close the dataset to write file to disk
        self.out_ds = None 

        # Free some memory
        self.ds = None
        new_array = None

        # Save loadings etc as json in the metadata abstract of the layer
        json_dict = {
                'singular values': S.tolist(),
                'eigenvectors': VT.T.tolist(),
                'loadings': loadings.tolist(),
                'variance explained': variance_explained.tolist(),
                'variance_ratio': variance_ratio.tolist(),
                'variance explained cumsum': variance_explained_cumsum.tolist(),
                'band mean': col_mean.tolist(),
                }
        
        if self.standardscaler:
            json_dict['band std'] = col_std.tolist()

        encoded = json.dumps(json_dict)

        
        global updatemetadata
        updatemetadata = self.UpdateMetadata(encoded)
        context.layerToLoadOnCompletionDetails(self.output_raster).setPostProcessor(updatemetadata)

        output = {self.OUTPUT: self.output_raster,
                'singular values': S,
                'loadings': loadings,
                'variance explained': variance_explained,
                'variance_ratio': variance_ratio,
                'variance explained cumsum': variance_explained_cumsum,
                'band mean': col_mean,
                'eigenvectors': VT.T,
                'json': encoded}
        
        if self.standardscaler:
            output['band std'] = col_std

        if withplotly:
            feedback.setProgress(90)
            html_file = self.parameterAsFileOutput(parameters, self.PLOT, context)  
            fig = plot_pca_variance(variance_ratio, variance_explained_cumsum)
            fig.write_html(html_file)
            output[self.PLOT] = html_file

        feedback.setProgress(100)

        return output


    def checkParameterValues(self, parameters, context):
        layer = self.parameterAsRasterLayer(parameters, self.INPUT, context)

        if not layer.providerType() == 'gdal':
            return False, tr("Raster provider {} is not supported, must be raster layer with a local file".format(layer.providerType()))
        
        # Check maxsize
        size = layer.width() * layer.height() / 1000000 # megapixels
        if size > self.maxsize:
            return False, tr("Raster size is larger than maxsize (see settings).")

        # PCA only possible with more than 1 bands
        if layer.bandCount() == 1:
            return (False, tr("PCA only possible if input layer has more than 1 bands"))
            
        return super().checkParameterValues(parameters, context)
    
    class UpdateMetadata(QgsProcessingLayerPostProcessorInterface):
        """
        To add metadata in the postprocessing step.
        """
        def __init__(self, abstract):
            self.abstract = abstract
            super().__init__()
            
        def postProcessLayer(self, layer, context, feedback):
            meta = layer.metadata()
            meta.setAbstract(self.abstract)
            layer.setMetadata(meta)
    


    def name(self):
        """
        Returns the algorithm name, used for identifying the algorithm. This
        string should be fixed for the algorithm, and must not be localised.
        The name should be unique within each provider. Names should contain
        lowercase alphanumeric characters only and no spaces or other
        formatting characters.
        """
        return self._name

    def displayName(self):
        """
        Returns the translated algorithm name, which should be used for any
        user-visible display of the algorithm name.
        """
        return tr('Principal Component Analysis (PCA)')

    def group(self):
        """
        Returns the name of the group this algorithm belongs to. This string
        should be localised.
        """
        s = groups.get("pca")
        return tr(s)

    def groupId(self):
        """
        Returns the unique ID of the group this algorithm belongs to. This
        string should be fixed for the algorithm, and must not be localised.
        The group id should be unique within each provider. Group id should
        contain lowercase alphanumeric characters only and no spaces or other
        formatting characters.
        """
        return "pca"
    
    def shortHelpString(self):
        """
        Returns the help string that is shown on the right side of the 
        user interface.
        """
        docstring = self.__doc__
        return convert_docstring_to_html(docstring)
    
        


    def createInstance(self):
        return SciPyPCAAlgorithm()  
    

def plot_pca_variance(ratio_variance_explained, cumsum_variance_explained):
    """
    Plot the variance explained by the principal components.

    Requires plotly.
    """
    
    x = list(range(1, len(ratio_variance_explained) + 1))

    fig = go.Figure(
        layout_title_text="Variance Explained"
    )

    fig.add_trace(go.Scatter(
        x=x,
        y=ratio_variance_explained,
        name='Fraction of Variance Explained',
        hovertemplate='Component %{x}<br>%{y:.5f}',
        mode='lines+markers'
    ))

    fig.add_trace(go.Scatter(
        x=x,
        y=cumsum_variance_explained,
        name='Cumulated Sum of Variance Explained',
        hovertemplate='Component %{x}<br>%{y:.5f}',
        mode='lines+markers',
        visible='legendonly',
    ))
                
    fig.update_layout(
        title='Variance Explained by Principal Components',
        xaxis_title="Principal Component",
        yaxis_title="Fraction of Variance Explained",
        xaxis_dtick=1,
    )

    return fig

    