from tempfile import SpooledTemporaryFile

import numpy as np
from forgeo.interpolation import BSPTreeBuilder, is_valid
from forgeo.rigs import all_intersections
from forgeo.rigs.tetcube import tetgrid
from qgis import utils as qutils
from qgis.core import Qgis, QgsBox3D, QgsPluginLayer, QgsProcessingException, QgsProject
from qgis.gui import QgsCheckableComboBox, QgsExtentWidget, QgsMapLayerComboBox
from qgis.PyQt.QtCore import Qt
from qgis.PyQt.QtGui import QColor, QDoubleValidator, QIntValidator, QStandardItem
from qgis.PyQt.QtWidgets import (
    QDialog,
    QDialogButtonBox,
    QGroupBox,
    QHBoxLayout,
    QLabel,
    QLineEdit,
    QMessageBox,
    QPushButton,
    QVBoxLayout,
)

from ..layers import FaultNetworkLayer, ModelLayer
from ..utils import get_forgeo_output_dir, raster_layer_to_description
from .utils import QgsPluginLayerComboBox


class SurfaceExtractionSelection(QGroupBox):
    def __init__(self, layer: ModelLayer | None = None, parent=None):
        super().__init__(parent)
        self.setTitle(self.tr("Surfaces to extract"))

        # Select model
        cbox_model = QgsPluginLayerComboBox(ModelLayer, defaultLayer=layer)
        self.cbox_model = cbox_model

        # Select surfaces to extract
        cbox_surfaces = QgsCheckableComboBox()

        def update_all_surfaces_status(check_state):
            for idx in range(cbox_surfaces.count()):
                cbox_surfaces.setItemCheckState(idx, check_state)

        def reset_cbox_surfaces():
            cbox_surfaces.clear()
            layer = cbox_model.currentLayer()
            if layer is None:
                return  # Should only happen if the project does not contain ModelLayers
            interfaces = [item.name for item in layer.model.dataset if item.is_surface]
            cbox_surfaces.addItems(interfaces)
            update_all_surfaces_status(Qt.CheckState.Checked)

        reset_cbox_surfaces()  # Call to force the first initialization
        cbox_model.layerChanged.connect(reset_cbox_surfaces)
        self.cbox_surfaces = cbox_surfaces

        btn_select_all = QPushButton(self.tr("Select all"))
        btn_select_all.clicked.connect(
            lambda: update_all_surfaces_status(Qt.CheckState.Checked)
        )
        btn_unselect_all = QPushButton(self.tr("Deselect all"))
        btn_unselect_all.clicked.connect(
            lambda: update_all_surfaces_status(Qt.CheckState.Unchecked)
        )
        # By default, we propose to export all surfaces
        update_all_surfaces_status(Qt.CheckState.Checked)

        layout = QVBoxLayout()
        layout.addWidget(cbox_model)
        layout.addWidget(cbox_surfaces)
        layout_btn_selection = QHBoxLayout()
        layout_btn_selection.addWidget(btn_select_all)
        layout_btn_selection.addWidget(btn_unselect_all)
        layout.addLayout(layout_btn_selection)
        self.setLayout(layout)

    @property
    def model(self):
        return self.cbox_model.currentLayer().model

    @property
    def surfaces(self):
        return self.cbox_surfaces.checkedItems()  # List of (item) names


class FaultsExtractionSelection(QGroupBox):
    def __init__(self, layer: FaultNetworkLayer | None = None, parent=None):
        super().__init__(parent)
        self.setTitle(self.tr("Faults to extract"))
        faultnet = layer.faultnet

        # Select model
        cbox_faultnet = QgsPluginLayerComboBox(FaultNetworkLayer, defaultLayer=layer)
        self.cbox_fault_network = cbox_faultnet

        # Select surfaces to extract
        # FIXME TODO cbox_faultnet.layerChanged.connect(...)
        cbox_surfaces = QgsCheckableComboBox()
        cbox_surfaces.checkedItemsChanged.connect(self.new_fault_checked)
        for fault in self.faultnet.dataset:
            item = QStandardItem(fault.name)
            if self.faultnet.is_active(fault.name):
                item.setEnabled(True)
                checkstate = Qt.CheckState.Checked
            else:
                item.setEnabled(False)
                checkstate = Qt.CheckState.Unchecked
            item.setCheckState(checkstate)
            cbox_surfaces.model().appendRow(item)
        self.checked_faultnames = cbox_surfaces.checkedItems()
        self.cbox_surfaces = cbox_surfaces

        def update_all_surfaces_status(check_state):
            for i in range(faultnet.nb_faults):
                if faultnet.active_faults[i]:
                    cbox_surfaces.setItemCheckState(i, check_state)

        btn_select_all = QPushButton(self.tr("Select all"))
        btn_select_all.clicked.connect(
            lambda: update_all_surfaces_status(Qt.CheckState.Checked)
        )
        btn_unselect_all = QPushButton(self.tr("Deselect all"))
        btn_unselect_all.clicked.connect(
            lambda: update_all_surfaces_status(Qt.CheckState.Unchecked)
        )

        layout = QVBoxLayout()
        layout.addWidget(cbox_faultnet)
        layout.addWidget(cbox_surfaces)
        layout_btn_selection = QHBoxLayout()
        layout_btn_selection.addWidget(btn_select_all)
        layout_btn_selection.addWidget(btn_unselect_all)
        layout.addLayout(layout_btn_selection)
        self.setLayout(layout)

    @property
    def faultnet(self):
        return self.cbox_fault_network.currentLayer().faultnet

    @property
    def surfaces(self):
        return self.cbox_surfaces.checkedItems()  # List of (item) names

    def new_fault_checked(self):
        new_checked_faultnames = self.cbox_surfaces.checkedItems()
        more_checked = len(self.checked_faultnames) < len(new_checked_faultnames)
        assert len(self.checked_faultnames) - len(new_checked_faultnames) in [-1, 1]
        if more_checked:
            for fname in new_checked_faultnames:
                if fname not in self.checked_faultnames:
                    name = fname
        else:
            for fname in self.checked_faultnames:
                if fname not in new_checked_faultnames:
                    name = fname
        if not self.faultnet.is_active(name):
            # If not enabled, uncheck it and do not update relations.
            # setEnabled just makes the item grey, and do not prevent to check it
            # (setCheckable and setSelectable neither)
            item = self.cbox_surfaces.model().findItems(name)[0]
            item.setCheckState(Qt.CheckState.Unchecked)
        else:
            self.checked_faultnames = new_checked_faultnames


class BoundingBoxWidget(QGroupBox):
    def __init__(self, layer, parent=None):
        # layer: used to re-open the widget with an pre-existing parameter set
        super().__init__(parent)

        self.setTitle(self.tr("Bounding box"))
        layout = QVBoxLayout()

        # X Y
        extent = QgsExtentWidget()
        extent.setMapCanvas(qutils.iface.mapCanvas(), True)
        original_extent = layer.extent()
        original_crs = layer.crs()
        if not original_extent.isEmpty():
            extent.setOriginalExtent(original_extent, original_crs)
            extent.setCurrentExtent(original_extent, original_crs)
            extent.setOutputCrs(original_crs)
        extent.extentChanged.connect(
            lambda rect: self.extent.setCurrentExtent(rect, self.extent.currentCrs())
        )
        self.extent = extent
        # Z
        if not (box3D := layer.extent3D()).isEmpty():
            zcoords = (box3D.zMinimum(), box3D.zMaximum())
        else:
            zcoords = ("", "")
        validator = QDoubleValidator()
        z_layout = QHBoxLayout()
        for name, value in zip(("zmin", "zmax"), zcoords, strict=True):
            label = QLabel(name)
            edit = QLineEdit(str(value))
            edit.setValidator(validator)
            setattr(self, "edit_" + name, edit)
            z_layout.addWidget(label)
            z_layout.addWidget(edit)

        layout.addWidget(self.extent)
        layout.addLayout(z_layout)
        self.setLayout(layout)

    @property
    def box3D(self):
        return QgsBox3D(
            self.xmin, self.ymin, self.zmin, self.xmax, self.ymax, self.zmax
        )

    @property
    def xmin(self):
        return (
            float(v)
            if (v := self.extent.currentExtent().xMinimum()) != np.finfo(np.float64).max
            else None
        )

    @property
    def xmax(self):
        return (
            float(v)
            if (v := self.extent.currentExtent().xMaximum()) != np.finfo(np.float64).min
            else None
        )

    @property
    def ymin(self):
        return (
            float(v)
            if (v := self.extent.currentExtent().yMinimum()) != np.finfo(np.float64).max
            else None
        )

    @property
    def ymax(self):
        return (
            float(v)
            if (v := self.extent.currentExtent().yMaximum()) != np.finfo(np.float64).min
            else None
        )

    @property
    def zmin(self):
        return float(v) if (v := self.edit_zmin.text()) else None

    @property
    def zmax(self):
        return float(v) if (v := self.edit_zmax.text()) else None


class TopographySelectionWidget(QGroupBox):
    def __init__(self, parent=None):
        super().__init__(parent)
        self.setTitle(self.tr("Add topography"))
        self.setCheckable(True)
        self.setChecked(False)
        # Combobox for layer selection
        cbox_raster = QgsMapLayerComboBox()
        cbox_raster.setFilters(Qgis.LayerFilter.RasterLayer)
        cbox_raster.setAllowEmptyLayer(True, self.tr("No raster selected"))
        self.cbox_raster = cbox_raster
        # Set Layout
        layout = QVBoxLayout()
        layout.addWidget(cbox_raster)
        self.setLayout(layout)

    @property
    def layer(self):
        if self.isChecked():
            return self.cbox_raster.currentLayer()
        return None


class GeometricLimitsWidget(QGroupBox):
    def __init__(self, layer, parent=None):
        # layer: used to re-open the widget with an pre-existing parameter set
        super().__init__(parent)
        self.setTitle(self.tr("Domain (geometric limits)"))

        self.bbox = BoundingBoxWidget(layer)
        self.topography = TopographySelectionWidget()

        layout = QVBoxLayout()
        layout.addWidget(self.bbox)
        layout.addWidget(self.topography)
        self.setLayout(layout)


class DiscretizationParametersWidget(QGroupBox):
    def __init__(self, layer, parent=None):
        # layer: used to re-open the widget with an pre-existing parameter set
        super().__init__(parent)
        self.setTitle(self.tr("Discretization steps"))

        layout = QHBoxLayout()
        validator = QIntValidator()
        validator.setBottom(1)
        steps = (None, None, None)
        if (params := layer.discretization_params) is not None:
            steps = params

        def get_label_and_line_edit(label, value):
            label = QLabel(label)
            edit = QLineEdit()
            edit.setValidator(validator)
            if value is None:
                value = 10
            edit.setText(str(value))
            return label, edit

        for x, value in zip(("x", "y", "z"), steps, strict=True):
            label = "n" + x
            lbl, edt = get_label_and_line_edit(label, value)
            setattr(self, "edt_" + label, edt)
            layout.addWidget(lbl)
            layout.addWidget(edt)

        self.setLayout(layout)

    @property
    def nx(self):
        return int(v) if (v := self.edt_nx.text()) else None

    @property
    def ny(self):
        return int(v) if (v := self.edt_ny.text()) else None

    @property
    def nz(self):
        return int(v) if (v := self.edt_nz.text()) else None

    @property
    def steps(self):
        return (self.nx, self.ny, self.nz)


class OutputFileLocation(QGroupBox):
    def __init__(self, parent=None):
        super().__init__(parent)
        self.setTitle(self.tr("Result"))

        layout = QHBoxLayout()
        label = QLabel(self.tr("No file selected"))
        layout.addWidget(label)
        # FIXME Continue here
        self.setLayout(layout)


class SurfaceExtractionDialog(QDialog):
    def __init__(self, layer, parent=None):
        # Note: this is important NOT to pass a temporary layer, otherwise the
        # layer selection combobox will be empty when opening
        super().__init__(parent)
        self.setWindowTitle(self.tr("Extract model surfaces"))

        # Select surfaces
        assert isinstance(layer, QgsPluginLayer)
        self.layer = layer
        if isinstance(layer, ModelLayer):
            self.surfaces_selector = SurfaceExtractionSelection(layer)
        elif isinstance(layer, FaultNetworkLayer):
            self.surfaces_selector = FaultsExtractionSelection(layer)
        else:
            msg = (
                f"Invalid input layer type: '{type(layer)}'. "
                "Expected 'ModelLayer' or 'FaultNetworkLayer'"
            )
            raise TypeError(msg)
        # Geometric limits
        self.limits = GeometricLimitsWidget(layer)
        self.discretization_params = DiscretizationParametersWidget(layer)
        # Buttons
        buttons = QDialogButtonBox(
            QDialogButtonBox.StandardButton.Cancel | QDialogButtonBox.StandardButton.Ok
        )
        buttons.accepted.connect(self.accept)
        buttons.rejected.connect(self.reject)
        self.finished.connect(self.deleteLater)

        # Layouts
        main_layout = QVBoxLayout()
        main_layout.addWidget(self.surfaces_selector)
        main_layout.addWidget(self.limits)
        main_layout.addWidget(self.discretization_params)
        main_layout.addWidget(buttons)

        self.accepted.connect(
            lambda: self.run_process()
        )  #  FIXME Why is lambda needed again?!

        self.setLayout(main_layout)

    def run_process(self):
        pvgis = qutils.plugins.get("pvGIS", None)
        if pvgis is None:
            msg = "pvGIS plugin cannot be found: is it loaded?"
            raise QgsProcessingException(msg)
        from pyvista import MultiBlock, PolyData  # noqa: PLC0415 <comes with pvgis>

        bbox = self.limits.bbox
        if bbox.xmax is None:
            QMessageBox.warning(self, "Missing parameters", "Please enter an extent")
            return
        if bbox.zmin is None or bbox.zmax is None:
            QMessageBox.warning(self, "Missing parameters", "Please enter zmin / zmax")
            return

        box3D = bbox.box3D
        self.layer.setExtent3D(box3D)
        steps = self.discretization_params.steps
        self.layer.discretization_params = steps
        # FIXME Keep also use_topo info in layer?
        input_params = {}
        # Get geological ojects to discretize
        if isinstance(self.layer, ModelLayer):
            # FIXME For now, we always export all surfaces. Handle selection some day...
            model = self.surfaces_selector.model
            input_params["model"] = model
            fault_network = None
            if (fault_layer_id := self.layer.faultnetlayer_id) is not None:
                fault_layer = QgsProject.instance().mapLayer(fault_layer_id)
                fault_network = fault_layer.faultnet
            # FIXME Do we need to call 'get_fault_network' here also?
            input_params["fault_network"] = fault_network
            name = model.name
        elif isinstance(self.layer, FaultNetworkLayer):
            faultnet = self.surfaces_selector.faultnet
            selected_faults = self.surfaces_selector.checked_faultnames
            input_params["fault_network"] = faultnet.get_subnetwork(selected_faults)
            name = faultnet.name
        topo = self.limits.topography.layer  # QgsRasterLayer
        input_params["topography"] = raster_layer_to_description(topo)

        rigs_params = BSPTreeBuilder.from_params(**input_params)
        if rigs_params is None:
            # Run validity checks to detect the problematics elements
            with SpooledTemporaryFile(max_size=1e6, mode="w+") as f:
                for elt in input_params.values():
                    is_valid(elt, buffer=f)
                f.seek(0)
                warnings = f.read()
            warning_messagebar = qutils.iface.messageBar().createMessage(
                "Interpolators warnings\n", warnings
            )
            qutils.iface.messageBar().pushWidget(warning_messagebar, Qgis.Warning)
            return
        vertices, tets = tetgrid(
            steps,
            extent=(bbox.xmin, bbox.xmax, bbox.ymin, bbox.ymax, bbox.zmin, bbox.zmax),
        )
        results = all_intersections(
            vertices, tets, **rigs_params, return_iso_surfaces=True
        )
        result_name = f"{name}_surface"

        def surfaces_to_multiblock():
            # Adapted from gmsig
            surfaces = results.iso_surfaces
            v, f, part = surfaces.vertices, surfaces.faces, surfaces.color
            names = rigs_params["names"]
            colormap = colors_as_strings(rigs_params["colors"])
            topo_id = rigs_params.get("topography")
            is_fault = rigs_params["is_fault"]
            uids = np.unique(part)
            faults, formations, topography = [], [], []
            for uid in uids:
                faces = [fi for i, fi in enumerate(f) if part[i] == uid]
                mesh = PolyData.from_irregular_faces(v, faces)
                if is_fault(uid):
                    faults.append((mesh, names[uid], colormap[uid]))
                elif uid == topo_id:
                    topography.append((mesh, names[uid], colormap[uid]))
                else:
                    formations.append((mesh, names[uid], colormap[uid]))

            def to_multiblock(meshes, order=None):
                # Reorder input individual PolyData
                if order is not None:
                    ordered_meshes = []
                    for name in order:
                        # elt = (mesh, name, color)
                        elt = next((elt for elt in meshes if elt[1] == name), None)
                        if elt is not None:  # Typically, skip units
                            ordered_meshes.append(elt)
                    assert len(ordered_meshes) == len(meshes)
                    meshes = ordered_meshes
                # create multiblock
                surfaces = {}
                for mesh, name, color in meshes:
                    if color is None:
                        color = "#cccccc"  # (204, 204, 204) : grey 80% white # noqa: PLW2901
                    mesh.field_data["color"] = color
                    surfaces[name] = mesh
                return MultiBlock(surfaces)

            all_groups = {}
            if topography:
                all_groups["Topography"] = to_multiblock(topography)
            if faults:
                order = [f.name for f in input_params["fault_network"].dataset]
                all_groups["Faults"] = to_multiblock(faults, order)
            if formations:
                # reversed so the youngest item is the first layer of the group in pvgis
                order = [i.name for i in reversed(input_params["model"].dataset)]
                all_groups["Formations"] = to_multiblock(formations, order)

            return MultiBlock(all_groups)

        mb = surfaces_to_multiblock()
        filepath = get_forgeo_output_dir() / (result_name + ".vtm")
        mb.save(str(filepath))
        pvgis.viewer.loader.load_multiblock(mb, name=result_name)


def colors_as_strings(colors: list):
    result = []
    for color in colors:
        if isinstance(color, str):
            result.append(color)
            continue
        # convert to QColor first
        if not isinstance(color, QColor):
            if color is None:
                colorname = QColor(127, 127, 127).name()
            elif all(isinstance(ck, int) for ck in color):
                assert all(0 <= ck <= 255 for ck in color)
                colorname = QColor(*color).name()
            elif all(isinstance(ck, float) for ck in color):
                assert all(0 <= ck <= 1 for ck in color)
                colorname = QColor(*(int(ck * 255) for ck in color)).name()
            else:
                msg = f"could not convert {color} to QColor"
                raise AssertionError(msg)
        result.append(colorname)
    return result
