import os

import numpy as np
import pyqtgraph as pg
from qgis.core import QgsApplication
from qgis.PyQt import uic
from qgis.PyQt.QtCore import pyqtSignal
from qgis.PyQt.QtGui import QColor, QIcon, QLinearGradient, QPainter, QPixmap
from qgis.PyQt.QtWidgets import QDialog, QHeaderView

from openlog.datamodel.assay.generic_assay import AssayDomainType, AssaySeriesType
from openlog.gui.assay_visualization.categorical_symbology_dialog import (
    HtmlColorItemDelegate,
)
from openlog.gui.assay_visualization.cross_symbology.lut_model import (
    LutModel,
    NumericDelegate,
)


def split_variable_column(name: str) -> tuple:
    """
    Split full name into column and variable name.
    """
    splitted = name.split(" ")
    variable = splitted[-1]
    variable = variable.replace("[", "").replace("]", "")
    column = " ".join(splitted[:-1])
    return column, variable


def get_color_ramp_icon(cr_name: str, width=100, height=100) -> QIcon:
    """
    For a given color ramp name, return corresponding QIcon.
    """
    cmap = pg.colormap.get(cr_name)
    pixmap = QPixmap(width, height)
    painter = QPainter(pixmap)

    n_colors = 20

    for i in range(n_colors):
        color = QColor(*cmap.map(i / n_colors))
        painter.setBrush(color)
        painter.setPen(QColor(0, 0, 0, 0))
        painter.drawRect(i * width // n_colors, 0, width // n_colors, height)

    painter.end()

    return QIcon(pixmap)


def intervals_overlap(intervals: list) -> tuple[bool, list]:
    """
    intervals: liste de tuples (min, max)
    retourne True s'il y a un chevauchement et les lignes concernées.
    """
    # Trie par min
    sorted_intervals = sorted(intervals, key=lambda x: x[0])
    for i in range(1, len(sorted_intervals)):
        prev_min, prev_max = sorted_intervals[i - 1]
        curr_min, curr_max = sorted_intervals[i]
        if curr_min < prev_max:
            indexes = [
                sorted_intervals.index((prev_min, prev_max)),
                sorted_intervals.index((curr_min, curr_max)),
            ]
            return True, indexes
    return False, []


class CrossSymbologyWidget(QDialog):
    parameterChanged = pyqtSignal()
    variableListChanged = pyqtSignal()
    """
    Base class for cross symbology.
    """

    def __init__(self, parent=None):
        super().__init__(parent)
        uic.loadUi(
            os.path.join(os.path.dirname(__file__), "cross_symbology_config.ui"), self
        )
        self.config = None
        self.assay_iface = None
        self.assay_definitions = []
        self.hole_id = None
        self.column = None
        self.assay_display_name = None
        # assay values (not editable)
        self.min_value = None
        self.max_value = None
        # user values (editable)
        self.lut_min_value = None
        self.lut_max_value = None

        self.x_values = None
        self.altitude_values = None
        self.planned_altitude_values = None
        self.altitude = "length"
        self.y_values = None
        self.domain = None
        self.parameters = {}
        # set model
        self.lut_model = LutModel()
        self.lut_table_view.setModel(self.lut_model)
        self.lut_table_view.setItemDelegateForColumn(
            self.lut_model.COLOR_COL, HtmlColorItemDelegate(self)
        )
        self.lut_table_view.setItemDelegateForColumn(
            self.lut_model.FROM_COL, NumericDelegate(self)
        )
        self.lut_table_view.setItemDelegateForColumn(
            self.lut_model.TO_COL, NumericDelegate(self)
        )
        self.lut_table_view.horizontalHeader().setSectionResizeMode(QHeaderView.Stretch)

        # add palette choices
        for name in pg.colormap.listMaps():
            icon = get_color_ramp_icon(cr_name=name)
            self.palette_cbox.addItem(icon, name)
        self.palette_cbox.setCurrentText("CET-R4")

        # discretization choices
        self.discretize_method_cbox.addItems(["Linear", "Equal", "Quantile"])

        # buttons icons
        self.add_btn.setIcon(QIcon(QgsApplication.iconPath("mActionAdd.svg")))
        self.del_btn.setIcon(QIcon(QgsApplication.iconPath("mActionRemove.svg")))

        # Propagation
        self.propagate_cbox.addItems(["Parameters", "Defined LUT"])

        self.variable_cbox.currentTextChanged.connect(self._on_variable_change)
        self.palette_cbox.currentTextChanged.connect(self._update_table)
        self.breaks_sb.valueChanged.connect(self._update_table)
        self.min_val_sb.valueChanged.connect(self._on_bound_change)
        self.max_val_sb.valueChanged.connect(self._on_bound_change)
        self.discretize_method_cbox.currentTextChanged.connect(self._update_table)
        self.del_btn.pressed.connect(self._delete_classe)
        self.add_btn.pressed.connect(self._add_classe)
        self.lut_model.dataChanged.connect(self._on_table_change)

        # enable/disable variable choice depending plugin
        from openlog.plugins.manager import get_plugin_manager

        self.variable_cbox.setEnabled(
            get_plugin_manager().get_cross_symbology_plugin().enable
        )

    def is_original_variable(self) -> bool:
        """
        Check if mapped variable is different from assay.
        """
        column_name = f"{self.config.column.name} [{self.config.assay_name}]"
        current_variable = self.variable_cbox.currentText()

        return column_name == current_variable

    def _is_dicretized(self) -> bool:
        """
        If discretized, return True.
        """
        return self.discretize_method_cbox.currentText() in ["Equal", "Quantile"]

    def _on_table_change(self) -> None:
        """
        Slot called after table manual edition.
        """

        min_, max_ = self.lut_model.get_range()
        if min_ is not None:
            self._update_range_sb(min_, max_)

        self._update_breaks_sb()
        self.validate()
        self.parameterChanged.emit()

    def validate(self) -> None:
        """
        Slot called to enable/disable Ok button.
        Unvalid LUT is :
        - overlapping classes
        - empty classes
        - out of range classes
        """
        self.lut_model.colored_rows.clear()
        if self._is_dicretized():
            is_valid, indexes = self._validate_discretized()
        else:
            is_valid, indexes = self._validate_continuous()

        self.lut_model.colored_rows = indexes
        # bloquer bouton
        ok_btn = self.buttonBox.button(self.buttonBox.Ok)
        ok_btn.setDisabled(not is_valid)

    def _validate_continuous(self) -> tuple[bool, list]:
        """
        Check continuous LUT is valid.
        Returns a tuple with True if valid and a list of problematic rows.
        """
        message = ""
        from_, _, colors = self.lut_model.get_lut()
        # check None
        nones = [val is None for val in from_]
        if any(nones):
            indexes = [i for i, value in enumerate(nones) if value]
            message = "Values must be defined"
            self.message_label.setText(message)
            return False, indexes

        self.message_label.setText(message)
        return True, []

    def _validate_discretized(self) -> tuple[bool, list]:
        """
        Check categorized LUT is valid.
        Returns a tuple with True if valid and a list of problematic rows.
        """
        message = ""
        from_, to_, colors = self.lut_model.get_lut()
        # check None
        nones = [val_f is None or val_t is None for val_f, val_t in zip(from_, to_)]
        if any(nones):
            indexes = [i for i, value in enumerate(nones) if value]
            message = "Values must be defined"
            self.message_label.setText(message)
            return False, indexes

        # check from is smaller than to
        smaller = [
            True if val_f >= val_t else False for val_f, val_t in zip(from_, to_)
        ]
        if any(smaller):
            indexes = [i for i, value in enumerate(smaller) if value]
            message = "\n".join([message, "From values must be smaller than To"])
            self.message_label.setText(message)
            return False, indexes
        # overlap
        overlaps, indexes = intervals_overlap(
            [(val_f, val_t) for val_f, val_t in zip(from_, to_)]
        )
        if overlaps:
            message = "\n".join([message, "Presence of overlaps"])
            self.message_label.setText(message)
            return False, indexes

        self.message_label.setText(message)
        return True, []

    def save_current_parameters(self) -> None:
        """
        Save existing parameters in a dictionnary.
        """
        params = {}
        params["variable"] = self.variable_cbox.currentText()
        params["discretize"] = self._is_dicretized()
        params["discretize_method"] = self.discretize_method_cbox.currentText()
        params["palette"] = self.palette_cbox.currentText()
        params["n_breaks"] = self.breaks_sb.value()
        params["lut_min_value"] = self.min_val_sb.value()
        params["lut_max_value"] = self.max_val_sb.value()
        params["manual"] = self.def_gpx.isChecked()
        params["lut"] = self.lut_model.get_lut()
        self.parameters = params

    def restore_parameters(self) -> None:
        """
        Restore saved parameters.
        """
        if len(self.parameters) == 0:
            return

        self.variable_cbox.setCurrentText(self.parameters.get("variable"))
        self.discretize_method_cbox.setCurrentText(
            self.parameters.get("discretize_method")
        )
        self.palette_cbox.setCurrentText(self.parameters.get("palette"))
        self.breaks_sb.setValue(self.parameters.get("n_breaks"))
        self.min_val_sb.setValue(self.parameters.get("lut_min_value"))
        self.max_val_sb.setValue(self.parameters.get("lut_max_value"))
        self.def_gpx.setChecked(self.parameters.get("manual"))
        self.lut_model.set_lut(*self.parameters.get("lut"))

    def _add_classe(self) -> None:
        """
        Add a new classe.
        """
        self.lut_model.add_classe()
        self.validate()
        self._on_table_change()

    def _delete_classe(self) -> None:
        """
        Add a new classe.
        """
        selected_indexes = self.lut_table_view.selectionModel().selectedRows()
        self.lut_model.delete_classe(selected_indexes)
        self.validate()
        self._on_table_change()

    def get_lut_from_parameters(self) -> tuple:
        """
        Return 3 lists : breaks (from and to) and corresponding colors.
        """

        palette_name = self.palette_cbox.currentText()
        cmap = pg.colormap.get(palette_name)
        n_breaks = self.breaks_sb.value()
        is_discretized = self._is_dicretized()
        if is_discretized:
            return self._get_discretized_lut_from_parameters(cmap, n_breaks)
        else:
            return self._get_continuous_lut_from_parameters(cmap, n_breaks)

    def _get_discretized_lut_from_parameters(
        self, color_map: pg.ColorMap, n_breaks: int
    ) -> tuple:

        method = self.discretize_method_cbox.currentText()
        if method == "Equal":
            breaks = np.linspace(
                start=self.lut_min_value, stop=self.lut_max_value, num=n_breaks + 1
            ).tolist()
        else:
            try:
                quantiles = np.linspace(0, 1, n_breaks + 1)
                breaks = np.quantile(self.y_values, q=quantiles).tolist()
            except Exception:
                return [], [], []
        from_ = breaks[:-1]
        to_ = breaks[1:]
        colors = [
            color_map.map(
                (break_ - self.lut_min_value)
                / (self.lut_max_value - self.lut_min_value),
                mode=pg.ColorMap.QCOLOR,
            ).name()
            for break_ in np.linspace(self.lut_min_value, self.lut_max_value, n_breaks)
        ]

        return from_, to_, colors

    def _get_continuous_lut_from_parameters(
        self, color_map: pg.ColorMap, n_breaks: int
    ) -> tuple:

        breaks = np.linspace(
            start=self.lut_min_value, stop=self.lut_max_value, num=n_breaks
        ).tolist()
        colors = [
            color_map.map(
                (break_ - self.lut_min_value)
                / (self.lut_max_value - self.lut_min_value),
                mode=pg.ColorMap.QCOLOR,
            ).name()
            for break_ in breaks
        ]
        to = [None for break_ in breaks]
        return breaks, to, colors

    def _update_table(self) -> None:
        """
        Update LUT.
        """

        # hide/show To column
        if self._is_dicretized():
            self.lut_model.setHorizontalHeaderLabels(["From", "To", "Color"])
            self.lut_table_view.showColumn(1)
        else:
            self.lut_model.setHorizontalHeaderLabels(["Break", "To", "Color"])
            self.lut_table_view.hideColumn(1)
        # fill table
        if self.min_value is None:
            return

        try:
            lut = self.get_lut_from_parameters()
        except Exception:
            return
        self.lut_model.set_lut(*lut)

        self.parameterChanged.emit()

    def _on_bound_change(self):

        self.lut_min_value = self.min_val_sb.value()
        self.lut_max_value = self.max_val_sb.value()
        self._update_table()

    def _update_range_sb(self, min_: float, max_: float) -> None:
        """
        Update range spinboxes.
        """
        self.min_val_sb.valueChanged.disconnect(self._on_bound_change)
        self.max_val_sb.valueChanged.disconnect(self._on_bound_change)
        self.min_val_sb.setValue(min_)
        self.max_val_sb.setValue(max_)
        self.lut_min_value = min_
        self.lut_max_value = max_
        self.min_val_sb.valueChanged.connect(self._on_bound_change)
        self.max_val_sb.valueChanged.connect(self._on_bound_change)

    def _update_breaks_sb(self) -> None:
        """
        Update number of breaks spinboxe.
        """
        self.breaks_sb.valueChanged.disconnect(self._update_table)
        self.breaks_sb.setValue(self.lut_model.rowCount())
        self.breaks_sb.valueChanged.connect(self._update_table)

    def _on_variable_change(self):

        self._update_unit_label()
        self._update_min_max_values()
        self._update_table()

    def _update_min_max_from_children(self) -> None:
        """
        Update min max values from children configs.
        Used when hole_id == ""
        """
        children = self.config.child_configs
        min_ = []
        max_ = []
        y_values = []
        # get mapped variable
        value = self.variable_cbox.currentText()
        column_name, variable = split_variable_column(value)
        for config in children:
            assay = self.assay_iface.get_assay(
                hole_id=config.hole_id, variable=variable
            )
            _, y = assay.get_all_values(column=column_name)
            if len(y) == 0:
                continue
            y = y.astype(float)
            min_.append(np.nanmin(y))
            max_.append(np.nanmax(y))
            y_values += y.tolist()

        if len(min_) > 0:
            self.y_values = y_values
            # always update values for global config
            self._update_range_sb(np.nanmin(min_), np.nanmin(max_))
            self.min_value = np.nanmin(min_)
            self.max_value = np.nanmax(max_)
            self.raw_values.setText(
                f"Raw values range : {self.min_value} - {self.max_value}"
            )

    def _update_unit_label(self) -> None:
        """
        Update unit label.
        """
        value = self.variable_cbox.currentText()
        if value == "":
            return
        column_name, variable = split_variable_column(value)
        # scan assay definitions
        assay_def = [
            assay_def
            for assay_def in self.assay_definitions
            if assay_def.variable == variable
        ][0]
        assay_column = assay_def.columns.get(column_name)

        unit_str = "No unit" if assay_column.unit == "" else assay_column.unit
        self.unit_label.setText(f"Unit : {unit_str}")

    def _update_min_max_values(self) -> None:
        """
        Update min and max values attributes and labels.
        """
        self.lut_min_value = self.min_val_sb.value()
        self.lut_max_value = self.max_val_sb.value()

        value = self.variable_cbox.currentText()
        if value == "":
            return
        if self.hole_id == "":
            self._update_min_max_from_children()
            return

        column_name, variable = split_variable_column(value)
        assay = self.assay_iface.get_assay(hole_id=self.hole_id, variable=variable)
        x, y = assay.get_all_values(column=column_name)
        self.x_values = x
        self.altitude_values = assay.get_altitude(planned=False)
        self.planned_altitude_values = assay.get_altitude(planned=True)
        self.y_values = y.astype(float)
        if len(self.y_values) == 0:
            return
        # if self.min_value is None:
        if np.nanmin(self.y_values) == np.nan:
            return
        self._update_range_sb(np.nanmin(self.y_values), np.nanmax(self.y_values))
        self.lut_min_value = self.min_val_sb.value()
        self.lut_max_value = self.max_val_sb.value()
        self.min_value = np.nanmin(self.y_values)
        self.max_value = np.nanmax(self.y_values)
        self.raw_values.setText(
            f"Raw values range : {self.min_value} - {self.max_value}"
        )

    def update_variables(self):
        """
        Fill combobox with available variables.
        """
        if self.hole_id == "":
            self.propagate_gpx.show()
        else:
            self.propagate_gpx.hide()

        if self.assay_iface:
            self.variable_cbox.currentTextChanged.disconnect(self._on_variable_change)
            self.variable_cbox.clear()
            self.assay_definitions = (
                self.assay_iface.get_all_available_assay_definitions()
            )
            for assay_def in self.assay_definitions:
                for col_name, assay_column in assay_def.columns.items():
                    if (
                        assay_column.series_type == AssaySeriesType.NUMERICAL
                        and assay_def.domain == self.domain
                    ):
                        self.variable_cbox.addItem(f"{col_name} [{assay_def.variable}]")

            # set current data as default value
            self.variable_cbox.setCurrentText(
                f"{self.config.column.name} [{self.config.assay_name}]"
            )
            self._on_variable_change()
            self.variable_cbox.currentTextChanged.connect(self._on_variable_change)
            self.variableListChanged.emit()

    def get_gradient(self) -> QLinearGradient:
        """
        Return gradient according to parameters.
        """
        # vertical
        grad = (
            QLinearGradient(0, 0, 0, 1)
            if self.domain == AssayDomainType.DEPTH
            else QLinearGradient(0, 0, 1, 0)
        )
        grad.setCoordinateMode(QLinearGradient.CoordinateMode.ObjectMode)
        grad.setColorAt(0, QColor("white"))
        grad.setColorAt(1, QColor("white"))

        if self.altitude == "length":
            x_values = self.x_values
        elif self.altitude == "effective":
            x_values = self.altitude_values
        else:
            x_values = self.planned_altitude_values
        if x_values is None:
            return

        x_values = np.array(x_values)
        # extended data : set x values at middle of intervals
        if len(x_values.shape) == 2:
            # reshape if needed
            if x_values.shape[1] != 2:
                x_values = np.array([[f, t] for f, t in zip(x_values[0], x_values[1])])
            x_values = x_values.mean(axis=1)

        cmap = self.lut_model.get_colormap(is_discretized=self._is_dicretized())
        if cmap is None:
            return grad
        if len(x_values) == 0:
            return grad
        length = x_values.max() - x_values.min()
        min_value = x_values.min()
        if self._is_dicretized():
            cmap.fill_gaps(self.min_value, self.max_value)
            # interpolate values to find exact color change
            augmented_depths, augmented_values, colors = cmap.get_augmented_series(
                x_values, self.y_values
            )

            for b, c in zip(augmented_depths, colors):
                at = (b - min_value) / length
                grad.setColorAt(at, c)

        else:
            cmap.set_min_max_values(self.lut_min_value, self.lut_max_value)
            breakpoints = self.lut_model.get_lut()[0]
            augmented_depths, augmented_values, colors = cmap.get_augmented_series(
                x_values, self.y_values, breakpoints
            )
            for b, c in zip(augmented_depths, colors):
                at = (b - min_value) / length
                grad.setColorAt(at, c)

        return grad
