# -*- coding: utf-8 -*-

"""
/***************************************************************************
 SciPyFilters
                                 A QGIS plugin
 Filter collection implemented with SciPy
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2024-03-03
        copyright            : (C) 2024 by Florian Neukirchen
        email                : mail@riannek.de
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""

__author__ = 'Florian Neukirchen'
__date__ = '2024-03-03'
__copyright__ = '(C) 2024 by Florian Neukirchen'

# This will get replaced with a git SHA1 when you do a git archive

__revision__ = '$Format:%H$'

from osgeo import gdal
import numpy as np
import json
from qgis.PyQt.QtCore import QCoreApplication
from qgis._core import QgsProcessingContext, QgsProcessingFeedback
from qgis.core import (QgsProcessingAlgorithm,
                       QgsProcessingParameterString,
                       QgsProcessingParameterDefinition,
                       QgsProcessingParameterNumber,
                       QgsProcessingException,
                       QgsProcessingLayerPostProcessorInterface,
                       QgsProcessingParameterBoolean,
                       QgsProcessingParameterRasterLayer,
                       QgsProcessingParameterEnum,
                       QgsProcessingParameterFileDestination)

from processing.core.ProcessingConfig import ProcessingConfig


from scipy_filters.helpers import (str_to_array, 
                      convert_docstring_to_html,
                      bandmean,
                      MAXSIZE)

from scipy_filters.scipy_algorithm_baseclasses import groups
from scipy_filters.ui.i18n import tr


import numpy as np
import plotly.graph_objects as go

gdal.UseExceptions()

def pca_biplot(pc1, pc2, loadings, labels=None):
    """
    Create a PCA biplot from scores and loadings.

    :param pc1: The data (raster band) of the first principal component. 
    :type pc2: numpy.ndarray
    :param pc2: The data (raster band) of the second principal component.
    :type pc2: numpy.ndarray
    :param loadings: The loadings of the principal components. At least the first two components are required.
    :type loadings: numpy.ndarray
    :param labels: The labels for the bands of the original raster.
    :type labels: list, optional
    """

    pc1 = pc1.flatten()
    pc2 = pc2.flatten()

    if not labels:
        labels = [f'Band {i+1}' for i in range(loadings.shape[1])]
    elif len(labels) != loadings.shape[1]:
        raise ValueError('The number of labels must match the number of bands in the loadings.')

    # https://plotly.com/python/pca-visualization/
    # https://statisticsglobe.com/biplot-pca-explained

    fig = go.Figure()

    fig.add_trace(go.Histogram2dContour(
        x=pc1,
        y=pc2,
    ))

    # Same scale for both axes
    fig.update_yaxes(
        scaleanchor="x",
        scaleratio=1,
    )

    fig.update_layout(
        title="PCA Biplot",
        xaxis_title="PC1",
        yaxis_title="PC2",
    )

    for i, label in enumerate(labels):
        fig.add_annotation(
            ax=0,
            ay=0,
            axref="x",
            ayref="y",
            x=loadings[i, 0],
            y=loadings[i, 1],
            showarrow=True,
            arrowhead=2,
            arrowsize=2,
            xanchor="right",
            yanchor="top",
        )


        fig.add_annotation(
            x=loadings[i, 0],
            y=loadings[i, 1],
            text=label,
            showarrow=False,
            xanchor="center",
            yanchor="bottom",
            yshift=10,
        )

    return fig


class SciPyPCABiplot(QgsProcessingAlgorithm):
    """
    Plot an interactive biplot of the principal components 
    
    The input raster must be the result of the PCA algorithm of this plugin,
    with at least the first two components (2 bands).
    The loadings are read from the metadata.

    The scores (PC1 and PC2 of each pixel) are plotted as contours (there would be too many points for a scatter plot).

    The loadings are plotted as vectors, with the length and direction indicating the contribution of the bands
    to the first and second principal component.

    Use the zoom tool to explore.

    .. note:: Requires Plotly.

    **Input** PCA layer
    
    **Plot** The plot is saved as html and can be opened in a browser.
    """

    OUTPUT = 'OUTPUT'
    INPUT = 'INPUT'

    _name = 'biplot'
    _displayname = tr('Biplot')

    # Init Algorithm
    def initAlgorithm(self, config):
        """
        Here we define the inputs and output of the algorithm, along
        with some other properties.
        """

        # Add parameters
        self.addParameter(
            QgsProcessingParameterRasterLayer(
                self.INPUT,
                tr('Input layer'),
            )
        )

        self.addParameter(QgsProcessingParameterFileDestination(
                self.OUTPUT,
                tr('PCA Biplot'),
                tr('HTML files (*.html)'),
            ))
        
    def processAlgorithm(self, parameters, context, feedback):

        # Get Parameters
        self.inputlayer = self.parameterAsRasterLayer(parameters, self.INPUT, context)
        self.output = self.parameterAsFileOutput(parameters, self.OUTPUT, context)  

        abstract = self.inputlayer.metadata().abstract()
        loadings = self.json_to_parameters(abstract)

        if loadings is None:
            raise ValueError('Loadings not found in metadata.')

        # Open Raster with GDAL
        self.ds = gdal.Open(self.inputlayer.source())

        if not self.ds:
            raise Exception("Failed to open Raster Layer")
        
        pc1 = self.ds.GetRasterBand(1).ReadAsArray()
        pc2 = self.ds.GetRasterBand(2).ReadAsArray()

        fig = pca_biplot(pc1, pc2, loadings, None)
        fig.write_html(self.output)

        return {self.OUTPUT: self.output}
    

    def json_to_parameters(self, s):
        """
        Get loadings from json string
        
        :param s: The json string
        :type s: str
        :return: The PCA loadings 
        :rtype: numpy.ndarray
        """
        s = s.strip()
        if s == "":
            return None
        try:
            decoded = json.loads(s)
        except (json.decoder.JSONDecodeError, ValueError, TypeError):
            return None

        loadings = decoded.get("loadings", None)

        try:
            loadings = np.array(loadings)
        except (ValueError, TypeError):
            loadings = None

        return loadings


    def checkParameterValues(self, parameters, context):

        inputlayer = self.parameterAsRasterLayer(parameters, self.INPUT, context)

        if not inputlayer.providerType() == "gdal":
            return False, tr("Raster provider {} is not supported, must be raster layer with a local file".format(inputlayer.providerType()))
        

        abstract = inputlayer.metadata().abstract()
        loadings = self.json_to_parameters(abstract)

        if loadings is None:
            return False, tr('Loadings not found in metadata. Please use a PCA layer as input.')
        
        return super().checkParameterValues(parameters, context)
    

    def shortHelpString(self):
        """
        Returns the help string that is shown on the right side of the 
        user interface.
        """
        docstring = self.__doc__
        return convert_docstring_to_html(docstring)
    
    def name(self):
        """
        Returns the algorithm name, used for identifying the algorithm. This
        string should be fixed for the algorithm, and must not be localised.
        The name should be unique within each provider. Names should contain
        lowercase alphanumeric characters only and no spaces or other
        formatting characters.
        """
        return self._name

    def displayName(self):
        """
        Returns the translated algorithm name, which should be used for any
        user-visible display of the algorithm name.
        """
        return tr(self._displayname)

    def group(self):
        """
        Returns the name of the group this algorithm belongs to. This string
        should be localised.
        """
        s = groups.get("pca")
        return tr(s)

    def groupId(self):
        """
        Returns the unique ID of the group this algorithm belongs to. This
        string should be fixed for the algorithm, and must not be localised.
        The group id should be unique within each provider. Group id should
        contain lowercase alphanumeric characters only and no spaces or other
        formatting characters.
        """
        return "pca"

    def createInstance(self):
        return SciPyPCABiplot()  