import base64
import json
import os
from io import StringIO
from typing import Dict, Type, Any

import numpy as np
import pandas as pd
from pykrige import OrdinaryKriging
from sklearn.linear_model import LinearRegression

from landsklim.lk.cache import qgis_project_cache
from landsklim.lk.landsklim_constants import DATASET_RESPONSE_VARIABLE
from landsklim.lk.landsklim_analysis import LandsklimAnalysis
from landsklim.lk.landsklim_configuration import LandsklimConfiguration
from landsklim.lk.landsklim_interpolation import LandsklimRectangle, LandsklimInterpolation, LandsklimInterpolationType
from landsklim.lk.landsklim_project import LandsklimProject
from landsklim.lk.map_layer import MapLayer, VectorLayer, RasterLayer, MapLayerCollection
from landsklim.lk.phase import IPhase
from landsklim.lk.phase_composite import PhaseComposite
from landsklim.lk.phase_multiple_regression import PhaseMultipleRegression
from landsklim.lk.polygons_definition import PolygonsDefinition
from landsklim.lk.regression_model import MultipleRegressionModel
from landsklim.lk.regressor import Regressor
from landsklim.serialization.json_encoder import PUBLIC_ENUMS
from landsklim.serialization.landsklim_unpickler import PearsonRResultDummyClass


class LandsklimDecoder(json.JSONDecoder):
    """
    Deserialize Landsklim classes from JSON
    """

    remapping = {
        'LISDQSProject': 'LandsklimProject',
        'LISDQSConfiguration': 'LandsklimConfiguration',
        'LISDQSAnalysis': 'LandsklimAnalysis',
        'LISDQSInterpolation': 'LandsklimInterpolation'
    }

    def __init__(self, *args, **kwargs):
        json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
        self.lk_class_map = {
            'np.int': np.integer,
            'np.float': np.floating,
            "LandsklimRectangle": LandsklimRectangle,
            'PearsonRResultDummyClass': PearsonRResultDummyClass,
            'LinearRegression': LinearRegression,
            'OrdinaryKriging': OrdinaryKriging,
            'LISDQSProject': LandsklimProject,  # Old name of LandsklimProject
            'LandsklimProject': LandsklimProject,
            'LISDQSConfiguration': LandsklimConfiguration,  # Old name of LandsklimConfiguration
            'LandsklimConfiguration': LandsklimConfiguration,
            'LISDQSAnalysis': LandsklimAnalysis,  # Old name of LandkslimAnalysis
            'LandsklimAnalysis': LandsklimAnalysis,
            'MultipleRegressionModel': MultipleRegressionModel,
            'LISDQSInterpolation': LandsklimInterpolation,  # Old name of LandkslimInterpolation
            'LandsklimInterpolation': LandsklimInterpolation,
            'PolygonsDefinition': PolygonsDefinition
        }
        subclasses = Regressor.__subclasses__() + IPhase.__subclasses__()
        for sc in subclasses:
            self.lk_class_map[sc.__name__] = sc

        maplayer_subclasses = MapLayer.__subclasses__()
        self.maplayer_class_map = {}
        for sc in maplayer_subclasses:
            self.maplayer_class_map[sc.__name__] = sc

    def rebuild_layer(self, dct: Dict, map_layer: MapLayer):
        map_layer._id = dct["id"]
        map_layer._landsklim_code = dct["landsklim_code"]
        map_layer.resolve()

    def __list_to_maplayercollection(self, lst: list):
        mlc = MapLayerCollection()
        for l in lst:
            mlc.append(l)
        return mlc

    def rebuild_maplayercollection(self, instance: LandsklimProject):
        instance._stations = self.__list_to_maplayercollection(instance._stations)
        instance._variables = self.__list_to_maplayercollection(instance._variables)
        instance._additional_layers = self.__list_to_maplayercollection(instance._additional_layers)

    def rebuild_polygons_path(self, instance: LandsklimProject):
        for analysis in instance.get_analysis():  # type: LandsklimAnalysis

            if analysis.is_local():
                # Update polygons path in case of the project was moved or the analysis was renamed ...
                analysis._polygons.set_polygon_path(analysis.get_polygons_raster_path())

            # Repopulate IPhase polygons from LandsklimAnalysis
            for situation in analysis.get_station_situations():  # type: int
                for phase in analysis.get_phases(situation):  # type: IPhase
                    phase._polygons = analysis.get_polygons()

    def get_json_dependencies(self, analysis: LandsklimAnalysis):
        path = os.path.join(qgis_project_cache().homePath(), os.path.normpath(analysis.json_polygons).replace("\\", "/"))
        with open(path, "rb") as json_file:
            analysis._polygons = json.load(json_file, cls=LandsklimDecoder)
        delattr(analysis, "json_polygons")

        path = os.path.join(qgis_project_cache().homePath(), os.path.normpath(analysis.json_models).replace("\\", "/"))
        with open(path, "rb") as json_file:
            analysis._models = json.load(json_file, cls=LandsklimDecoder)
        delattr(analysis, "json_models")

    def rebuild_hierarchy(self, instance: object):
        """
        Recreate cyclic links in the object graph between
         LandsklimProject, LandsklimConfiguration, LandsklimAnalysis and LandsklimInterpolation
        """
        if isinstance(instance, LandsklimProject):
            for configuration in instance.get_configurations():
                configuration._lk_project = instance
        if isinstance(instance, LandsklimConfiguration):
            for analysis in instance.get_analysis():
                analysis._configuration = instance
        if isinstance(instance, LandsklimAnalysis):
            for interpolation in instance.get_interpolations():
                interpolation._analysis = instance

    def object_hook(self, dct: Dict) -> object:
        """
        Each parsed JSON dictionary is sent to this method.
        """

        if '__class__' in dct:
            cls_item = dct['__class__']
            if cls_item == 'tuple':
                return tuple(dct['array'])
            elif cls_item == "np.array":
                return np.array(dct['array'], dtype=dct['type'])
            elif cls_item == 'pd.DataFrame':
                return pd.read_csv(StringIO(dct['data']), index_col=0)
                # return pd.read_json(dct['data'], orient='split')
            elif cls_item == 'pd.Series':
                dct.pop('__class__')
                dtype = dct.pop('type')
                return pd.Series(dct, dtype=dtype)
            elif cls_item in self.maplayer_class_map:
                cls: Type = self.maplayer_class_map[cls_item]
                instance: MapLayer = cls.__new__(cls)
                if "lisdqs_code" in dct:
                    dct["landsklim_code"] = dct["lisdqs_code"]
                    dct.pop("lisdqs_code")
                self.rebuild_layer(dct, instance)
                return instance
            else:  # Every other types with a serialized __class__ attribute is a "Landsklim object"
                cls: Type = self.lk_class_map[cls_item]
                instance: object =  self.__new_instance(cls, dct)
                self.postprocessing(instance)
                return instance

        if '__enum__' in dct:
            enum_value: str = dct['__enum__']
            enum_name, member = enum_value.split('.')  # format of __enum__ is 'ENUM_CLASS.VALUE'
            return getattr(PUBLIC_ENUMS[enum_name], member)
        return dct

    def __process_attr_name(self, attr_name: str) -> str:
        """
        Process attribute name by replacing old class identifier by a new class identifier
        Ex: _LISDQSAnalysis__name -> _LandsklimAnalysis__name

        :rtype: str
        """
        new_name: str = attr_name
        for oldclass, newclass in self.remapping.items():
            oldclass_identifier = '_{0}_'.format(oldclass)
            newclass_identifier = '_{0}_'.format(newclass)
            if attr_name.startswith(oldclass_identifier):
                new_name = attr_name.replace(oldclass_identifier, newclass_identifier)
        return new_name

    def __new_instance(self, cls: Type, state: Dict[str, Any]) -> Any:
        """
        Create an instance of any class by passing its type and its state

        :param cls: Type of the object to instantiate
        :type cls: Type

        :param state: The attributes of the object to be instantiated
        :type: state: Dict[str, Any]

        :returns: A new object
        :rtype: Any

        """
        state.pop('__class__')  # Remove the __class__ attribute who was there to identify the type of instantiate
        obj = cls.__new__(cls)
        for attr, value in state.items():
            attr_name = self.__process_attr_name(attr)
            setattr(obj, attr_name, value)
        return obj

    def postprocessing(self, instance:  Any):
        """
        Correct by hand what can't be deserialized by the JSONDecoder
        Especially :
         - Rebuild the hierarchy in the object graph
         - Rebuild MapLayerCollection objects from LandsklimProject
         - Convert dicts string keys to int keys when needed
         - IPhase recovers polygons definition from LandsklimAnalysis which they have dropped
         - PhaseComposite recovers phase_1 and phase_2 which they have dropped
        """
        self.rebuild_hierarchy(instance)

        if isinstance(instance, OrdinaryKriging):
            self.postprocessing_ordinary_kriging(instance)

        if isinstance(instance, PolygonsDefinition):
            self.postprocessing_polygons_definition(instance)

        if isinstance(instance, LandsklimProject):
            self.postprocessing_project(instance)

        if isinstance(instance, LandsklimAnalysis):
            self.postprocessing_analysis(instance)

        if isinstance(instance, IPhase):
            self.postprocessing_iphase(instance)

        if isinstance(instance, PhaseComposite):
            self.postprocessing_phase_composite(instance)

        if isinstance(instance, LandsklimInterpolation):
            self.postprocessing_interpolation(instance)

        if isinstance(instance, PhaseMultipleRegression):
            self.postprocessing_multiple_regression(instance)

        if isinstance(instance, MultipleRegressionModel):
            self.postprocessing_multiple_regression_model(instance)

    def postprocessing_ordinary_kriging(self, instance: OrdinaryKriging):
        # OrdinaryKriging.variogram_function is a lambda which cannot be serialized
        instance.variogram_function = instance.variogram_dict[instance.variogram_model]

    def postprocessing_project(self, instance: LandsklimProject):
        self.rebuild_maplayercollection(instance)
        self.rebuild_polygons_path(instance)

    def postprocessing_polygons_definition(self, instance: PolygonsDefinition):
        # PolygonsDefinition.__polygons_path was added from 0.8
        if not hasattr(instance, '_PolygonsDefinition__polygons_path'):
            setattr(instance, '_PolygonsDefinition__polygons_path', None)
        # PolygonsDefinition.__polygons_layer was removed from 0.8
        if hasattr(instance, '_PolygonsDefinition__polygons_layer'):
            delattr(instance, '_PolygonsDefinition__polygons_layer')

    def postprocessing_analysis(self, instance: LandsklimAnalysis):
        """
        The nested dict keys are regressors,
        which were to be converted into their index to be serialized : rebuild the dict
        """

        self.get_json_dependencies(instance)

        rebuild_pc = {}
        for key, value in getattr(instance, '_pearson_correlation').items():
            rebuild_pc[int(key)] = {}
            for reg_index, pc in value.items():
                reg = instance._regressors[int(reg_index)]
                rebuild_pc[int(key)][reg] = pc
        instance._pearson_correlation = rebuild_pc

        new_datasets = {}
        for k, v in getattr(instance, '_datasets').items():
            new_datasets[int(k)] = v
        instance._datasets = new_datasets

        new_kriging_layers = {}
        for k, v in getattr(instance, '_kriging_layers').items():
            new_kriging_layers[int(k)] = v
        instance._kriging_layers = new_kriging_layers

        new_models = {}
        for k, v in getattr(instance, '_models').items():
            new_models[int(k)] = v
        instance._models = new_models

        # Refill PhaseComposite phases
        for situation in instance.get_station_situations():
            for phase in instance.get_phases(situation):
                if isinstance(phase, PhaseComposite):
                    phase._phases = [instance.get_phases(situation)[0], instance.get_phases(situation)[1]]

    def postprocessing_iphase(self, instance: IPhase):
        new_corr = {}
        for k, v in getattr(instance, '_variables_correlation').items():
            new_corr[k] = {}
            for vk, vv in v.items():
                new_corr[k][int(vk)] = vv
        instance._variables_correlation = new_corr

    def postprocessing_phase_composite(self, instance: PhaseComposite):
        setattr(instance, "_phases_raster_path", {})

    def postprocessing_interpolation(self, instance: LandsklimInterpolation):
        new_layers = {}
        for k, v in instance._layers.items():
            _, member = k.split('.')  # format of str is 'LandsklimInterpolationType.VALUE'
            nk = getattr(LandsklimInterpolationType, member)
            new_layers[nk] = {}
            for vk, vv in v.items():
                new_layers[nk][int(vk)] = vv
        instance._layers = new_layers

    def postprocessing_multiple_regression(self, instance: PhaseMultipleRegression):
        """
        Restore scale_means and scale_std in MultipleRegressionModel objects
        """
        snames = instance.scale_column_names
        for model, (_, smeans), (_, sstd), dmins, dmaxs in zip(instance._model, instance.scale_means.iterrows(),
                                                               instance.scale_std.iterrows(), instance.dataset_mins,
                                                               instance.dataset_maxs):
            setattr(model, "_scale_means", smeans)
            setattr(model, "_scale_std", sstd)
            setattr(model, "_scale_column_names", snames)
            setattr(model, "_dataset_mins", dmins)
            setattr(model, "_dataset_maxs", dmaxs)

            # MultipleRegressionModel._pearson_correlations are serialized as array (and not dict)
            new_pc = {}
            for i, c in enumerate(snames):
                if c != DATASET_RESPONSE_VARIABLE:
                    new_pc[c] = [model.pearson_correlations[2 * i], model.pearson_correlations[2 * i + 1]]
            setattr(model, "_pearson_correlations", new_pc)
            delattr(model, "pearson_correlations")

        delattr(instance, "scale_means")
        delattr(instance, "scale_std")
        delattr(instance, "scale_column_names")
        delattr(instance, "dataset_mins")
        delattr(instance, "dataset_maxs")

    def postprocessing_multiple_regression_model(self, instance: MultipleRegressionModel):
        """
        MultipleRegressionModel._predictors_are_integers has been removed on serialisation
        MultipleRegressionModel._standardize has been removed on serialisation
        """
        setattr(instance, "_predictors_are_integers", False)
        setattr(instance, "_standardize", MultipleRegressionModel.STANDARDIZE_DEFAULT_VALUE)
        setattr(instance, "_MultipleRegressionModel__residuals", None)
        setattr(instance, "debug", -1)
