import sys
from typing import Union

import numpy as np
import pyqtgraph as pg
from qgis.PyQt import QtCore
from qgis.PyQt.QtCore import QLocale, QVariant
from qgis.PyQt.QtGui import QColor, QStandardItemModel
from qgis.PyQt.QtWidgets import QDoubleSpinBox, QStyledItemDelegate


def interpolate(
    start_x: float, start_y: float, end_x: float, end_y: float, y: float
) -> float:
    """
    Linear interpolation given 2 2D points.
    Return interpolated x corresponding to y.
    """
    new_x = start_x + (y - start_y) * (end_x - start_x) / (end_y - start_y)
    return new_x


class ContinuousColorMap(pg.ColorMap):
    """
    Colormap for continuous data.
    Add a method to add breakpoints to data serie.
    """

    min_value = None
    max_value = None

    def _get_breakpoints(
        self, start_value: float, end_value: float, breaks: tuple
    ) -> list:
        """
        Given start and end value, give breakpoints.
        """
        if start_value < end_value:
            result = [br for br in breaks if br > start_value and br < end_value]
        else:
            result = [br for br in breaks if br < start_value and br > end_value]

        return result

    def set_min_max_values(self, min_value: float, max_value: float) -> None:
        """
        Set min max values used to standardize values.
        """
        self.min_value = min_value
        self.max_value = max_value

    def get_augmented_series(
        self, x_values: list, y_values: list, breaks: tuple
    ) -> tuple:
        """
        Given 2D point serie, add points at break points.
        """
        # add breakpoints
        x_val = list(x_values.copy())
        y_val = list(y_values.copy())
        breaks = sorted(breaks)

        for i in range(len(x_values) - 1):
            current_x = x_values[i]
            next_x = x_values[i + 1]
            current_y = y_values[i]
            next_y = y_values[i + 1]
            breakpoints = self._get_breakpoints(
                start_value=current_y, end_value=next_y, breaks=breaks
            )
            for bp in breakpoints:
                if bp in [current_y, next_y]:
                    continue
                new_x = interpolate(current_x, current_y, next_x, next_y, bp)
                x_val.append(new_x)
                y_val.append(bp)

        # sort x
        x_val, y_val = zip(*sorted(zip(x_val, y_val)))

        colors = [
            self.map(
                (y - self.min_value) / (self.max_value - self.min_value),
                mode=self.QCOLOR,
            )
            for y in y_val
        ]
        return x_val, y_val, colors


class DiscretizedColorMap(dict):
    """
    Colormap for discretized data.
    """

    GAP_COLOR = "#FFFFFF"

    def fill_gaps(self, min_value: float = None, max_value: float = None) -> None:
        """
        Insert classes with transparent color where there are gaps.
        """
        classes = list(self.keys())
        from_ = [classe[0] for classe in classes]
        to_ = [classe[1] for classe in classes]
        colors = list(self.values())
        if min_value is None:
            min_value = min(from_)
        if max_value is None:
            max_value = max(to_)
        # first add min and max classes if missing
        min_ = min(from_)
        if min_ > min_value:
            from_.append(min_value)
            to_.append(min_)
            colors.append(self.GAP_COLOR)
        max_ = max(to_)
        if max_ < max_value:
            from_.append(max_)
            to_.append(max_value)
            colors.append(self.GAP_COLOR)
        # fill gap
        ## get diff
        diff_f = sorted([f for f in from_ if f not in to_ and f != min_value])
        diff_t = sorted([t for t in to_ if t not in from_ and t != max_value])
        for d_f, d_t in zip(diff_f, diff_t):
            from_.append(d_t)
            to_.append(d_f)
            colors.append(self.GAP_COLOR)

        # sort
        from_, to_, colors = zip(*sorted(zip(from_, to_, colors)))

        self.clear()
        for f, t, c in zip(from_, to_, colors):
            self[(f, t)] = c

    def get_lut(self) -> tuple:
        """
        Return 3 lists describing LUT.
        """
        from_ = []
        to_ = []
        colors = []
        for (f, t), c in self.items():
            from_.append(f)
            to_.append(t)
            colors.append(c)
        return from_, to_, colors

    def get_color(self, value: float) -> QColor:
        result = None
        for (from_, to_), color in self.items():
            if value >= from_ and value < to_:
                result = QColor(color)
                return result
        # last classe
        if result is None:
            return QColor(color)

    def _get_classe(self, value: float) -> tuple:
        """
        For a value, return classe (float, float).
        """
        result = None
        for (from_, to_), color in self.items():
            if value >= from_ and value < to_:
                return (from_, to_)
        # last classe
        if result is None:
            return (from_, to_)

    def _get_classe_index(self, value: float) -> int:
        """
        For a value, return classe index.
        """
        result = None
        i = 0
        for (from_, to_), color in self.items():
            if value >= from_ and value < to_:
                result = QColor(color)
                return i
            i += 1
        # last classe
        if result is None:
            return len(self) - 1

    def _get_classe_from_index(self, index: int) -> tuple:
        """
        Given an index, return corresponding classe.
        """
        i = 0
        for (from_, to_), color in self.items():
            if index == i:
                return (from_, to_)
            i += 1

    def _get_breakpoints(self, start_value: float, end_value: float) -> list:
        """
        Given starting and ending value, return changing color breakpoints.
        """
        slope = end_value - start_value
        # use lower or upper bound depending slope
        bound = 0 if slope > 0 else 1
        start_index = self._get_classe_index(start_value)
        end_index = self._get_classe_index(end_value)
        step = 1 if slope > 0 else -1
        result = []
        for i in range(start_index, end_index + step, step):
            if i == start_index:
                continue
            classe = self._get_classe_from_index(i)
            breakpoint_ = classe[bound]
            result.append(breakpoint_)

        return result

    def get_interval_from_color(self, qcolor: QColor):
        for (from_, to_), color in self.items():
            if QColor(color) == qcolor:
                return (from_, to_)

    def get_augmented_series(self, x_values: list, y_values: list) -> tuple:
        """
        Given 2D point serie, add points at classes breaks for exact color change.
        """
        # add breakpoints
        x_val = list(x_values.copy())
        y_val = list(y_values.copy())
        min_x, max_x = min(x_val), max(x_val)
        bounds = list(self.keys())
        bounds = [element for tup in bounds for element in tup]
        # get min x step
        x_step = np.diff(np.array(x_val)).min()
        epsilon = abs(x_step / 1000)
        for i in range(len(x_values) - 1):
            current_x = x_values[i]
            next_x = x_values[i + 1]
            current_y = y_values[i]
            next_y = y_values[i + 1]
            breakpoints = self._get_breakpoints(start_value=current_y, end_value=next_y)
            for bp in breakpoints:
                if bp in [current_y, next_y]:
                    continue
                new_x = interpolate(current_x, current_y, next_x, next_y, bp)
                x_val.append(new_x)
                y_val.append(bp)

        # sort x
        x_val, y_val = zip(*sorted(zip(x_val, y_val)))

        # add breakpoints enveloppe
        env_x = []
        env_y = []
        bounds = list(self.keys())
        bounds = [element for tup in bounds for element in tup]
        min_bound, max_bound = min(bounds), max(bounds)
        rm_indexes = []
        for i in range(1, len(x_val) - 1):
            previous_x = x_val[i - 1]
            previous_y = y_val[i - 1]
            current_x = x_val[i]
            next_x = x_val[i + 1]
            current_y = y_val[i]
            next_y = y_val[i + 1]
            if current_y not in bounds:
                continue
            lower_x = current_x - epsilon
            upper_x = current_x + epsilon
            lower_y = interpolate(previous_y, previous_x, current_y, current_x, lower_x)
            upper_y = interpolate(current_y, current_x, next_y, next_x, upper_x)
            env_x.append(lower_x)
            env_x.append(upper_x)
            env_y.append(lower_y)
            env_y.append(upper_y)
            # remove peaks with single color
            if (
                self._get_classe_index(lower_y)
                == self._get_classe_index(upper_y)
                != self._get_classe_index(current_y)
            ):
                rm_indexes.append(i)
        x_val = [value for i, value in enumerate(x_val) if i not in rm_indexes]
        y_val = [value for i, value in enumerate(y_val) if i not in rm_indexes]
        x_val = list(x_val) + env_x
        y_val = list(y_val) + env_y

        # sort x
        x_val, y_val = zip(*sorted(zip(x_val, y_val)))

        colors = [self.get_color(y) for y in y_val]
        return x_val, y_val, colors


class LutModel(QStandardItemModel):
    FROM_COL = 0
    TO_COL = 1
    COLOR_COL = 2

    def __init__(self, parent=None) -> None:
        """
        QStandardItemModel for BarSymbology color dict display

        Args:
            parent: QWidget
        """
        super().__init__(parent=parent)
        self.setHorizontalHeaderLabels(["From", "To", "Color"])
        self.colored_rows = []

    def data(
        self, index: QtCore.QModelIndex, role: int = QtCore.Qt.DisplayRole
    ) -> QVariant:
        """
        Override StringMapTableModel data() for :
        - icon of pattern col

        Args:
            index: QModelIndex
            role: Qt role

        Returns: QVariant

        """
        result = super().data(index, role)

        if role == QtCore.Qt.BackgroundRole and index.column() == self.COLOR_COL:
            html_color = str(self.data(index, QtCore.Qt.DisplayRole))
            result = QColor(html_color)

        if (
            role == QtCore.Qt.BackgroundRole
            and index.row() in self.colored_rows
            and index.column() != self.COLOR_COL
        ):
            result = QColor("orange")

        return result

    def set_lut(self, from_: list, to_: list, colors: list) -> None:
        """
        Fill table.
        """
        # clear
        while self.rowCount():
            self.removeRow(0)

        for f, t, c in zip(from_, to_, colors):
            self.insertRow(self.rowCount())
            row = self.rowCount() - 1
            self.setData(self.index(row, self.FROM_COL), f)
            self.setData(self.index(row, self.TO_COL), t)
            self.setData(self.index(row, self.COLOR_COL), c)

    def get_lut(self) -> tuple:
        """
        Return 3 lists : from, to, colors
        """
        from_ = []
        to_ = []
        colors = []
        for i in range(self.rowCount()):
            from_.append(self.data(self.index(i, self.FROM_COL)))
            to_.append(self.data(self.index(i, self.TO_COL)))
            colors.append(self.data(self.index(i, self.COLOR_COL)))

        return from_, to_, colors

    def get_range(self) -> tuple:
        """
        Return min max values.
        """
        from_, to_, _ = self.get_lut()
        all_ = from_ + to_
        all_ = [v for v in all_ if v is not None]
        if len(all_) == 0:
            return None, None

        return min(all_), max(all_)

    def get_colormap(
        self, is_discretized: bool
    ) -> Union[ContinuousColorMap, DiscretizedColorMap]:
        """
        Return either:
            - a ContinuousColorMap corresponding to LUT table (continuous)
            - a dict describing classes and associated colors (discretized)
        """
        if is_discretized:
            result = DiscretizedColorMap()
            from_, to_, colors = self.get_lut()
            for f, t, c in zip(from_, to_, colors):
                result[(f, t)] = c
            sorted_result = DiscretizedColorMap(
                sorted(result.items(), key=lambda x: x[0])
            )
            return sorted_result
        else:
            breaks, _, colors = self.get_lut()
            if len(breaks) == 0:
                return None
            max_ = np.max(breaks)
            min_ = np.min(breaks)
            relative_breaks = [(break_ - min_) / (max_ - min_) for break_ in breaks]
            colormap = ContinuousColorMap(pos=relative_breaks, color=colors)
            return colormap

    def add_classe(self) -> None:
        """
        Add a new row.
        """
        self.insertRow(self.rowCount())
        row = self.rowCount() - 1
        self.setData(self.index(row, self.COLOR_COL), "#000000")

    def delete_classe(self, indexes: list) -> None:
        """
        Remove selected rows.
        """
        rows = [index.row() for index in indexes]
        rows = sorted(rows, reverse=True)
        for row in rows:
            self.removeRow(row)


class NumericDelegate(QStyledItemDelegate):
    def createEditor(self, parent, option, index):
        editor = QDoubleSpinBox(parent)
        editor.setLocale(QLocale("C"))
        editor.setDecimals(15)
        editor.setMaximum(sys.float_info.max)
        editor.setMinimum(-sys.float_info.max)
        return editor
