from collections.abc import Callable, Iterable

import numpy as np
from qgis.core import (
    QgsCategorizedSymbolRenderer,
    QgsCoordinateTransform,
    QgsGeometry,
    QgsLineSegment2D,
    QgsLineString,
    QgsLineSymbol,
    QgsPointXY,
    QgsRasterLayer,
    QgsRectangle,
    QgsRendererCategory,
    QgsWkbTypes,
)
from qgis.PyQt.QtGui import QColor


def sampleGridPoints(
    extent: QgsRectangle,
    du: float,
    z: float | list[float] | Callable | None = None,
    return_quads: bool = True,
) -> tuple[list[float], list[int]]:
    """Build surface mesh primitives (vertices and cells) from section paramters.

    The extent is resampled at resolution ``du`` and tesselated with quads.

    0. create 2 orthogonal segments from extent (for metric resampling purpose)
    1. resample the segment to match the resolution ``du`` using QGIS routines
    2. use numpy meshgrid extend these 2 arrays in 2d and reshape
    3. [only ``if z is not None``] for each vertex:
        - initialize or compute z-dimension
    4. generate quad cells from vertices indexes

    Args:
        extent (QgsRectangle): the extent of the section domain
        du (float): resolution - in the XY plane [meters]
        z (None|float|Iterable[float]|Callable, optional): z coordinate. Defaults to ``None``, a 2d grid is returned
        return_quads (bool, optional): Whether to return the cells. Defaults to ``True``

    Returns:
        list of vertices,
        list of facets (i.e. quads vertices indexes)
    """
    # define the axes of the grid
    ux = QgsLineSegment2D(
        extent.xMinimum(), extent.yMinimum(), extent.xMaximum(), extent.yMinimum()
    )
    uy = QgsLineSegment2D(
        extent.xMinimum(), extent.yMinimum(), extent.xMinimum(), extent.yMaximum()
    )
    # resample the axes to match resolution
    xx = QgsGeometry.fromPolyline(
        QgsLineString([ux.start(), ux.end()])
    ).densifyByDistance(du)
    yy = QgsGeometry.fromPolyline(
        QgsLineString([uy.start(), uy.end()])
    ).densifyByDistance(du)
    # determine the number and size of cells
    nx, ny = len(list(xx.vertices())), len(list(yy.vertices()))
    dx, dy = ux.length() / (nx - 1), uy.length() / (ny - 1)
    del ux, uy, xx, yy
    # create the mesh grid
    vertices = np.mgrid[0:nx, 0:ny].T.reshape(-1, 2).astype(float)
    # rescale/translate grid
    vertices[..., 0] *= dx
    vertices[..., 0] += extent.xMinimum()
    vertices[..., 1] *= dy
    vertices[..., 1] += extent.yMinimum()
    # fill Z component
    if np.isscalar(z):
        vertices = np.insert(vertices, vertices.shape[1], z, axis=1)
    elif callable(z):
        vertices = np.column_stack((vertices, z(vertices)))
    elif z:
        assert len(z) == vertices.shape[1]
        vertices = np.column_stack((vertices, z))
    if not return_quads:
        return vertices
    # build the quad cells
    cells = np.array([[0, 1, nx + 1, nx]], dtype=int)  # single quad
    cells = np.vstack([cells + k for k in range(nx - 1)])  # strip of quads
    cells = np.vstack([cells + nx * k for k in range(ny - 1)])  # matrix of quads

    return vertices, cells


def verticalSectionFromPolyLine(
    geometry: QgsGeometry,
    # extent: QgsBox3d,
    zmin: float,
    zmax: float,
    du: float,
    dz: float,
    topo: Callable | None = None,
    triangles=True,
) -> tuple[list[float], list[int]]:
    """Build surface mesh primitives (vertices and cells) from section paramters.

    The geometry is treated as a single polyline, multiparts will be merged !
    For each (sliding) pair of points along the geometry (each segment):

    1. resample the segment to match the resolution ``du``
    2. for each resampled vertex along the polyline:
        - define the top of the section (z from ``{topo>geometry>extent}``)
    3. define the vertical number of cells based on ``(zMaximum-extent.zMinimum)/dv``
    4. for each resampled vertex along the polyline:
        - duplicate the `top ring` shifted by ``(z-extent.zMinimum) / nz``
    5. generate quad cells from vertices indexes

    Args:
        geometry (QgsGeometry): the trace of the section (might have a Z)
        extent (QgsBox3d): the extent of the section domain (model bounding box)
        du (float): horizontal resolution - along the trace (in meters)
        dz (float): vertical resolution - along the section (in meters)
        topo (callable, optional): the maximal elevation calculator. Defaults to None (means extent zMaximum)

    Returns:
        list of vertices,
        list of facets (i.e. quads vertices indexes)
    """
    # densify linestring
    if not isinstance(geometry, QgsGeometry):
        geometry = QgsGeometry.fromPolyline(geometry)
    geometry = geometry.densifyByDistance(du)
    # cast linestring vertices to numpy arrays
    x, y, z = [], [], []
    hasZ = QgsWkbTypes.hasZ(geometry.wkbType())
    for pt in geometry.vertices():
        x += [pt.x()]
        y += [pt.y()]
        if callable(topo):
            z.append(topo(pt.x(), pt.y()))
        else:
            z += [pt.z()] if hasZ else [zmax]
    nu = len(x)
    assert nu > 1
    z = np.asarray(z)
    assert len(z) == nu
    nz = int(np.ceil((np.max(z) - zmin) / dz) + 1)
    zz = np.empty((nz, nu))
    for i, top in enumerate(z):
        zz[:, i] = np.linspace(top, zmin, num=nz, endpoint=True).tolist()
    vertices = np.empty((nu * nz, 3))
    vertices[..., 0] = x * int(nz)
    vertices[..., 1] = y * int(nz)
    vertices[..., 2] = zz.flatten()
    # build the quad cells
    if triangles:
        cells = np.array(
            [(0, 1, nu + 1), (0, nu + 1, nu)], dtype=int
        )  # single triangle
        cells = np.vstack([cells + k for k in range(nu - 1)])  # strip of triangles
        cells = np.vstack(
            [cells + nu * k for k in range(nz - 1)]
        )  # matrix of triangles
    else:
        cells = np.array([[0, 1, nu + 1, nu]], dtype=int)  # single quad
        cells = np.vstack([cells + k for k in range(nu - 1)])  # strip of quads
        cells = np.vstack([cells + nu * k for k in range(nz - 1)])  # matrix of quads

    return vertices, cells


def sampleRaster(
    raster: QgsRasterLayer,
    points: Iterable[float],
    transform: QgsCoordinateTransform = None,
) -> list[float]:
    """Helper function to sample a raster a points locations"""

    provider = raster.dataProvider()
    array = np.zeros((len(points), 1))
    for i, pt in enumerate(points):
        loc = transform(*pt) if transform else QgsPointXY(*pt)
        for n in range(raster.bandCount()):
            z, ok = provider.sample(loc, n)
            if not ok:
                continue
            array[i] += z
    return array


def pointsToWkt(points: Iterable[float]) -> str:
    """Helper function to format points list as WKT str"""

    points = np.asarray(points)
    n, _ = points.shape
    point_t = "POINT " if n == 1 else "MULTIPOINT "
    coords = [f"({' '.join(map(str, pt))})" for pt in points]
    return f"{point_t}{'(' if n > 1 else ''}{', '.join(coords)}{')' if n > 1 else ''}"


def polygonsToWkt(points: Iterable[float], cells: Iterable[Iterable[int]]) -> str:
    """Helper function to format polygons (defined by a mesh vertices/cells lists) as WKT str"""

    points = np.array(points, dtype=float)
    n = len(cells)
    wkt = "MULTIPOlYGON (" if n > 1 else "POLYGON "
    for cell in cells:
        buff = points[np.asarray(cell, dtype=int)]
        buff = [" ".join(map(str, pt)) for pt in buff]
        wkt += f"(({', '.join(buff)})), "
    wkt = wkt[:-2]
    wkt += ")" if n > 1 else ""
    return wkt


def custom_symbol_renderer(
    field_name, symbology, symbol=QgsLineSymbol, **symbol_properties
):
    renderer = QgsCategorizedSymbolRenderer(field_name)
    for value, info in symbology.items():
        color, label = info
        if label is None:
            label = "unknown"
        if not isinstance(color, str):
            if color is None:
                color = (0.5,) * 3
            if not isinstance(color, QColor):
                rgbI = [int(x * 255) for x in color]  # float RGB -> integer RGB
                rgbI += [255] if len(rgbI) < 4 else []
                color = QColor(*rgbI)
            color = color.name(QColor.NameFormat.HexArgb)
        properties = {
            "color": color,
            "outline_color": color,  # for polygons
        }
        properties.update(symbol_properties)
        symbol = symbol.createSimple(properties)
        category = QgsRendererCategory(str(value), symbol, label)
        renderer.addCategory(category)
    return renderer


def colors_as_strings(colors: list):
    result = []
    for color in colors:
        if isinstance(color, str):
            result.append(color)
            continue
        # convert to QColor first
        if not isinstance(color, QColor):
            if color is None:
                colorname = QColor(127, 127, 127).name()
            elif all(isinstance(ck, int) for ck in color):
                assert all(0 <= ck <= 255 for ck in color)
                colorname = QColor(*color).name()
            elif all(isinstance(ck, float) for ck in color):
                assert all(0 <= ck <= 1 for ck in color)
                colorname = QColor(*(int(ck * 255) for ck in color)).name()
            else:
                msg = f"could not convert {color} to QColor"
                raise AssertionError(msg)
        result.append(colorname)
    return result
