from __future__ import annotations

import codecs
import pickle
import uuid
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from functools import cached_property
from pathlib import Path
from typing import Any
from xml.etree.ElementTree import Element, SubElement

import pyvista as pv
from qtpy.QtWidgets import QFileDialog
from unidecode import unidecode

from .geometryType import pvGeometryType, vtkIOXMLFileExtension
from .layerContext import pvLayerContext
from .layerSymbology import pvLayerSymbology


@dataclass
class pvItem(ABC):
	name: str = "New item"
	parent: object = None
	state: bool = True

	def __bool__(self) -> bool:
		return True  # if False, item is ignored by the view

	def __len__(self) -> int:
		return 0

	def __iter__(self):
		yield self

	@property
	def geom_t(self) -> pvGeometryType:
		return pvGeometryType.INVALID

	@property
	def is_enabled(self) -> bool:
		"""Check global state (recurse over parents).

		Returns:
		        bool: Whether an item and all its parents are enabled.
		"""
		current = self
		while current.state:
			if current.parent is None:
				return True
			current = current.parent
		return False

	def to_xml(self, tag: str = "Item") -> Element:
		return Element(tag, name=str(self.name), state=str(self.state))

	@classmethod
	def from_xml(cls, el: Element) -> pvItem:
		return cls(**el.attrib)

	@abstractmethod
	def export(self) -> Element: ...


@dataclass
class pvGroup(pvItem):
	name: str = "New Group"
	children: list = field(default_factory=list)

	def __str__(self) -> str:
		"""instance string conversion"""
		return f"{self.name}"

	def __iter__(self):
		yield from self.children

	def __len__(self):
		return len(self.children)

	@property
	def geom_t(self):
		for el in self:
			if el.geom_t > 0:
				return pvGeometryType.MULTI
		return pvGeometryType.INVALID

	def append(self, child):
		child.parent = self
		self.children.append(child)

	def insert(self, index, child):
		child.parent = self
		self.children.insert(index, child)

	def pop(self, *args) -> pvItem:
		return self.children.pop(*args)

	def remove(self, child) -> pvItem:
		idx = next((i for i, x in enumerate(self) if x == child))
		return self.pop(idx)

	def to_xml(self, **kwargs) -> Element:  # noqa: ARG002
		return super().to_xml(tag="Group")

	def export(self, destination: Path, index: int = 0, prefix: str = "") -> Element:
		name = unidecode(self.name)
		prefix = "/".join((unidecode(prefix), name))
		root = Element("Block", name=name, index=str(index))
		for i, child in enumerate(self.children):
			root.append(child.export(destination, index=i, prefix=prefix))
		return root

	def to_multiblock(
		self,
	) -> pv.MultiBlock:
		block = pv.MultiBlock()
		for child in self.children:
			if isinstance(child, pvGroup):
				block.append(child.to_multiblock, name=child.name)
			elif isinstance(child, pvLayer):
				block.append(child.geometry, name=child.name)
		return block


@dataclass
class pvLayer(pvItem):
	name: str = "New Layer"
	geometry: Any = None
	context: pvLayerContext = field(default_factory=pvLayerContext)
	symbology: pvLayerSymbology = field(default_factory=pvLayerSymbology)

	_uid: str = None  # WARNING: should be initialized only on deserialization !

	def __post_init__(self):
		self._uid = (
			self._uid or f"{unidecode(self.name).replace(' ', '_')}_{uuid.uuid4()}"
		)
		if self.source and not self.geometry:
			self.geometry = pv.read(self.source)

	def __str__(self) -> str:
		return str(self.uid)

	def __repr__(self) -> str:
		return (
			f"<{self.__class__.__name__}:{self.geom_t.name} '{self.name}' ({self.uid})>"
		)

	def __eq__(self, other):
		"""allow comparison with item content"""
		return self.uid == getattr(other, "uid", None)

	def __setattr__(self, attr: str, value: Any):
		if attr == "geometry":
			self.__dict__.pop("geom_t", None)
		super().__setattr__(attr, value)

	@property
	def is_valid(self) -> bool:
		return (self.geom_t != pvGeometryType.INVALID) and (self.geometry.n_points > 0)

	@property
	def is_temp(self):
		return bool(self.source is None or not Path(self.source).is_file())

	@property
	def source(self):
		return self.context.source

	@property
	def uid(self):
		return self._uid

	@cached_property
	def geom_t(self):
		return pvGeometryType.of(self.geometry)

	@staticmethod
	def from_file(file: str, **kwargs) -> pvLayer:
		# read file
		path = Path(file)
		mesh = pv.read(path)
		# build layer
		kwargs.setdefault("name", path.stem)
		kwargs["geometry"] = mesh
		layer = pvLayer(**kwargs)
		# overwrite context
		layer.context.source = path.as_posix()
		for f in mesh.field_data:
			match f.lower():
				case ("crs", "srs", "epsg", "coordinatereferencesystem"):
					crs = str(mesh.field_data[f])
					layer.context.crs = crs.strip("[]").strip("")
				case "color":
					color = " ".join([str(e) for e in mesh.field_data[f]])
					layer.symbology.color = color
				case "edg_color":
					color = " ".join([str(e) for e in mesh.field_data[f]])
					layer.symbology.edge_color = color.strip("[]").strip("")
		# return layer
		return layer

	@staticmethod
	def from_memory(obj: Any, **kwargs):
		# wrap geometry
		geom = pv.wrap(obj)
		# build layer
		kwargs["geometry"] = geom
		layer = pvLayer(**kwargs)
		# overwrite context
		layer.context.source = None
		for f in geom.field_data:
			try:
				if f.lower().replace("*", ":") in (
					"crs",
					"crs:name",
					"srs",
					"epsg",
					"coordinatereferencesystem",
				):
					layer.context.crs = ":".join(geom.field_data[f])
				elif f.lower().replace("*", ":") in (
					"color",
					"solid:color",
					"label:color",
					"path:color",
				):
					layer.symbology.color = "".join(geom.field_data[f])
				elif f.lower().replace("*", ":") in (
					"edge_color",
					"line:color",
					"path:color",
				):
					layer.symbology.edge_color = "".join(geom.field_data[f])
			except Exception:
				continue
		# return layer
		return layer

	def to_xml(self, uid_only: bool = False, include_data: bool = False) -> Element:
		if uid_only:
			return Element("Layer", uid=self.uid)
		xml = super().to_xml("Layer")
		xml.attrib["uid"] = str(self.uid)
		xml.append(self.context.to_xml())
		xml.append(self.symbology.to_xml(include_data))
		if self.geometry and include_data:
			geom = SubElement(xml, "Geometry")
			geom.text = codecs.encode(pickle.dumps(self.geometry), "base64").decode()
			xml.find("Context").set("source", "memory:")

		return xml

	@classmethod
	def from_xml(cls, el: Element) -> pvLayer:
		kwargs = dict(el.attrib)
		kwargs["_uid"] = kwargs.pop("uid")
		kwargs["context"] = pvLayerContext.from_xml(el.find("Context"))
		kwargs["symbology"] = pvLayerSymbology.from_xml(el.find("Symbology"))
		geometry = el.find("Geometry")
		if geometry is not None:
			kwargs["geometry"] = pickle.loads(
				codecs.decode(geometry.text.encode(), "base64")
			)
		layer = cls(**kwargs)
		if not layer.geometry and layer.source:
			layer.geometry = cls.from_file(layer.source).geometry
		return layer

	def save(self, name: str | None = None, prefix: str | None = None) -> bool:
		if not self.is_valid:
			# TODO: Warning box ?!
			return False

		# file picker
		if not (name or prefix):
			fmts = " ".join([f"*{_}" for _ in self.geometry._WRITERS])
			file = QFileDialog.getSaveFileName(
				None,
				"Save {self.name} as ...",
				dir=str(Path(prefix or "", name or self.name)),
				filter=(f"Supported formats ({fmts})"),
			)[0]
			if not file:
				return False
			path = Path(file)
		else:
			path = Path(prefix, name or self.uid) if prefix else Path(name)

		# append default ext if none provided
		if not path.suffix:
			path = path.with_suffix(vtkIOXMLFileExtension(self.geometry))
		path.parent.mkdir(parents=True, exist_ok=True)
		self.geometry.save(path)
		self.context.source = path
		return True

	def export(self, destination: Path, index: int = 0, prefix: str = "") -> Element:
		if not self.is_valid:
			return None

		name = unidecode(self.name)
		file = "/".join((prefix, self.uid))
		path = Path(destination, file).with_suffix(vtkIOXMLFileExtension(self.geometry))
		path.parent.mkdir(parents=True, exist_ok=True)
		self.geometry.save(path)
		return Element("DataSet", index=str(index), name=name, file=file)
