from math import ceil
from pathlib import Path

import pyvista as pv
from qgis.core import (
    Qgis,
    QgsProcessingAlgorithm,
    QgsProcessingParameterFile,
    QgsProcessingParameterNumber,
    QgsProcessingParameterRasterDestination,
)
from qgis.PyQt.QtCore import QCoreApplication

from .utils import array2geotiff, dataset_from_upper


class RasterizeTopView(QgsProcessingAlgorithm):
    """Create a top-view raster of a mesh."""

    # Inputs parameters
    INPUT = "INPUT"
    # Raster resolution
    RESOLUTION = "RESOLUTION"
    # Output parameters
    OUTPUT = "OUTPUT"

    def tr(self, string):
        """
        Returns a translatable string with the self.tr() function.
        """
        return QCoreApplication.translate("Processing", string)

    def createInstance(self):
        return self.__class__()

    @classmethod
    def name(cls):
        """
        Returns the algorithm name, used for identifying the algorithm. This
        string should be fixed for the algorithm, and must not be localised.
        The name should be unique within each provider. Names should contain
        lowercase alphanumeric characters only and no spaces or other
        formatting characters.
        """
        return Path(__file__).stem

    def displayName(self):
        """
        Returns the translated algorithm name, which should be used for any
        user-visible display of the algorithm name.
        """
        return self.tr("Rasterize 3D mesh from top")

    def shortHelpString(self):
        """
        Returns a localised short helper string for the algorithm. This string
        should provide a basic description about what the algorithm does and the
        parameters and outputs associated with it..
        """
        return self.tr(self.__doc__)

    def initAlgorithm(self, config=None):  # noqa: ARG002
        """
        Here we define the inputs and output of the algorithm, along
        with some other properties.
        """

        self.addParameter(
            QgsProcessingParameterFile(
                self.INPUT,
                self.tr("Input mesh file"),
                fileFilter="VTK mesh (*.vtk *.vtp *vtu *vtm);; All files (*.*)",
            )
        )
        self.addParameter(
            QgsProcessingParameterNumber(
                self.RESOLUTION,
                self.tr("Raster resolution"),
                Qgis.ProcessingNumberParameterType.Double,
                optional=False,
                defaultValue=None,
            )
        )
        self.addParameter(
            QgsProcessingParameterRasterDestination(
                name=self.OUTPUT,
                description=self.tr("Output Raster"),
                optional=True,
                createByDefault=True,
            )
        )

    def processAlgorithm(self, parameters, context, feedback):
        """
        Here is where the processing itself takes place.
        """

        source = self.parameterAsString(parameters, self.INPUT, context)
        resolution = self.parameterAsDouble(parameters, self.RESOLUTION, context)
        output = self.parameterAsOutputLayer(parameters, self.OUTPUT, context)

        # wrap mesh using pyvista
        if Path(source).is_file():
            dataset = pv.read(source)
        else:
            try:
                dataset = pv.wrap(source)
            except Exception as err:
                feedback.reportError(err, fatalError=True)

        feedback.pushInfo(f"mesh type = {dataset.__class__.__name__}")

        if isinstance(dataset, pv.MultiBlock):
            dataset = dataset.combine()

        xmin, xmax, ymin, ymax, zmin, zmax = dataset.bounds
        nx = max(2, ceil((xmax - xmin) / resolution))
        ny = max(2, ceil((ymax - ymin) / resolution))
        dx, dy = (xmax - xmin) / nx, (ymax - ymin) / ny
        array = dataset_from_upper(dataset, (nx, ny))[..., -1].reshape(nx, ny)

        crs = ""
        for f in dataset.field_data:
            if f.lower().replace("*", ":") in (
                "crs",
                "crs:name",
                "srs",
                "epsg",
                "coordinatereferencesystem",
            ):
                crs = ":".join(dataset.field_data[f])
                break

        array2geotiff(
            filename=output,
            array=array,
            origin=(xmin, ymin),
            spacing=(dx, dy),
            crs=crs,
            indexing="ij",
        )
        return {self.OUTPUT: output}
