from math import floor, ceil
from typing import List

import numpy as np
from qgis._core import QgsProject, QgsRasterBandStats, QgsColorRamp, QgsStyle, QgsColorRampShader, QgsRasterShader, \
    QgsSingleBandPseudoColorRenderer
from qgis.core import QgsMapLayerType

from landsklim.landsklim import Landsklim
from landsklim.lk.landsklim_analysis import LandsklimAnalysis
from landsklim.lk.phase_kriging import PhaseKriging
from landsklim.lk.phase_multiple_regression import PhaseMultipleRegression
from landsklim.lk.utils import LandsklimUtils


class QuickCommands:

    disabled: bool = False  # Make commands unusable

    @staticmethod
    def command(lk_instance: Landsklim, *args: str):
        if len(args) == 0 or QuickCommands.disabled:
            return
        else:
            command_name: str = str(args[0])
            if command_name == "integer":
                QuickCommands.command_integer(lk_instance, *args)
            if command_name == "float":
                QuickCommands.command_float(lk_instance, *args)
            if command_name == "style":
                QuickCommands.command_style(lk_instance, *args)
            if command_name == "force_model":
                QuickCommands.command_force_model(lk_instance, *args)

    @staticmethod
    def command_integer(lk_instance: Landsklim, *args):
        print("Execute command ...")
        for analysis in lk_instance.get_landsklim_project().get_analysis():  # type: LandsklimAnalysis
            for situation in analysis.get_station_situations():  # type: int
                for phase in analysis.get_phases(situation):  # type: IPhase
                    if phase.class_name() == PhaseMultipleRegression.class_name():
                        for model in phase.get_model():
                            model._predictors_are_integers = True
                            print("[model]")

    @staticmethod
    def command_float(lk_instance: Landsklim, *args):
        print("Execute command ...")
        for analysis in lk_instance.get_landsklim_project().get_analysis():  # type: LandsklimAnalysis
            for situation in analysis.get_station_situations():  # type: int
                for phase in analysis.get_phases(situation):  # type: IPhase
                    if phase.class_name() == PhaseMultipleRegression.class_name():
                        for model in phase.get_model():
                            model._predictors_are_integers = False
                            print("[model]")

    @staticmethod
    def command_style(lk_instance: Landsklim, *args):
        """
        from qgis.utils import plugins
        plugins['landsklim'].command("style")
        """
        print("Execute command ...")
        raster_layers = [layer for layer in list(QgsProject().instance().mapLayers().values()) if
                         layer.type() == QgsMapLayerType.RasterLayer]
        for layer in raster_layers:  # type: QgsRasterLayer
            provider = layer.dataProvider()
            stats = provider.bandStatistics(1, QgsRasterBandStats.All)
            p_min, p_max = floor(stats.minimumValue), ceil(stats.maximumValue)
            p_steps = (p_max - p_min) if len(args) < 2 else int(args[1])
            p_steps = max(1, p_steps)
            delta = p_max - p_min
            ramp: QgsColorRamp = QgsStyle().defaultStyle().colorRamp("Spectral")

            shader_function = QgsColorRampShader()
            shader_function.setColorRampType(QgsColorRampShader.Discrete)
            shader_function.setClassificationMode(QgsColorRampShader.EqualInterval)
            shader_function.setSourceColorRamp(ramp)

            fractional_steps = [i / p_steps for i in range(p_steps + 1)]
            colors = [ramp.color(1 - f) for f in fractional_steps]
            steps = [p_min + f * delta for f in fractional_steps]
            color_ramp = [QgsColorRampShader.ColorRampItem(step, color, str(step)) for step, color in
                          zip(steps, colors)]
            shader_function.setColorRampItemList(color_ramp)

            shader: QgsRasterShader = QgsRasterShader()
            shader.setRasterShaderFunction(shader_function)
            renderer: QgsSingleBandPseudoColorRenderer = QgsSingleBandPseudoColorRenderer(layer.dataProvider(), 1,
                                                                                          shader)
            layer.setRenderer(renderer)
            print("[layer]", layer)
            layer.triggerRepaint()

    @staticmethod
    def command_force_model(lk_instance: Landsklim, *args):
        print("Execute command ...")
        if len(args) < 3:
            print("force_model needs 3 arguments")
            return
        else:
            analysis_name: str = str(args[1])
            files: List[str] = []
            for i in range(len(args) - 2):
                files.append(str(args[2 + i]))
            i = 0
            for analysis in lk_instance.get_landsklim_project().get_analysis():  # type: LandsklimAnalysis
                if analysis.get_name() == analysis_name:
                    for situation in analysis.get_station_situations():
                        for phase in analysis.get_phases(situation):
                            if phase.class_name() == PhaseMultipleRegression.class_name():
                                LandsklimUtils.force_local_analysis_models(phase, files[i])
                                i = i + 1 if i < len(files) else i
                                print("[phase]")
                            if phase.class_name() == PhaseKriging.class_name():
                                # Reconstruct kriging from forced multiple regression model
                                import pandas as pd
                                from landsklim.lk.landsklim_constants import DATASET_COLUMN_X, DATASET_COLUMN_Y, \
                                    DATASET_RESPONSE_VARIABLE
                                residuals: np.ndarray = analysis.get_phases(situation)[0].get_residuals()
                                # LandsklimUtils.write_point_geopackage(os.path.abspath("residuals.gpkg"), self.get_landsklim_project().get_stations_sources()[0].qgis_layer(), residuals, "res-"+analysis.get_situation_name(situation), 30000)
                                dataframe_residuals: pd.DataFrame = pd.DataFrame(
                                    {DATASET_COLUMN_X: analysis._datasets[situation].dropna()[DATASET_COLUMN_X].values,
                                     DATASET_COLUMN_Y: analysis._datasets[situation].dropna()[DATASET_COLUMN_Y].values,
                                     DATASET_RESPONSE_VARIABLE: residuals})
                                phase.construct_model(dataframe_residuals)
