from math import isnan
from pathlib import Path

from pyvista import DataSet, PolyData, UnstructuredGrid
from qgis.core import (
    QgsCoordinateTransform,
    QgsPointXY,
    QgsProcessing,
    QgsProcessingAlgorithm,
    QgsProcessingParameterBoolean,
    QgsProcessingParameterFeatureSource,
    QgsProcessingParameterFileDestination,
    QgsProject,
    QgsRenderContext,
    QgsSingleSymbolRenderer,
    QgsWkbTypes,
)
from qgis.PyQt.QtCore import QCoreApplication

from .utils import load_3d_file, reencode


class MeshifyVector(QgsProcessingAlgorithm):
    """Convert a vector layer to mesh file.

    This filter convert any vector geometries into a 3D mesh.
    Grouping defaults on renderer symbols (legend), but can be deactivated (raw polygon soup).
    """

    # Inputs parameters
    INPUT = "INPUT"
    # Output parameters
    # GROUPBY = "GROUPBY"
    SPLIT = "SPLIT"
    OUTPUT = "OUTPUT"

    def tr(self, string):
        """
        Returns a translatable string with the self.tr() function.
        """
        return QCoreApplication.translate("Processing", string)

    def createInstance(self):
        return self.__class__()

    @classmethod
    def name(cls):
        """
        Returns the algorithm name, used for identifying the algorithm. This
        string should be fixed for the algorithm, and must not be localised.
        The name should be unique within each provider. Names should contain
        lowercase alphanumeric characters only and no spaces or other
        formatting characters.
        """
        return Path(__file__).stem

    def displayName(self):
        """
        Returns the translated algorithm name, which should be used for any
        user-visible display of the algorithm name.
        """
        return self.tr("Export vector as 3D mesh")

    def shortHelpString(self):
        """
        Returns a localised short helper string for the algorithm. This string
        should provide a basic description about what the algorithm does and the
        parameters and outputs associated with it..
        """
        return self.tr(self.__doc__)

    def initAlgorithm(self, config=None):  # noqa: ARG002
        """
        Here we define the inputs and output of the algorithm, along
        with some other properties.
        """

        self.addParameter(
            QgsProcessingParameterFeatureSource(
                self.INPUT,
                self.tr("Input vector layer"),
                [QgsProcessing.TypeVectorAnyGeometry],
                optional=False,
            )
        )

        # self.addParameter(
        #     QgsProcessingParameterExpression(
        #         self.GROUPBY,
        #         self.tr("Group features by attributes"),
        #         parentLayerParameterName=self.INPUT,
        #         defaultValue=None,
        #         optional=True,
        #     )
        # )

        self.addParameter(
            QgsProcessingParameterBoolean(
                self.SPLIT,
                self.tr("Group features by legend items"),
                defaultValue=True,
                optional=True,
            )
        )
        self.addParameter(
            QgsProcessingParameterFileDestination(
                self.OUTPUT,
                self.tr("Output mesh file"),
                fileFilter="VTK mesh (*vtk *vtm *.vtu *.vtp);; All files (*.*)",
                defaultValue=None,
                optional=True,
                createByDefault=True,
            )
        )

    def processAlgorithm(self, parameters, context, feedback):
        """
        Here is where the processing itself takes place.
        """

        input_source = self.parameterAsSource(parameters, self.INPUT, context)
        input_layer = self.parameterAsLayer(parameters, self.INPUT, context)
        split_by_legend = self.parameterAsBoolean(parameters, self.SPLIT, context)
        output_file = self.parameterAsFileOutput(parameters, self.OUTPUT, context)

        # infer z-coordinate from input layer topology
        geom_type = input_source.wkbType()
        if not QgsWkbTypes.hasZ(geom_type):
            feedback.pushInfo("Inferring elevation from `elevationProperties()`")
            heightAt = (
                QgsProject.instance().elevationProperties().terrainProvider().heightAt
            )
            reproject = QgsCoordinateTransform(
                input_source.sourceCrs(),
                context.project().elevationProperties().terrainProvider().crs(),
                context.project(),
            ).transform
            z_offset = float(input_layer.elevationProperties().zOffset())
            z_scale = float(input_layer.elevationProperties().zScale())

            def elevation(pt):
                coords = reproject(QgsPointXY(pt))
                z = heightAt(coords.x(), coords.y()) * z_scale + z_offset
                if isnan(z):
                    return z_offset
                return z
        else:
            feedback.pushInfo("Inferring elevation from vertices $z")

            def elevation(pt):
                return pt.z()

        # clone renderer & enable all legend items
        renderer = input_layer.renderer().clone()
        for it in renderer.legendSymbolItems():  # prevent `None` for unchecked items
            renderer.checkLegendSymbolItem(it.ruleKey())
        # create a rendering context for symbology sampling
        render_context = QgsRenderContext()
        renderer.startRender(render_context, input_layer.fields())
        legend_items = {it.ruleKey(): it for it in renderer.legendSymbolItems()}

        # loop over feature to build mesh cells
        feedback.pushInfo("Building mesh cells from features ...")
        off, points, cells, cell_data, cell_legend = 0, [], [], [], []
        total = input_source.featureCount()
        for i, feature in enumerate(input_source.getFeatures()):
            if feedback.isCanceled():
                feedback.reportError("Algorithm was canceld !", fatalError=True)
            else:
                feedback.setProgress(i / total)

            if not feature.hasGeometry:
                continue
            geometry = feature.geometry()
            if geometry.isEmpty():
                continue

            keys = renderer.legendKeysForFeature(feature, render_context)
            key = keys.pop() if keys else None

            for part in geometry.constParts():
                for _n, pt in enumerate(part.vertices()):
                    coords = (pt.x(), pt.y(), elevation(pt))
                    points.append(coords)
                _n += 1
                cells.extend([_n, *range(off, off + _n)])
                off += _n
                cell_data.append(feature.attributes())
                cell_legend.append(key)
        renderer.stopRender(render_context)
        if not points:
            feedback.reportError("No geometries found !", fatalError=True)

        # build mesh (PolyData or UnstructuredGrid)
        geom_dim = QgsWkbTypes.wkbDimensions(geom_type)
        cells = {
            "lines": cells if geom_dim == 1 else None,
            "faces": cells if geom_dim == 2 else None,
        }
        mesh = UnstructuredGrid(PolyData(points, **cells))
        # keep track of CRS as field_data
        mesh.add_field_data(
            input_source.sourceCrs().authid().split(":"), "CoordinateReferenceSystem"
        )
        # wrap features fields as cell_data
        for name, data in zip(
            input_source.fields().names(), zip(*cell_data, strict=False), strict=False
        ):
            mesh[reencode(name)] = [
                reencode(d) if isinstance(d, str) else d for d in data
            ]
        # group-by sub-meshes
        if split_by_legend:
            legend_ids = {k: i for i, k in enumerate(set(cell_legend))}
            mesh["lid"] = [legend_ids[k] for k in cell_legend]
            mesh = mesh.split_values(
                scalars="lid",
                preference="cell",
                pass_point_ids=False,
                pass_cell_ids=False,
            )
            for i, block in enumerate(mesh):  # rename & color
                item = legend_items[list(legend_ids.keys())[i]]
                label = reencode(item.label())
                mesh.set_block_name(i, label)
                block.add_field_data(label, "label")
                block.add_field_data(item.symbol().color().name(), "color")
            mesh.clean(empty=True)
            if len(mesh) > 1:  # force output file to multiblock
                output_file = Path(output_file).with_suffix(".vtm").as_posix()
            else:  # handle single lengend entries
                mesh = mesh[0]
        elif isinstance(renderer, QgsSingleSymbolRenderer):
            color = renderer.symbol().color().name()
            mesh.add_field_data(color, "color")

        # save results
        if Path(output_file).stem == self.OUTPUT:  # default name to layer name
            output_file = Path(output_file).with_stem(input_layer.name()).as_posix()
        if Path(output_file).suffix.lower() == ".file":
            output_file = Path(output_file).with_suffix(".vtu").as_posix()
        if (
            isinstance(mesh, DataSet) and Path(output_file).suffix.lower() == ".vtp"
        ):  # handle polydata
            mesh = mesh.extract_surface()
        mesh.save(output_file)
        feedback.pushInfo(f"OUTPUT: {output_file}")

        try:  # load results
            load_3d_file(output_file, name=input_layer.name())
            feedback.pushInfo(f"{output_file} was loaded into 3D viewer")
        except Exception as err:
            feedback.reportError(
                f"Could NOT load {output_file} into 3D viewer:\n {err}",
                fatalError=False,
            )

        return {self.OUTPUT: output_file}
