from pathlib import Path

import numpy as np
import pyvista as pv
from qgis.core import (
	QgsProcessingAlgorithm,
	QgsProcessingParameterBand,
	QgsProcessingParameterEnum,
	QgsProcessingParameterFileDestination,
	QgsProcessingParameterRasterLayer,
)
from qgis.PyQt.QtCore import QCoreApplication

from .utils import load_3d_file, ravel_cells, reencode


class MeshifyRaster(QgsProcessingAlgorithm):
	"""This filter convert any raster layer into a 3D mesh (pyvista supported file format, see VTK and/or meshio).
	It provides options for the cells geometry (points only, triangles, quads).

	Args:
	    INPUT (QgsRasterLayer): Raster layer.
	    ZBAND (int): INPUT band to use as elevation/z value.
	    GEOMETRY (str): OUTPUT cells type. Must be one of {'points', 'triangles', 'quads'}.

	Retruns:
	    OUTPUT (Path): Mesh file path.

	"""

	# Inputs parameters
	INPUT = "INPUT"
	ZBAND = "ZBAND"
	# Filter parameters
	GEOMETRY = "GEOMETRY"
	# Output parameters
	OUTPUT = "OUTPUT"

	def tr(self, string):
		"""
		Returns a translatable string with the self.tr() function.
		"""
		return QCoreApplication.translate("Processing", string)

	def createInstance(self):
		return self.__class__()

	@classmethod
	def name(cls):
		"""
		Returns the algorithm name, used for identifying the algorithm. This
		string should be fixed for the algorithm, and must not be localised.
		The name should be unique within each provider. Names should contain
		lowercase alphanumeric characters only and no spaces or other
		formatting characters.
		"""
		return Path(__file__).stem

	def displayName(self):
		"""
		Returns the translated algorithm name, which should be used for any
		user-visible display of the algorithm name.
		"""
		return self.tr("Export raster as 3D mesh")

	def shortHelpString(self):
		"""
		Returns a localised short helper string for the algorithm. This string
		should provide a basic description about what the algorithm does and the
		parameters and outputs associated with it..
		"""
		return self.tr(self.__doc__)

	def initAlgorithm(self, config=None):  # noqa: ARG002
		"""
		Here we define the inputs and output of the algorithm, along
		with some other properties.
		"""

		self.addParameter(
			QgsProcessingParameterRasterLayer(
				self.INPUT,
				self.tr("Input mesh file"),
			)
		)

		self.addParameter(
			QgsProcessingParameterBand(
				self.ZBAND,
				self.tr("Elevation band (z)"),
				parentLayerParameterName=self.INPUT,
				defaultValue=1,
				optional=True,
			)
		)

		self.addParameter(
			QgsProcessingParameterEnum(
				self.GEOMETRY,
				self.tr("Mesh cells type (only apply for *.vtu *.vtp)"),
				[
					"quads",
					"triangles",
					"points",
				],
				allowMultiple=False,
				defaultValue="quads",
				optional=False,
				usesStaticStrings=True,
			)
		)

		self.addParameter(
			QgsProcessingParameterFileDestination(
				self.OUTPUT,
				self.tr("Output mesh file"),
				fileFilter="VTK polydata (*.vtp);; VTK unstructured grid (*.vtu);; VTK structured grid (*.vts)",
				defaultValue=None,
				optional=True,
				createByDefault=True,
			)
		)

	def processAlgorithm(self, parameters, context, feedback) -> dict:
		"""
		Here is where the processing itself takes place.
		"""
		raster = self.parameterAsRasterLayer(parameters, self.INPUT, context)
		band = self.parameterAsInt(parameters, self.ZBAND, context)
		geom_type = self.parameterAsString(parameters, self.GEOMETRY, context)
		output_file = self.parameterAsFileOutput(parameters, self.OUTPUT, context)

		feedback.pushInfo(f"OUTPUT: {output_file}")
		file = Path(output_file)
		if file.suffix not in (".vtp", ".vtu", ".vts", ".vtk", ".vtm"):
			feedback.reportError(f"Wrong file format: {file.suffix}", fatalError=True)
			msg = f"Unsupported file format: {file.suffix}"
			raise NotImplementedError(msg)

		data = []
		for i, array in enumerate(raster.as_numpy()):
			# filter nans
			arr = array.astype(float)
			arr[arr == raster.dataProvider().sourceNoDataValue(i + 1)] = np.nan
			data.append(arr)
		data = np.asarray(data)

		z = data[band - 1, ...]

		ny, nx = z.shape

		extent = raster.extent()
		xmin, xmax = extent.xMinimum(), extent.xMaximum()
		ymin, ymax = extent.yMinimum(), extent.yMaximum()
		x, y = np.meshgrid(np.linspace(xmin, xmax, nx), np.linspace(ymax, ymin, ny))

		if file.suffix == ".vts":
			mesh = pv.StructuredGrid(x, y, z)
		else:
			points = np.column_stack((x.flat, y.flat, z.flat))
			if geom_type == "quads":
				cells = np.array([[0, 1, nx + 1, nx]], dtype=int)  # single quad
				cells = np.vstack([cells + k for k in range(nx - 1)])  # strip of quads
				cells = np.vstack(
					[cells + nx * k for k in range(ny - 1)]
				)  # matrix of quads
			elif geom_type == "triangles":
				cells = np.array(
					[[0, 1, nx + 1], [0, nx + 1, nx]], dtype=int
				)  # dual triangles
				cells = np.vstack(
					[cells + k for k in range(nx - 1)]
				)  # strip of triangles
				cells = np.vstack(
					[cells + nx * k for k in range(ny - 1)]
				)  # matrix of triangles
			else:
				cells = [[i] for i in range(nx * ny)]
			if file.suffix == ".vtu":
				mesh = pv.UnstructuredGrid(points, ravel_cells(cells))
			elif file.suffix in (".vtp", ".vtk"):
				mesh = pv.PolyData(points, ravel_cells(cells))

		for i in range(raster.bandCount()):
			if i != band - 1:
				mesh[reencode(raster.bandName(i + 1))] = data[i, ...].flatten()

		mesh.add_field_data(
			raster.crs().authid().split(":"), "CoordinateReferenceSystem"
		)
		mesh.save(output_file)

		try:
			load_3d_file(output_file, name=raster.name())
			feedback.pushInfo(f"{output_file} was loaded into 3D viewer")
		except Exception as err:
			feedback.reportError(
				f"Could NOT load {output_file} into 3D viewer:\n {err}",
				fatalError=False,
			)

		return {self.OUTPUT: output_file}
