"""
Implementation of the solution proposed on the README.MD of https://github.com/YoannQDQ/qgis-layer-tree-icons
Monkey patch the default menuProvider to add new action especially to Analysis node where regression models results
can be accessed through the contextual menu
"""
from enum import Enum
from typing import Optional, Callable, List

from PyQt5.QtCore import QCoreApplication
from qgis.PyQt.QtCore import QObject, QEvent
from qgis.PyQt.QtWidgets import QAction, QMenu
from qgis._core import QgsLayerTree, QgsVectorLayer, QgsRasterLayer
from qgis._gui import QgisInterface, QgsLayerTreeView, QgsLayerTreeViewMenuProvider

from landsklim.lk.landsklim_analysis import LandsklimAnalysis
from landsklim.lk.landsklim_interpolation import LandsklimInterpolation, LandsklimInterpolationType


class LayerTreeContextMenuManager(QObject):

    def __init__(self, iface: QgisInterface, parent=None):
        super().__init__(parent)
        self._iface: QgisInterface = iface
        self.view: QgsLayerTreeView = iface.layerTreeView()
        self.menuProvider: QgsLayerTreeViewMenuProvider = self.view.menuProvider()

        self.providers: List[Callable] = []

        self.patch()

        self.view.viewport().installEventFilter(self)

    # Made event argument optional because mocked MenuProvider for tests doesn't pass an event parameter to this function (why ?)
    def createContextMenu(self, event: QEvent=None):
        menu: QMenu = self._iface.layerTreeView().menuProvider()._original()
        for provider in self._iface.layerTreeView().menuProvider().providers:  # type: Callable
            try:
                provider(menu, event)
            except TypeError:
                provider(menu)
        return menu

    def patch(self):

        if not hasattr(self.menuProvider, "_original"):
            self.menuProvider._original = self.menuProvider.createContextMenu
        if not hasattr(self.menuProvider, "providers"):
            self.menuProvider.providers = []

        self.menuProvider.createContextMenu = self.createContextMenu  # = lambda event: print(event) and self.createContextMenu(event)

    def eventFilter(self, obj: QObject, event: QEvent):
        if event.type() == QEvent.ContextMenu:
            menu: QMenu = self.menuProvider.createContextMenu(event)
            menu.exec(self.view.mapToGlobal(event.pos()))
            return True
        return False

    def addProvider(self, provider: Callable):
        if not callable(provider):
            return
        if provider in self.menuProvider.providers:
            return
        self.providers.append(provider)
        self.menuProvider.providers.append(provider)

    def removeProvider(self, provider: Callable):
        try:
            self.menuProvider.providers.remove(provider)
        except ValueError:
            pass

    def __del__(self):
        for provider in self.providers:
            self.removeProvider(provider)
        self.view.viewport().removeEventFilter(self)


class LandsklimMenuProvider:
    def __init__(self, iface: QgisInterface, landsklim_instance: "Landsklim"):
        self.view: QgsLayerTreeView = iface.layerTreeView()
        self.__landsklim_instance: "Landsklim" = landsklim_instance
        self.__handle_on_analysis_situation_click: Optional[Callable[[LandsklimAnalysis, int], None]] = None
        self.__handle_on_analysis_delete: Optional[Callable[[LandsklimAnalysis], None]] = None
        self.__handle_on_interpolation_delete: Optional[Callable[[LandsklimAnalysis, LandsklimInterpolation], None]] = None
        self.__handle_on_save_as_netcdf_click: Optional[Callable[[str], None]] = None
        self.__handle_on_interpolation_save_as_netcdf_click: Optional[Callable[[LandsklimAnalysis, LandsklimInterpolation, LandsklimInterpolationType], None]] = None

    def tr(self, string: str) -> str:
        res = QCoreApplication.translate('LandsklimMenuProvider', string)
        return res

    def handle_on_analysis_situation_click(self, handle: Callable[[LandsklimAnalysis, int], None]):
        self.__handle_on_analysis_situation_click = handle

    def handle_on_save_as_netcdf_click(self, handle: Callable[[str], None]):
        self.__handle_on_save_as_netcdf_click = handle

    def handle_on_interpolation_save_as_netcdf_click(self, handle: Callable[[LandsklimAnalysis, LandsklimInterpolation, LandsklimInterpolationType], None]):
        self.__handle_on_interpolation_save_as_netcdf_click = handle

    def handle_on_analysis_delete_click(self, handle: Callable[[LandsklimAnalysis], None]):
        self.__handle_on_analysis_delete = handle

    def handle_on_interpolation_delete_click(self, handle: Callable[[LandsklimAnalysis, LandsklimInterpolation], None]):
        self.__handle_on_interpolation_delete = handle

    def get_analysis_from_property(self, landsklim_object_property: str) -> LandsklimAnalysis:
        """
        Retrieve LandsklimAnalysis from landsklim_object node property
        """
        index_configuration: int = int(landsklim_object_property.replace("configuration_", "").split("_analysis_")[0])
        index_analysis: int = int(landsklim_object_property.replace("configuration_{0}_analysis_".format(index_configuration), "").split("_interpolation_")[0])
        return self.__landsklim_instance.get_landsklim_project().get_configurations()[index_configuration].get_analysis()[index_analysis]

    def get_interpolation_from_property(self, landsklim_object_property: str) -> LandsklimInterpolation:
        index_configuration: int = int(landsklim_object_property.replace("configuration_", "").split("_analysis_")[0])
        index_analysis: int = int(landsklim_object_property.replace("configuration_{0}_analysis_".format(index_configuration), "").split("_interpolation_")[0])
        index_interpolations: int = int(landsklim_object_property.replace("configuration_{0}_analysis_{1}_interpolation_".format(index_configuration, index_analysis), ""))
        return self.__landsklim_instance.get_landsklim_project().get_configurations()[index_configuration].get_analysis()[index_analysis].get_interpolations()[index_interpolations]

    def __call__(self, menu):
        landsklim_object: str = self.view.currentGroupNode().customProperty("landsklim_object", "")

        if not QgsLayerTree.isLayer(self.view.currentNode()) and "configuration_" in landsklim_object and "analysis_" in landsklim_object and "interpolation_" not in landsklim_object:
            menu.addSeparator()
            analysis: LandsklimAnalysis = self.get_analysis_from_property(landsklim_object)
            lb_ev_delete = lambda _, arg_analysis=analysis: self.__handle_on_analysis_delete(arg_analysis)
            action_ev_delete: QAction = QAction(self.tr("Delete analysis"), menu)
            action_ev_delete.triggered.connect(lb_ev_delete)
            menu.addAction(action_ev_delete)

        if not QgsLayerTree.isLayer(self.view.currentNode()) and "configuration_" in landsklim_object and "analysis_" in landsklim_object and "interpolation_" in landsklim_object:
            analysis: LandsklimAnalysis = self.get_analysis_from_property(landsklim_object)
            interpolation: LandsklimInterpolation = self.get_interpolation_from_property(landsklim_object)
            menu.addSeparator()
            lb_ev_delete = lambda _, arg_analysis=analysis, arg_interpolation=interpolation: self.__handle_on_interpolation_delete(arg_analysis, arg_interpolation)
            action_ev_delete: QAction = QAction(self.tr("Delete interpolation"), menu)
            action_ev_delete.triggered.connect(lb_ev_delete)
            menu.addAction(action_ev_delete)
            if interpolation.is_on_grid():
                menu.addSeparator()
                for phase_type in interpolation.get_interpolation_types():  # type: LandsklimInterpolationType
                    lb_ev_export = lambda _, arg_analysis=analysis, arg_interpolation=interpolation, arg_phase_type=phase_type: self.__handle_on_interpolation_save_as_netcdf_click(arg_analysis, arg_interpolation, arg_phase_type)
                    action_ev_export: QAction = QAction(self.tr("[{0}] Export as NetCDF").format(phase_type.str()), menu)
                    action_ev_export.setData('export_netcdf')
                    action_ev_export.triggered.connect(lb_ev_export)
                    menu.addAction(action_ev_export)

        # self.view.currentLayer() to access the current layer
