from __future__ import annotations

from pathlib import Path

import pyvista as pv
from qgis.core import (
	Qgis,
	QgsCoordinateReferenceSystem,
	QgsFeature,
	QgsFeatureSink,
	QgsField,
	QgsFields,
	QgsProcessingAlgorithm,
	QgsProcessingParameterEnum,
	QgsProcessingParameterFeatureSink,
	QgsProcessingParameterFile,
	QgsWkbTypes,
)
from qgis.PyQt.QtCore import QCoreApplication, QMetaType

from .utils import (
	coords2LineString,
	coords2MultiLineString,
	coords2MultiPoint,
	coords2MultiPolygon,
	coords2Points,
	coords2Polygon,
	coords2PolyhedralSurface,
	coords2TriangulatedSurface,
	mesh2coords,
)

points_to_geom = {
	"Point": coords2Points,
	"MultiPoint": coords2MultiPoint,
	"LineString": coords2LineString,
	"MultiLineString": coords2MultiLineString,
	"Polygon": coords2Polygon,
	"MultiPolygon": coords2MultiPolygon,
	"PolyhedralSurface": coords2PolyhedralSurface,
	"TriangulatedSurface": coords2TriangulatedSurface,
}


class VectorifyMesh(QgsProcessingAlgorithm):
	"""Load a mesh file as vector layer.

	This filter convert any VTK mesh into a vector layer,
	with an option on the geometry wanted.
	"""

	INPUT = "INPUT"
	TYPE = "TYPE"
	OUTPUT = "OUTPUT"

	def tr(self, string):
		"""
		Returns a translatable string with the self.tr() function.
		"""
		return QCoreApplication.translate("Processing", string)

	def createInstance(self):
		return self.__class__()

	@classmethod
	def name(cls):
		"""
		Returns the algorithm name, used for identifying the algorithm. This
		string should be fixed for the algorithm, and must not be localised.
		The name should be unique within each provider. Names should contain
		lowercase alphanumeric characters only and no spaces or other
		formatting characters.
		"""
		return Path(__file__).stem

	def displayName(self):
		"""
		Returns the translated algorithm name, which should be used for any
		user-visible display of the algorithm name.
		"""
		return self.tr("Import 3D mesh as vector")

	def shortHelpString(self):
		"""
		Returns a localised short helper string for the algorithm. This string
		should provide a basic description about what the algorithm does and the
		parameters and outputs associated with it..
		"""
		return self.tr(self.__doc__)

	def initAlgorithm(self, config=None):  # noqa: ARG002
		"""
		Here we define the inputs and output of the algorithm, along
		with some other properties.
		"""

		self.addParameter(
			QgsProcessingParameterFile(
				self.INPUT,
				self.tr("Input mesh file"),
				fileFilter="VTK mesh (*.vtk *.vtp *vtu *vtm);; All files (*.*)",
			)
		)

		self.addParameter(
			QgsProcessingParameterEnum(
				self.TYPE,
				self.tr("Output geometry type"),
				points_to_geom.keys(),
				allowMultiple=False,
				defaultValue=None,
				optional=True,
				usesStaticStrings=True,
			)
		)

		self.addParameter(
			QgsProcessingParameterFeatureSink(
				name=self.OUTPUT,
				description=self.tr("Output layer"),
				optional=True,
				createByDefault=True,
			)
		)

	def processAlgorithm(self, parameters, context, feedback):
		"""
		Here is where the processing itself takes place.
		"""

		source = self.parameterAsFile(parameters, self.INPUT, context)
		wkt_type = self.parameterAsString(parameters, self.TYPE, context)

		# wrap mesh using pyvista
		if Path(source).is_file():
			dataset = pv.read(source)
		else:
			try:
				dataset = pv.wrap(source)
			except Exception as err:
				feedback.reportError(err, fatalError=True)

		feedback.pushInfo(f"mesh type = {dataset.__class__.__name__}")

		# fuse multiblocks ... TODO: track block names ?
		if isinstance(dataset, pv.MultiBlock):
			dataset = dataset.combine()

		# make sure we have a common API
		dataset = pv.UnstructuredGrid(dataset)
		if not wkt_type:  # infer type from max-D cell type
			match cell_t := max(dataset.celltypes):
				case _ if cell_t >= 10:  # 3D cells > use Point !
					wkt_type = "Point"
				case _ if cell_t >= 5:  # 2D faces > use Polygon !
					wkt_type = "Polygon"
				case _ if cell_t >= 3:  # 2D lines > use LineString !
					wkt_type = "LineString"
				case _:
					wkt_type = "MultiPoint"

		feedback.pushInfo(f"{wkt_type = }")
		match wkt_type:
			case "Point":
				dataset = dataset.cell_data_to_point_data()
				geometries = points_to_geom[wkt_type](dataset.points)
				fields = QgsFields(
					[QgsField(n) for n in dataset.point_data.keys()]  # noqa: SIM118
				)
				attributes = list(zip(*dataset.point_data.values(), strict=False))
			case "MultiPoint":
				geometries = [points_to_geom[wkt_type](dataset.points)]
				fields = QgsFields(
					[QgsField(n) for n in dataset.field_data.keys()]  # noqa: SIM118
				)
				attributes = list(zip(*dataset.field_data.values(), strict=False))
			case "PolyhedralSurface":
				dataset = pv.UnstructuredGrid(dataset.extract_surface(False, False, 0))
				coords = mesh2coords(dataset.points, dataset.cells)
				geometries = [points_to_geom[wkt_type](coords)]
				fields = QgsFields(
					[QgsField(n) for n in dataset.field_data.keys()]  # noqa: SIM118
				)
				attributes = list(zip(*dataset.field_data.values(), strict=False))
			case "TriangulatedSurface":
				dataset = pv.UnstructuredGrid(dataset.extract_surface(False, False, 1))
				coords = mesh2coords(dataset.points, dataset.cells)
				geometries = [points_to_geom[wkt_type](coords)]
				fields = QgsFields(
					[QgsField(n) for n in dataset.field_data.keys()]  # noqa: SIM118
				)
				attributes = list(zip(*dataset.field_data.values(), strict=False))
			case _ if wkt_type.startswith("Multi"):
				dataset = pv.UnstructuredGrid(dataset.extract_surface(False, False, 0))
				coords = mesh2coords(dataset.points, dataset.cells)
				geometries = [points_to_geom[wkt_type](coords)]
				fields = QgsFields(
					[QgsField(n) for n in dataset.field_data.keys()]  # noqa: SIM118
				)
				attributes = list(zip(*dataset.field_data.values(), strict=False))
			case _:  # Polygon or LineString
				dataset = pv.UnstructuredGrid(
					dataset.point_data_to_cell_data().extract_surface(False, False, 0)
				)
				coords = mesh2coords(dataset.points, dataset.cells)
				geometries = [points_to_geom[wkt_type](points) for points in coords]
				fields = QgsFields(
					[QgsField(n) for n in dataset.cell_data.keys()]  # noqa: SIM118
				)
				attributes = list(zip(*dataset.cell_data.values(), strict=False))
		feedback.pushInfo(f"{fields.names() = }")

		# parse dataset arrays
		crs = QgsCoordinateReferenceSystem(
			next(
				(
					":".join(dataset.field_data[f])
					for f in dataset.field_data
					if f.lower() in ("crs", "srs", "epsg", "coordinatereferencesystem")
				),
				"",
			)
		)
		if crs:
			feedback.pushInfo(f"{crs.authid() = }")

		total = len(geometries)
		if total == len(attributes):
			feedback.pushInfo(f"{total} features wrapped !")
		else:
			feedback.pushWarning(
				f"features length [{total}] missmatch fields length [{len(attributes)}], discarding cells data"
			)
			fields = QgsFields([QgsField("fid", QMetaType.UInt)])
			attributes = range(len(attributes))

		# initialize feature sink
		(sink, sink_id) = self.parameterAsSink(
			parameters,
			self.OUTPUT,
			createOptions={"layerName": Path(source).stem},
			context=context,
			fields=fields,
			geometryType=QgsWkbTypes.parseType(wkt_type),
			crs=crs,
		)
		for i, (geom, attrs) in enumerate(zip(geometries, attributes, strict=False)):
			if feedback.isCanceled():
				feedback.reportError("Algorithm was canceld !", fatalError=True)
			else:
				feedback.setProgress(i / total)
			if not geom.isValid():
				match context.invalidGeometryCheck():
					case Qgis.InvalidGeometryCheck.SkipInvalid:
						continue
					case Qgis.InvalidGeometryCheck.AbortOnInvalid:
						feedback.reportError(
							f"feature #{i} is invalid !", fatalError=False
						)
					case Qgis.InvalidGeometryCheck.NoCheck:
						pass
			feature = QgsFeature(fields)
			feature.setGeometry(geom)
			if attrs:
				feature.setAttributes([str(el) for el in attrs])
			sink.addFeature(feature, QgsFeatureSink.FastInsert)

		return {self.OUTPUT: sink_id}
