from __future__ import annotations

from typing import Any
from xml.etree.ElementTree import Element, SubElement

from qtpy.QtCore import (
    QAbstractItemModel,
    QByteArray,
    QMimeData,
    QModelIndex,
    QObject,
    Qt,
    Signal,
)

from .layer import pvGroup, pvItem, pvLayer


class pvItemModel(QAbstractItemModel):
    """Qt Model for pvGIS layers.

    Args:
        parent (QObject, optional): parent QObject. Defaults to None.
        columns (list[str], optional): list of `item.data()` attributes to expose. Defaults to [].
    """

    layersAdded = Signal(list)
    layersRemoved = Signal(list)
    layersChanged = Signal(list)

    def __init__(self, parent: QObject = None):
        super().__init__(parent=parent)
        self.root = pvGroup("__root__")
        self.columns = ["name"]

        self.rowsInserted.connect(self.notifyInsertion)
        self.rowsAboutToBeRemoved.connect(self.notifyRemoval)
        self.dataChanged.connect(self.notifyChange)
        # self.layersChanged.connect(self.dataChanged.emit) # TODO: add method to match layers to model.indexes and emit the model.dataChanged(indexes)

    def item(
        self,
        index: QModelIndex = QModelIndex(),  # noqa: B008
        *,
        child: int | None = None,
    ) -> pvItem:
        """Return a valid QModelIndex.internalPointer() (defaults to self.root)

        Args:
            index (QModelIndex, optional): The model index. Defaults to QModelIndex().
            child (int, optional): Return the n-th child instead. Defaults to None.

        Returns:
            object: The data stored at index.
        """
        item = index.internalPointer() or self.root

        if child is None:
            return item
        if not isinstance(child, int):
            raise TypeError(child)

        if n := len(item):
            if -n < child < n:
                return next((x for i, x in enumerate(item) if i == (child % n)))
            msg = f"index '{child}' out of '{item.name}' range '[-{n},{n}['"
            raise IndexError(msg)
        msg = f"'{item.name}' is empty"
        raise IndexError(msg)

    def layers(self, index: QModelIndex = QModelIndex()) -> list[pvLayer]:  # noqa: B008
        """(recursive) list of stored layers.

        Args:
            index (QModelIndex, optional): The root item. Defaults to `self.root`.

        Returns:
            list[object]: data
        """

        items = list(self.item(index))
        layers = []
        while items:
            el = items.pop()
            if isinstance(el, pvLayer):
                layers.append(el)
            elif isinstance(el, pvGroup):
                items += el.children
        return layers

    def hasFlag(self, index: QModelIndex, flag: Qt.ItemFlag):
        return self.flags(index) & flag == flag

    def moveItems(
        self,
        items: list[pvItem],
        parent: QModelIndex = QModelIndex(),  # noqa: B008
        row: int = -1,
    ) -> bool:
        """Move a list of items.

        Args:
            sourceItems (Iterable[pvItem]): List of items.
            destinationParent (QModelIndex): Index of the parent destination.
            destinationChild (int): Position in the parent destination.

        Returns:
            bool: whether the move was a success or not
        """

        # cannot move to a non-dropable item
        # if not self.hasFlag(parent, Qt.ItemFlag.ItemIsDropEnabled):
        # return False

        host = self.item(parent)

        for i, item in enumerate(filter(None, items)):
            donnor = self.parent(item)
            # ask Qt if the move is valid
            if not self.beginMoveRows(
                donnor,
                donnor.row(),
                donnor.row(),
                parent,
                row + i,
            ):
                return False
            if item.parent:  # pop from parent if any
                item.parent.remove(item)
            host.insert(row + i, item)
            # end the move event
            self.endMoveRows()

        return True

    def notifyInsertion(self, parent: QModelIndex, first: int, last: int):
        items = []
        for i in range(first, last + 1):
            items.extend(self.item(parent, child=i))
        if items:
            self.layersAdded.emit(items)

    def notifyRemoval(self, parent: QModelIndex, first: int, last: int):
        items = []
        for i in range(first, last + 1):
            items += self.item(parent, child=i)
        if items:
            self.layersRemoved.emit(items)

    def notifyChange(
        self,
        topLeft: QModelIndex,
        bottomRight: QModelIndex,
        roles: list[int] | None = None,
    ):
        if roles is None:
            roles = []
        layers = []

        if topLeft.parent() == bottomRight.parent():
            first, last = topLeft.row(), bottomRight.row() + 1
            parent = topLeft.parent()
            for i in range(first, last):
                layers += self.layers(self.index(i, parent=parent))
        else:
            layers += self.layers(topLeft)
            layers += self.layers(bottomRight)

        self.layersChanged.emit(layers)

    def add(
        self,
        item: pvItem,
        *,
        row: int = -1,
        parent: QModelIndex = QModelIndex(),  # noqa: B008
    ) -> QModelIndex:
        """Add an item to the model.

        Args:
            item (pvItem): The item to insert.
            row (int, optional): Row position in the parent. Defaults to 0.
            parent (QModelIndex, optional): Index of the insertion host. Defaults to `self.root`.

        Returns:
            QModelIndex: Index of the insertion.
        """

        if not isinstance(item, pvItem):
            raise TypeError(item)

        host = self.item(parent)
        if not isinstance(host, pvGroup):
            parent, row = parent.parent(), parent.row() + 1
            host = self.item(parent)

        row = row % (len(host) + 1)

        self.beginInsertRows(parent, row, row)
        host.insert(row, item)
        self.endInsertRows()
        self.layoutChanged.emit()
        return self.createIndex(row, 0, item)

    def add_group(
        self,
        name: str,
        *,
        row: int = -1,
        parent: QModelIndex = QModelIndex(),  # noqa: B008
    ) -> QModelIndex:
        """A shortcut to add a group to a model.

        Args:
            name (str): The name of the group to be created.
            row (int, optional): cf. model.add.
            parent (QModelIndex, optional): cf. model.add.

        Returns:
            QModelIndex: Index of the insertion.
        """

        return self.add(pvGroup(name), row=row, parent=parent)

    def remove(self, index: QModelIndex) -> pvItem:
        parent, row = index.parent(), index.row()
        host = self.item(parent)
        self.beginRemoveRows(parent, row, row)
        item = host.pop(row)
        self.endRemoveRows()

        return item

    def to_xml(self, include_data: bool = False) -> Element:
        root = Element(self.__class__.__name__)
        layers = SubElement(root, "Layers")
        tree = SubElement(root, "TreeView")

        def wrap_element(item: pvItem, parent: Element):
            if isinstance(item, pvLayer):
                layers.append(item.to_xml(uid_only=False, include_data=include_data))
                parent.append(item.to_xml(uid_only=True, include_data=False))
            elif isinstance(item, pvGroup):
                group = item.to_xml()
                for child in item.children:
                    wrap_element(child, group)
                parent.append(group)
            else:
                raise AssertionError()

        for el in self.root.children:
            wrap_element(el, tree)

        return root

    def load_xml(self, root: Element):
        # re-create layers from payload
        layers = {}
        for el in root.findall(".//Layers/*"):
            layer = pvLayer.from_xml(el)
            layers[layer.uid] = layer

        # recursive grafting function
        def graft(el: Element, parent: QModelIndex = QModelIndex()):  # noqa: B008
            if el.tag == "Layer":
                return self.add(layers.pop(el.get("uid")), parent=parent)
            if el.tag == "Group":
                group = self.add(pvGroup.from_xml(el), parent=parent)
                for sub in el:
                    graft(sub, group)
                return None
            return None

        # re-create treeview model
        for el in root.findall(".//TreeView/*"):
            graft(el)

        if layers:
            # TODO: warn there are invisible layers and handle them !
            ...

    def clear(self):
        self.beginResetModel()
        self.layersRemoved.emit(self.layers())
        self.root.children.clear()
        self.endResetModel()

    ####################################
    ### QAbstractItemModel overloads ###
    ####################################

    # overload QAbstractItemModel.rowCount()
    def rowCount(self, parent: QModelIndex = QModelIndex()) -> int:  # noqa: B008
        """Returns the number of rows under the given parent. When the parent is valid it means that rowCount is returning the number of children of parent."""
        return len(self.item(parent))

    # overload QAbstractItemModel.columnCount()
    def columnCount(
        self,
        parent: QModelIndex = QModelIndex(),  # noqa: B008, ARG002
    ) -> int:
        """Returns the number of columns for the children of the given parent."""
        return len(self.columns)

    # overload QAbstractItemModel.flags()
    def flags(self, index: QModelIndex) -> Qt.ItemFlag:
        """Returns the item flags for the given index."""

        flags = Qt.NoItemFlags

        if not index.isValid():
            return flags

        flags |= (
            Qt.ItemIsEnabled
            | Qt.ItemIsSelectable
            | Qt.ItemIsDragEnabled
            | Qt.ItemIsEditable
            | Qt.ItemIsUserCheckable
        )

        if isinstance(self.item(index), pvGroup):
            flags |= Qt.ItemIsDropEnabled

        return flags

    # overload QAbstractItemModel.index()
    def index(
        self,
        row: int,
        column: int = 0,
        parent: QModelIndex = QModelIndex(),  # noqa:B008
    ) -> QModelIndex:
        """Returns the index of the item in the model specified by the given row, column and parent index."""
        try:
            child = self.item(parent, child=row)
            return self.createIndex(row, column, child)
        except IndexError:
            return QModelIndex()

    # overload QAbstractItemModel.parent()
    def parent(self, child: QModelIndex | pvItem) -> QModelIndex:
        """Returns the parent of the model item with the given index. If the item has no parent, an invalid QModelIndex is returned."""
        if isinstance(child, QModelIndex):
            if child.isValid():
                return self.parent(self.item(child))
            return QModelIndex()
        if isinstance(child, pvItem):
            if parent := child.parent:
                row = next((i for i, x in enumerate(parent) if x == child))
                return self.createIndex(row, 0, parent)
            return QModelIndex()
        raise TypeError(child)

    # overload QAbstractItemModel.data()
    def data(self, index: QModelIndex, role: int = Qt.DisplayRole) -> Any:
        """Returns the data stored under the given role for the item referred to by the index."""

        if not index.isValid():
            return None

        item = index.internalPointer()
        assert isinstance(item, pvItem)

        if role == Qt.CheckStateRole:
            value = bool(item.state)
            return Qt.Checked if value else Qt.Unchecked
        if role in (Qt.DisplayRole, Qt.EditRole):
            attr = self.columns[index.column()]
            return getattr(item, attr, None)
        return None

    # overload QAbstractItemModel.setData()
    def setData(
        self,
        index: QModelIndex,
        value: Any,
        role: int = Qt.EditRole,
    ) -> bool:
        """Sets the role data for the item at index to value. Returns true if successful; otherwise returns false."""
        if role not in (Qt.CheckStateRole, Qt.EditRole) or not index.isValid():
            return False

        item = index.internalPointer()
        assert isinstance(item, pvItem)

        if role == Qt.CheckStateRole:
            item.state = bool(value)
        else:
            attr = self.columns[index.column()]
            setattr(item, attr, value)

        self.dataChanged.emit(index, index, [role])
        return True

    # overload QAbstractItemModel.supportedDragActions()
    def supportedDragActions(self) -> Qt.DropAction:
        """Returns the actions supported by the data in this model."""
        return Qt.DropAction.MoveAction

    # overload QAbstractItemModel.supportedDropActions()
    def supportedDropActions(self) -> Qt.DropAction:
        """Returns the drop actions supported by this model."""
        return Qt.DropAction.MoveAction

    # overload QAbstractItemModel.mimeTpyes()
    def mimeTypes(self) -> list[str]:
        """Returns the list of allowed MIME types."""
        return ["application/x-pvitems"]

    # overload QAbstractItemModel.dropMimeData()
    def dropMimeData(
        self,
        data: QMimeData,
        action: Qt.DropAction,  # noqa: ARG002
        row: int,
        column: int,  # noqa: ARG002
        parent: QModelIndex,
    ) -> bool:
        if data.hasFormat("application/x-pvitems"):
            if items := data.x_pvitems:
                return self.moveItems(items, parent, row)
            return False
        return False

    # overload QAbstractItemModel.mimeData()
    def mimeData(self, indexes: list[QModelIndex]) -> QMimeData:
        """Returns an object that contains serialized items of data corresponding to the list of indexes specified."""

        items = [
            it for it in [idx.internalPointer() for idx in indexes] if it is not None
        ]
        if not items:
            return None

        data = QMimeData()
        data.setData("application/x-pvitems", QByteArray())

        # FIXME: avoid cast to QByteArray ... in a dirty way
        data.__setattr__("x_pvitems", list(items))

        return data
