import numpy as np
from qgis.core import (
    QgsFeature,
    QgsFields,
    QgsPoint,
    QgsPointXY,
    QgsProcessingParameterExtent,
    QgsProcessingParameterNumber,
    QgsProcessingParameterRasterLayer,
    QgsProcessingParameterVectorDestination,
    QgsRaster,
    QgsWkbTypes,
)

from . import UtilsAlgorithm


class CreateGrid3d(UtilsAlgorithm):
    """This processing creates vector layer of 3D points along a grid."""

    # Inputs parameters
    EXTENT = "EXTENT"
    DX = "DX"
    DY = "DY"
    DZ = "DZ"
    ZMIN = "ZMIN"
    ZMAX = "ZMAX"
    # Optionnal input parameters
    TOPO = "TOPO"
    # Output parameters
    OUTPUT = "OUTPUT"

    def displayName(self):
        """
        Returns the translated algorithm name, which should be used for any
        user-visible display of the algorithm name.
        """
        return self.tr("Create 3D grid")

    def initAlgorithm(self, config=None):  # noqa: ARG002
        """
        Here we define the inputs and output of the algorithm, along
        with some other properties.
        """

        self.addParameter(
            QgsProcessingParameterExtent(
                self.EXTENT, self.tr("Emprise (xmin, xmax, ymin, ymax)")
            )
        )
        self.addParameter(
            QgsProcessingParameterNumber(
                self.ZMIN,
                self.tr("zmin (float)"),
                QgsProcessingParameterNumber.Double,
                defaultValue=-5000,
            )
        )
        self.addParameter(
            QgsProcessingParameterNumber(
                self.ZMAX,
                self.tr("zmax (float)"),
                QgsProcessingParameterNumber.Double,
                defaultValue=0,
            )
        )
        self.addParameter(
            QgsProcessingParameterNumber(
                self.DX,
                self.tr("dx (float)"),
                QgsProcessingParameterNumber.Double,
                defaultValue=500,
            )
        )
        self.addParameter(
            QgsProcessingParameterNumber(
                self.DY,
                self.tr("dy (float)"),
                QgsProcessingParameterNumber.Double,
                defaultValue=500,
            )
        )
        self.addParameter(
            QgsProcessingParameterNumber(
                self.DZ,
                self.tr("dz (float)"),
                QgsProcessingParameterNumber.Double,
                defaultValue=500,
            )
        )
        self.addParameter(
            QgsProcessingParameterRasterLayer(
                self.TOPO, self.tr("Cuts points above this topography:"), optional=True
            )
        )
        self.addParameter(
            QgsProcessingParameterVectorDestination(
                self.OUTPUT,
                self.tr("Points output"),
                optional=True,
                createByDefault=True,
            )
        )

    def processAlgorithm(self, parameters, context, feedback) -> dict:  # noqa: ARG002
        """Here is where the processing itself takes place"""

        # Inputs
        extent = self.parameterAsExtent(parameters, self.EXTENT, context)
        xmin, xmax, ymin, ymax = (
            extent.xMinimum(),
            extent.xMaximum(),
            extent.yMinimum(),
            extent.yMaximum(),
        )
        zmin = self.parameterAsDouble(parameters, self.ZMIN, context)
        zmax = self.parameterAsDouble(parameters, self.ZMAX, context)
        dx = self.parameterAsDouble(parameters, self.DX, context)
        dy = self.parameterAsDouble(parameters, self.DY, context)
        dz = self.parameterAsDouble(parameters, self.DZ, context)
        topo = self.parameterAsRasterLayer(parameters, self.TOPO, context)

        # Create grid points layer
        output_layer, output_id = self.parameterAsSink(
            parameters,
            self.OUTPUT,
            context,
            QgsFields(),
            QgsWkbTypes.PointZ,
            context.project().crs(),
        )
        X = np.arange(xmin, xmax, dx, dtype=float)
        Y = np.arange(ymin, ymax, dy, dtype=float)
        Z = np.arange(zmin, zmax, dz, dtype=float)
        X, Y, Z = np.meshgrid(X, Y, Z, indexing="ij")
        points = np.column_stack((X.ravel(), Y.ravel(), Z.ravel()))
        if topo is None:
            for pt in points:
                feature = QgsFeature()
                feature.setGeometry(QgsPoint(*pt))
                output_layer.addFeature(feature)

        # Remove points above topography
        else:
            topo_provider = topo.dataProvider()
            for pt in points:
                identifier = topo_provider.identify(
                    QgsPointXY(pt[0], pt[1]), QgsRaster.IdentifyFormatValue
                )
                if identifier.isValid() and pt[2] <= identifier.results()[1]:
                    # Uses value of band 1
                    feature = QgsFeature()
                    feature.setGeometry(QgsPoint(*pt))
                    output_layer.addFeature(feature)

        return {self.OUTPUT: output_id}
