from functools import cmp_to_key

import numpy as np
from forgeo.interpolation import BSPTreeBuilder
from forgeo.rigs import all_intersections
from forgeo.rigs.trirect import trigrid
from qgis.core import (
    QgsPluginLayer,
    QgsProfileRequest,
    QgsProject,
    QgsRasterDemTerrainProvider,
    QgsRenderContext,
)
from qgis.PyQt.QtCore import QPointF, QRectF
from qgis.PyQt.QtGui import QColor, QPen, QPolygonF

from ...utils import raster_layer_to_description


def layer_to_rigs_parameters(layer: QgsPluginLayer, request: QgsProfileRequest = None):
    """Takes a ModelLayer or FaultNetworkLayer as input and returns the rigs.BSPTree
    set up and ready to evaluate.

    Additionally, the request that was sent to the QgsAbstractProfileGenerator can
    be porvided to add the project DEM to the BSPTree.
    """
    # Cannot use isinstance(layer, [Model|Fault]Layer) to avoid circular imports...
    if (model := getattr(layer, "model", None)) is not None:
        params = {"model": model}
        if (faults := layer.faultnetlayer_id) is not None:
            params["fault_network"] = QgsProject.instance().mapLayer(faults).faultnet
    elif (faults := getattr(layer, "faultnet", None)) is not None:
        params = {"fault_network": faults}
    else:
        msg = f"Unknown layer type: '{type(layer)}'"
        raise TypeError(msg)
    # Topography
    if (
        ((provider := request.terrainProvider()) is not None)
        and isinstance(provider, QgsRasterDemTerrainProvider)
        and ((layer := provider.layer()) is not None)
    ):
        params["topography"] = raster_layer_to_description(layer)

    params = BSPTreeBuilder.from_params(**params)

    # To ensure nice intersections when drawing. Given S1 and S2 two surfaces,
    # if S1 is truncated by S2, we paint S1 first, so its terminations is hidden
    # when painting S2 over it
    bsp_tree = params["tree"]
    order = [i for i in params["ids"].values() if not bsp_tree.is_boundary(i)]
    order.sort(
        key=cmp_to_key(lambda i, j: 1 if bsp_tree.evaluation_order(i, j) else -1)
    )
    params["drawing_order"] = order
    return params


class ProfileCurve:
    """Stores a curve vertices and several some geometric information"""

    def __init__(self, curve, start_offset=0.0):
        self.vertices = np.asarray(curve, dtype=np.float64)
        self.segments = np.diff(self.vertices, axis=0)
        self.lengths = np.linalg.norm(self.segments, axis=1)
        offsets = np.empty(len(self.vertices), dtype=np.float64)
        offsets[0] = 0.0
        np.cumsum(self.lengths, out=offsets[1:])
        self.vertices_offsets = offsets
        if start_offset > 0.0:
            self.start_offset = start_offset
            self.vertices_offsets += start_offset
        else:
            self.start_offset = 0.0

    @property
    def nb_vertices(self):
        return len(self.vertices)

    @property
    def nb_segments(self):
        return len(self.segments)

    def clip(self, u_lower, u_upper):
        vertices = self.vertices
        offsets = self.vertices_offsets
        nb_vertices = self.nb_segments
        low_clip = True
        if u_lower <= 0.0:
            low_clip = False
        high_clip = True
        if u_upper >= offsets[-1]:
            high_clip = False
        if not (low_clip or high_clip):
            return self
        assert u_lower < u_upper
        start_idx = 0
        if low_clip:
            while u_lower > offsets[start_idx]:
                start_idx += 1
            start_idx -= 1
            if start_idx >= 0:
                assert offsets[start_idx] <= u_lower
            start_offset = offsets[start_idx]
            ratio = (u_lower - start_offset) / (offsets[start_idx + 1] - start_offset)
            v0 = vertices[start_idx]
            vstart = v0 + ratio * (vertices[start_idx + 1] - v0)
        else:
            vstart = vertices[0]
        end_idx = nb_vertices - 1
        if high_clip:
            while u_upper < offsets[end_idx]:
                end_idx -= 1
            assert offsets[end_idx] <= u_upper
            start_offset = offsets[end_idx]
            ratio = (u_upper - start_offset) / (offsets[end_idx + 1] - start_offset)
            v0 = vertices[end_idx]
            vend = v0 + ratio * (vertices[end_idx + 1] - v0)
        else:
            vend = vertices[-1]

        vertices = list(vertices[start_idx + 1 : end_idx + 1])
        vertices.insert(0, vstart)
        vertices.append(vend)
        return ProfileCurve(vertices, u_lower)

    def __repr__(self):
        return (
            f"{self.__class__}\n"
            + f"Vertices: {list(self.vertices)}\n"
            + f"Offsets: {list(self.vertices_offsets)}\n"
        )


def create_rigs_evaluation_surfaces(curve: ProfileCurve, zmin, zmax, nu, nz):
    """
    Parameters
    ----------
    curve: ProfileCurve
        The 2D polyline along which to compute the grids
    zmin, zmax: float
        The vertical range of the grids
    nu, nz: int
        The number of cells to include along u and z axes. nu is the total
        number of cells along u and will be splitted between the different
        surfaces if curve has more than one segment
    """
    assert nu > 0
    assert nz > 1
    delta_u = curve.vertices_offsets[-1] - curve.start_offset
    normalized_lengths = (nu / delta_u) * curve.lengths
    nb_cells = [max(int(nl), 1) for nl in normalized_lengths]

    vertices_3d = []
    triangles = []
    for i in range(curve.nb_segments):
        pts2d, tris = trigrid((nb_cells[i], nz), (0, 1, zmin, zmax))
        v0 = curve.vertices[i]
        seg = curve.segments[i]
        pts3d = np.empty((len(pts2d), 3), dtype=np.float64)
        pts3d[:, :2] = pts2d[:, 0, np.newaxis]  # Copy u in (x,u)
        pts3d[:, :2] *= seg
        pts3d[:, :2] += v0
        pts3d[:, 2] = pts2d[:, 1]
        vertices_3d.append(pts3d)
        triangles.append(tris)
    return vertices_3d, triangles


def renderRigsResults(
    profile_curve: ProfileCurve,
    params: dict,
    context: QgsRenderContext,
    cache,
):
    # Elevation profile ranges
    u = context.distanceRange()
    z = context.elevationRange()
    # Do not draw anything if the profile curve is entirely outside of the
    # elevation profile viewport
    if u.upper() <= 0.0:
        return None
    if u.lower() >= profile_curve.vertices_offsets[-1]:
        return None
    if cache:
        all_lines = cache
    else:
        # cache is None: do not use cache, or
        # cache = {}: cache results, but not computed yet
        all_lines = discretize_along_profile_with_rigs(profile_curve, params, context)
    # QGsRenderContext https://qgis.org/pyqgis/3.40/core/QgsRenderContext.html#qgis.core.QgsRenderContext
    rc = context.renderContext()
    painter = rc.painter()  # https://doc.qt.io/archives/qt-5.15/qpainter.html
    painter.save()  # Backup the previous painter state
    # QRectF(u0, z0, width, height), note : z.lower() > z.upper()
    window = QRectF(u.lower(), z.upper(), u.upper() - u.lower(), z.lower() - z.upper())
    painter.setWindow(window.toRect())
    colormap = [QColor(c) for c in params["colors"]]
    for uid, lines in all_lines.items():
        pen = QPen(colormap[uid])
        # Do not rely on pen.setWidth as pen width is not scale invariant
        pen.setCosmetic(True)  # Width is always 1 pixel when displayed
        painter.setPen(pen)
        for line in lines:
            painter.drawPolyline(line)
    # Finalize
    painter.restore()  # Restore the previous painter state
    if cache is None:
        return None
    return all_lines


def discretize_along_profile_with_rigs(
    profile_curve: ProfileCurve,
    params: dict,
    context: QgsRenderContext,
):
    # Elevation profile ranges
    u = context.distanceRange()
    z = context.elevationRange()
    # Viewport size (used to compute the number of cells for discretization)
    viewport = context.renderContext().painter().viewport()
    delta_u = viewport.right() - viewport.left()
    delta_z = viewport.bottom() - viewport.top()
    assert delta_u > 0
    assert delta_z > 0
    nmax = 100  # TODO Relation with discretization parameters?
    if delta_z < delta_u:
        nu = nmax
        nz = int(nu * delta_z / delta_u) + 1
    else:
        nz = nmax
        nu = int(nz * delta_u / delta_z) + 1
    # Recompute the profile on the fly
    profile_curve = profile_curve.clip(u.lower(), u.upper())
    vertices_3d, triangles = create_rigs_evaluation_surfaces(
        profile_curve, z.lower(), z.upper(), nu, nz
    )
    vertices = profile_curve.vertices
    vertices_offsets = profile_curve.vertices_offsets

    # Prepare output
    all_lines = {i: [] for i in params["drawing_order"]}  # i: element id in rigs

    for i in range(len(triangles)):  # Iterate over the curve segments
        # Extract intersections in 3D
        contacts = all_intersections(
            vertices_3d[i], triangles[i], **params, return_contact_polylines=True
        ).contacts
        pts3d = contacts.vertices
        if len(pts3d) == 0:
            continue
        # Convert 3D points to 2D points along the profile curve
        pts2d = np.empty((len(pts3d), 2), dtype=np.float64)
        pts2d[:, 0] = np.linalg.norm(pts3d[:, :2] - vertices[i], axis=1)
        pts2d[:, 0] += vertices_offsets[i]
        pts2d[:, 1] = pts3d[:, 2]
        # Get 2D lines to plot
        for uid, lines in contacts.lines.items():
            for line in lines:
                all_lines[uid].append(
                    QPolygonF([QPointF(pt[0], pt[1]) for pt in pts2d[line]])
                )
    return all_lines
