from __future__ import annotations

import codecs
import pathlib
import pickle
from typing import ClassVar

from forgeo.gmlib.GeologicalModel3D import GeologicalModel
from qgis.core import (
    Qgis,
    QgsBox3d,
    QgsCategorizedSymbolRenderer,
    QgsCoordinateReferenceSystem,
    QgsCoordinateTransformContext,
    QgsFeature,
    QgsFeatureRenderer,
    QgsField,
    QgsFields,
    QgsFillSymbol,
    QgsGeometry,
    QgsLineSymbol,
    QgsMapLayer,
    QgsMapLayerRenderer,
    QgsMarkerSymbol,
    QgsPluginLayer,
    QgsPluginLayerType,
    QgsPoint,
    QgsProject,
    QgsReadWriteContext,
    QgsRectangle,
    QgsRenderContext,
    QgsRendererCategory,
    QgsSymbol,
)
from qgis.gui import QgsOptionsDialogBase
from qgis.PyQt.QtCore import QObject, QVariant
from qgis.PyQt.QtGui import QColor
from qgis.PyQt.QtXml import QDomDocument, QDomNode
from qgis.utils import iface

from ..settings import PluginSettings


class GmLayer(QgsPluginLayer):
    """Custom `qgis.core.QgsMapLayer` to display/interact with GeoModels

    GeoModels are displayed by their extent only, but this allows one to
    visualize the location of the model and to investigate/edit its properties.

    Note that the model is recomputed at loading !

    Args:
        source (file_like or gmlib.GeologicalModel): GeoModel project files (`*.xml`).
    """

    LAYER_TYPE = PluginSettings().value("package_name")
    DEFAULT_PROTOCOL = 1
    DEFAULT_PARAMS: ClassVar[dict] = {
        "name": f"{PluginSettings().value('package_name')} layer",
        "crs": QgsProject.instance().crs(),
        "source": "",
    }

    def __init__(self, **kwargs):
        # decode source (if there is one)
        source = kwargs.get("source")
        layerID = (
            pathlib.Path(source).stem
            if source
            else self.__class__.DEFAULT_PARAMS["name"]
        )
        super().__init__(GmLayer.LAYER_TYPE, layerID)
        self.setSource(source)

    def renderer(self) -> QgsFeatureRenderer:
        """Returns the feature renderer used for rendering the layer features in 2D map views."""
        return self._renderer if hasattr(self, "_renderer") else None

    def setRenderer(self, renderer: QgsFeatureRenderer):
        """Sets the feature renderer which will be invoked to represent this layer in 2D map views. Ownership is transferred."""
        if renderer and not isinstance(renderer, QgsFeatureRenderer):
            msg = f"renderer should be QgsFeatureRenderer, not: {type(renderer)}"
            raise TypeError(msg)
        self._renderer = renderer

    @property
    def model(self) -> GeologicalModel:
        """Returns the GeoModel."""
        return self.customProperty("model", None)

    def box(self) -> QgsBox3d:
        """Returns the GeoModel 3D extent."""
        if not self.model:
            return None
        box = self.model.getbox()
        return QgsBox3d(
            QgsPoint(box.xmin, box.ymin, box.zmin),
            QgsPoint(box.xmax, box.ymax, box.zmax),
        )

    def extent(self) -> QgsRectangle:
        """Returns the GeoModel 2D extent."""
        if not self.model:
            return None
        box = self.model.getbox()
        return QgsRectangle(
            box.xmin,
            box.ymin,
            box.xmax,
            box.ymax,
        )

    def setExtent(self, rect: QgsRectangle):
        raise NotImplementedError()

    def crs(self) -> QgsCoordinateReferenceSystem:
        """Returns layer's spatial reference system."""
        return super().crs()

    def setCrs(self, crs: QgsCoordinateReferenceSystem, force: bool = True):
        """Sets layer's spatial reference system."""
        crs = QgsCoordinateReferenceSystem(crs)
        if not force and not crs.isValid():
            return
        super().setCrs(crs)

    def createMapRenderer(
        self, rendererContext: QgsRenderContext
    ) -> QgsMapLayerRenderer:
        """Returns new instance of QgsMapLayerRenderer that will be used for rendering of given context"""
        return GmLayerRenderer(self.id(), rendererContext)

    @staticmethod
    def decodeSource(
        source: str | pickle | GeologicalModel,
    ) -> tuple[dict, GeologicalModel]:
        """Will parse the XML project file metadata and load the GeoModel."""
        meta, model = GmLayer.DEFAULT_PARAMS, None
        if not source:  # let it dirty and erase source/data
            return meta, model
        if isinstance(source, str):  # either a file or encoded pickle
            if (
                len(source) < 260 and pathlib.Path(source).is_file()
            ):  # paths are limited to 260 char !
                try:
                    meta["name"] = pathlib.Path(source).stem
                    model = GeologicalModel(source)
                    meta["source"] = pathlib.Path(source).as_posix()
                    meta["crs"] = (
                        QgsCoordinateReferenceSystem(model.crs.qgis)
                        if model.crs.qgis
                        else QgsCoordinateReferenceSystem(model.crs.geomodeller)
                    )
                except OSError as err:
                    iface.messageBar().pushMessage(
                        "IOError", str(err), level=Qgis.Critical
                    )
                return meta, model
            return GmLayer.decodeSource(codecs.decode(source, "base64"))
        if isinstance(source, bytes):  # try unpickle bytes
            try:
                obj = pickle.loads(source)
                return_model = GmLayer.decodeSource(obj)
            except Exception as err:
                iface.messageBar().pushMessage(
                    err.__class__.__name__, str(err), level=Qgis.Critical
                )
            finally:
                return_model = meta, model
            return return_model
        if isinstance(source, GeologicalModel):  # wrapper around gmlib.GeologicalModel
            model = source
            meta["crs"] = (
                QgsCoordinateReferenceSystem(model.crs.qgis)
                if model.crs.qgis
                else QgsCoordinateReferenceSystem(model.crs.geomodeller)
            )
            return meta, model
        msg = f"source should be [str|pickle|GeologicalModel], not: {type(source)}"
        raise TypeError(msg)

    def source(self) -> str:
        """Returns the source for the layer. This source may contain usernames, passwords and other sensitive information."""
        return super().source()

    def setSource(self, source: str):
        """ "Set the data source for the layer. Emit a signal that layer's data changed."""

        super().setSource(source)

        params, model = self.__class__.decodeSource(self.source())
        self.setName(params["name"])
        self.setCrs(params["crs"])
        if isinstance(model, GeologicalModel):
            self.setCustomProperty("model", model)
            self.setRenderer(GmStackRenderer(self, self.name()))
            self.setValid(True)
        else:
            self.setCustomProperty("model", None)
            self.setRenderer(None)
            self.setValid(False)
        self.dataChanged.emit()

    def setTransformContext(self, transformContext: QgsCoordinateTransformContext):  # noqa: ARG002
        """Sets the coordinate transform context to `transformContext`"""
        return

    def readXml(self, node: QDomNode, context: QgsReadWriteContext) -> bool:
        """Called by readLayerXML(), used by children to read state specific to them from project files."""

        super().readXml(node, context)

        def decode(obj):
            dump = codecs.encode(obj, encoding="base64").encode("latin-1")
            return pickle.loads(dump, self.DEFAULT_PROTOCOL)

        # custom properties
        if self.source():
            self.setSource(self.source())
        else:
            model = node.toElement().attribute("model")
            self.setCustomProperty("model", decode(model))
            if self.model:  # FIXME: not sufficient check
                self.setValid(True)
        return True

    def writeXml(
        self, node: QDomNode, doc: QDomDocument, context: QgsReadWriteContext
    ) -> bool:
        """Called by writeLayerXML(), used by children to write state specific to them to project files."""
        super().writeXml(node, doc, context)

        def encode(obj):
            dump = pickle.dumps(obj, self.DEFAULT_PROTOCOL)
            return codecs.decode(dump, encoding="base64").decode("latin-1")

        element = node.toElement()
        # write plugin layer type to project  (essential to be read from project)
        element.setAttribute("type", "plugin")
        element.setAttribute("name", self.LAYER_TYPE)
        # custom properties
        if not self.source():
            model = encode(self.model)
            element.setAttribute("model", model)
        return True

    @classmethod
    def isinstance(cls, obj: object) -> bool:
        """Return whether an object is an instance of `GmLayer`."""

        # handle search by layerId
        if (
            isinstance(obj, str)
            and obj in QgsProject.instance().layerStore().mapLayers()
        ):
            obj = QgsProject.instance().layerStore().mapLayers()[obj]
        # layer type can be return as baseclass QgsPluginLayer...
        if isinstance(obj, cls):
            return True
        if isinstance(obj, QgsPluginLayer):
            return obj.pluginLayerType() == cls.LAYER_TYPE
        return False


class GmLayerType(QgsPluginLayerType):
    def __init__(self):
        super().__init__(GmLayer.LAYER_TYPE)

    def createLayer(self, source: str = "", name: str = "") -> GmLayer:
        """Creates a GmLayer, given a uri and name."""
        return GmLayer(source=source, name=name)

    @staticmethod
    def isinstance(obj: object) -> bool:
        """Return whether an object is an instance of `GmLayer`."""
        return GmLayer.isinstance(obj)

    # You can also add GUI code for displaying custom information in the layer properties
    def showLayerProperties(self, layer: GmLayer):
        """Display the GeoModel properties."""
        try:
            dlg = GmLayerProperties(iface.mapCanvas(), iface.messageBar(), layer, iface)
            dlg.exec()
            return True
        except Exception:
            return False


class GmLayerProperties(QgsOptionsDialogBase):
    """Generate a dialog widget for GmLayer options and properties.

    Args:
        layer (qgis.core.QgsMapLayer): The layer to parameter.
    """

    def __init__(self, layer: QgsMapLayer, parent: QObject = None):
        super().__init__("GeomodelLayerProperties", parent)
        self.layer = layer


class GmLayerRenderer(QgsMapLayerRenderer):
    """Utility class that encapsulate information necessary for the rendering of GmLayer.

    Args:
        layerId (str): The id of the layer to be rendered.
    """

    def __init__(self, layerId: str, rendererContext: QgsRenderContext):
        super().__init__(layerId, rendererContext)

    def render(self) -> bool:
        """Display the extent if visible in the canvas current state."""
        # required validations of the layer to be rendered
        mapLayers = QgsProject.instance().mapLayers()
        if self.layerId() not in mapLayers:
            self.errors().append("Unknown layer.")
            return False
        layer = QgsProject.instance().mapLayers()[self.layerId()]

        if not layer.isValid() or not layer.extent():
            self.errors().append("Invalid layer.")
            return False
        if not layer.crs() or not layer.crs().isValid():
            self.errors().append("Invalid CRS.")
            return False
        if not layer.renderer():
            self.errors().append("No renderer for drawing.")
            return False

        # check for visibility in the canvas current state
        canvasExtent = self.renderContext().mapExtent()
        layerExtent = layer.extent()
        layerExtent = (
            self.renderContext().coordinateTransform().transformBoundingBox(layerExtent)
        )
        if not (
            canvasExtent.contains(layerExtent) or canvasExtent.intersects(layerExtent)
        ):
            return True  # Not an error

        # create a feature symbolyzing the extent
        renderer = layer.renderer().clone()
        fields = QgsFields()
        attribute = renderer.classAttribute()
        fields.append(QgsField(attribute, QVariant.Int))
        feature = QgsFeature(fields)
        feature.setGeometry(QgsGeometry.fromRect(layer.extent()))
        feature.setAttribute(attribute, str(-1))

        # display the extent feature
        renderer.startRender(self.renderContext(), feature.fields())
        ok = renderer.renderFeature(feature, self.renderContext())
        renderer.stopRender(self.renderContext())
        return ok


class GmStackRenderer(QgsCategorizedSymbolRenderer):
    SymbolType: ClassVar[dict] = {
        QgsSymbol.SymbolType.Marker: QgsMarkerSymbol,
        QgsSymbol.SymbolType.Line: QgsLineSymbol,
        QgsSymbol.SymbolType.Fill: QgsFillSymbol,
    }

    def __init__(
        self,
        layer: QgsMapLayer,
        attrName: str | None = None,
        withExtent: bool = True,
        symbolType: QgsSymbol.SymbolType = QgsSymbol.SymbolType.Fill,
    ):
        """The GmLayer symbol renderer based on qgis.core.QgsCategorizedSymbolRenderer.

        Args:
            layer (qgis.core.QgsMapLayer): The GeoModel wrapping layer.
            attrName (str, optional): Specifies the layer's field name, or expression, which the categories will be matched against. Defaults to `layer`.name().
            withExtent (bool, optional): Whether to create a dummy category for extent. Defaults to True.
            symbolType (qgis.core.QgsSymbol.SymbolType, optional): The type of rendering symbol. Defaults to Fill.
        """
        # match layer formations against attrName
        if not attrName:
            attrName = layer.name()
        super().__init__(attrName)
        # create a dummy category for the extent (value = -1)
        if withExtent:
            color = QColor(PluginSettings().value("gui/default_color"))
            color.setAlphaF(0.0)
            self.addCategory(
                QgsRendererCategory(
                    QVariant(-1),  # value
                    self.SymbolType[QgsSymbol.SymbolType.Fill].createSimple(
                        # Extent will always be a Fill symbol
                        {
                            "color": color.name(QColor.NameFormat.HexArgb),
                            "outline_color": color.name(QColor.NameFormat.HexRgb),
                            "outline_width": "5",
                            "outline_width_unit": "Pixel",
                            "joinstyle": "miter",
                        }
                    ),
                    "Extent",
                ),
            )
        # create a 'category' per formation
        model = layer.model
        for value, label in enumerate(model.pile_formations):
            rgbF = model.formation_colors[label]  # float RGB array
            rgbI = [int(x * 255) for x in rgbF]
            rgbI += [255] if len(rgbI) < 4 else []
            color = QColor(*rgbI)
            symbol = self.SymbolType[symbolType].createSimple(
                {
                    "color": color.name(QColor.NameFormat.HexArgb),
                    "outline_color": color.name(QColor.NameFormat.HexArgb),
                }
            )
            category = QgsRendererCategory(str(value + 1), symbol, label)
            self.addCategory(category)

    def cloneToType(
        self, symbolType: QgsSymbol.SymbolType, hideExtent: bool = True
    ) -> QgsCategorizedSymbolRenderer:
        """Helper to adapt renderer to another geometry type."""
        clone = self.clone(hideExtent)
        if self.symbolType != symbolType:
            for category in clone.categories():
                props = category.symbol().dump()
                symbol = self.SymbolType[symbolType].createSimple(props)
                category.setSymbol(symbol)
        return clone

    def clone(self, hideExtent: bool = False) -> QgsCategorizedSymbolRenderer:
        clone = super().clone()
        if hideExtent:
            extentId = clone.categoryIndexForLabel("Extent")
            if extentId >= 0:
                clone.deleteCategory(clone.categoryIndexForLabel("Extent"))
        return clone
