from __future__ import annotations

import numpy as np
from osgeo import gdal
from qgis.core import QgsCoordinateReferenceSystem


def array2geotiff(
    filename: str,
    array: np.ndarray,
    origin: tuple,
    spacing: tuple,
    crs: str = "local",
    nodata: float = np.nan,
    indexing: str = "xy",
):
    """Write a numpy array as a GeoTIFF raster.

    Args:
        filename (str): Output raster file.
        array (np.ndarray): Raster data as numpy array.
        origin (tuple): Raster origin (i.e. x/y top left corner).
        resolution (tuple): Raster pixels size [m/px].
        crs (str, optional): Coordinates Reference System. Defaults to "WGS84".
        nodata (float, optional): Raster 'no-data' value. Defaults to np.nan.

    Returns:
        GDALDataset: Output GDAL Dataset
    """
    # check inputs
    array = np.asarray(array).astype(float)
    assert array.ndim in (2, 3), "array must be 2D (single band) or 3D (multiband)"
    crs = QgsCoordinateReferenceSystem(crs).toWkt()
    if indexing == "ij":
        array = array.swapaxes(0, 1)
    # parse inputs
    rows, cols, *nbands = array.shape
    nbands = nbands[0] if nbands else 1
    x, y = origin
    dx, dy = spacing
    # create raster
    driver = gdal.GetDriverByName("GTiff")
    with driver.Create(filename, cols, rows, nbands, gdal.GDT_Float32) as raster:
        raster.SetGeoTransform((x, dx, 0, y, 0, dy))
        raster.SetProjection(crs)
        # fill raster data (WARNING: use band.FlushCache() to actually write data !)
        if nbands > 1:
            for b in range(nbands):
                band = raster.GetRasterBand(b + 1)
                band.WriteArray(array[:, :, b])
                band.FlushCache()
                band.SetNoDataValue(nodata)
        else:
            band = raster.GetRasterBand(nbands)
            band.WriteArray(array)
            band.SetNoDataValue(nodata)
            band.FlushCache()
        return raster
