from __future__ import annotations

import numpy as np
import pyvista as pv


def dataset_from_upper(
	dataset: pv.DataSet, pts: int | tuple[int, int] | np.ndarray = None
):
	if not isinstance(dataset, pv.PolyData):
		dataset = pv.UnstructuredGrid(dataset).extract_surface(False, False, 1)
	if not dataset.is_all_triangles:
		dataset = dataset.triangulate()
	xmin, xmax, ymin, ymax, zmin, zmax = dataset.bounds
	pts = np.unique(dataset.points[..., :2], axis=0) if pts is None else np.asarray(pts)
	if pts.ndim < 2:
		assert 0 < pts.size < 3
		assert pts.dtype == int
		assert np.all(pts >= 2)
		nx, ny = (int(x) for x in pts) if pts.size == 2 else (int(pts), int(pts))
		pts = np.c_[
			*[
				x.flat
				for x in np.meshgrid(
					np.linspace(xmin, xmax, nx, endpoint=True),
					np.linspace(ymin, ymax, ny, endpoint=True),
					indexing="ij",
				)
			]
		]
	else:
		pts = np.asarray(pts, dtype=float)[..., :2]
	pts = np.column_stack((pts, [zmax + 1] * len(pts)))
	points, rays, cells = dataset.multi_ray_trace(
		pts, [(0, 0, -1)] * len(pts), first_point=True, retry=True
	)
	pts[..., -1] = [np.nan] * len(pts)
	pts[rays] = points
	return pts
