import xml.etree.ElementTree as ET
from pathlib import Path

from forgeo.core import Model
from qgis.core import (
    Qgis,
    QgsBox3D,
    QgsMapLayerRenderer,
    QgsMessageLog,
    QgsPluginLayer,
    QgsPluginLayerType,
    QgsProject,
)
from qgis.PyQt.QtWidgets import QMessageBox
from qgis.utils import iface

import forgeo.io.xml as fxml

from ..io.xml.qgis import FiltersSerializer, ModelFilters
from ..utils import extract_item_data, get_forgeo_data_dir, qicon
from .elevation_profile.model import (
    ModelLayerElevationProperties,
    ModelLayerProfileSource,
)
from .utils import _repr_plugin_layer, connect, disconnect, set_unique_name_for_layer


class ModelLayerRenderer(QgsMapLayerRenderer):
    def __init__(self, layerId, rendererContext):
        super().__init__(layerId, rendererContext)

    def render(self):
        return True


class ModelLayer(QgsPluginLayer):
    """Custom `qgis.core.QgsMapLayer` to display/interact with geological model"""

    LAYER_TYPE = "ModelLayer"

    def __init__(self, model=None, pilelayer_id=None, faultnetlayer_id=None):
        super().__init__(ModelLayer.LAYER_TYPE)
        self.setValid(True)
        # Set model layer CRS: not used (at least up to now), but avoids QGIS
        # prompting when opening the project to ask for setting the input CRS
        self.setCrs(QgsProject.instance().crs())
        self.nameChanged.connect(self._on_rename)
        self.model = model
        if model is not None:
            self.setName(model.name)
        self.pilelayer_id = pilelayer_id
        self.faultnetlayer_id = faultnetlayer_id
        self.update_interpolators_from_faultnetwork()
        self.filters = {}
        self.discretization_params = None
        self.connected_to_layers = False

        # Elevation profile related stuffs

        # Not used in practice, but anticipate future needs for large models...
        # If True, the model section is computed only once when initializing the
        # profile, and future (de)zooms are only "in the drawing".
        # If False, the model section is recomputed at each (de)zoom, so small
        # elements can (dis)appear when (de)zooming in the elevation profile.
        cache_profile = False  # TODO Provide access to this option
        # Keep a reference to avoid garbage collection
        self.profile_source = ModelLayerProfileSource(self, cache_profile)

    __repr__ = _repr_plugin_layer

    def icon(self):
        return qicon("model_icon")

    def createMapRenderer(self, rendererContext):
        return ModelLayerRenderer(self.id(), rendererContext)

    def setTransformContext(self, ct):
        pass

    def readXml(self, node, *_):
        """Called by readLayerXML(), used by children to read state specific to
        them from project files.
        """
        super().readXml(node, *_)
        elt = node.toElement()
        # Load model from source XML file
        root_dir = Path(QgsProject.instance().absolutePath())
        filepath = elt.attribute("source")
        if not (root_dir / filepath).exists():
            # FIXME To remove "soon", allows reopening old projects
            root_dir = root_dir.parent
            assert (root_dir / filepath).exists()
        self.model = fxml.load(root_dir / filepath)
        self.setName(self.model.name)  # TODO Check if this works
        # Other attributes
        if pilelayer_id := elt.attribute("pilelayer_id"):
            self.pilelayer_id = pilelayer_id
        if faultnetlayer_id := elt.attribute("faultnetlayer_id"):
            self.faultnetlayer_id = faultnetlayer_id
        if connected_to_layers := elt.attribute("connected_to_layers"):
            self.connected_to_layers = connected_to_layers == "True"
        if self.model.dataset is not None:
            filepath_filter = get_forgeo_data_dir() / (self.model.name + "_filters.xml")
            if filepath_filter.exists():
                loaded_filters = ET.parse(filepath_filter).getroot()
                filters = FiltersSerializer.load(loaded_filters)
                if isinstance(filters, list):
                    # FIXME Remove SOON, this is only a temporary hack to easily
                    # handle the old serialization format
                    names = [item.name for item in self.model.dataset]
                    filters = {
                        name: filter
                        for name, filter in zip(names, filters, strict=True)
                        if len(filter) > 0
                    }
                self.filters = filters
                if self.connected_to_layers:
                    self.reconnect_data_layers()
        if elt.attribute("extent3D"):
            mins, maxs = elt.attribute("extent3D").split(" : ")
            xmin, ymin, zmin = (float(m) for m in mins.split(","))
            xmax, ymax, zmax = (float(m) for m in maxs.split(","))
            self.setExtent3D(QgsBox3D(xmin, ymin, zmin, xmax, ymax, zmax))
        if params := elt.attribute("discretization_params"):
            self.discretization_params = params.split(";")
        return True

    def writeXml(self, node, *_):
        """Called by writeLayerXML(), used by children to write state specific
        to them to project files.
        """
        super().writeXml(node, *_)
        elt = node.toElement()
        elt.setAttribute("type", "plugin")
        elt.setAttribute("name", self.LAYER_TYPE)
        elt.setAttribute("pilelayer_id", self.pilelayer_id)
        elt.setAttribute("faultnetlayer_id", self.faultnetlayer_id)
        connected = "True" if self.connected_to_layers else "False"
        elt.setAttribute("connected_to_layers", connected)
        # Dump the model in XML file
        filepath = get_forgeo_data_dir() / (self.model.name + ".xml")
        fxml.dump(self.model, filepath)
        if hasattr(self, "filters"):
            filepath_filters = get_forgeo_data_dir() / (
                self.model.name + "_filters.xml"
            )
            dumped_filters = FiltersSerializer.dump(ModelFilters(self.filters))
            ET.ElementTree(dumped_filters).write(filepath_filters.as_posix())
        # Set the relative XML path as source for the model layer
        root_dir = Path(QgsProject.instance().absolutePath())
        filepath = filepath.relative_to(root_dir).as_posix()
        elt.setAttribute("source", "./" + str(filepath))
        # Other attributes
        if not self.extent3D().isEmpty():
            elt.setAttribute("extent3D", self.extent3D().toString())
        if (params := self.discretization_params) is not None:
            params = str(params[0]) + ";" + str(params[1]) + ";" + str(params[2])
            elt.setAttribute("discretization_params", params)
        return True

    @classmethod
    def new(cls, name, pilelayer_id=None, faultnetlayer_id=None):
        if QgsProject.instance().mapLayersByName(name):
            QMessageBox.warning(
                iface.mainWindow(), "Name already used", "Layer not created"
            )
        else:
            pile = QgsProject.instance().mapLayer(pilelayer_id).pile
            model = Model(name, pile.name)
            model.initialize(pile)
            return cls(model, pilelayer_id, faultnetlayer_id)
        return None

    @classmethod
    def clone(cls, layer):
        model = None
        if layer.model is not None:
            model = fxml.deep_copy(layer.model)
        new_layer = cls(model, layer.pilelayer_id, layer.faultnetlayer_id)
        if (filters := layer.filters) is not None:
            new_layer.filters = {**filters}
        if (params := layer.discretization_params) is not None:
            new_layer.discretization_params = list(params)
        new_layer.setCrs(layer.crs())
        return new_layer

    def update_from(self, layer):
        """Updates this layer properties (self.model, self.filters) by capturing
        the attributes of the input layer

        Warning: The input layer is left in an undefined state after calling this
        method. This method is meant to take as input a temporary layer that should
        be deleted afterwards
        """
        self.model = layer.model
        layer.model = None
        if self.model.name != self.name():
            self.setName(self.model.name)
        all_old_filters = self.filters
        self.filters = layer.filters
        layer.filters = None
        self.pilelayer_id = layer.pilelayer_id
        self.faultnetlayer_id = layer.faultnetlayer_id
        if (params := layer.discretization_params) is not None:
            self.discretization_params = list(params)
        if layer.connected_to_layers:
            self.connected_to_layers = True
            for item in self.model.dataset:
                old_filters = all_old_filters.get(item.name, [])
                new_filters = self.filters.get(item.name, [])
                self.update_connected_data_layers(old_filters, new_filters)
        else:
            self.connected_to_layers = False

    def _on_rename(self):
        name = set_unique_name_for_layer(self)
        if self.model is not None:
            self.model.name = name

    def connect_layers_added(self, layers=None):
        data_layers = set()
        for dataselection in self.filters.values():
            for filter in dataselection:
                data_layers.add(filter.layer_id)
        data_layers = [
            layer
            for id_ in data_layers
            if (layer := QgsProject.instance().mapLayer(id_)) is not None
        ]
        if layers is None:
            layers = data_layers
        for layer in layers:
            if layer in data_layers:
                connect(layer, self)

    def reconnect_data_layers(self):
        self.connected_to_layers = True
        if self.filters is None:
            return
        # A bit tricky, but necessary to handle the "random" order
        # in which layers are loaded when opening a project
        # Connect layers loaded before self
        self.connect_layers_added()
        # Connect layers loaded after self
        layerStore = QgsProject.instance().layerStore()
        layerStore.layersAdded.connect(self.connect_layers_added)

    def update_connected_data_layers(self, old_filters, new_filters):
        # Disconnect unused data layers
        new_layer_ids = [f.layer_id for f in new_filters]
        to_remove = [
            id_ for f in old_filters if (id_ := f.layer_id) not in new_layer_ids
        ]
        for id in to_remove:
            keep = False
            for _, filters in self.filters.items():
                if any(id == f.layer_id for f in filters):
                    keep = True
                    break
            if not keep:
                disconnect(QgsProject.instance().mapLayer(id), self)
        # Connect new data layers
        old_layer_ids = [f.layer_id for f in old_filters]
        to_add = [id_ for f in new_filters if (id_ := f.layer_id) not in old_layer_ids]
        for id in to_add:
            connect(QgsProject.instance().mapLayer(id), self)

    def disconnect_data_layers(self):
        self.connected_to_layers = False
        if self.filters is None:
            return
        # Disconnect data layers
        for dataselection in self.filters.values():
            for filter in dataselection:
                disconnect(QgsProject.instance().mapLayer(filter.layer_id), self)
        # Stop connecting layers added in the project
        try:
            layerStore = QgsProject.instance().layerStore()
            layerStore.layersAdded.disconnect(self.connect_layers_added)
        except Exception:
            pass
        QgsMessageLog.logMessage(f"{self.model.name} disconnected", "", level=Qgis.Info)

    def on_data_layers_deletion(self, layer_ids):
        to_remove = {}
        for item_name, filters in self.filters.items():
            to_remove[item_name] = [f for f in filters if f.layer_id in layer_ids]
        for item_name, filters in to_remove.items():
            if filters:
                self.update_item(
                    item_name, [f for f in self.filters[item_name] if f not in filters]
                )

    def update_dataset(self):
        for item_name, data_selection in self.filters.items():
            item = self.model.get_item(item_name)
            item_data = extract_item_data(data_selection, self.crs())
            item.item_data = item_data
        QgsMessageLog.logMessage(
            f"{self.model.name} dataset updated", "", level=Qgis.Info
        )

    def update_item(self, item_name, new_filters):
        """Replaces all the QgisVectorDataFilter of a given item"""
        item = self.model.get_item(item_name)
        assert item is not None
        old_filters = self.filters.get(item_name)
        self.filters[item_name] = new_filters
        item_data = extract_item_data(new_filters, self.crs())
        item.item_data = item_data
        if self.connected_to_layers:
            self.update_connected_data_layers(old_filters, new_filters)
            QgsMessageLog.logMessage(
                f"{self.model.name} updated", item.name, level=Qgis.Info
            )

    def load_data_from_model(self, other_model_layer):
        """Replaces model data and filters from elements of other_model with the same name."""
        for other_item in other_model_layer.model.dataset:
            if ((item := self.model.get_item(other_item.name)) is not None) and (
                (other_filter := other_model_layer.filters.get(item.name)) is not None
            ):
                self.update_item(item.name, other_filter)

    def update_interpolators_from_faultnetwork(self):
        if self.faultnetlayer_id is not None:
            fault_network = QgsProject.instance().mapLayer(self.faultnetlayer_id).faultnet
            faults_names = [fault.name for fault in fault_network.dataset]
            for interp in self.model.interpolators:
                faults_to_remove = []
                for disc_name in interp.discontinuities:
                    if disc_name in faults_names and not fault_network.is_active(disc_name):
                        faults_to_remove.append(disc_name)
                for disc_name in faults_to_remove:
                    interp.remove_discontinuity(disc_name)

    def elevationProperties(self):
        return ModelLayerElevationProperties()

    def profileSource(self):
        return self.profile_source


class ModelLayerType(QgsPluginLayerType):
    def __init__(self):
        super().__init__(ModelLayer.LAYER_TYPE)

    def createLayer(self):
        return ModelLayer()

    def showLayerProperties(self, layer):
        from ..widgets.model_widget import ModelEditionDialog  # noqa: PLC0415

        try:
            ModelEditionDialog.edit(layer)
        except Exception as e:
            raise e
        return True


class TemporaryModelLayer(ModelLayer):
    LAYER_TYPE = "TemporaryModelLayer"

    def __init__(self, model=None, pilelayer_id=None, faultnetlayer_id=None):
        super().__init__(model, pilelayer_id, faultnetlayer_id)
        self.profile_source = None

    def createMapRenderer(self, rendererContext):
        pass

    def readXml(self, node, *_):  # noqa: ARG002
        return True

    def writeXml(self, node, *_):  # noqa: ARG002
        return True


class TemporaryModelLayerType(QgsPluginLayerType):
    def __init__(self):
        super().__init__(TemporaryModelLayer.LAYER_TYPE)

    def createLayer(self):
        return TemporaryModelLayer()

    def showLayerProperties(self):
        return False
