from pathlib import Path

import numpy as np
from forgeo.rigs import all_intersections
from forgeo.rigs.tetcube import tetgrid
from qgis import utils as qutils
from qgis.core import (
    QgsCoordinateTransform,
    QgsProcessingException,
    QgsProcessingParameterExtent,
    QgsProcessingParameterNumber,
)

from ._utils import colors_as_strings
from .gmprocessing import GmProcessingAlgorithm


class InterfacesToPvgis(
    GmProcessingAlgorithm,
    name=Path(__file__).stem,
    display_name="send interfaces to PVGIS",
):
    EXTENT = "EXTENT"
    ZMIN = "ZMIN"
    ZMAX = "ZMAX"
    NX = "NX"
    NY = "NY"
    NZ = "NZ"

    def initAlgorithm(self, config=None):  # noqa: ARG002
        self.add_geomodel_parameter()
        self.addParameter(
            QgsProcessingParameterExtent(self.EXTENT, self.tr("Extent"), optional=False)
        )

        def add_double_parameter(name, description, default):
            self.addParameter(
                QgsProcessingParameterNumber(
                    name,
                    self.tr(description),
                    defaultValue=default,
                    type=QgsProcessingParameterNumber.Double,
                )
            )

        add_double_parameter(self.ZMIN, "bottom of the box (zmin)", -1000.0)
        add_double_parameter(self.ZMAX, "top of the box (zmax)", 1000.0)

        def add_int_parameter(name, description):
            self.addParameter(
                QgsProcessingParameterNumber(
                    name,
                    self.tr(description),
                    defaultValue=10,
                    type=QgsProcessingParameterNumber.Integer,
                )
            )

        add_int_parameter(self.NX, "number of cells along (Ox)")
        add_int_parameter(self.NY, "number of cells along (Oy)")
        add_int_parameter(self.NZ, "number of cells along (Oz)")

    def processAlgorithm(self, parameters, context, feedback):
        pvgis = qutils.plugins.get("pvGIS", None)
        if pvgis is None:
            msg = "pvGIS plugin cannot be found: is it loaded?"
            raise QgsProcessingException(msg)
        from pyvista import PolyData  # noqa: PLC0415 <comes with pvgis>

        feedback.setProgress(0)
        input_layer = self.model_layer(parameters)
        zmin, zmax = (
            self.parameterAsDouble(parameters, name, context)
            for name in (self.ZMIN, self.ZMAX)
        )
        nx, ny, nz = (
            self.parameterAsInt(parameters, name, context)
            for name in (self.NX, self.NY, self.NZ)
        )
        extent = self.parameterAsExtent(parameters, self.EXTENT, context)
        extent_crs = self.parameterAsExtentCrs(parameters, self.EXTENT, context)
        # FIXME: should we test crs is the same to skip transform?
        ct = QgsCoordinateTransform(
            extent_crs, input_layer.crs(), context.transformContext()
        )
        extent = ct.transform(extent)
        vertices, tets = tetgrid(
            (nx, ny, nz),
            extent=(
                extent.xMinimum(),
                extent.xMaximum(),
                extent.yMinimum(),
                extent.yMaximum(),
                zmin,
                zmax,
            ),
        )
        params = self.extraction_parameters(
            parameters, with_topography=True, faults_only=False
        )
        results = all_intersections(vertices, tets, **params, return_iso_surfaces=True)
        iso = results.iso_surfaces

        v, f, part = iso.vertices, iso.faces, iso.color
        names = params["names"]
        colormap = colors_as_strings(params["colors"])
        topoid = params.get("topography")
        is_fault = params["is_fault"]
        uids = np.unique(part)
        viewer = pvgis.viewer
        faults = []
        formations = []
        topography = None
        for uid in uids:
            faces = [fi for i, fi in enumerate(f) if part[i] == uid]
            mesh = PolyData.from_irregular_faces(v, faces)
            if is_fault(uid):
                faults.append((mesh, names[uid], colormap[uid]))
            elif uid == topoid:
                topography = (mesh, names[uid], colormap[uid])
            else:
                formations.append((mesh, names[uid], colormap[uid]))
        model_group = viewer.model.add_group(input_layer.name())

        def load_item(mesh, name, color, parent):
            index = viewer.loader.load_mesh(mesh, name=name, parent=parent)
            layer = viewer.model.item(index)
            layer.symbology.color = color

        if topography:
            load_item(*topography, model_group)
        if faults:
            faults_group = viewer.model.add_group("Faults", parent=model_group)
            for item in faults:
                load_item(*item, faults_group)
        if formations:
            formations_group = viewer.model.add_group("Formations", parent=model_group)
            for item in formations:
                load_item(*item, formations_group)
        feedback.setProgress(100)
        return {}
