from enum import Enum
from typing import Dict, List, Optional, Tuple, Callable, Any, Iterable

import numpy as np
from PyQt5.QtCore import QCoreApplication
from matplotlib.axes import Axes
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
from cycler import cycler
from matplotlib.collections import PathCollection
from matplotlib.lines import Line2D
from matplotlib.text import Annotation
import matplotlib.dates as mdates


class ChartAttribute(Enum):
    DRAW_LINE_ON_0 = 0
    SUNSET_BACKGROUND = 1
    PLOT_ONLY_WHITE_LINES = 2


class LandsklimChart:
    """
    Encapsulate Matplotlib to ease chart creation
    """

    """
    # Singleton management
    # Use __chart_provider = LandsklimChart() to use charting functions
    __instance: "LandsklimChart" = None

    def __new__(cls):
        if cls.__instance is None:
            cls.__instance = super(LandsklimChart, cls).__new__(cls)
            cls.__instance.__initialized = False
        return cls.__instance

    def __init__(self):
        if not self.__initialized:
            self.__initialized = True
            self.init_matplotlib()"""

    def tr(self, string: str):
        return QCoreApplication.translate("LandsklimChart", string)

    def __init__(self, fig_height: Optional[float]=None, chart_attributes: Optional[List[ChartAttribute]]=None):
        self.init_matplotlib()
        self.figure: plt.Figure = plt.figure()
        if fig_height is not None:
            self.figure.set_figheight(fig_height)
        self.__chart_attributes = [] if chart_attributes is None else chart_attributes
        self.canvas = FigureCanvas(self.figure)

    def has_attribute(self, attribute: ChartAttribute):
        return attribute in self.__chart_attributes

    def get_canvas(self):
        return self.canvas

    def init_matplotlib(self):
        if 'seaborn' in plt.style.available:
            plt.style.use('seaborn')
        elif 'seaborn-v0_8' in plt.style.available:
            plt.style.use('seaborn-v0_8')
        params = {"ytick.color": "black",
                  "xtick.color": "black",
                  "axes.labelcolor": "black",
                  "axes.edgecolor": "black",
                  "axes.formatter.use_mathtext": True,
                  "text.usetex": False,
                  "font.family": "serif",
                  "font.serif": ["Garamond", "cmr10"]}
        plt.rcParams.update(params)

    def plot_scatter(self, X: np.ndarray, Y: np.ndarray, data_labels: List[str], x_label: str, y_label: str, title: str, plot_fit: bool = True, rse_interval: Optional[float] = None, data: List[str] = None) -> FigureCanvas:
        """
        Make a scatter plot

        :param X: X values. If X is a 2d numpy array, one scatter plot will be drawn for each sample
        :type X: np.ndarray

        :param Y: Y values. If Y is a 2d numpy array, one scatter plot will be drawn for each sample
        :type Y: np.ndarray

        :param data_labels: Data labels
        :type data_labels: List[str]

        :param x_label: x axis label
        :type x_label: str

        :param y_label: y axis label
        :type y_label: str

        :param title: chart title
        :type title: str

        :returns: Created figure
        :rtype: FigureCanvas
        """

        if X.shape != Y.shape:
            raise RuntimeError("X, Y and colors must be the same size")

        self.figure.clf()


        ax: Axes = self.figure.add_subplot(111)  # == add_subplot(1, 1, 1)
        ax.clear()
        if len(X.shape) > 1:
            for i in range(len(X)):
                sc = ax.scatter(X[i], Y[i], s=4, label=data_labels[i])
        else:
            sc = ax.scatter(X, Y, c='red', s=8, label=data_labels)

        if plot_fit:
            m, b = np.polyfit(X, Y, 1)
            ax.plot(np.array(X), np.array(m * X + b), linewidth=1)

        if rse_interval is not None:
            m, b = np.polyfit(X, Y, 1)
            ax.plot(np.array(X), np.array(m * X + b) - rse_interval, linewidth=1, color='blue', label="Residual Standard Error")
            ax.plot(np.array(X), np.array(m * X + b) + rse_interval, linewidth=1, color='blue')

        ax.set_xlabel(x_label)
        ax.set_ylabel(y_label)
        ax.set_title(title)
        ax.legend()

        self.__set_annotations(ax, data, sc, self.hover_scatter)

        # refresh canvas
        self.canvas.draw()
        return self.canvas

    def update_annot(self, pos: np.array, ind: Dict[str, np.ndarray], data: List[str], annot: Annotation):
        """
        :param ind: Indices of hovered points
        :type ind: List[int]
        """
        annot.xy = pos

        pos_x = pos[0]
        x_min, x_max = annot.axes.get_xlim()
        if pos_x > 0.7*(x_max - x_min) + x_min:
            annot.set_x(-100)
        else:
            annot.set_x(20)

        data_index = ind["ind"][0]
        annot.set_text(data[data_index])
        annot.get_bbox_patch().set_facecolor("lightblue")
        annot.get_bbox_patch().set_alpha(0.7)

    def update_annot_scatter(self, ind: Dict[str, np.ndarray], data: List[str], sc: PathCollection, annot: Annotation):
        """
        :param ind: Indices of hovered points
        :type ind: List[int]
        """
        pos = sc.get_offsets()[ind["ind"][0]]
        self.update_annot(pos, ind, data, annot)

    def update_annot_line(self, ind: Dict[str, np.ndarray], data: List[str], line: Line2D, annot: Annotation):
        """
        :param ind: Indices of hovered points
        :type ind: List[int]
        """
        x, y = line.get_data()
        pos = (x[ind["ind"][0]], y[ind["ind"][0]])
        self.update_annot(pos, ind, data, annot)

    def hover_scatter(self, event, data: List[str], ax: Axes, sc: PathCollection, annot: Annotation):
        vis = annot.get_visible()
        if event.inaxes == ax:
            cont, ind = sc.contains(event)
            if cont:
                self.update_annot_scatter(ind, data, sc, annot)
                annot.set_visible(True)
                self.canvas.draw_idle()
            else:
                if vis:
                    annot.set_visible(False)
                    self.canvas.draw_idle()

    def hover_line(self, event, data: Dict[str, Dict[int, str]], ax: Axes, lines: List[Line2D], annot: Annotation):
        vis = annot.get_visible()
        if event.inaxes == ax:
            cont = False
            for i, line in enumerate(lines):
                cont, ind = line.contains(event)
                if cont:
                    # FIXME: Bypass road to access list of coefficient of the hovered regressor
                    self.update_annot_line(ind, list(data[list(data.keys())[i]].values()), line, annot)
                    annot.set_visible(True)
                    self.canvas.draw_idle()
                    break
            if not cont:
                if vis:
                    annot.set_visible(False)
                    self.canvas.draw_idle()

    def plot_lines(self, data: Dict[str, Dict[float, float]], x_label: str, y_label: str, title: str, display_marker: bool = True, x_ticks: Optional[Dict[float, str]] = None, y_lim: Optional[Tuple[float, float]] = None, x_axis_as_month: bool=False, labels: Dict[str, Dict[int, str]] = None) -> FigureCanvas:
        """
        Plot multiple lines

        :param data: Dictionary mapping each data label with its data. Data is represented as another dictionary
                     mapping x with y.
        :type data: Dict[str, Dict[float, float]]

        :param x_label: x axis label
        :type x_label: str

        :param y_label: y axis label
        :type y_label: str

        :param title: chart title
        :type title: str

        :param display_marker: Display a marker for each sample on the line
        :type display_marker: bool

        :param x_ticks: Display names instead of numbers for x axis
        :type x_ticks: Dict[float, str]

        :param x_axis_as_month: Show months names along the x axis
        :type x_axis_as_month: bool

        :param labels: Label attached to data. Displayed when mouse passes over the point.
        :type labels: Dict[str, Dict[int, str]]

        :returns: Created figure
        :rtype: FigureCanvas
        """
        lines: List[Line2D] = []
        self.figure.clf()
        ax: Axes = self.figure.add_subplot(111)
        ax.clear()

        ax.set_prop_cycle('color', plt.cm.Spectral(np.linspace(0, 1, len(data)+1)))

        for i, (data_label, data_data) in enumerate(data.items()):
            x = []
            y = []
            for data_x, data_y in data_data.items():
                x.append(data_x)
                y.append(data_y)

            markersize = 15 if display_marker else 0
            color = {'color': 'white'} if (i == 0 and self.has_attribute(ChartAttribute.PLOT_ONLY_WHITE_LINES)) else {}
            lines2d, = ax.plot(x, y, marker=".", markersize=markersize, label=data_label, **color)
            lines.append(lines2d)

        if x_ticks is None:
            x_ticks = {}

        if len(x_ticks) > 0:
            ax.set_xticks(list(x_ticks.keys()), list(x_ticks.values()))  #, rotation=45
            plt.setp(ax.get_xticklabels(), rotation=45)

        ax.set_xlabel(x_label)
        ax.set_ylabel(y_label)

        if x_axis_as_month:
            locator = mdates.MonthLocator()
            fmt = mdates.DateFormatter('%b')
            ax.xaxis.set_major_locator(locator)
            ax.xaxis.set_major_formatter(fmt)

        if y_lim is not None:
            ax.set_ylim(y_lim[0], y_lim[1])

        ax.set_title(title)
        self.__set_legend(ax, data)
        self.__set_lines(ax)
        self.__set_background(ax)
        self.__set_annotations(ax, labels, lines, self.hover_line)

        self.canvas.draw()
        return self.canvas

    def __set_annotations(self, ax: Axes, labels: Optional[Any], graphics: Iterable[Any], motion_notify_event: Callable):
        if labels is not None:
            annot: Annotation = ax.annotate("", xy=(0, 0), xytext=(20, 20), textcoords="offset points", bbox=dict(boxstyle="round"), arrowprops=dict(arrowstyle="->"))
            annot.set_visible(False)
            self.canvas.mpl_connect("motion_notify_event", lambda event: motion_notify_event(event, labels, ax, graphics, annot))

    def __set_legend(self, ax: Axes, data: Dict[str, Dict[float, float]]):
        if len(data) > 10:
            ax.legend(mode="expand", loc=3, bbox_to_anchor=(0., 1.02, 1., 0.102), ncol=6, fancybox=True, shadow=True)
            plt.tight_layout()
        else:
            ax.legend()

    def __set_lines(self, ax: Axes):
        if self.has_attribute(ChartAttribute.DRAW_LINE_ON_0):
            ax.axhline(y=0, color='r', linestyle='-')

    def __set_background(self, ax: Axes):
        if self.has_attribute(ChartAttribute.SUNSET_BACKGROUND):

            xmin, xmax = ax.get_xlim()  # plt.xlim()
            ymin, ymax = ax.get_ylim()  # plt.ylim()
            ymin = -max(abs(ymin), abs(ymax))
            ymax = max(abs(ymin), abs(ymax))
            ax.imshow([[1, 1], [0, 0]], cmap="plasma", interpolation="bicubic", extent=(xmin, xmax, ymin, ymax), alpha=0.5, aspect='auto')

    def plot_hist(self, data: np.ndarray, x_label: str, y_label: str, title: str):
        """
        Plot histogram

        :param data: Data source for the histogram
        :type data: np.ndarray

        :param x_label: x axis label
        :type x_label: str

        :param y_label: y axis label
        :type y_label: str

        :param title: chart title
        :type title: str

        :returns: Created figure
        :rtype: FigureCanvas
        """

        self.figure.clf()
        ax: Axes = self.figure.add_subplot(111)
        ax.clear()

        bin_width = 0.05
        ax.hist(data, bins=np.arange(-1., 1. + bin_width, bin_width), rwidth=0.75)

        ax.set_xlabel(x_label)
        ax.set_ylabel(y_label)
        ax.set_title(title)

        self.canvas.draw()
        return self.canvas

    def plot_bar(self, positive_data: np.ndarray, negative_data: np.ndarray, x_ticks: List[str], y_axis_as_percent: bool, x_label: str, y_label: str, title: str):
        """
        Plot bar chart with positive and negative values

        :param positive_data: Positive values
        :type positive_data: np.ndarray

        :param negative_data: Negative values
        :type negative_data: str

        :param x_ticks: Display names instead of numbers for x axis
        :type x_ticks: List[str]

        :param y_axis_as_percent: Data on the y axis is percentages
        :type y_axis_as_percent: bool

        :param x_label: x axis label
        :type x_label: str

        :param y_label: y axis label
        :type y_label: str

        :param title: chart title
        :type title: str

        :returns: Created figure
        :rtype: FigureCanvas
        """
        if len(positive_data) != len(negative_data):
            raise RuntimeError("Negative and positive data must be the same size")
        
        x = range(len(positive_data))

        self.figure.clf()
        ax: Axes = self.figure.add_subplot(111)
        ax.clear()

        ax.bar(x, negative_data, width=0.75, color='b')
        ax.bar(x, positive_data, width=0.75, color='r')

        ax.set_xlabel(x_label)
        ax.set_ylabel(y_label)
        ax.set_title(title)

        if y_axis_as_percent:
            ax.yaxis.set_major_formatter(mtick.PercentFormatter())
            ax.set_ylim(-100, 100)

        if len(x_ticks) > 0:
            ax.set_xticks(x, x_ticks)  #, rotation=45
            plt.setp(ax.get_xticklabels(), rotation=45)

        self.canvas.draw()
        return self.canvas

    def plot_variogram(self, lags: np.ndarray, semivariance: np.ndarray, function: np.ndarray) -> FigureCanvas:
        """
        Plot a variogram

        :param lags: Lags of the variogram
        :type lags: np.ndarray

        :param semivariance: Semivariance of the experimental variogram for each lags
        :type semivariance: np.ndarray

        :param function: Semivariance of the theoretical variogram for each lags
        :type function: np.ndarray

        :returns: Created figure
        :rtype: FigureCanvas
        """
        self.figure.clf()
        ax: Axes = self.figure.add_subplot(111)
        ax.clear()

        ax.plot(lags, semivariance, "r*")
        ax.plot(lags, function, "k-")
        ax.set_xlabel(self.tr("Distance"))
        ax.set_ylabel(self.tr("Semivariance"))
        ax.set_title(self.tr("Variogram"))

        self.canvas.draw()
        return self.canvas

    def plot_boxplot(self, data: np.ndarray, x_ticks: List[str], y_lim: Tuple[float, float], x_label: str, y_label: str, title: str) -> FigureCanvas:
        """
        Plot a box plot

        :param data: 2D array who represents series
        :type data: np.ndarray

        :param x_ticks: Display names instead of numbers for x axis
        :type x_ticks: List[str]

        :param y_lim: Min/Max of y axis
        :type y_lim: Tuple[float, float]

        :param x_label: x axis label
        :type x_label: str

        :param y_label: y axis label
        :type y_label: str

        :param title: chart title
        :type title: str

        :returns: Created figure
        :rtype: FigureCanvas
        """
        self.figure.clf()
        ax: Axes = self.figure.add_subplot(111)
        ax.clear()
        boxprops = dict(linewidth=0.5, color='red')
        medianprops = dict(linewidth=0.5, color='salmon')
        flierprops = dict(markeredgecolor='salmon')
        whiskerprops = dict(color='salmon')
        capprops = dict(color='salmon')
        ax.set_ylim(y_lim[0], y_lim[1])
        ax.boxplot(data, labels=x_ticks, boxprops=boxprops, medianprops=medianprops, flierprops=flierprops, whiskerprops=whiskerprops, capprops=capprops)
        plt.setp(ax.get_xticklabels(), rotation=45) # plt.xticks(rotation=45)

        ax.set_xlabel(x_label)
        ax.set_ylabel(y_label)
        ax.set_title(title)

        self.canvas.draw()
        return self.canvas
