from math import isnan
from pathlib import Path

from pyvista import MultiBlock, PolyData, UnstructuredGrid
from qgis.core import (
    QgsCategorizedSymbolRenderer,
    QgsCoordinateTransform,
    QgsPointXY,
    QgsProcessing,
    QgsProcessingAlgorithm,
    QgsProcessingParameterFeatureSource,
    QgsProcessingParameterFileDestination,
    QgsProject,
    QgsSingleSymbolRenderer,
    QgsWkbTypes,
)
from qgis.PyQt.QtCore import QCoreApplication

from .utils import load_3d_file, reencode


class MeshifyVector(QgsProcessingAlgorithm):
    """Convert a vector layer to mesh file.

    This filter convert any vector geometries into a 3D mesh.
    """

    # Inputs parameters
    INPUT = "INPUT"
    # Output parameters
    OUTPUT = "OUTPUT"

    def tr(self, string):
        """
        Returns a translatable string with the self.tr() function.
        """
        return QCoreApplication.translate("Processing", string)

    def createInstance(self):
        return self.__class__()

    @classmethod
    def name(cls):
        """
        Returns the algorithm name, used for identifying the algorithm. This
        string should be fixed for the algorithm, and must not be localised.
        The name should be unique within each provider. Names should contain
        lowercase alphanumeric characters only and no spaces or other
        formatting characters.
        """
        return Path(__file__).stem

    def displayName(self):
        """
        Returns the translated algorithm name, which should be used for any
        user-visible display of the algorithm name.
        """
        return self.tr("Export vector as 3D mesh")

    def shortHelpString(self):
        """
        Returns a localised short helper string for the algorithm. This string
        should provide a basic description about what the algorithm does and the
        parameters and outputs associated with it..
        """
        return self.tr(self.__doc__)

    def initAlgorithm(self, config=None):  # noqa: ARG002
        """
        Here we define the inputs and output of the algorithm, along
        with some other properties.
        """

        self.addParameter(
            QgsProcessingParameterFeatureSource(
                self.INPUT,
                self.tr("Input vector layer"),
                [QgsProcessing.TypeVectorAnyGeometry],
                optional=False,
            )
        )

        self.addParameter(
            QgsProcessingParameterFileDestination(
                self.OUTPUT,
                self.tr("Output mesh file"),
                fileFilter="VTK polydata (*.vtp);; VTK unstructured grid (*.vtu)",
                defaultValue=None,
                optional=True,
                createByDefault=True,
            )
        )

    def processAlgorithm(self, parameters, context, feedback):
        """
        Here is where the processing itself takes place.
        """

        input_source = self.parameterAsSource(parameters, self.INPUT, context)
        input_layer = self.parameterAsLayer(parameters, self.INPUT, context)
        output_file = self.parameterAsFileOutput(parameters, self.OUTPUT, context)

        file = Path(output_file)
        if file.suffix not in (".vtp", ".vtu", ".vtk", ".vtm"):
            feedback.reportError(f"Wrong file format: {file.suffix}", fatalError=True)

        geom_type = input_source.wkbType()

        if not QgsWkbTypes.hasZ(geom_type):
            feedback.pushInfo("Inferring elevation from `elevationProperties()`")
            heightAt = (
                QgsProject.instance().elevationProperties().terrainProvider().heightAt
            )
            reproject = QgsCoordinateTransform(
                input_source.sourceCrs(),
                context.project().elevationProperties().terrainProvider().crs(),
                context.project(),
            ).transform
            z_offset = float(input_layer.elevationProperties().zOffset())
            z_scale = float(input_layer.elevationProperties().zScale())

            def elevation(pt):
                coords = reproject(QgsPointXY(pt))
                z = heightAt(coords.x(), coords.y()) * z_scale + z_offset
                if isnan(z):
                    return z_offset
                return z

        else:
            feedback.pushInfo("Inferring elevation from vertices $z")

            def elevation(pt):
                return pt.z()

        geom_dim = QgsWkbTypes.wkbDimensions(geom_type)

        off, points, cells, cell_data = 0, [], [], []
        total = input_source.featureCount()
        for i, feature in enumerate(input_source.getFeatures()):
            if feedback.isCanceled():
                feedback.reportError("Algorithm was canceld !", fatalError=True)
            else:
                feedback.setProgress(i / total)

            if not feature.hasGeometry:
                continue
            geometry = feature.geometry()
            if geometry.isEmpty():
                continue

            for part in geometry.constParts():
                for _n, pt in enumerate(part.vertices()):
                    coords = (pt.x(), pt.y(), elevation(pt))
                    points.append(coords)
                _n += 1
                cells.extend([_n, *range(off, off + _n)])
                off += _n
                cell_data.append(feature.attributes())

        if not points:
            feedback.reportError("No geometries found !", fatalError=True)

        # build PolyData mesh
        cells = {
            "lines": cells if geom_dim == 1 else None,
            "faces": cells if geom_dim == 2 else None,
        }
        mesh = PolyData(points, **cells)
        # add data
        mesh.add_field_data(
            input_source.sourceCrs().authid().split(":"), "CoordinateReferenceSystem"
        )
        for name, data in zip(
            input_source.fields().names(), zip(*cell_data, strict=False), strict=False
        ):
            mesh[reencode(name)] = [
                reencode(d) if isinstance(d, str) else d for d in data
            ]

        renderer = input_layer.renderer()
        if isinstance(renderer, QgsSingleSymbolRenderer):
            color = renderer.symbol().color().name()
            mesh.add_field_data(color, "color")
        elif isinstance(renderer, QgsCategorizedSymbolRenderer):
            # TODO: active scalars ? custom colormap ? point based
            ...

        match file.suffix.lower():
            case ".vtu":
                UnstructuredGrid(mesh).save(output_file)
            case ".vtm":
                MultiBlock(mesh).save(output_file)
            case _:
                mesh.save(output_file)

        feedback.pushInfo(f"OUTPUT: {output_file}")

        # symbology = pvLayerSymbology(
        # color =

        # color: str = "white"
        # edge_color: str = "white"
        # opacity: str = 1.0
        # show_edges: str = False
        # render_points_as_spheres: str = False
        # point_size: str = 5.0
        # render_lines_as_tubes: str = False
        # line_width: str = 1.0

        try:
            load_3d_file(output_file, name=input_layer.name())
            feedback.pushInfo(f"{output_file} was loaded into 3D viewer")
        except Exception as err:
            feedback.reportError(
                f"Could NOT load {output_file} into 3D viewer:\n {err}",
                fatalError=False,
            )

        return {self.OUTPUT: output_file}
