# Treat all annotations as strings to avoid circular imports.
# When PEP 649 is implemented (Python 3.14) we can
# remove this, and the annotations will be deferred
from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, List, Optional, cast

from qgis.core import (
    Qgis,
    QgsAbstractProfileGenerator,
    QgsAbstractProfileResults,
    QgsAbstractProfileSource,
    QgsCoordinateTransform,
    QgsCsException,
    QgsDoubleRange,
    QgsFeedback,
    QgsGeometry,
    QgsLineString,
    QgsProfileGenerationContext,
    QgsProfileRenderContext,
    QgsProfileRequest,
)
from qgis.PyQt.QtCore import QRect
from qgis.PyQt.QtGui import QImage

from woody.toolbelt.log_handler import PlgLogger

if TYPE_CHECKING:
    from woody.layer.image_layer import ImageLayer


class ImageLayerProfileSource(QgsAbstractProfileSource):
    """class for Storage of map layer elevation properties"""

    def __init__(self, layer: ImageLayer) -> None:
        super().__init__()
        self._layer = layer

    def createProfileGenerator(self, request: QgsProfileRequest) -> ImageLayerProfileGenerator:
        # we do not need to keep a reference because the function is tagged as SIP_FACTOR
        # on the C++ side
        return ImageLayerProfileGenerator(self._layer, request)


class ImageLayerProfileGenerator(QgsAbstractProfileGenerator):
    """class for Storage of map layer elevation properties"""

    def __init__(self, layer: ImageLayer, request: QgsProfileRequest) -> None:
        super().__init__()

        self._layer = layer
        self._results: Optional[ImageLayerProfileResults] = None
        self._feedback = QgsFeedback()
        self._request = request

    def feedback(self) -> QgsFeedback:
        """Access to feedback object of the generator"""
        return self._feedback

    def sourceId(self) -> str:
        """Returns a unique identifier representing the source of the profile"""
        return self._layer.id()

    def generateProfile(self, context: QgsProfileGenerationContext) -> bool:
        """
        Generate the profile - called from a worker thread
        Returns true if the profile was generated successfully
        (i.e. the generation was not cancelled early).
        """

        # No image if no image geometry or no image path or no z_range or no profile curve
        if (
            self._layer.geometry() is None
            or (filename := self._layer.getImagePath()) is None
            or (z_values := self._layer.getImageZ()) is None
            or (profile_curve := self._request.profileCurve()) is None
        ):
            return True

        # Clone profile curve to avoid working on the original object
        # (risks of segfaults when a method takes ownership of the object)
        linestring_profile = QgsGeometry(profile_curve.clone())

        layer_geometry = QgsGeometry(self._layer.geometry())

        # Transform the layer geometry to the profile curve CRS
        try:
            layer_geometry.transform(
                QgsCoordinateTransform(
                    self._layer.crs(), self._request.crs(), self._request.transformContext()
                )
            )
        except QgsCsException:
            PlgLogger.log(
                message="ImageLayerProfileGenerator: unable to apply coordinate transform "
                "from  image layer CRS to profile curve CRS",
                log_level=Qgis.MessageLevel.Critical,
            )
            return False

        # Reproduction of the buffered profile curve as we see it in the map canvas
        buffered_profile = linestring_profile.buffer(
            self._request.tolerance(), 8, Qgis.EndCapStyle.Flat, Qgis.JoinStyle.Round, 2
        )

        # No image if image geometry does not intersect the profile curve
        if not buffered_profile.intersects(layer_geometry):
            return True

        # MultiLineString of all intersections
        intersections = buffered_profile.intersection(layer_geometry)

        # List of individual LineStrings from all intersections
        intersections_list = cast(list[QgsLineString], list(intersections.parts()))

        # List of start and end distances for each intersection
        profile_distances = self.sectionsAlongLine(
            linestring_profile, intersections_list, percentage=False, withOrientation=True
        )

        # List of start end end image portion (in image width percentage) for each intersection
        image_ranges = self.sectionsAlongLine(
            layer_geometry, intersections_list, percentage=True, withOrientation=False
        )

        self._results = ImageLayerProfileResults(
            filename,
            z_values,
            profile_distances,
            image_ranges,
            self._layer.isImageFlipped(),
            self._layer.getImageOpacity(),
            self._layer.getImageBounds(),
        )

        return not self._feedback.isCanceled()

    def takeResults(self) -> Optional[QgsAbstractProfileResults]:
        """Takes results from the generator"""
        return self._results

    def sectionsAlongLine(
        self,
        reference_line: QgsGeometry,
        sections: list[QgsLineString],
        percentage=False,
        withOrientation=False,
    ) -> list[list]:
        """
        Returns a list containing [min distance, max distance] for each section
        along a reference line, in map unit or in percentage.

        If withOrientation is True, returns a list containing [min distance, max distance, reversed]
        to know if the section is reversed compared to the reference line.
        """

        # List of distances starting from the beginning of the reference line
        # for each point of each section
        distance_points = [
            [
                reference_line.lineLocatePoint(QgsGeometry.fromPoint(point))
                for point in linestring.points()
            ]
            for linestring in sections
        ]

        reference_length = reference_line.length()
        result = []
        for distances in distance_points:
            min_val = min(distances) / reference_length if percentage else min(distances)
            max_val = max(distances) / reference_length if percentage else max(distances)
            if withOrientation:
                result.append((min_val, max_val, distances[0] - distances[-1] > 0))
            else:
                result.append([min_val, max_val])

        return result


class ImageLayerProfileResults(QgsAbstractProfileResults):
    """Class to render the results as features, geometries, or display them on the profile"""

    def __init__(
        self,
        filename: Path,
        z_range: QgsDoubleRange,
        profile_distances: list[list],
        image_ranges: list[list],
        is_image_flipped: bool,
        image_opacity: int,
        image_bounds: QRect,
    ) -> None:
        super().__init__()

        self.z_range = z_range
        self.profile_distances = profile_distances
        self.image_ranges = image_ranges
        self.image_opacity = image_opacity
        self.image_bounds = image_bounds

        self.image = QImage(str(filename))
        # crop the image if necessary
        if self.image_bounds.isValid():
            self.image = self.image.copy(self.image_bounds)

        self.flipped_image = self.image.mirrored(True, False)

        # Flip the entire image if the layer property is set by the user.
        # In fact, we swap our 2 variables ;-)
        if is_image_flipped:
            self.image, self.flipped_image = self.flipped_image, self.image

    def type(self) -> str:  # noqa: A003
        return "ImageProfileResults"

    def asGeometries(self) -> List[QgsGeometry]:
        """Returns a list of geometries representing the calculated elevation results."""
        # Not needed as we display a background image
        return []

    def asFeatures(
        self,
        type: Qgis.ProfileExportType,
        feedback: Optional[QgsFeedback] = None,  # noqa: A002
    ) -> List[QgsAbstractProfileResults.Feature]:  # noqa: A002
        """Returns a list of features representing the calculated elevation results"""
        # Not needed as we display a background image
        return []

    def renderResults(self, context: QgsProfileRenderContext) -> None:
        """Renders the results"""

        if (painter := context.renderContext().painter()) is None:
            return

        y_ratio = painter.viewport().height() / (
            context.elevationRange().upper() - context.elevationRange().lower()
        )
        x_ratio = painter.viewport().width() / (
            context.distanceRange().upper() - context.distanceRange().lower()
        )

        # Manage the image opacity
        painter.setOpacity(self.image_opacity / 100.0)

        # Loop on each image portion.
        # Each range-pair corresponds to a distance-pair.
        # When reversed is True the current image portion must be flipped.
        for [range_min, range_max], [dist_min, dist_max, reversed] in zip(
            *[self.image_ranges, self.profile_distances]
        ):
            # The pixel coordinates of the viewport in which the image portion will be rendered
            target = QRect(
                round(x_ratio * (dist_min - context.distanceRange().lower())),
                round(y_ratio * (context.elevationRange().upper() - self.z_range.upper())),
                round(x_ratio * (dist_max - dist_min)),
                round(y_ratio * (self.z_range.upper() - self.z_range.lower())),
            )

            # The pixel coordinates of the image portion.
            # The width and heigth do not change even if the image portion must be flipped.
            source = QRect()
            source.setHeight(self.image.height())
            source.setWidth(round(self.image.width() * (range_max - range_min)))

            # If the image portion must be fipped, we must use the flipped image
            # and adapt the location of the source rectangle to get the correct portion.
            if reversed:
                source.translate(round(self.image.width() * (1 - range_max)), 0)
                painter.drawImage(target, self.flipped_image, source)
            else:
                source.translate(round(self.image.width() * range_min), 0)
                painter.drawImage(target, self.image, source)

    def zRange(self) -> QgsDoubleRange:
        """Returns the range of the retrieved elevation values"""
        return self.z_range
