from __future__ import annotations

from pathlib import Path
from xml.etree.ElementTree import Element, parse

import pygcd
from pyvista import DataObject, DataSet, MultiBlock, wrap
from qtpy.QtCore import QModelIndex
from qtpy.QtWidgets import QFileDialog

from .layer import pvGroup, pvItem, pvLayer
from .model import pvItemModel


class pvLoader:
    def __init__(
        self, model: pvItemModel, currentIndex: callable = lambda: QModelIndex()
    ):
        if not isinstance(model, pvItemModel):
            raise TypeError(model)
        if not isinstance(currentIndex(), QModelIndex):
            raise TypeError(currentIndex)

        self.model = model
        self.currentIndex = currentIndex

    def _load(self, item: pvItem, *, parent: QModelIndex = None) -> QModelIndex:
        assert isinstance(item, pvItem)
        return self.model.add(item, parent=parent or self.currentIndex())

    def load_mesh(
        self, source: str | DataSet = None, parent: QModelIndex = None, **kwargs
    ):
        if source:
            try:
                return self._load(pvLayer.from_file(source, **kwargs), parent=parent)
            except TypeError:
                return self._load(pvLayer.from_memory(source, **kwargs), parent=parent)

        return [
            self.load_mesh(f, parent=parent)
            for f in QFileDialog.getOpenFileNames(
                parent=None,
                caption="Select a 3D mesh file",
                filter="VTK files (*.vtk *.vtp *.vti *.vtr *.vts *.vtu);; Other mesh file (*.obj *.ply *.stl);; All files (*.*)",
            )[0]
        ]

    def load_multiblock(self, source: str | MultiBlock = None, **kwargs):
        if not source:
            return [
                self.load_multiblock(f)
                for f in QFileDialog.getOpenFileNames(
                    parent=None,
                    caption="Select a 3D mesh file",
                    filter="VTK files (*.vtm);; All files (*.*)",
                )[0]
            ]

        try:
            path = Path(source)
            xml = parse(path).getroot()
            parent = self._load(pvGroup(path.stem))
            for vtm in xml.iterfind("vtkMultiBlockDataSet"):
                self._load_vtm_block(vtm, prefix=path.parent, parent=parent, **kwargs)
            return parent
        except TypeError:
            return self._load_pv_block(wrap(source), **kwargs)

    def _load_pv_block(
        self,
        block: DataObject,
        parent=QModelIndex(),  # noqa: B008
        **kwargs,
    ):
        if isinstance(block, MultiBlock):
            child = self._load(pvGroup(kwargs.pop("name", "MultiBlock")), parent=parent)
            for part, name in zip(block, block.keys(), strict=False):
                self._load_pv_block(part, parent=child, name=name, **kwargs)
            return child
        return self.load_mesh(block, parent=parent, **kwargs)

    def _load_vtm_block(
        self,
        block: Element,
        prefix: Path,
        parent=QModelIndex(),  # noqa: B008
        **kwargs,
    ):
        for b in block.iterfind("Block"):
            child = self._load(pvGroup(b.get("name")), parent=parent)
            self._load_vtm_block(b, prefix, child)
        for ds in block.iterfind("DataSet"):
            file = ds.get("file")
            kwargs["name"] = ds.get("name", Path(file).stem)
            layer = pvLayer.from_file(prefix / file, **kwargs)
            self._load(layer, parent=parent)

    def load_gocad(self, file: str | None = None):
        if not file:
            return [
                self.load_gocad(f)
                for f in QFileDialog.getOpenFileNames(
                    parent=None,
                    caption="Select a GOCAD project",
                    filter="ASCII files (*.txt *.ascii *.ds);; All files (*.*)",
                )[0]
            ]

        content = pygcd.load(file)
        root = self._load(pvGroup(Path(file).stem))
        for geom_t in set(content.geometries):
            self._load_pv_block(
                pygcd.read(file, geometries=[geom_t], wrapper="pyvista"),
                name=geom_t.name,
                parent=root,
            )
        return root

    # def load_raster(self, file: str = None):
    #     raise NotImplementedError("WIP")

    # def load_vector(self, file: str = None):
    #     raise NotImplementedError("WIP")

    # def load_layer(self, layer: QgsMapLayer):
    #     try:
    #         pass
    #     except ImportError:
    #         from_memory()

    #     else:
    #         return self._load(pvLayer.from_memory(mesh))
