import numpy as np
import pyqtgraph as pg
import pyqtgraph.parametertree.parameterTypes as pTypes
from pyqtgraph.graphicsItems.UIGraphicsItem import UIGraphicsItem
from qgis.PyQt import QtCore, QtGui
from qgis.PyQt.QtGui import QColor, QPen
from qgis.PyQt.QtWidgets import QDialog

from openlog.gui.assay_visualization.cross_symbology.cross_symbology_widget import (
    CrossSymbologyWidget,
)
from openlog.gui.pyqtgraph.BoldBoolParameter import BoldBoolParameter
from openlog.gui.pyqtgraph.ColorRampPreviewParameter import ColorRampPreviewParameter
from openlog.gui.pyqtgraph.CustomActionParameter import CustomActionParameter


class CustomGradientLegend(pg.GradientLegend):
    """
    Dragable gradientLegend with a title.
    """

    def __init__(self, size, offset):
        super().__init__(size, offset)
        self.title = ""
        self.titleFont = QtGui.QFont("SansSerif", 8, QtGui.QFont.Bold)
        self.titleColor = QtGui.QColor("black")
        self.setAcceptedMouseButtons(QtCore.Qt.LeftButton)
        self._dragStart = None
        self.background_rect = None

    def mousePressEvent(self, event):

        # catch only if click is inside legend
        view = self.getViewBox()
        pointer_pos = event.scenePos() - view.mapToScene(QtCore.QPointF(0, 0))
        if not self.background_rect.contains(pointer_pos):
            event.ignore()
            return

        self._dragStart = pointer_pos
        self._startOffset = self.offset
        event.accept()

    def mouseMoveEvent(self, event):

        if self._dragStart is not None:
            vb_size = self.getViewBox().size()
            view = self.getViewBox()
            current_pos = event.scenePos() - view.mapToScene(QtCore.QPointF(0, 0))
            delta = current_pos - self._dragStart
            new_offset = QtCore.QPointF(*self._startOffset) + delta
            off_x, off_y = new_offset.x(), new_offset.y()
            # ensure stay inside viewbox
            if off_x >= 0 or off_x < -vb_size.width():
                return
            if off_y < 0 or off_y > vb_size.height():
                return

            # update legend position
            self.offset = (off_x, off_y)
            self.update()
            event.accept()

    def set_title(self, title: str) -> None:
        """
        Add title to gradient legend.
        Calling update() call paint() method.
        """
        self.title = title
        self.update()

    def _get_title_width(self, p: QtGui.QPainter) -> float:
        """
        Return title width in pixels.
        Args:
            - p : QPainter
        """

        return p.fontMetrics().boundingRect(self.title).width()

    def paint(self, p, opt, widget):
        UIGraphicsItem.paint(self, p, opt, widget)

        view = self.getViewBox()
        if view is None:
            return
        p.save()  # save painter state before we change transformation
        trans = view.sceneTransform()
        p.setTransform(trans)  # draw in ViewBox pixel coordinates
        rect = view.rect()

        ## determine max width of all labels
        labelWidth = 0
        labelHeight = 0
        for k in self.labels:
            b = p.boundingRect(
                QtCore.QRectF(0, 0, 0, 0),
                QtCore.Qt.AlignmentFlag.AlignLeft
                | QtCore.Qt.AlignmentFlag.AlignVCenter,
                str(k),
            )
            labelWidth = max(labelWidth, b.width())
            labelHeight = max(labelHeight, b.height())

        textPadding = 2  # in px

        xR = rect.right()
        xL = rect.left()
        yT = rect.top()
        yB = rect.bottom()

        # coordinates describe edges of text and bar, additional margins will be added for background
        if self.offset[0] < 0:
            x3 = (
                xR + self.offset[0]
            )  # right edge from right edge of view, offset is negative!
            x2 = x3 - labelWidth - 2 * textPadding  # right side of color bar
            x1 = x2 - self.size[0]  # left side of color bar
        else:
            x1 = xL + self.offset[0]  # left edge from left edge of view
            x2 = x1 + self.size[0]
            x3 = (
                x2 + labelWidth + 2 * textPadding
            )  # leave room for 2x textpadding between bar and text
        if self.offset[1] < 0:
            y2 = (
                yB + self.offset[1]
            )  # bottom edge from bottom of view, offset is negative!
            y1 = y2 - self.size[1]
        else:
            y1 = yT + self.offset[1]  # top edge from top of view
            y2 = y1 + self.size[1]
        self.b = [x1, x2, x3, y1, y2, labelWidth]

        ## Draw background
        p.setPen(self.pen)
        p.setBrush(self.brush)  # background color
        rect = QtCore.QRectF(
            QtCore.QPointF(
                x1 - textPadding, y1 - labelHeight / 2 - textPadding
            ),  # extra left/top padding
            QtCore.QPointF(
                x3 + textPadding, y2 + labelHeight / 2 + textPadding
            ),  # extra bottom/right padding
        )
        p.drawRect(rect)
        self.background_rect = rect
        # title section
        p.save()
        background_width = abs((x1 - textPadding) - (x3 + textPadding))
        p.setPen(QPen(self.titleColor))
        p.setFont(self.titleFont)
        title_width = self._get_title_width(p)
        extra_width = background_width - title_width
        extra_width = abs(extra_width) if extra_width < 0 else 0
        title_text_rect = QtCore.QRectF(
            QtCore.QPointF(
                x1 - textPadding - extra_width / 2,
                y1 - labelHeight / 2 - textPadding - 20,
            ),  # extra left/top padding
            QtCore.QPointF(
                x3 + textPadding + extra_width / 2, y1 - labelHeight / 2 - textPadding
            ),  # extra bottom/right padding
        )

        p.drawText(title_text_rect, QtCore.Qt.AlignCenter, self.title)
        p.restore()
        ## Draw color bar
        self.gradient.setStart(0, y2)
        self.gradient.setFinalStop(0, y1)
        p.setBrush(self.gradient)
        rect = QtCore.QRectF(QtCore.QPointF(x1, y1), QtCore.QPointF(x2, y2))
        p.drawRect(rect)

        ## draw labels
        p.setPen(self.textPen)

        tx = x2 + 2 * textPadding  # margin between bar and text
        lh = labelHeight
        lw = labelWidth
        for k in self.labels:
            y = y2 - self.labels[k] * (y2 - y1)
            p.drawText(
                QtCore.QRectF(tx, y - lh / 2, lw, lh),
                QtCore.Qt.AlignmentFlag.AlignLeft
                | QtCore.Qt.AlignmentFlag.AlignVCenter,
                str(k),
            )

        p.restore()  # restore QPainter transform to original state


class CrossSymbologyHandler:
    """
    Class instantiated as config's attribute.
    """

    def __init__(self, config, icon: str):
        self.config = config
        self.param = BoldBoolParameter(name="", type="bool", value=False, default=False)
        self.param.setOpts(expanded=False)
        self.symbology_widget = CrossSymbologyWidget()
        self.symbology_widget.config = config
        self.legend = None
        self.symbology_widget.parameterChanged.connect(
            self._synchronize_params_from_widget
        )
        self.symbology_widget.variableListChanged.connect(
            self._synchronize_variable_list
        )
        # manual edition button
        self.btn = CustomActionParameter(
            name="",
            type="action",
            width=60,
            icon=icon,
            tooltip="Manual edition",
        )
        self.param.addChild(self.btn)
        self.btn.sigActivated.connect(self.open_dialog)

        # parameters
        self.mapped_variable_param = pTypes.ListParameter(
            name="Variable", type="str", value=""
        )
        self.param.addChild(self.mapped_variable_param)
        # hide parameter if combobox is disabled
        if not self.symbology_widget.variable_cbox.isEnabled():
            self.mapped_variable_param.hide()

        self.method_param = pTypes.ListParameter(
            name="Classifier",
            type="str",
            value="Linear",
            limits=["Linear", "Equal", "Quantile"],
        )
        self.param.addChild(self.method_param)

        list_of_maps = pg.colormap.listMaps()
        list_of_maps = sorted(list_of_maps, key=lambda x: x.swapcase())
        list_of_maps = [cmap for cmap in list_of_maps if "CET-C" not in cmap]
        list_of_maps = [cmap for cmap in list_of_maps if "CET-I" not in cmap]

        self.palette_param = ColorRampPreviewParameter(
            name="Palette",
            type="str",
            value="PAL-relaxed_bright",
            default="PAL-relaxed_bright",
            limits=list_of_maps,
        )

        self.param.addChild(self.palette_param)

        self.breaks_param = pTypes.SimpleParameter(
            name="Breaks",
            type="int",
            value=5,
            default=5,
            min=2,
            max=300,
        )
        self.param.addChild(self.breaks_param)

        self.min_param = pTypes.SimpleParameter(name="Min", type="float")
        self.param.addChild(self.min_param)

        self.max_param = pTypes.SimpleParameter(name="Max", type="float")
        self.param.addChild(self.max_param)

        self.legend_param = pTypes.SimpleParameter(name="Legend", type="bool")
        self.param.addChild(self.legend_param)

        # connect signals
        self.mapped_variable_param.sigValueChanged.connect(
            self._synchronize_variable_changed
        )
        self.method_param.sigValueChanged.connect(self._synchronize_params_to_widget)
        self.palette_param.sigValueChanged.connect(self._synchronize_params_to_widget)
        self.breaks_param.sigValueChanged.connect(self._synchronize_params_to_widget)
        self.min_param.sigValueChanged.connect(self._synchronize_params_to_widget)
        self.max_param.sigValueChanged.connect(self._synchronize_params_to_widget)
        self.legend_param.sigValueChanged.connect(self.display_legend)

    def get_parameters(self) -> dict:
        """
        Return parameters.
        """
        self.symbology_widget.save_current_parameters()
        return self.symbology_widget.parameters

    def set_parameters(self, params: dict) -> None:
        """
        Apply parameters.
        """
        self.symbology_widget.parameters = params
        self.symbology_widget.restore_parameters()

    def _synchronize_variable_changed(self) -> None:
        """
        Synchronize variable change to widget.
        Bounds should be updated.
        """
        self.symbology_widget.parameterChanged.disconnect(
            self._synchronize_params_from_widget
        )
        self.symbology_widget.variable_cbox.setCurrentText(
            self.mapped_variable_param.value()
        )

        self.symbology_widget.parameterChanged.connect(
            self._synchronize_params_from_widget
        )
        # get new bounds
        self._synchronize_params_from_widget()
        # apply changes
        self.param.sigValueChanged.emit(self.param, self.param.value())

    def _synchronize_params_to_widget(self) -> None:
        """
        Synchronize parameters with widget parameters.
        Called when parameters are changed.
        """
        self.symbology_widget.parameterChanged.disconnect(
            self._synchronize_params_from_widget
        )
        # self.symbology_widget.variable_cbox.setCurrentText(
        #     self.mapped_variable_param.value()
        # )
        self.symbology_widget.discretize_method_cbox.setCurrentText(
            self.method_param.value()
        )
        self.symbology_widget.palette_cbox.setCurrentText(self.palette_param.value())
        self.symbology_widget.breaks_sb.setValue(self.breaks_param.value())
        if self.min_param.value() is not None and self.max_param.value() is not None:
            if self.min_param.value() > self.max_param.value():
                self.min_param.setValue(
                    self.max_param.value(),
                    blockSignal=self._synchronize_params_to_widget,
                )

            self.symbology_widget.min_val_sb.setValue(self.min_param.value())
            self.symbology_widget.max_val_sb.setValue(self.max_param.value())
        self.symbology_widget.parameterChanged.connect(
            self._synchronize_params_from_widget
        )
        # apply changes
        self.param.sigValueChanged.emit(self.param, self.param.value())

    def _synchronize_params_from_widget(self) -> None:
        """
        Synchronize parameters with widget parameters.
        Called when widget emit signal.
        """
        self.mapped_variable_param.setValue(
            self.symbology_widget.variable_cbox.currentText(),
            blockSignal=self._synchronize_variable_changed,
        )
        self.method_param.setValue(
            self.symbology_widget.discretize_method_cbox.currentText(),
            blockSignal=self._synchronize_params_to_widget,
        )
        self.palette_param.setValue(
            self.symbology_widget.palette_cbox.currentText(),
            blockSignal=self._synchronize_params_to_widget,
        )
        self.breaks_param.setValue(
            self.symbology_widget.breaks_sb.value(),
            blockSignal=self._synchronize_params_to_widget,
        )
        self.min_param.setValue(
            self.symbology_widget.min_val_sb.value(),
            blockSignal=self._synchronize_params_to_widget,
        )
        self.max_param.setValue(
            self.symbology_widget.max_val_sb.value(),
            blockSignal=self._synchronize_params_to_widget,
        )

    def _synchronize_variable_list(self) -> None:
        """
        Update variable list from widget.
        """
        list_items = [
            self.symbology_widget.variable_cbox.itemText(i)
            for i in range(self.symbology_widget.variable_cbox.count())
        ]
        self.mapped_variable_param.sigValueChanged.disconnect(
            self._synchronize_variable_changed
        )
        self.mapped_variable_param.setLimits(list_items)
        self.mapped_variable_param.setValue(
            self.symbology_widget.variable_cbox.currentText()
        )
        self.mapped_variable_param.sigValueChanged.connect(
            self._synchronize_variable_changed
        )

    def open_dialog(self) -> None:
        """
        Open symbology configuration.
        """
        self.symbology_widget.save_current_parameters()
        res = self.symbology_widget.exec()
        if res == QDialog.Accepted:
            self.param.sigValueChanged.emit(self.param, self.param.value())
        else:
            self.symbology_widget.restore_parameters()

    def get_pyqtgraph_param(self) -> BoldBoolParameter:
        return self.param

    def set_assay_iface(self, assay_iface) -> None:
        """
        Set all available variable for cross symbology.
        """
        if self.symbology_widget.assay_iface:
            return
        self.symbology_widget.hole_id = self.config.hole_id
        self.symbology_widget.column = self.config.column
        self.symbology_widget.assay_display_name = self.config.assay_name
        self.symbology_widget.assay_iface = assay_iface
        self.symbology_widget.domain = self.config.domain
        self.symbology_widget.update_variables()

    def set_hole_id(self, hole_id: str) -> None:

        self.symbology_widget.hole_id = hole_id
        self.symbology_widget.update_variables()

    def update_global_min_max(self) -> None:
        """
        Update min max values for global config.
        """
        self.symbology_widget._update_min_max_from_children()
        self.symbology_widget._update_table()

    def copy_lut_from_config(self, handler) -> None:
        """
        Set LUT from another handler.
        """
        propagation = handler.symbology_widget.propagate_cbox.currentText()
        self.symbology_widget.variable_cbox.setCurrentText(
            handler.symbology_widget.variable_cbox.currentText()
        )
        if propagation == "Defined LUT":
            from_, to_, colors = handler.get_lut()
            self.symbology_widget.lut_model.set_lut(from_, to_, colors)
        else:
            self.symbology_widget.discretize_method_cbox.setCurrentText(
                handler.symbology_widget.discretize_method_cbox.currentText()
            )
            self.symbology_widget.palette_cbox.setCurrentText(
                handler.symbology_widget.palette_cbox.currentText()
            )
            self.symbology_widget.breaks_sb.setValue(
                handler.symbology_widget.breaks_sb.value()
            )

    def get_lut(self) -> tuple:
        """
        Return defined LUT.
        """
        return self.symbology_widget.lut_model.get_lut()

    def get_colors(self) -> list:
        """
        Return colors corresponding to values.
        """
        is_discretized = self.symbology_widget._is_dicretized()
        cm = self.symbology_widget.lut_model.get_colormap(is_discretized)
        min_val = self.symbology_widget.lut_min_value
        max_val = self.symbology_widget.lut_max_value

        if is_discretized:
            colors = [cm.get_color(val) for val in self.symbology_widget.y_values]
        else:
            colors = [
                cm.map((val - min_val) / (max_val - min_val), mode=cm.QCOLOR)
                for val in self.symbology_widget.y_values
            ]

        return colors

    def _get_discretized_legend(self) -> pg.GradientLegend:

        gl = CustomGradientLegend((10, 200), (-30, 30))
        gl.pen = QPen(QColor("white"))

        # get lut from DiscretizedColorMap for filling gap and sorting
        cmap = self.symbology_widget.lut_model.get_colormap(True)
        cmap.fill_gaps()
        from_, to_, colors = cmap.get_lut()

        augmented_breaks = []
        augmented_colors = []
        previous_color = None
        epsilon = 1 / 1000
        for i, (f, c) in enumerate(zip(from_, colors)):
            pos = i / (len(from_))
            if previous_color:
                augmented_breaks.append(pos - epsilon)
                augmented_colors.append(previous_color)
            augmented_breaks.append(pos)
            augmented_colors.append(c)
            augmented_breaks.append(pos + epsilon)
            augmented_colors.append(c)

            previous_color = c
        # add last break
        augmented_breaks.append(1)
        augmented_colors.append(colors[-1])

        cm = pg.ColorMap(pos=augmented_breaks, color=augmented_colors)

        # labels
        positions = [i / (len(from_)) for i, elt in enumerate(from_)]
        labels = {
            np.round(f, decimals=3)
            if f > 0.1
            else np.format_float_scientific(f, precision=3): pos
            for pos, f in zip(positions, from_)
        }
        labels[to_[-1]] = 1

        gl.setColorMap(cm)
        gl.setLabels(labels)
        return gl

    def _get_continuous_legend(self) -> pg.GradientLegend:

        gl = CustomGradientLegend((10, 200), (-50, 30))
        gl.pen = QPen(QColor("white"))
        from_, to_, colors = self.get_lut()
        min_, max_ = min(from_), max(from_)
        positions = [(elt - min_) / (max_ - min_) for elt in from_]
        labels = {
            np.round(f, decimals=3)
            if f > 0.1
            else np.format_float_scientific(f, precision=3): pos
            for pos, f in zip(positions, from_)
        }
        cm = pg.ColorMap(pos=positions, color=colors)
        gl.setColorMap(cm)
        gl.setLabels(labels)
        return gl

    def get_legend(self) -> pg.GradientLegend:
        """
        Return Legend.
        """
        if self.symbology_widget._is_dicretized():
            return self._get_discretized_legend()
        else:
            return self._get_continuous_legend()

    def display_legend(self) -> None:
        """
        Add legend to plot widget.
        """
        if not self.param.value():
            return

        if not self.legend_param.value() and self.legend:
            try:
                self.config.plot_widget.removeItem(self.legend)
            except:
                pass

        if self.legend_param.value() and self.config.plot_item is not None:
            gl = self.get_legend()
            gl.set_title(self.symbology_widget.variable_cbox.currentText())
            self.legend = gl
            self.config.plot_widget.addItem(gl)


# class LineCrossSymbologyHandler(CrossSymbologyHandler):
#     """
#     Class for coloring line with color depending of another variable.
#     Continuous and discretized gradient are supported.
#     """
