from typing import Dict, List

import pyqtgraph as pg
from qgis.PyQt import QtCore

from openlog.core.pint_utilities import unit_conversion
from openlog.datamodel.assay.generic_assay import AssayDomainType
from openlog.gui.assay_visualization.assay_bar_graph_item import AssayBarGraphItem
from openlog.gui.assay_visualization.assay_inspector_line import AssayInspectorLine
from openlog.gui.assay_visualization.assay_plot_data_item import AssayPlotDataItem
from openlog.gui.assay_visualization.config.assay_column_visualization_config import (
    AssayColumnVisualizationConfig,
)
from openlog.gui.assay_visualization.config.assay_visualization_config import (
    AssayVisualizationConfig,
)
from openlog.gui.assay_visualization.discrete.categorical_scatterplot_item import (
    CategoricalScatterPlotItem,
)
from openlog.gui.assay_visualization.stacked.stacked_config import StackedConfiguration


class CustomItemSample(pg.ItemSample):
    def __init__(self, item):
        """
        custom pg.ItemSample used to override mouseClickEvent method over legend items.
        This class is used when instantiating a pg.LegendItem class.
        When a legend item is clicked, it check or uncheck visibility parameter of AssayColumnVisualizationConfig.
        """
        super().__init__(item)

    def _synchronize_scatterplot_visibility(self):
        """
        Since each brush is unique by category, we scan each pg.SpotItem and update visibility according to its brush.
        """
        is_visible = self.item.isVisible()
        brush = self.item.opts.get("brush")
        plot_item = self.topLevelWidget().items[0]
        spots = plot_item.points()
        for spot in spots:
            if spot.brush() == brush:
                spot.setVisible(is_visible)

    def mouseClickEvent(self, event):
        """
        If categoricalScatterPlotItem, we have to manually show/hide concerned spots.
        Else, it synchronize visibility with config
        """
        super().mouseClickEvent(event)
        if isinstance(self.item, pg.ScatterPlotItem):

            self._synchronize_scatterplot_visibility()

        else:
            is_visible = self.item.config.visibility_param.value()
            self.item.config.visibility_param.setValue(not is_visible)


class AssayPlotItem(pg.PlotItem):
    def __init__(
        self,
        parent=None,
        name=None,
        labels=None,
        title=None,
        viewBox=None,
        axisItems=None,
        enableMenu=True,
        collar_stack=False,
        **kargs,
    ):
        """
        pg.PlotItem override for automatic legend of plot graph

        Args:
            parent : see pg documentation
            name : see pg documentation
            labels : see pg documentation
            title : see pg documentation
            viewBox : see pg documentation
            axisItems : see pg documentation
            enableMenu : see pg documentation
            collar_stack (bool) : True if plotItem is for collar stack (multiple collar available), default False
            **kargs,
        """
        super().__init__(
            parent, name, labels, title, viewBox, axisItems, enableMenu, **kargs
        )
        self.collar_stack = collar_stack

        # No legend if no item inserted
        self.legend = None

        # only "legendable" items are considered
        self.legendable_types = (
            AssayPlotDataItem,
            AssayBarGraphItem,
            CategoricalScatterPlotItem,
        )

    def _mapColor(self) -> None:
        """
        Assign color to an item via configuration.
        """
        items = [item for item in self.items if isinstance(item, self.legendable_types)]
        for index, item in enumerate(items):
            current_pen = item.config.pen_params.pen
            current_pen.setColor(pg.intColor(index))
            item.config.pen_params.pen = current_pen
            item.config._update_plot_item_color_ramp()

    def addItem(self, item, *args, **kargs):
        """Override pg.PlotItem addItem for automatic legend definition"""
        super().addItem(item, *args, **kargs)

        if isinstance(item, self.legendable_types):

            # Add legend if not available
            if self.legend is None:
                self.legend = pg.LegendItem(
                    offset=(-1, 50),
                    brush="white",
                    pen="black",
                    sampleType=CustomItemSample,
                )

            if isinstance(item, CategoricalScatterPlotItem):
                self.legend.setParentItem(self)
                for legend_item, name in item.legend_items:
                    self.legend.addItem(legend_item, name)

            else:
                # check number of items
                n_item = len(
                    [it for it in self.items if isinstance(it, self.legendable_types)]
                )

                # legend is displayed if there are multiple items
                if n_item > 1:
                    self._mapColor()
                    self.legend.setParentItem(self)

                collar_display_name, column_name = (
                    item.config.hole_display_name,
                    item.config.column.name,
                )
                # for stacked plot, label is "collar_id - column name"
                if self.collar_stack:
                    full_name = f"{collar_display_name} - {column_name}"
                else:
                    full_name = column_name

                # add legend idem
                self.legend.addItem(item, full_name)

    def removeItem(self, item):
        """Override pg.PlotItem removeItem for automatic item deletion from legend"""
        super().removeItem(item)

        # only "legendable" items are considered
        legendable_types = (AssayPlotDataItem, AssayBarGraphItem)
        if self.legend is not None and isinstance(item, legendable_types):
            self.legend.removeItem(item)


class AssayPlotWidget(pg.PlotWidget):
    def __init__(
        self,
        domain: AssayDomainType,
        parent=None,
        background="default",
        plotItem=None,
        collar_stack=False,
        **kargs,
    ):
        """
        pg.PlotWidget override to store domain and extra functions to define limits and display AssayInspectorLine

        Args:
            domain: AssayDomainType
            parent:
            background:
            plotItem:
            collar_stack (bool) : True if PlotWidget is for collar stack (multiple collar available), default False
            **kargs:
        """
        if plotItem is None:
            plotItem = AssayPlotItem(collar_stack=collar_stack, **kargs)

        super().__init__(parent, background, plotItem, **kargs)
        self.domain = domain
        self.collar_stack = collar_stack
        self.setBackground("white")

        # Define axis from domain
        if self.domain == AssayDomainType.TIME:
            axis = pg.DateAxisItem()
            self.setAxisItems({"bottom": axis})
            self.getViewBox().setMouseEnabled(y=False)
        else:
            styles = {"color": "black", "font-size": "15px"}
            self.setLabel("left", self.tr("Depth"), "m", **styles)
            self.showAxis("top")

            self.getViewBox().invertY(True)  # Y inversion for depth assay
            self.getViewBox().setMouseEnabled(x=False)

        # Add an inspector line (not attached by default)
        self.inspector = AssayInspectorLine(domain=self.domain)
        self._sync_inspector = None
        self._connection_to_sync_inspector = None
        self._connection_from_sync_inspector = None

        self.assay_config = None
        self.assay_column_config = None
        self.stacked_config = None

        # Disable plot options
        self.disable_plot_option()

        # Store displayed unit
        self._displayed_unit = ""

    def disable_plot_option(self):
        menu = self.plotItem.getMenu()
        menu.menuAction().setVisible(False)
        # hide Mouse Mode options in contextual menu
        vbm_actions = self.getViewBox().menu.actions()
        for a in vbm_actions:
            if a.text() == "Mouse Mode":
                a.setVisible(False)

    def set_assay_config(self, assay_config: AssayVisualizationConfig) -> None:
        self.assay_config = assay_config
        self._displayed_unit = self.assay_config.unit_parameter.value()

        self._update_title()

    def set_assay_column_config(
        self, assay_column_config: AssayColumnVisualizationConfig
    ) -> None:
        self.assay_column_config = assay_column_config
        self._displayed_unit = self.assay_column_config.unit_parameter.value()
        self._update_title()

    def set_stacked_config(self, stacked_config: StackedConfiguration) -> None:
        self.stacked_config = stacked_config
        self._displayed_unit = self.stacked_config.unit_parameter.value()
        self._update_title()

    def update_displayed_unit(self, new_unit: str) -> None:
        if self._displayed_unit != new_unit:
            ranges = self.getViewBox().viewRange()

            if self.domain == AssayDomainType.TIME:
                update_range = ranges[1]
                new_range = unit_conversion(
                    from_unit=self._displayed_unit, to_unit=new_unit, array=update_range
                )
                self.getViewBox().setRange(yRange=new_range)

            else:
                update_range = ranges[0]
                new_range = unit_conversion(
                    from_unit=self._displayed_unit, to_unit=new_unit, array=update_range
                )
                self.getViewBox().setRange(xRange=new_range)
            self._displayed_unit = new_unit
        self._update_title()

    @property
    def _unit_str_suffix(self) -> str:
        if self._displayed_unit:
            return f" ({self._displayed_unit})"
        else:
            return ""

    @property
    def _time_assay_title(self) -> str:
        if self.assay_config:
            return f"{self.assay_config.assay_display_name}{self._unit_str_suffix}"
        elif self.assay_column_config:
            return (
                f"{self.assay_column_config.assay_display_name}{self._unit_str_suffix}"
            )
        elif self.stacked_config:
            return f"{self.stacked_config.name}{self._unit_str_suffix}"
        else:
            return ""

    @property
    def _depth_assay_title(self) -> str:
        if self.assay_config:
            return f"{self.assay_config.assay_display_name}{self._unit_str_suffix}"
        elif self.assay_column_config:
            # For column config, column name is indicated only if different from assay name
            column_name = self.assay_column_config.column.name
            assay_name = self.assay_column_config.assay_display_name
            if column_name != assay_name:
                return f"{assay_name}/{column_name}{self._unit_str_suffix}"
            else:
                return f"{assay_name}{self._unit_str_suffix}"
        elif self.stacked_config:
            return f"{self.stacked_config.name}{self._unit_str_suffix}"
        else:
            return ""

    @property
    def _plot_title(self) -> str:
        if self.domain == AssayDomainType.TIME:
            return self._time_plot_title
        else:
            return self._depth_plot_title

    @property
    def _depth_plot_title(self) -> str:
        # For depth plot we need indicate collar, assay name and unit
        if self.assay_config:
            return f"{self.assay_config.hole_display_name}<br>{self._depth_assay_title}"
        elif self.assay_column_config:
            return f"{self.assay_column_config.hole_display_name}<br>{self._depth_assay_title}"
        elif self.stacked_config:
            return f"{self.stacked_config.name}{self._unit_str_suffix}"
        else:
            return ""

    @property
    def _time_plot_title(self) -> str:
        # For time plot we need only indicate collar
        if self.assay_config:
            return self.assay_config.hole_display_name
        elif self.assay_column_config:
            # For column config, column name is indicated only if different from assay name
            column_name = self.assay_column_config.column.name
            assay_name = self.assay_column_config.assay_display_name
            if column_name != assay_name:
                return f"{self.assay_column_config.hole_display_name}<br>{column_name}"
            else:
                return self.assay_column_config.hole_display_name
        elif self.stacked_config:
            return self.stacked_config.name
        else:
            return ""

    def _update_title(self) -> None:
        self.setTitle(self._plot_title, color="black", size="15px")
        if self.domain == AssayDomainType.TIME:
            styles = {"color": "black", "font-size": "15px"}
            self.setLabel("left", self._time_assay_title, "", **styles)

    def set_limits(self, min_domain_axis: float, max_domain_axis: float) -> None:
        """
        Define limit depending on current domain

        Args:
            min_domain_axis: float
            max_domain_axis: float
        """
        if self.domain == AssayDomainType.TIME:
            self.getViewBox().setLimits(xMin=min_domain_axis, xMax=max_domain_axis)
        else:
            self.getViewBox().setLimits(yMin=min_domain_axis, yMax=max_domain_axis)

        if not self.inspector.valueDefined:
            self.inspector.setPos(
                min_domain_axis + (max_domain_axis - min_domain_axis) / 2.0
            )
            self.inspector.valueDefined = True

    def enable_inspector_line(self, enable: bool) -> None:
        """
        Enable or disable inspector line for this plot

        Args:
            enable: (bool)
        """
        if enable:
            self.inspector.attachToPlotItem(self.getPlotItem())
        else:
            self.inspector.dettach()

    def enable_inspector_sync(
        self, sync_inspector: AssayInspectorLine, enable: bool
    ) -> None:
        """
        Enable inspector line synchronisation

        Args:
            sync_inspector: (AssayInspectorLine) inspector line to be synchronized
            enable: (bool) enable or disable sync
        """
        # Disable current sync
        if self._sync_inspector and self._connection_to_sync_inspector:
            self._sync_inspector.disconnect(self._connection_to_sync_inspector)
            self._sync_inspector = None
            self._connection_to_sync_inspector = None

        if self._connection_from_sync_inspector:
            self.inspector.sigPositionChanged.disconnect(
                self._connection_from_sync_inspector
            )
            self._connection_from_sync_inspector = None

        # Enable if asked
        if enable and sync_inspector != self.inspector:
            self._sync_inspector = sync_inspector
            self._connection_to_sync_inspector = (
                self._sync_inspector.sigPositionChanged.connect(
                    lambda inspector: self.inspector.setPos(inspector.getPos())
                )
            )
            self._connection_from_sync_inspector = (
                self.inspector.sigPositionChanged.connect(
                    lambda inspector: self._sync_inspector.setPos(inspector.getPos())
                )
            )

    def minimumSizeHint(self):
        """
        Return minimum size hint with width calculation from displayed title and axes

        Returns: (QSize) minimum size hint

        """
        res = super().minimumSizeHint()
        if self.plotItem.titleLabel:
            size_title_label = self.plotItem.titleLabel.sizeHint(
                QtCore.Qt.SizeHint.MinimumSize, 0
            )
            res.setWidth(int(size_title_label.width()))
        if self.getAxis("left"):
            size_axis = self.getAxis("left").minimumSize()
            res.setWidth(res.width() + int(size_axis.width()))
            if self.getAxis("left").label:
                res.setHeight(
                    int(self.getAxis("left").label.boundingRect().width() * 0.8)
                )  ## bounding rect is usually an overestimate
        if self.getAxis("right"):
            size_axis = self.getAxis("right").minimumSize()
            res.setWidth(res.width() + int(size_axis.width()))
        if self.getAxis("top"):
            size_axis = self.getAxis("top").minimumSize()
            res.setHeight(res.height() + int(size_axis.height()))
        if self.getAxis("bottom"):
            size_axis = self.getAxis("bottom").minimumSize()
            res.setHeight(res.height() + int(size_axis.height()))
        return res
