# import rifter
# from forgeo.gmlib.architecture import from_GeoModeller, make_evaluator
# from .algorithm import Algorithm
from pathlib import Path

import numpy as np
from qgis import processing
from qgis.core import (
    QgsCoordinateTransform,
    QgsProcessingException,
    QgsProcessingParameterBand,
    QgsProcessingParameterExtent,
    QgsProcessingParameterRasterLayer,
    QgsRasterLayer,
)

from .featureextracter import FeatureExtracter


class GmRasterSlicer(
    FeatureExtracter,
    name=Path(__file__).stem,
    display_name="draw model on elevation map",
):
    ELEVATION_RASTER = "ELEVATION_RASTER"
    ELEVATION_BAND = "ELEVATION_BAND"
    EXTENT = "EXTENT"
    NX = "NX"
    NY = "NY"

    def initAlgorithm(self, config=None):  # noqa : ARG002
        self.add_geomodel_parameter()
        self.addParameter(
            QgsProcessingParameterRasterLayer(
                name=self.ELEVATION_RASTER,
                description=self.tr("Elevation raster layer"),
                optional=False,
            )
        )
        self.addParameter(
            QgsProcessingParameterBand(
                self.ELEVATION_BAND,
                self.tr("Raster band containing elevation"),
                parentLayerParameterName=self.ELEVATION_RASTER,
                optional=False,
            )
        )
        self.addParameter(
            QgsProcessingParameterExtent(self.EXTENT, self.tr("Extent"), optional=False)
        )
        self.add_parameter_sinks()

    def processAlgorithm(self, parameters, context, feedback):
        input_layer = self.model_layer(parameters)
        dtm = self.parameterAsRasterLayer(parameters, self.ELEVATION_RASTER, context)
        if dtm is None:
            raise QgsProcessingException(
                self.invalidSourceError(parameters, self.ELEVATION_RASTER)
            )
        zband = self.parameterAsInt(parameters, self.ELEVATION_BAND, context)
        extent = self.parameterAsExtent(parameters, self.EXTENT, context)
        extent_crs = self.parameterAsExtentCrs(parameters, self.EXTENT, context)
        # FIXME: should we test crs is the same to skip transform?
        ct = QgsCoordinateTransform(extent_crs, dtm.crs(), context.transformContext())
        extent = ct.transform(extent)
        crop = QgsRasterLayer(
            processing.run(
                "gdal:cliprasterbyextent",
                {"INPUT": dtm, "PROJWIN": extent, "OUTPUT": "TEMPORARY_OUTPUT"},
                context=context,
                feedback=feedback,
            )["OUTPUT"]
        )
        zmap = crop.as_numpy()[0]
        ny, nx = zmap.shape
        extent = crop.extent()
        # FIXME: cast extent to model crs
        input_crs = input_layer.crs()
        ct = QgsCoordinateTransform(dtm.crs(), input_crs, context.transformContext())
        extent = ct.transform(extent)
        # we recompute resolution (it may vary with CRS)
        dx = (extent.xMaximum() - extent.xMinimum()) / crop.width()
        dy = (extent.yMaximum() - extent.yMinimum()) / crop.height()
        x = np.linspace(
            extent.xMinimum() + 0.5 * dx, extent.xMaximum() - 0.5 * dx, nx, dtype="d"
        )
        y = np.linspace(
            extent.yMinimum() + 0.5 * dy, extent.yMaximum() - 0.5 * dy, ny, dtype="d"
        )
        pts = np.column_stack(
            [b.ravel() for b in np.meshgrid(x, y[::-1], indexing="xy")] + [zmap.ravel()]
        )
        cells = np.array(
            [(0, 1, nx + 1), (0, nx + 1, nx)], dtype=int
        )  # single triangle
        cells = np.vstack([cells + k for k in range(nx - 1)])  # strip of triangles
        cells = np.vstack(
            [cells + nx * k for k in range(ny - 1)]
        )  # matrix of triangles

        return self.extract_features_once(
            mesh=(pts, cells),
            parameters=parameters,
            context=context,
            feedback=feedback,
            support_name=dtm.bandName(zband),
        )
