from pathlib import Path

from forgeo.core import Model, ModellingUnit
from qgis.core import QgsProject
from qgis.gui import QgsCheckableComboBox
from qgis.PyQt.QtCore import QSignalBlocker, Qt
from qgis.PyQt.QtGui import QColor, QStandardItem
from qgis.PyQt.QtWidgets import (
    QAction,
    QComboBox,
    QDialog,
    QDialogButtonBox,
    QFileDialog,
    QGridLayout,
    QHBoxLayout,
    QLabel,
    QLineEdit,
    QMenuBar,
    QMessageBox,
    QPushButton,
    QScrollArea,
    QVBoxLayout,
    QWidget,
)
from qgis.utils import iface

import forgeo.io.xml as fxml
from forgeo.io.xml import deep_copy

from ..layers import FaultNetworkLayer, ModelLayer, PileLayer
from ..layers.model import TemporaryModelLayer
from ..utils import (
    DEFAULT_MODEL_NAME,
    cleargridlayoutcolumn,
    clearlayout,
    get_forgeo_data_dir,
    popup_save_changes,
    qicon,
    save_as_png,
    save_as_xml,
)
from .adddata_widget import AddItemDataDialog
from .color_picker import IcsColorDialog as QICSColorDialog
from .discretization import SurfaceExtractionDialog
from .interpolator_widget import InterpolationParametersDialog
from .utils import (
    QgsPluginLayerComboBox,
    display_data_icon,
    display_data_icon_itemgroup,
    surface_symbols,
)


class ContactLabel(QWidget):
    """Widget for the color of a contact"""

    def __init__(self, contact, parent=None):
        super().__init__(parent)
        self.contact = contact
        self.label = QLabel(contact.name, alignment=Qt.AlignmentFlag.AlignCenter)
        self.label.setFixedHeight(50)
        self.symbol_left = QLabel("--" * 20)
        self.symbol_left.setAlignment(
            Qt.AlignmentFlag.AlignRight | Qt.AlignmentFlag.AlignVCenter
        )
        self.symbol_right = QLabel("--" * 20)
        self.symbol_right.setAlignment(
            Qt.AlignmentFlag.AlignLeft | Qt.AlignmentFlag.AlignVCenter
        )
        self.set_color(QColor("#f0f0f0"))
        if self.contact.info is not None and self.contact.info["color"] is not None:
            self.set_color(QColor(self.contact.info["color"]))
        self.symbol_left.setStyleSheet(f"color: {self.color.name()};")
        self.symbol_right.setStyleSheet(f"color: {self.color.name()};")
        layout = QHBoxLayout()
        layout.addWidget(self.symbol_left)
        layout.addWidget(self.label)
        layout.addWidget(self.symbol_right)
        layout.setContentsMargins(0, 0, 0, 0)
        self.setLayout(layout)

    def set_color(self, color):
        self.label.setStyleSheet(f"background-color: {color.name()};")
        self.symbol_left.setStyleSheet(f"color: {color.name()};")
        self.symbol_right.setStyleSheet(f"color: {color.name()};")
        self.color = color

    def mouseDoubleClickEvent(self, event):  # noqa: ARG002
        dialog = QICSColorDialog(parent=self)
        dialog.setWindowTitle(f"Change {self.contact.name} color")
        dialog.setCurrentColor(self.color)
        ok = dialog.exec()
        if ok:
            self.set_color(dialog.currentColor())
            if self.contact.info is None:
                self.contact.info = {}
            self.contact.info["color"] = str(self.color.name())

    def text(self):
        return self.label.text()


class ModelEditionDialog(QDialog):
    def __init__(self, model_layer, parent=None):
        super().__init__(parent)
        # Init
        self.setWindowTitle("Model edition")
        assert model_layer is not None
        assert isinstance(model_layer, ModelLayer)
        self.model_layer = model_layer

        # Menus
        menu_bar = QMenuBar()
        self.menu_bar = menu_bar
        # Menu model
        model_menu = menu_bar.addMenu(self.tr("Model"))
        self.model_menu = model_menu
        new_action = QAction(self.tr("New"), parent=self)
        new_action.triggered.connect(self.new)
        model_menu.addAction(new_action)
        open_action = QAction(self.tr("Open"), parent=self)
        open_action.triggered.connect(self.open_layer)
        model_menu.addAction(open_action)
        load_action = QAction(self.tr("Load from XML"), parent=self)
        load_action.triggered.connect(self.load_from_xml_file)
        model_menu.addAction(load_action)
        save_as_menu = model_menu.addMenu(self.tr("Save as"))
        save_as_xml_action = QAction("XML", parent=self)
        save_as_xml_action.triggered.connect(lambda: save_as_xml(self, self.model))
        save_as_menu.addAction(save_as_xml_action)
        save_as_img_action = QAction("PNG", parent=self)
        save_as_img_action.triggered.connect(
            lambda: save_as_png(self.scroll_area_widget, self, self.model.name)
        )
        save_as_menu.addAction(save_as_img_action)
        save_as_new_action = QAction(self.tr("New layer"), parent=self)
        save_as_new_action.triggered.connect(self.open_copy)
        save_as_menu.addAction(save_as_new_action)
        model_menu.addMenu(save_as_menu)
        # Menu data
        data_menu = menu_bar.addMenu(self.tr("Data"))
        self.data_menu = data_menu
        self.connect_action = QAction(self.tr("Automatic data update"), parent=self)
        self.connect_action.setCheckable(True)
        self.connect_action.setChecked(model_layer.connected_to_layers)
        self.connect_action.triggered.connect(
            lambda checked: self.automatic_data_update(checked)
        )
        data_menu.addAction(self.connect_action)
        update_action = QAction(self.tr("Update model"), parent=self)
        update_action.triggered.connect(self.update_dataset)
        data_menu.addAction(update_action)
        fill_model_action = QAction(
            self.tr("Load data from another model"), parent=self
        )
        fill_model_action.triggered.connect(self.load_data_from_model)
        data_menu.addAction(fill_model_action)
        # Menu Faults
        faultnet_menu = menu_bar.addMenu(self.tr("Faults"))
        self.faultnet_menu = faultnet_menu
        addfaultnet_action = QAction(self.tr("Select fault network"), parent=self)
        addfaultnet_action.triggered.connect(self.add_fault_network)
        faultnet_menu.addAction(addfaultnet_action)
        removefaultnet_action = QAction(self.tr("Remove fault network"), parent=self)
        removefaultnet_action.triggered.connect(self.remove_fault_network)
        faultnet_menu.addAction(removefaultnet_action)
        # Menu discretization
        menu_discretization = menu_bar.addMenu(self.tr("Discretization"))
        self.menu_discretization = menu_discretization
        action_surfaces = QAction(self.tr("Extract surfaces"), parent=self)
        action_surfaces.triggered.connect(self.extract_model_surfaces)
        menu_discretization.addAction(action_surfaces)

        # Buttons
        cancel_button = QPushButton("Cancel")
        cancel_button.clicked.connect(self.reject)
        save_button = QPushButton("Save")
        save_button.clicked.connect(lambda: self.save())  # FIXME
        ok_button = QPushButton("OK")
        ok_button.clicked.connect(self.accept)

        # Signals
        self.accepted.connect(lambda: self.save())
        self.finished.connect(self.deleteLater)

        # Model
        scroll_area = QScrollArea()
        scroll_area.setWidgetResizable(True)
        self.scroll_area_widget = QWidget()
        self.grid_layout = QGridLayout()
        self.grid_layout.setVerticalSpacing(0)
        self.scroll_area_widget.setLayout(self.grid_layout)
        scroll_area.setWidget(self.scroll_area_widget)

        # Layouts
        self.layout = QVBoxLayout()
        buttons_layout = QHBoxLayout()
        buttons_layout.addWidget(cancel_button)
        buttons_layout.addStretch(1)
        buttons_layout.addWidget(save_button)
        buttons_layout.addWidget(ok_button)
        self.layout.setMenuBar(menu_bar)
        self.layout.addWidget(scroll_area)
        self.layout.addLayout(buttons_layout)
        self.setLayout(self.layout)
        self.resize(800, 500)

        # Display model
        self.refresh()

    @property
    def model(self):
        return self.model_layer.model

    @property
    def fault_network(self):
        fid = self.model_layer.faultnetlayer_id
        if fid is None:
            return None
        return QgsProject.instance().mapLayer(fid).faultnet

    def refresh(self):
        # Delete previous pile
        clearlayout(self.grid_layout)
        # Display new pile
        self._refresh_pile_column()
        self._refresh_data_column()
        self._refresh_interpolators_column()

    def _refresh_pile_column(self):
        row = 0
        previous_unit = None

        # FIXME iterate over dataset instead of pile
        pile_layer = QgsProject.instance().mapLayer(self.model_layer.pilelayer_id)
        pile_units = list(pile_layer.pile.subunits())
        for elem in reversed(pile_units):
            if isinstance(elem, ModellingUnit):
                if previous_unit is not None:
                    name = Model.get_contact_name(elem, previous_unit)
                    name = f"Contact {elem.name} - {previous_unit}"
                    item = self.model.get_item(name)
                    self._add_contact_to_layout(row, item)
                    row += 1
                self._add_modelling_unit_to_layout(row, elem)
                previous_unit = elem
            else:  # isinstance(elem, Erosion)
                self._add_erosion_to_layout(row, elem)
                previous_unit = None
            row += 1

    def _add_modelling_unit_to_layout(self, row, unit):
        color = QColor(unit.info["color"])
        label = QLabel(unit.name, alignment=Qt.AlignmentFlag.AlignCenter)
        label.setFixedHeight(50)
        label.setStyleSheet(f"background-color: {color.name()};")
        self.grid_layout.addWidget(label, row, 0, 1, 3)

    def _add_contact_to_layout(self, row, item):
        self.grid_layout.addWidget(ContactLabel(item), row, 0, 1, 3)
        # Add combobox
        contact_type = QComboBox()
        contact_type.setFixedWidth(150)
        contact_type.addItem("Conformable")
        conf = contact_type.model().item(contact_type.count() - 1)
        conf.setEnabled(False)
        contact_type.addItem(qicon("Conformable"), "")
        contact_type.addItem(qicon("Top"), "with formation below")
        contact_type.addItem(qicon("Base"), "with formation above")
        contact_type.addItem("Unconformable")
        unconf = contact_type.model().item(contact_type.count() - 1)
        unconf.setEnabled(False)
        contact_type.addItem(qicon("Unconformable"), "")
        idx_to_str = {1: "Conformable", 2: "Top", 3: "Base", 5: "Surface only"}
        str_to_idx = {v: k for k, v in idx_to_str.items()}

        def update_contact_element_type():
            item.type = idx_to_str[self.sender().currentIndex()]
            self.model.update_interpolators()
            self._refresh_interpolators_column()

        contact_type.currentIndexChanged.connect(update_contact_element_type)
        # Reload previously selected contact type :
        with QSignalBlocker(contact_type):
            contact_type.setCurrentIndex(str_to_idx[item.type])
        self.grid_layout.addWidget(contact_type, row, 2)
        self.grid_layout.setAlignment(contact_type, Qt.AlignmentFlag.AlignCenter)

    def _add_erosion_to_layout(self, row, erosion):
        color = QColor(erosion.info["color"])
        label = QLabel(erosion.name, alignment=Qt.AlignmentFlag.AlignCenter)
        label.setFixedHeight(50)
        label.setStyleSheet(f"background-color: {color.name()};")
        self.grid_layout.addWidget(label, row, 1)
        symbol_left, symbol_right = surface_symbols(color, is_erosion=True)
        self.grid_layout.addWidget(symbol_left, row, 0)
        self.grid_layout.addWidget(symbol_right, row, 2)

    def _refresh_data_column(self):
        cleargridlayoutcolumn(self.grid_layout, 4)
        self.grid_layout.setColumnMinimumWidth(3, 50)
        self.grid_layout.setColumnStretch(3, 0)
        for row, item in enumerate(reversed(self.model.dataset)):
            button_add_data = QPushButton("Data")
            button_add_data.clicked.connect(self.add_data)
            self.grid_layout.addWidget(button_add_data, row, 4)
            # Data icon
            display_data_icon(item, self.grid_layout, row, 3)

    def _refresh_interpolators_column(self):
        cleargridlayoutcolumn(self.grid_layout, 5) # Icons
        cleargridlayoutcolumn(self.grid_layout, 6) # Interpolators buttons
        cleargridlayoutcolumn(self.grid_layout, 7) # Faults
        fault_network = self.fault_network
        subpiles = self._create_subpiles()
        nb_elements = len(self.model.dataset)
        row_span = 4 if fault_network is None else 5
        for subpile, start_idx in subpiles:
            nb_items = len(subpile)
            if nb_items == 1 and subpile[0].type == "Unit":
                continue
            button_interp = QPushButton("Interpolator")
            button_interp.clicked.connect(self.define_interp)
            grey_box = QLabel()
            lightgrey = QColor("#f0f0f0").name()
            darkgrey = QColor("#c8c8c8").name()
            grey_box.setStyleSheet(
                f"background-color: {darkgrey}; border-top: 1px solid {lightgrey}; border-bottom: 1px solid {lightgrey};"
            )
            fromRow = nb_elements - start_idx - nb_items
            self.grid_layout.addWidget(grey_box, fromRow, 3, nb_items, row_span)
            grey_box.lower()  # Send to background
            self.grid_layout.addWidget(button_interp, fromRow, 6, nb_items, 1)
            # Data icon
            display_data_icon_itemgroup(subpile, self.grid_layout, fromRow, 5, nb_items)
            # Faults
            if fault_network is None:
                continue
            name = self.get_item_name(
                self.grid_layout.itemAtPosition(fromRow, 0).widget()
            )
            interp = self.model.get_interpolator(name)
            checked_fnames = interp.discontinuities
            faults_cbox = QgsCheckableComboBox()
            faults_cbox.checkedItemsChanged.connect(
                lambda names: self.new_fault_checked(names)
            )
            faults_cbox.addItem("All faults")
            for fault in fault_network.dataset:
                item = QStandardItem(fault.name)
                item.setData(fault)
                if not fault_network.is_active(fault.name):
                    item.setEnabled(False)
                elif checked_fnames is not None and fault.name in checked_fnames:
                    item.setCheckState(Qt.CheckState.Checked)
                faults_cbox.model().appendRow(item)
            self.grid_layout.addWidget(faults_cbox, fromRow, 7, len(subpile), 1)

    def _create_subpiles(self):
        # FIXME Replace by model.interpolators...
        subpiles = []
        current_subpile = []
        start_idx = 0
        nb_elements = len(self.model.dataset)
        for idx, item in enumerate(self.model.dataset):
            if item.type in ["Erosion", "Surface only"]:
                if len(current_subpile) != 0:
                    subpiles.append((current_subpile, start_idx))
                subpiles.append(([item], idx))
                current_subpile = []
                start_idx = idx + 1
            elif item.type == "Top":
                current_subpile.append(item)
                subpiles.append((current_subpile, start_idx))
                current_subpile = []
                start_idx = idx + 1
            elif item.type == "Base":
                subpiles.append((current_subpile, start_idx))
                current_subpile = [item]
                start_idx = idx
            else:  # Unit or Conformable contact
                current_subpile.append(item)
                if idx == nb_elements - 1:
                    subpiles.append((current_subpile, start_idx))
        return subpiles

    def get_item_name(self, sender):
        row_idx, _, _, _ = self.grid_layout.getItemPosition(
            self.grid_layout.indexOf(sender)
        )
        name = self.grid_layout.itemAtPosition(row_idx, 0).widget().text()
        if name == "~~" * 10:  # Erosion
            name = self.grid_layout.itemAtPosition(row_idx, 1).widget().text()
        return name

    def add_data(self):
        item = self.model.get_item(self.get_item_name(self.sender()))
        layer = self.model_layer
        filter = layer.filters.get(item.name)
        # Open AddItemDataDialog dialog
        dlg = AddItemDataDialog(item, filter, item.get_color())

        def _update_item_data(result):
            if result == QDialog.DialogCode.Rejected:
                return
            dlg.close()
            # Update model and model layer
            self.model_layer.update_item(item.name, dlg.filter)
            # Refresh display
            self._refresh_data_column()
            self._refresh_interpolators_column()

        dlg.finished.connect(_update_item_data)
        dlg.open()

    def update_dataset(self):
        self.model_layer.update_dataset()
        # Update layout
        self._refresh_data_column()
        self._refresh_interpolators_column()

    def automatic_data_update(self, checked):
        if checked:
            dialog = QDialog(parent=iface.mainWindow())
            dialog.setWindowTitle("Automatic data update")
            label = QLabel(
                "The model will be automatically updated when a source layer has been modified (Edited, deleted, change of coordinated reference system).\n"
                "Activating automatic data update will save the current model."
            )
            buttons = QDialogButtonBox(
                QDialogButtonBox.StandardButton.Cancel
                | QDialogButtonBox.StandardButton.Ok
            )
            buttons.accepted.connect(dialog.accept)
            buttons.rejected.connect(dialog.reject)
            layout = QVBoxLayout()
            layout.addWidget(label)
            layout.addWidget(buttons)
            dialog.setLayout(layout)
            if dialog.exec() == QDialog.DialogCode.Accepted:
                self.save()
                self.model_layer.reconnect_data_layers()
            else:
                self.connect_action.setChecked(Qt.CheckState.Unchecked)
        else:
            self.model_layer.disconnect_data_layers()

    def load_data_from_model(self):
        dialog = QDialog(parent=iface.mainWindow())
        dialog.setWindowTitle("Load data from another model")
        label = QLabel(
            "Replaces data of elements with the same names in both models.\n"
            "Choose source model :"
        )
        cbox = QgsPluginLayerComboBox(ModelLayer)
        buttons = QDialogButtonBox(
            QDialogButtonBox.StandardButton.Cancel | QDialogButtonBox.StandardButton.Ok
        )
        buttons.accepted.connect(dialog.accept)
        buttons.rejected.connect(dialog.reject)
        layout = QVBoxLayout()
        layout.addWidget(label)
        layout.addWidget(cbox)
        layout.addWidget(buttons)
        dialog.setLayout(layout)
        if dialog.exec() == QDialog.DialogCode.Accepted:
            self.model_layer.load_data_from_model(cbox.currentLayer())
            # Update layout
            self._refresh_data_column()
            self._refresh_interpolators_column()

    def define_interp(self):
        # FIXME I kept this implementation to save time, but this is dangerous...
        row_idx, _, row_span, _ = self.grid_layout.getItemPosition(
            self.grid_layout.indexOf(self.sender())
        )
        start_idx = len(self.model.dataset) - row_idx - row_span
        first_item = self.model.dataset[start_idx]
        interpolator = self.model.get_interpolator(first_item.name)
        # Open widget
        dlg = InterpolationParametersDialog(interpolator)

        def _update_interpolator(result):
            if result == QDialog.DialogCode.Rejected:
                return
            # Update model interpolator
            self.model.update_interpolator(dlg.interpolator)
            dlg.close()

        dlg.finished.connect(_update_interpolator)
        dlg.show()

    def open_layer(self):
        dlg = QDialog(parent=iface.mainWindow())
        dlg.setWindowTitle(dlg.tr("Open an existing model"))
        cbox_layers = QgsPluginLayerComboBox(ModelLayer)
        buttons = QDialogButtonBox(
            QDialogButtonBox.StandardButton.Cancel | QDialogButtonBox.StandardButton.Ok
        )
        buttons.accepted.connect(dlg.accept)
        buttons.rejected.connect(dlg.reject)
        layout = QVBoxLayout()
        layout.addWidget(QLabel(dlg.tr("Select a model layer")))
        layout.addWidget(cbox_layers)
        layout.addWidget(buttons)
        dlg.setLayout(layout)

        def process_result(result):
            if result == QDialog.DialogCode.Rejected:  # Cancel button
                return
            layer = cbox_layers.currentLayer()
            save_current_model = popup_save_changes(self.model.name)
            # Open a new ModelEditionDialog
            self.edit(layer)
            # Close the current ModelEditionDialog
            if save_current_model:
                result = QDialog.DialogCode.Accepted
                self.save()
            else:
                result = QDialog.DialogCode.Rejected
            self.done(result)  # Causes self to close(), and emit accepted/rejected

        dlg.finished.connect(process_result)
        dlg.open()

    def load_from_xml_file(self):
        # Load Model from XML
        src_dir = get_forgeo_data_dir()
        filename = QFileDialog.getOpenFileName(  # Returns a 2-tuple (filename, filter)
            parent=self,
            caption=self.tr("Load an existing model"),
            directory=str(src_dir),
            filter=self.tr("XML (*.xml)"),
        )[0]
        if not filename:
            return
        model = fxml.load(Path(filename))
        if model is None or not isinstance(model, Model):
            return
        # Retrieve model pilelayer
        pilename = model.pilename
        pile_layer = QgsProject.instance().mapLayersByName(pilename)
        pile_layer = pile_layer[0] if pile_layer else None
        if pile_layer is None or not isinstance(pile_layer, PileLayer):
            QMessageBox.warning(
                self,
                f"Cannot find pile layer matching the model pile: {pilename}",
                "Please reload the XML file corresponding to the pile and retry",
            )
            return
        # Retrieve model fault network layer
        faultnetname = model.faultnetname
        faults_layer_id = None
        if faultnetname is not None:
            faults_layer = QgsProject.instance().mapLayersByName(faultnetname)
            faults_layer = faults_layer[0] if faults_layer else None
            if faults_layer is None or not isinstance(faults_layer, FaultNetworkLayer):
                QMessageBox.warning(
                    self,
                    f"Cannot find fault network layer matching the model fault network: {faultnetname}",
                    "Please reload the XML file corresponding to the fault network and retry",
                )
                return
            faults_layer_id = faults_layer.id()

        # Cancel the changes in the current model
        self.done(QDialog.DialogCode.Rejected)
        # Everything fine, update the widget
        layer = ModelLayer(model, pile_layer.id(), faults_layer_id)
        QgsProject.instance().addMapLayer(layer)
        # Popup new model edition dialog
        self.edit(layer)

    def add_fault_network(self):
        dlg = QDialog(parent=iface.mainWindow())
        dlg.setWindowTitle("Add fault network")
        cbox_faultnet = QgsPluginLayerComboBox(FaultNetworkLayer)
        buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok)
        buttons.accepted.connect(dlg.accept)
        layout = QVBoxLayout()
        layout.addWidget(QLabel("Fault network"))
        layout.addWidget(cbox_faultnet)
        layout.addWidget(buttons)
        dlg.setLayout(layout)
        if dlg.exec() == QDialog.DialogCode.Accepted:
            layer = cbox_faultnet.currentLayer()
            self.model.set_fault_network(layer.faultnet)
            self.model_layer.faultnetlayer_id = layer.id()
            self._refresh_interpolators_column()

    def remove_fault_network(self):
        self.model.remove_fault_network()
        self.model_layer.faultnetlayer_id = None
        self._refresh_interpolators_column()

    def new_fault_checked(self, new_checked_faultnames):
        fault_network = self.fault_network
        cbox = self.sender()
        interp = self.model.get_interpolator(self.get_item_name(self.sender()))
        interp = deep_copy(interp)
        checked_faultnames = interp.discontinuities or []
        nb_checked = len(checked_faultnames) - len(new_checked_faultnames)
        assert nb_checked in [-1, 0, 1]
        more_checked = nb_checked < 0
        if nb_checked == 0:
            fault_name = "All faults"
        if more_checked:
            for fname in new_checked_faultnames:
                if fname not in checked_faultnames:
                    fault_name = fname
        else:
            for fname in checked_faultnames:
                if fname not in new_checked_faultnames:
                    fault_name = fname
        if fault_name == "All faults":  # Checked or unchecked all faults
            for i in range(1, cbox.count()):
                fname = cbox.itemText(i)
                if fault_network.is_active(fname):
                    if "All faults" in new_checked_faultnames:
                        cbox.setItemCheckState(i, Qt.CheckState.Checked)
                        interp.add_discontinuity(fname)
                    else:
                        cbox.setItemCheckState(i, Qt.CheckState.Unchecked)
                        interp.remove_discontinuity(fname)
        elif not fault_network.is_active(fault_name):
            # If not active, uncheck it and do not update.
            # setEnabled just makes the item grey, and do not prevent to check it
            # (setCheckable and setSelectable neither)
            item = cbox.model().findItems(fault_name)[0]
            item.setCheckState(Qt.CheckState.Unchecked)
        else:
            interp.set_discontinuities(new_checked_faultnames)
        self.model.update_interpolator(interp)

    def save(self):
        # FIXME Easier if the widget stores a ref to the permanent layer too?
        permanent_layer = QgsProject.instance().mapLayersByName(self.model_layer.name())
        if not permanent_layer:
            permanent_layer = ModelLayer()
            QgsProject.instance().addMapLayer(permanent_layer)
        else:
            permanent_layer = permanent_layer[0]
            assert isinstance(permanent_layer, ModelLayer)
        permanent_layer.update_from(self.model_layer)
        self.model_layer = TemporaryModelLayer.clone(permanent_layer)

    def open_copy(self):
        self.accept()
        layer = ModelLayer.clone(self.model_layer)
        layer.setName(layer.name() + " copy")
        self.edit(layer)

    @classmethod
    def new(cls, name=None, pile_layer=None, faultnet_layer=None):
        name, pile_layer, faultnet_layer = _popup_new_model_dialog(
            name, pile_layer, faultnet_layer
        )
        # If there is another one with this name in que QGIS project, ask for a new name
        while QgsProject.instance().mapLayersByName(name):
            QMessageBox.warning(
                iface.mainWindow(), "Already existing", "Please enter a different name"
            )
            name, pile_layer, faultnet_layer = _popup_new_model_dialog(
                name, pile_layer, faultnet_layer
            )
        if name is None:
            return
        assert pile_layer.pile is not None
        faultnetlayer_id = faultnet_layer.id() if faultnet_layer is not None else None
        layer = ModelLayer.new(name, pile_layer.id(), faultnetlayer_id)
        dlg = cls(layer, parent=iface.mainWindow())
        dlg.show()

    @classmethod
    def edit(cls, model_layer):
        model_layer = TemporaryModelLayer.clone(model_layer)
        dlg = cls(model_layer, parent=iface.mainWindow())
        dlg.show()

    # Actions for the "Discretization" menu

    def extract_model_surfaces(self):
        self.save()
        # FIXME Easier if the widget stores a ref to the permanent layer too?
        permanent_layer = QgsProject.instance().mapLayersByName(self.model_layer.name())
        permanent_layer = permanent_layer[0]
        assert isinstance(permanent_layer, ModelLayer)
        # 'parent=self': the dialog will be centered with self
        dlg = SurfaceExtractionDialog(permanent_layer, parent=self)
        dlg.show()  # show() instead of open(), as open() creates a modal dialog


def _popup_new_model_dialog(model_name=None, pile_layer=None, faultnet_layer=None):
    dlg = QDialog(parent=iface.mainWindow())
    dlg.setWindowTitle("New model")
    name = model_name or DEFAULT_MODEL_NAME
    edt_name = QLineEdit(name)
    cbox_pile = QgsPluginLayerComboBox(PileLayer, defaultLayer=pile_layer)
    cbox_faultnet = QgsPluginLayerComboBox(
        FaultNetworkLayer,
        defaultLayer=faultnet_layer,
        emptyLayerMessage="No faults",
    )
    buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok)
    buttons.accepted.connect(dlg.accept)
    layout = QGridLayout()
    layout.addWidget(QLabel("Name"), 0, 0)
    layout.addWidget(edt_name, 0, 1)
    layout.addWidget(QLabel("Pile"), 1, 0)
    layout.addWidget(cbox_pile, 1, 1)
    layout.addWidget(QLabel("Fault network"), 2, 0)
    layout.addWidget(cbox_faultnet, 2, 1)
    layout.addWidget(buttons, 3, 0, 1, 2)
    dlg.setLayout(layout)
    if dlg.exec() == QDialog.DialogCode.Accepted:
        name = edt_name.text()
        pile_layer = cbox_pile.currentLayer()
        faultnet_layer = cbox_faultnet.currentLayer()
        return name, pile_layer, faultnet_layer
    return (None, None, None)
