from datetime import datetime
from typing import List, Union, Optional
import os
from enum import Enum

import numpy as np
import pandas as pd
from qgis.PyQt import uic
from PyQt5.QtWidgets import QListWidgetItem, QDialog
from qgis.PyQt import QtWidgets, QtCore
from qgis.PyQt.QtCore import QAbstractTableModel, Qt, QCoreApplication, QModelIndex, QEvent, QAbstractItemModel

from landsklim.lk.regressor_factory import RegressorFactory, RegressorDefinition


class LandsklimTableColumnType(Enum):
    """
    Represents column type of Landsklim tables
    """

    COLUMN_LABEL = 1
    """
    Represents a standard label
    """
    COLUMN_WINDOWS = 2
    """
    Represents a list of window for each regressor
    """
    COLUMN_DEGREE = 3
    """
    Represents polynomial degree of each variable for an analysis
    """
    COLUMN_CHECKBOX = 4
    """
    Represents a checkbox
    """
    COLUMN_DATE = 5
    """
    Represents a date
    """


class LandsklimTableModel(QAbstractTableModel):
    """
    Represents the table model used inside Landsklim

    :param labels: Visible content of the table.
    :type labels: List[List[object]]

    :param data: Data linked to cells
    :type data: List[List[object]]

    :param headers: Column headers
    :type headers: List[str]

    :param columns_type: Type of columns
    :type columns_type: List[LandsklimTableColumnType]
    """

    def __init__(self, labels, data, headers, columns_type):
        super(LandsklimTableModel, self).__init__()
        self._labels = labels
        self._data = data
        self._headers = headers
        self._columns_type = columns_type

    def raw_label(self, index: QModelIndex):
        return self._labels[index.row()][index.column()]

    def data(self, index: QModelIndex, role=Qt.DisplayRole):
        """
        Get cell value
        """
        if index.isValid():
            if role == Qt.CheckStateRole and self._columns_type[index.column()] == LandsklimTableColumnType.COLUMN_CHECKBOX:
                value = self._labels[index.row()][index.column()]
                return Qt.Checked if value else Qt.Unchecked
            elif role == Qt.DisplayRole or role == Qt.EditRole:
                if self._columns_type[index.column()] == LandsklimTableColumnType.COLUMN_DATE:
                    return self._labels[index.row()][index.column()].strftime("%Y-%m-%d %H:%M")
                else:
                    return self._labels[index.row()][index.column()]
            elif role == Qt.UserRole:
                return self._data[index.row()][index.column()]

    def check_window_string_validity(self, value: str, min_window_size: Optional[int], max_window_size: Optional[int]) -> bool:
        """
        Check if a string is correctly formatted to describe windows of a regressor
        """

        if len(value) == 0:
            return True

        value_windows = [i for i in value.split(',')]

        is_ok = True
        for window_str in value_windows:
            window_str = window_str.strip()
            check_max = window_str.isnumeric() and (max_window_size is None or int(window_str) <= max_window_size)
            check_min = window_str.isnumeric() and (min_window_size is None or int(window_str) >= min_window_size)
            is_ok = False if not(window_str.isnumeric() and (int(window_str) % 2 == 1 and check_min and check_max)) else is_ok

        # Check for duplicates
        if is_ok:
            windows_list = [int(i) for i in value.replace(' ', '').split(',')]
            if len(windows_list) != len(set(windows_list)):
                is_ok = False

        return is_ok

    def check_value(self, index: QModelIndex, value: str) -> bool:
        """
        Check if an input value is valid according the column type
        """
        res = False
        if self._columns_type[index.column()] == LandsklimTableColumnType.COLUMN_WINDOWS:
            (min_window_size, max_window_size) = self.data(index, Qt.UserRole)
            res = self.check_window_string_validity(value, min_window_size, max_window_size)
        if self._columns_type[index.column()] == LandsklimTableColumnType.COLUMN_LABEL:
            res = False
        if self._columns_type[index.column()] == LandsklimTableColumnType.COLUMN_CHECKBOX:
            res = False
        if self._columns_type[index.column()] == LandsklimTableColumnType.COLUMN_DEGREE:
            res = str(value).isnumeric() and int(value) > 0
        if self._columns_type[index.column()] == LandsklimTableColumnType.COLUMN_DATE:
            res = True
            try:
                datetime.strptime(value, "%Y-%m-%d %H:%M")
            except ValueError as e:
                res = False

        return res

    def flag_type(self, index: QModelIndex) -> int:
        """
        Get flags according to column type
        """
        res = Qt.ItemIsSelectable | Qt.ItemIsEnabled
        if self._columns_type[index.column()] == LandsklimTableColumnType.COLUMN_WINDOWS:
            res = Qt.ItemIsSelectable | Qt.ItemIsEnabled | Qt.ItemIsEditable
        if self._columns_type[index.column()] == LandsklimTableColumnType.COLUMN_LABEL:
            res = Qt.ItemIsSelectable | Qt.ItemIsEnabled
        if self._columns_type[index.column()] == LandsklimTableColumnType.COLUMN_CHECKBOX:
            res = Qt.ItemIsUserCheckable | Qt.ItemIsEnabled
        if self._columns_type[index.column()] == LandsklimTableColumnType.COLUMN_DEGREE:
            res = Qt.ItemIsSelectable
        if self._columns_type[index.column()] == LandsklimTableColumnType.COLUMN_DATE:
            res = Qt.ItemIsSelectable | Qt.ItemIsEnabled | Qt.ItemIsEditable
        return res

    def setData(self, index: QModelIndex, value: str, role) -> bool:
        """
        Edit cell value
        """
        if role == Qt.CheckStateRole and self._columns_type[index.column()] == LandsklimTableColumnType.COLUMN_CHECKBOX:
            self._labels[index.row()][index.column()] = bool(value)
            return True
        if role == Qt.EditRole and self.check_value(index, value):
            if self._columns_type[index.column()] == LandsklimTableColumnType.COLUMN_DATE:
                self._labels[index.row()][index.column()] = datetime.strptime(value, "%Y-%m-%d %H:%M")
            else:
                self._labels[index.row()][index.column()] = value
            return True
        else:
            return False

    def flags(self, index: QModelIndex) -> int:
        """
        Flags of a cell
        """
        if index.isValid():
            return self.flag_type(index)

    def rowCount(self, index: QModelIndex=None) -> int:
        """
        Number of rows in the table
        """
        return len(self._labels)

    def columnCount(self, index: QModelIndex=None) -> int:
        """
        Number of columns in the table
        """
        if len(self._labels) > 0:
            return len(self._labels[0])
        else:
            return 0

    def headerData(self, p_int, orientation, role=None) -> str:
        """
        Get headers title
        """
        if role == Qt.DisplayRole:
            if orientation == Qt.Horizontal:
                return self._headers[p_int]

    def refresh(self):
        """
        Refresh view after manual model changes
        """
        self.beginResetModel()
        self.endResetModel()


class LandsklimTableModelRegressors(LandsklimTableModel):
    """
    Represents the table model used inside Landsklim to display/edit regressors
    """

    def __init__(self, data: List[RegressorDefinition], max_window_size: Optional[int]=None):
        columns_type = [LandsklimTableColumnType.COLUMN_LABEL, LandsklimTableColumnType.COLUMN_WINDOWS, LandsklimTableColumnType.COLUMN_DEGREE]
        super(LandsklimTableModelRegressors, self).__init__(self.get_variables_labels(data), self.get_variables_data(data, max_window_size), [], columns_type)
        # Need self.tr so self._headers is redefined after super constructor call
        self._headers = [self.tr("Explanatory variables"), self.tr("Windows"), self.tr("Polynomial degree")]

    def get_variables_data(self, data: List[RegressorDefinition], max_window_size: Optional[int]) -> List[List[object]]:
        regressors = []
        for regressor in data:
            min_window = RegressorFactory.get_regressor_class(regressor.regressor_name).min_window()
            regressors.append([regressor.regressor_name, (min_window, max_window_size), None])
        return regressors

    def get_variables_labels(self, data: List[RegressorDefinition]) -> List[List[object]]:
        """
        Get all regressors information to setup analysis
        """
        regressors = []
        for regressor in data:
            regressor_label: str = RegressorFactory.get_regressor(regressor.regressor_name, 0, 0).name()
            regressors.append([regressor_label, ", ".join([str(i) for i in regressor.windows]), 1])
        return regressors

    def get_explicative_variables(self) -> List[RegressorDefinition]:
        regressors: List[RegressorDefinition] = []
        for i in range(self.rowCount()):
            row_regressor: str = self.data(self.index(i, 0), Qt.UserRole)
            row_windows: str = self.data(self.index(i, 1))
            windows = [int(i) for i in row_windows.replace(' ', '').split(',')] if len(row_windows.replace(' ', '')) > 0 else []
            row_polynomial_degree: int = self.data(self.index(i, 2))
            regressors.append(RegressorDefinition(row_regressor, windows, row_polynomial_degree))
        return regressors

class LandsklimTableSituationsDate(LandsklimTableModel):
    """
    Represents the table model used inside Landsklim to edit situations date
    """

    def __init__(self, situations: List[str], dates: List[datetime]):
        columns_type = [LandsklimTableColumnType.COLUMN_LABEL, LandsklimTableColumnType.COLUMN_DATE]
        data = [[None, None]] * len(situations)
        items = [list(item) for item in zip(situations, dates)]
        super(LandsklimTableSituationsDate, self).__init__(items, data, [], columns_type)
        # Need self.tr so self._headers is redefined after super constructor call
        self._headers = [self.tr("Situations"), self.tr("Date")]

    def get_dates(self) -> List[datetime]:
        dates = []
        for i in range(self.rowCount()):
            dates.append(self.raw_label(self.index(i, 1)))
        return dates


class TableModelPandas(QAbstractTableModel):

    def __init__(self, dataframe: pd.DataFrame, parent=None):
        super(TableModelPandas, self).__init__(parent=parent)
        self.__dataframe: pd.DataFrame = dataframe.round(decimals=3)

    def rowCount(self, parent: QModelIndex=None) -> int:
        return len(self.__dataframe)

    def columnCount(self, parent: QModelIndex=None) -> int:
        return len(self.__dataframe.columns)

    def data(self, index: QModelIndex, role=Qt.ItemDataRole):
        if not index.isValid():
            return None

        if role == Qt.DisplayRole:
            data = self.__dataframe.iloc[index.row(), index.column()]
            if np.isnan(data):
                return ""
            else:
                return str(data)

        return None

    def headerData(self, section: int, orientation: Qt.Orientation, role: Qt.ItemDataRole):
        if role == Qt.DisplayRole:
            if orientation == Qt.Horizontal:
                return str(self.__dataframe.columns[section])

            if orientation == Qt.Vertical:
                return str(self.__dataframe.index[section])


SELECTION_CLASS, _ = uic.loadUiType(os.path.join(os.path.dirname(__file__), 'widget_multiple_selection.ui'))


class QListWidgetItemDataComparable(QListWidgetItem):
    def __lt__(self, other: "QListWidgetItemDataComparable"):
        return self.data(QtCore.Qt.UserRole) < other.data(QtCore.Qt.UserRole)


class DialogSelection(QDialog, SELECTION_CLASS):
    """
    Represents a dialog allowing an user to select elements from a list

    :param available: List of items
    :type available: List[str]

    :param selected: List of selected item. Must be a sublist of [self.available]
    :type selected: List[str]

    """

    def __init__(self, title: str, available: List[str], selected: List[str], parent=None):
        super(DialogSelection, self).__init__(parent)
        self.setupUi(self)
        self.__available: List[str] = available
        self.__selected:  List[str] = selected
        self.fill_lists()

        self.pb_add.clicked.connect(self.on_add)
        self.pb_remove.clicked.connect(self.on_remove)
        self.pb_add_all.clicked.connect(self.on_add_all)
        self.pb_remove_all.clicked.connect(self.on_remove_all)

        self.setWindowTitle(title)

    def fill_lists(self):
        self.list_available.clear()
        self.list_selected.clear()
        for i, item in enumerate(self.__available):
            widget_item = QListWidgetItemDataComparable()
            widget_item.setText(item)
            widget_item.setData(QtCore.Qt.UserRole, i)
            if item not in self.__selected:
                self.list_available.addItem(widget_item)
            else:
                self.list_selected.addItem(widget_item)
        self.sort_lists()

    def sort_lists(self):
        self.list_available.sortItems()
        self.list_selected.sortItems()

    def on_add(self):
        for item in self.list_available.selectedItems():
            item_removed = self.list_available.takeItem(self.list_available.row(item))
            self.list_selected.addItem(item_removed)
        self.sort_lists()

    def on_remove(self):
        for item in self.list_selected.selectedItems():
            item_removed = self.list_selected.takeItem(self.list_selected.row(item))
            self.list_available.addItem(item_removed)
        self.sort_lists()

    def on_add_all(self):
        for i in range(self.list_available.count()):
            self.list_selected.addItem(self.list_available.takeItem(0))
        self.sort_lists()

    def on_remove_all(self):
        for i in range(self.list_selected.count()):
            self.list_available.addItem(self.list_selected.takeItem(0))
        self.sort_lists()

    def get_available(self) -> List[str]:
        items: List[str] = []
        for i in range(self.list_available.count()):
            item = self.list_available.item(i)
            items.append(item.text())

        return items

    def get_selected(self, get_indices: bool = False) -> List[Union[str, int]]:
        """
        :param get_indices: Instead of return selected items' text, returns selected items' index
        :type get_indices: bool
        """
        items: List[str] = []
        for i in range(self.list_selected.count()):
            item = self.list_selected.item(i)
            items.append(item.text() if not get_indices else item.data(QtCore.Qt.UserRole))
        return items

    def accept(self):
        self.__selected = self.get_selected()
        super().accept()

    def reject(self):
        self.fill_lists()
        super().reject()