from __future__ import annotations

from pathlib import Path
from typing import Optional, cast

from qgis.core import (
    Qgis,
    QgsAbstractProfileSource,
    QgsCoordinateReferenceSystem,
    QgsCoordinateTransform,
    QgsCoordinateTransformContext,
    QgsCsException,
    QgsDoubleRange,
    QgsGeometry,
    QgsLineSymbol,
    QgsMapLayerFactory,
    QgsMapLayerRenderer,
    QgsPluginLayer,
    QgsPluginLayerType,
    QgsReadWriteContext,
    QgsRectangle,
    QgsRenderContext,
    QgsSymbolLayerUtils,
    QgsTextFormat,
    QgsTextRenderer,
    QgsVectorLayerElevationProperties,
    QgsXmlUtils,
)
from qgis.PyQt.QtCore import QRect, QVariant
from qgis.PyQt.QtGui import QPolygonF
from qgis.PyQt.QtWidgets import QDialog
from qgis.PyQt.QtXml import QDomDocument, QDomNode

from woody.gui.dlg_image_layer_properties import ImageLayerPropertiesDialog
from woody.layer.image_layer_profile import ImageLayerProfileSource
from woody.toolbelt import PlgLogger
from woody.utils import refreshAllElevationProfiles


class ImageLayer(QgsPluginLayer):
    """Main class to define the layer"""

    LAYER_TYPE = "ImageLayer"

    def __init__(self, layer_name: str) -> None:
        super().__init__(self.LAYER_TYPE, layer_name)

        self.setValid(True)
        self._profile_source: Optional[ImageLayerProfileSource] = None
        self._geometry: Optional[QgsGeometry] = None
        self._properties = ImageLayerProperties(self)

    def setImagePath(self, image_path: Path) -> None:
        if not image_path.exists():
            PlgLogger.log(
                "ImageLayer: image path does not exist", log_level=Qgis.MessageLevel.Critical
            )
        self._properties.image_path = image_path
        refreshAllElevationProfiles(self)

    def getImagePath(self) -> Optional[Path]:
        if self._properties.image_path is None:
            return

        if self._properties.image_path.is_dir() or not self._properties.image_path.exists():
            PlgLogger.log(
                self.tr(
                    "Profile image not found for {}! The image path may be invalid.\n{}"
                ).format(self.name(), self._properties.image_path.resolve()),
                log_level=Qgis.MessageLevel.Critical,
                push=True,
            )
            return

        return self._properties.image_path

    def setImageZ(self, z_range: QgsDoubleRange) -> None:
        self._properties.z_range = z_range
        refreshAllElevationProfiles(self)

    def getImageZ(self) -> Optional[QgsDoubleRange]:
        if self._properties.z_range.isInfinite():
            return None
        return self._properties.z_range

    def setTransformContext(self, transform_context: QgsCoordinateTransformContext) -> None:
        """Describe how to handle coordinate transformation"""
        pass

    def createMapRenderer(self, rendererContext: QgsRenderContext) -> QgsMapLayerRenderer:
        """Renderer generator"""
        return ImageLayerRenderer(
            self.id(), rendererContext, self._properties, self.crs(), self._geometry
        )

    def elevationProperties(self) -> QgsVectorLayerElevationProperties:
        """Elevation Properties generator"""
        return ImageLayerElevationProperties()

    def profileSource(self) -> QgsAbstractProfileSource:
        """Returns the layer's profile source"""

        # keep a reference to avoid garbage collection
        if self._profile_source is None:
            self._profile_source = ImageLayerProfileSource(self)

        return self._profile_source

    def setImageFlipped(self, image_flipped: bool) -> None:
        self._properties.image_flipped = image_flipped
        refreshAllElevationProfiles(self)

    def isImageFlipped(self) -> bool:
        return self._properties.image_flipped

    def setImageBounds(self, image_bounds) -> None:
        self._properties.image_bounds = image_bounds
        refreshAllElevationProfiles(self)

    def getImageBounds(self) -> QRect:
        return self._properties.image_bounds

    def getImageOpacity(self) -> int:
        return self._properties.image_opacity

    def setImageOpacity(self, opacity) -> None:
        self._properties.image_opacity = opacity
        refreshAllElevationProfiles(self)

    def setGeometry(self, geometry: Optional[QgsGeometry]) -> None:
        self._geometry = geometry
        extent = self._geometry.boundingBox() if self._geometry else QgsRectangle()
        self.setExtent(extent)
        self.repaintRequested.emit()  # for map canvas
        refreshAllElevationProfiles(self)

    def geometry(self) -> Optional[QgsGeometry]:
        return self._geometry

    def writeXml(
        self, layer_node: QDomNode, document: QDomDocument, context: QgsReadWriteContext
    ) -> bool:
        """
        Write all that is related to the layer into an XML node of an XML document.
        Called when the QGIS project is saved.
        """

        layer_elmt = layer_node.toElement()
        if layer_elmt.isNull() or layer_elmt.nodeName() != "maplayer":
            return False

        # Set plugin layer attributes
        layer_elmt.setAttribute("type", QgsMapLayerFactory.typeToString(Qgis.LayerType.Plugin))
        layer_elmt.setAttribute("name", ImageLayer.LAYER_TYPE)

        # Save image geometry
        image_geometry_elmt = QgsXmlUtils.writeVariant(QVariant(self._geometry), document)
        image_geometry_elmt.setTagName("image_geometry")
        layer_elmt.appendChild(image_geometry_elmt)

        # Save layer properties
        self._properties.writeXml(layer_node, document, context)

        return True

    def readXml(self, layer_node: QDomNode, context: QgsReadWriteContext) -> bool:
        """
        Read all that is related to the layer from an XML node.
        Called when the QGIS project is loaded.
        """

        layer_elmt = layer_node.toElement()

        # Be sure that the node corresponds to an ImageLayer from our plugin
        if (
            layer_elmt.isNull()
            or layer_elmt.nodeName() != "maplayer"
            or layer_elmt.attribute("type")
            != QgsMapLayerFactory.typeToString(Qgis.LayerType.Plugin)
            or layer_elmt.attribute("name") != ImageLayer.LAYER_TYPE
        ):
            return False

        # Load the geometry
        geometry = QgsXmlUtils.readVariant(layer_elmt.firstChildElement("image_geometry"))
        self.setGeometry(geometry)

        # Load the layer properties
        self._properties.readXml(layer_node, context)

        return True


class ImageLayerProperties:
    """This class lists all the properties associated with the ImageLayer type"""

    def __init__(self, layer: ImageLayer) -> None:
        # The main 2D map canvas symbol. We keep a default symbol to be reused if needed.
        if (default_symbol := QgsLineSymbol.createSimple({"color": "black"})) is None:
            raise RuntimeError("Error while creating default line symbol layer")
        self.default_symbol = default_symbol
        self.map_symbol = default_symbol

        self.map_label_format = QgsTextFormat()
        self.map_label = layer.name()
        self.image_flipped = False
        self.image_opacity: int = 100
        self.image_path: Optional[Path] = None
        self.image_bounds = QRect()
        self.z_range = QgsDoubleRange()

    def __str__(self) -> str:
        """
        Convenient pretty print method.
        """

        return (
            f"map_label: {self.map_label}\n"
            f"map_label_format: {self.map_label_format.asCSS()}\n"
            f"map_symbol: {self.map_symbol.dump()}\n"
            f"image_flipped: {self.image_flipped}\n"
            f"image_opacity: {self.image_opacity}%\n"
            f"image_bounds: {self.image_bounds}\n"
            f"image_path: {self.image_path}\n"
            f"z_range : [{self.z_range.lower()}, {self.z_range.upper()}]"
        )

    def __eq__(self, other: object) -> bool:
        """
        Equality based on some trivial comparisons.
        """

        if not isinstance(other, ImageLayerProperties):
            return False

        return (
            self.map_label == other.map_label
            and self.map_label_format.asCSS() == other.map_label_format.asCSS()
            and self.map_symbol.dump() == other.map_symbol.dump()
            and self.image_flipped == other.image_flipped
            and self.image_opacity == other.image_opacity
            and self.image_bounds == other.image_bounds
            and self.image_path == other.image_path
            and self.z_range.lower() == other.z_range.lower()
            and self.z_range.upper() == other.z_range.upper()
        )

    def writeXml(self, node: QDomNode, document: QDomDocument, context: QgsReadWriteContext):
        """
        Write the properties in an XML node of an XML document.
        """

        map_label_format_elmt = document.createElement("map_label_format")
        map_label_format_elmt.appendChild(self.map_label_format.writeXml(document, context))
        node.appendChild(map_label_format_elmt)

        map_label_elmt = QgsXmlUtils.writeVariant(self.map_label, document)
        map_label_elmt.setTagName("map_label")
        node.appendChild(map_label_elmt)

        map_symbol_elmt = document.createElement("map_symbol")
        map_symbol_elmt.appendChild(
            QgsSymbolLayerUtils.saveSymbol("image_layer_symbol", self.map_symbol, document, context)
        )
        node.appendChild(map_symbol_elmt)

        image_flipped_elmt = QgsXmlUtils.writeVariant(self.image_flipped, document)
        image_flipped_elmt.setTagName("image_flipped")
        node.appendChild(image_flipped_elmt)

        image_opacity_elmt = QgsXmlUtils.writeVariant(self.image_opacity, document)
        image_opacity_elmt.setTagName("image_opacity")
        node.appendChild(image_opacity_elmt)

        if not self.image_bounds.isNull():
            image_bounds_elmt = document.createElement("image_bounds")
            image_bounds_elmt.setAttribute("xMin", str(self.image_bounds.left()))
            image_bounds_elmt.setAttribute("yMin", str(self.image_bounds.top()))
            image_bounds_elmt.setAttribute("width", str(self.image_bounds.width()))
            image_bounds_elmt.setAttribute("height", str(self.image_bounds.height()))
            node.appendChild(image_bounds_elmt)

        if self.image_path is not None:
            image_path_elmt = QgsXmlUtils.writeVariant(
                context.pathResolver().writePath(str(self.image_path)), document
            )
        else:
            image_path_elmt = QgsXmlUtils.writeVariant(self.image_path, document)
        image_path_elmt.setTagName("image_path")
        node.appendChild(image_path_elmt)

        z_range_elmt = document.createElement("z_range")
        z_range_elmt.setAttribute("upper", round(self.z_range.upper(), 8))
        z_range_elmt.setAttribute("lower", round(self.z_range.lower(), 8))
        node.appendChild(z_range_elmt)

    def readXml(self, node: QDomNode, context: QgsReadWriteContext) -> None:
        """
        Read the properties from an XML node.
        """

        self.map_label = QgsXmlUtils.readVariant(node.firstChildElement("map_label"))
        self.image_flipped = QgsXmlUtils.readVariant(node.firstChildElement("image_flipped"))
        self.image_opacity = QgsXmlUtils.readVariant(node.firstChildElement("image_opacity"))

        bounds_elt = node.firstChildElement("image_bounds")
        if not bounds_elt.isNull():
            xMin = int(bounds_elt.attribute("xMin"))
            yMin = int(bounds_elt.attribute("yMin"))
            width = int(bounds_elt.attribute("width"))
            height = int(bounds_elt.attribute("height"))
            self.image_bounds = QRect(xMin, yMin, width, height)
        else:
            self.image_bounds = QRect()

        # Image path can be None which is an accepted value
        if (
            image_path := QgsXmlUtils.readVariant(node.firstChildElement("image_path"))
        ) is not None:
            self.image_path = Path(context.pathResolver().readPath(image_path))
        else:
            self.image_path = image_path

        self.map_label_format.readXml(node.firstChildElement("map_label_format"), context)

        # Symbol can be None but it is not an accepted value so we set it to a default value.
        if (
            symbol := QgsSymbolLayerUtils.loadSymbol(
                node.firstChildElement("map_symbol").firstChildElement(), context
            )
        ) is not None:
            self.map_symbol = cast(QgsLineSymbol, symbol)
        else:
            symbol = self.default_symbol
            PlgLogger.log(
                "ImageLayerProperties: could not load symbol from project file."
                "Default symbol used.",
                log_level=Qgis.MessageLevel.Warning,
            )

        z_range_elmt = node.firstChildElement("z_range")
        self.z_range = QgsDoubleRange(
            float(z_range_elmt.attributeNode("lower").nodeValue()),
            float(z_range_elmt.attributeNode("upper").nodeValue()),
        )


class ImageLayerType(QgsPluginLayerType):
    """This class defines how to create an ImageLayer"""

    def __init__(self) -> None:
        super().__init__(ImageLayer.LAYER_TYPE)

    def createLayer(self) -> ImageLayer:
        return ImageLayer("")

    def showLayerProperties(self, layer: Optional[QgsPluginLayer]) -> bool:
        if not isinstance(layer, ImageLayer):
            return False

        dialog = ImageLayerPropertiesDialog(layer._properties)
        dialog.setWindowTitle(f"Properties - {layer.name()} (ImageLayer)")
        if dialog.exec() == QDialog.DialogCode.Accepted:
            layer.repaintRequested.emit()  # for map canvas
            refreshAllElevationProfiles(layer)

        return True


class ImageLayerRenderer(QgsMapLayerRenderer):
    """This class defines how to draw on the canvas"""

    def __init__(
        self,
        layerId: str,
        rendererContext: QgsRenderContext,
        layer_properties: ImageLayerProperties,
        layer_crs: Optional[QgsCoordinateReferenceSystem],
        layer_geometry: Optional[QgsGeometry],
    ) -> None:
        super().__init__(layerId, rendererContext)
        self._layer_properties = layer_properties
        self._layer_geometry = layer_geometry
        self._layer_crs = layer_crs

        self._log = PlgLogger().log

    def render(self) -> bool:
        """Do the rendering"""

        if (context := self.renderContext()) is None:
            return False

        # No geometry defined, don't do anything
        if self._layer_geometry is None or self._layer_crs is None:
            return True

        # Transform the curve into the canvas CRS
        map_transform = QgsCoordinateTransform(
            self._layer_crs,
            context.coordinateTransform().destinationCrs(),
            context.transformContext(),
        )
        geom_transformed = QgsGeometry(self._layer_geometry)

        try:
            geom_transformed.transform(map_transform)
        except QgsCsException:
            self._log(
                message="ImageLayerRenderer, unable to apply coordinate transform",
                log_level=Qgis.MessageLevel.Critical,
                push=False,
            )
            return False

        # Geometry is not visible, don't do anything
        if not context.mapExtent().intersects(geom_transformed.boundingBox()):
            return True

        points = QPolygonF(
            [context.mapToPixel().transform(xy).toQPointF() for xy in geom_transformed.asPolyline()]
        )

        # Draw line symbol with all its symbol layers
        self._layer_properties.map_symbol.startRender(context)
        self._layer_properties.map_symbol.renderPolyline(points=points, f=None, context=context)
        self._layer_properties.map_symbol.stopRender(context)
        # Draw text at the begining of the line
        QgsTextRenderer.drawTextOnLine(
            points,
            self._layer_properties.map_label,
            context,
            self._layer_properties.map_label_format,
            offsetAlongLine=0,
            offsetFromLine=-self._layer_properties.map_label_format.font().pointSize(),
        )

        return True


class ImageLayerElevationProperties(QgsVectorLayerElevationProperties):
    def __init__(self) -> None:
        super().__init__(None)

    def showByDefaultInElevationProfilePlots(self) -> bool:
        return True
