#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
#   Copyright (C) 2018 Oslandia <infos@oslandia.com>
#
#   This file is a piece of free software; you can redistribute it and/or
#   modify it under the terms of the GNU Library General Public
#   License as published by the Free Software Foundation; either
#   version 2 of the License, or (at your option) any later version.
#
#   This library is distributed in the hope that it will be useful,
#   but WITHOUT ANY WARRANTY; without even the implied warranty of
#   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
#   Library General Public License for more details.
#   You should have received a copy of the GNU Library General Public
#   License along with this library; if not, see <http://www.gnu.org/licenses/>.
#

import os
import json
from qgis.core import QgsProject, QgsVectorLayer
from PyQt5.QtXml import QDomDocument

from .edit_plot import EditPlot
from .qgeologis.plot_item import PlotItem

from .qgeologis.units import AVAILABLE_UNITS

SUBKEYS = ("stratigraphy_config", "log_measures", "timeseries", "imagery_data")

DEFAULT_PLOT_SIZE = 150


class PlotConfig:
    """Holds the configuration of a plot (log or timeseries)

    It is for now a wrapper around a dictionary object."""

    def __init__(self, config, parent=None):
        """
        Parameters
        ----------
        config: dict
          Dictionary of the plot configuration
        parent: LayerConfig
          The LayerConfig in which this PlotConfig is stored
        """
        self.__parent = parent
        self.__config = config
        self.__filter_value = None
        self.__filter_unique_values = []

    def get_layerid(self):
        return self.__config["source"]

    def get_uom(self):
        """Gets the unit of measure"""
        if self.__config["type"] == "instantaneous":
            return (
                self.__config["uom"]
                if "uom" in self.__config
                else "@" + self.__config["uom_column"]
            )
        else:
            return self.__config["uom"]

    def get_style_file(self):
        style = self.__config.get("style")
        return os.path.join(
            os.path.dirname(__file__),
            "qgeologis/styles",
            style if style else "stratigraphy_style.xml",
        )

    def get(self, key, default=None):
        return self.__config.get(key, default)

    def __getitem__(self, key):
        return self.__config[key]

    # TODO: filter_value and filter_unique_values setter and getter
    # are only helper method and should not be part of the configuration
    def set_filter_value(self, value):
        self.__filter_value = value

    def set_filter_unique_values(self, values):
        self.__filter_unique_values = values

    def get_filter_value(self):
        return self.__filter_value

    def get_filter_unique_values(self):
        return self.__filter_unique_values

    def _get_dict(self):
        return self.__config

    def get_symbology(self):
        """Returns the associated QGIS symbology

        Return
        ------
        A tuple (QDomDocument, int)
          The QDomDocument can be loaded by QgsFeatureRenderer.load()
          The int gives the renderer type
        or None, None
        """
        symbology = self.__config.get("symbology")
        if symbology is None:
            return None, None
        doc = QDomDocument()
        doc.setContent(symbology)
        return (doc, self.__config.get("symbology_type"))

    def set_symbology(self, symbology, symbology_type):
        """Sets the associated QGIS symbology

        Parameters
        ----------
        symbology: QDomDocument or None
        symbology_type: Literal[0,1,2]
          Type of renderer
          0: points
          1: lines
          2: polygons
        """
        if symbology is not None:
            self.__config["symbology"] = symbology.toString()
            self.__config["symbology_type"] = symbology_type
            if self.__parent:
                self.__parent.config_modified()

    def get_plot_size(self):
        return self.__config["plot_size"] if "plot_size" in self.__config else DEFAULT_PLOT_SIZE

    def __repr__(self):
        return repr(self.__config)


class LayerConfig:
    """Holds the configuration of a "root" layer (the layer where stations or collars are stored).
    It contains PlotConfigs
    """

    def __init__(self, config, layer_id):
        """
        Parameters
        ----------
        config: dict
          A dict {layer_id : dict of plot configuration} that will be updated if needed
        layer_id: str
          The main layer id
        """

        # The "main" or "parent" configuration
        self.__parent_config = config

        self.__layer_id = layer_id

        # Part of the main configuration, for a given layer id
        self.__config = config.get(layer_id)

        self._wrap()

    def _wrap(self):
        self.__stratigraphy_plots = [
            PlotConfig(p, self)
            for p in self.__config.get("stratigraphy_config", [])
            if p["display"] == "1"
        ]
        self.__log_plots = [
            PlotConfig(p, self)
            for p in self.__config.get("log_measures", [])
            if p["display"] == "1" or len(p["displayed_cat"]) > 0
        ]
        self.__timeseries = [
            PlotConfig(p, self)
            for p in self.__config.get("timeseries", [])
            if p["display"] == "1" or len(p["displayed_cat"]) > 0
        ]
        self.__imageries = [
            PlotConfig(p, self)
            for p in self.__config.get("imagery_data", [])
            if p["display"] == "1"
        ]
        self.__hided_stratigraphy_plots = [
            PlotConfig(p, self)
            for p in self.__config.get("stratigraphy_config", [])
            if p["display"] == "0"
        ]
        self.__hided_log_plots = [
            PlotConfig(p, self)
            for p in self.__config.get("log_measures", [])
            if p["display"] == "0"
        ]
        self.__hided_timeseries = [
            PlotConfig(p, self) for p in self.__config.get("timeseries", []) if p["display"] == "0"
        ]
        self.__hided_imageries = [
            PlotConfig(p, self)
            for p in self.__config.get("imagery_data", [])
            if p["display"] == "0"
        ]

    def get(self, key, default=None):
        return self.__config.get(key, default)

    def __getitem__(self, key):
        return self.__config[key]

    def get_stratigraphy_plots(self, display=True):
        return self.__stratigraphy_plots if display else self.__hided_stratigraphy_plots

    def get_log_plots(self, display=True):
        return self.__log_plots if display else self.__hided_log_plots

    def get_imageries(self, display=True):
        return self.__imageries if display else self.__hided_imageries

    def get_timeseries(self, display=True):
        return self.__timeseries if display else self.__hided_timeseries

    def get_vertical_plots(self, display=True):
        return self.__stratigraphy_plots + self.__log_plots

    def get_layer_id(self):
        return self.__layer_id

    def add_plot_config(self, config_type, plot_config):
        """
        Parameters
        ----------
        config_type: Literal["stratigraphy_config", "log_measures", "timeseries"]
        plot_config: PlotConfig
        """
        if config_type not in self.__config:
            self.__config[config_type] = []

        self.__config[config_type].append(plot_config._get_dict())
        self._wrap()
        self.config_modified()

    def remove_plot_config(self, item):
        """
        Parameters
        ----------
        item: dict
        """
        layer_id = item["source"]
        for data_type in SUBKEYS:
            for conf in self.__config[data_type]:
                if conf["source"] == layer_id:
                    self.__config[data_type].remove(conf)
                if layer_id in conf["stacked"]:
                    del conf["stacked"][layer_id]
        self._wrap()
        self.config_modified()

    def hide_plot(self, item, legend):
        """
        Parameters
        ----------
        item: PlotItem
        """
        for data_type in SUBKEYS:
            for i, conf in enumerate(self.__config[data_type]):
                if (
                    conf["source"] == item.layer().id()
                    and len(self.__config[data_type][i]["displayed_cat"]) == 0
                ):
                    self.__config[data_type][i]["display"] = "0"
                    self.__config[data_type][i]["stacked"] = {}
                elif conf["source"] == item.layer().id():
                    self.__config[data_type][i]["displayed_cat"].remove(legend.title())
                    self.__config[data_type][i]["stacked"] = {}
        self._wrap()
        self.config_modified()

    def edit_plot(self, item, legend):
        if not isinstance(item, PlotItem):
            dlg = EditPlot(item, legend, "unused")
            if dlg.exec_():
                scale_type, displayed_uom, min_y_value, max_y_value, uncertainty, plot_size = (
                    dlg.plot_config()
                )
                for data_type in SUBKEYS:
                    for conf in self.__config[data_type]:
                        if conf["source"] == item.layer().id():
                            conf["plot_size"] = plot_size
        else:
            for data_type in SUBKEYS:
                for conf in self.__config[data_type]:
                    if conf["source"] == item.layer().id():
                        if "uom" in conf.keys():
                            uom = conf["uom"]
                            displayed_uom = conf["displayed_uom"]
                        elif "uom_column" in conf.keys():
                            uom = "Categorical"
                            displayed_uom = None
            dlg = EditPlot(item, legend, uom, displayed_uom=displayed_uom)
            if dlg.exec_():
                scale_type, displayed_uom, min_y_value, max_y_value, uncertainty, plot_size = (
                    dlg.plot_config()
                )
                if min_y_value is not None:
                    min_y_value = float(min_y_value)

                if max_y_value is not None:
                    max_y_value = float(max_y_value)

                for data_type in SUBKEYS:
                    for conf in self.__config[data_type]:
                        if conf["source"] == item.layer().id():
                            conf["scale_type"] = "log" if scale_type else "linear"
                            if displayed_uom is not None:
                                conf["displayed_uom"] = displayed_uom
                            conf["min"] = min_y_value
                            conf["max"] = max_y_value
                            conf["uncertainty_column"] = uncertainty
                            conf["plot_size"] = plot_size
                            break
        self._wrap()
        self.config_modified()

    def zoom(self, dt, relative_pos, item):
        min_y, max_y = None, None
        for data_type in SUBKEYS:
            for conf in self.__config[data_type]:
                if conf["source"] == item.layer().id():
                    nw = item.data_window().height() * dt
                    dx = relative_pos * (item.data_window().height() - nw)
                    min_y = item.data_window().y() + dx
                    max_y = min_y + nw
                    conf["min"] = min_y
                    conf["max"] = max_y
                    break
        self._wrap()
        self.config_modified()
        return min_y, max_y

    def pan(self, translation_y, item):
        min_y, max_y = None, None
        for data_type in SUBKEYS:
            for conf in self.__config[data_type]:
                if conf["source"] == item.layer().id():
                    t = item.data_window().height() * translation_y
                    min_y = item.data_window().y() + t
                    max_y = item.data_window().y() + item.data_window().height() + t
                    conf["min"] = min_y
                    conf["max"] = max_y
                    break
        self._wrap()
        self.config_modified()
        return min_y, max_y

    def get_layer_config_from_layer_id(self, id):
        for data_type in SUBKEYS:
            for i, conf in enumerate(self.__config[data_type]):
                if conf["source"] == id:
                    return (data_type, self.__config[data_type][i])

    def get_layer_config(self, data_type, unit_type=None):
        config_list = {}
        for conf in self.__config[data_type]:
            uom = conf["uom"] if "uom" in conf else "Categorical"
            for k, v in AVAILABLE_UNITS.items():
                if uom in v:
                    if k == unit_type or unit_type == None:
                        config_list[conf["source"]] = conf["name"]
        return config_list

    def display_plot(self, item):
        """
        Parameters
        ----------
        item: dict
        """
        for data_type in SUBKEYS:
            for i, conf in enumerate(self.__config[data_type]):
                if conf["source"] == item["source"] and item.get_filter_value() is None:
                    self.__config[data_type][i]["display"] = "1"
                elif conf["source"] == item["source"]:
                    self.__config[data_type][i]["displayed_cat"].append(item.get_filter_value())
        self._wrap()
        self.config_modified()

    def config_modified(self):
        json_config = json.dumps(self.__parent_config)
        QgsProject.instance().writeEntry("QGeoloGIS", "config", json_config)


def export_config(main_config, filename):
    """Exports the given project configuration to a filename.
    Layer IDs stored in the configuration are converted into triple (source, url, provider)

    Filters on layers or virtual fields  will then be lost during the translation

    Parameters
    ----------
    main_config: dict
      The configuration as a dict
    filename: str
      Name of the file where to export the configuration to
    """
    # copy the input config
    from copy import deepcopy

    main_config = deepcopy(main_config)

    new_dict = {}

    # root layers at the beginning of the dict
    for root_layer_id, config in main_config.items():
        root_layer = QgsProject.instance().mapLayer(root_layer_id)
        if not root_layer:
            continue

        # replace "source" keys
        for subkey in SUBKEYS:
            for layer_cfg in config[subkey]:
                source_id = layer_cfg["source"]
                source = QgsProject.instance().mapLayer(source_id)
                if not source:
                    continue

                layer_cfg["source"] = {
                    "source": source.source(),
                    "name": source.name(),
                    "provider": source.providerType(),
                }

        root_key = "{}#{}#{}".format(
            root_layer.source(), root_layer.name(), root_layer.providerType()
        )
        new_dict[root_key] = dict(config)

    # write to the output file
    with open(filename, "w", encoding="utf-8") as fo:
        json.dump(new_dict, fo, ensure_ascii=False, indent=4)


def import_config(filename, overwrite_existing=False):
    """Import the configuration from a given filename

    Layers are created and added to the current project.

    Parameters
    ----------
    filename: str
      Name of the file where to import the configuration from
    overwrite_existing: bool
      Whether to try to overwrite existing layers that have
      the same data source definition

    Returns
    -------
    The configuration as a dict
    """
    with open(filename, "r", encoding="utf-8") as fi:
        config_json = json.load(fi)

    new_config = {}

    def find_existing_layer_or_create(source, name, provider, do_overwrite):
        if do_overwrite:
            for layer_id, layer in QgsProject.instance().mapLayers().items():
                if layer.source() == source and layer.providerType() == provider:
                    layer.setName(name)
                    return layer
        # layer not found, create it then !
        layer = QgsVectorLayer(source, name, provider)
        QgsProject.instance().addMapLayer(layer)
        return layer

    # root layers at the beginning of the dict
    for root_layer_source, config in config_json.items():
        root_layer_source, root_layer_name, root_layer_provider = root_layer_source.split("#")
        root_layer = find_existing_layer_or_create(
            root_layer_source, root_layer_name, root_layer_provider, overwrite_existing
        )

        for subkey in SUBKEYS:
            for layer_cfg in config[subkey]:
                source = layer_cfg["source"]
                layer = find_existing_layer_or_create(
                    source["source"], source["name"], source["provider"], overwrite_existing
                )
                layer_cfg["source"] = layer.id()

        # change the main dict key
        new_config[root_layer.id()] = dict(config)

    return new_config


def remove_layer_from_config(config, layer_id):
    """Remove a layer reference from a configuration object

    Parameters
    ----------
    config: dict
      The main plot configuration. It is modified in place.
    layer_id: str
      The layer whom references are to remove from the config
    """
    if layer_id in config.keys():
        del config[layer_id]
    return config
