import os
import math
import numpy as np

from qgis.core import (
    QgsProcessing,
    QgsProcessingAlgorithm,
    QgsProcessingParameterRasterLayer,
    QgsProcessingParameterNumber,
    QgsProcessingParameterVectorDestination,
    QgsProcessingException,
    QgsRasterLayer,
    QgsProcessingContext,
    QgsProcessingFeedback,
    QgsVectorLayer,
    QgsFeature,
    QgsGeometry,
    QgsWkbTypes,
)

import processing
from osgeo import gdal


class DirectionalSlopePolygonsAlg(QgsProcessingAlgorithm):
    P_DEM = "DEM"
    P_AZIMUTH = "AZIMUTH"
    P_PERP_MIN = "PERP_MIN"
    P_PERP_MAX = "PERP_MAX"
    P_BIN_HALF = "BIN_HALF"          # deg
    P_MIN_AREA = "MIN_AREA"          # map units^2
    P_SIMPLIFY_TOL = "SIMPLIFY_TOL"  # map units
    P_DEF_DIST = "DEF_DIST"          # meters
    P_OUTPUT = "OUTPUT"

    def name(self):
        return "directional_slope_polygons"

    def displayName(self):
        return "Directional slope polygons (bin, polygonize, simplify, filter, thin vertices)"

    def group(self):
        return "Terrain"

    def groupId(self):
        return "terrain"

    def shortHelpString(self):
        return (
            "Computes directional slope toward a chosen azimuth and perpendicular slope.\n"
            "Keeps pixels where abs(perpendicular slope) is within [min,max].\n"
            "Bins directional slope into classes using a user-defined half-width (e.g. ±2.5° => 5° bins),\n"
            "polygonizes them, optionally simplifies geometry, removes polygons smaller than the minimum area,\n"
            "and thins vertices to satisfy a 'definition' distance.\n\n"
            "Azimuth: degrees clockwise from North (0=N, 90=E).\n"
            "Perpendicular limits are ABSOLUTE.\n"
            "Definition distance: e.g. 6 m => enforce consecutive vertices >= 3 m apart.\n"
            "Min area is in map units² (with meters CRS this is m²)."
        )

    def createInstance(self):
        return DirectionalSlopePolygonsAlg()

    def initAlgorithm(self, config=None):
        self.addParameter(QgsProcessingParameterRasterLayer(
            self.P_DEM, "Input DEM (single-band raster)"
        ))

        self.addParameter(QgsProcessingParameterNumber(
            self.P_AZIMUTH,
            "Direction azimuth (degrees clockwise from North)",
            type=QgsProcessingParameterNumber.Double,
            defaultValue=0.0,
            minValue=0.0,
            maxValue=360.0
        ))

        self.addParameter(QgsProcessingParameterNumber(
            self.P_PERP_MIN,
            "Perpendicular slope ABS min (degrees)",
            type=QgsProcessingParameterNumber.Double,
            defaultValue=0.0,
            minValue=0.0
        ))

        self.addParameter(QgsProcessingParameterNumber(
            self.P_PERP_MAX,
            "Perpendicular slope ABS max (degrees)",
            type=QgsProcessingParameterNumber.Double,
            defaultValue=5.0,
            minValue=0.0
        ))

        self.addParameter(QgsProcessingParameterNumber(
            self.P_BIN_HALF,
            "Directional slope bin half-width (degrees). Example: 2.5 gives 5° bins",
            type=QgsProcessingParameterNumber.Double,
            defaultValue=2.5,
            minValue=0.01
        ))

        self.addParameter(QgsProcessingParameterNumber(
            self.P_MIN_AREA,
            "Minimum polygon area to keep (map units², 0 = keep all)",
            type=QgsProcessingParameterNumber.Double,
            defaultValue=0.0,
            minValue=0.0
        ))

        self.addParameter(QgsProcessingParameterNumber(
            self.P_SIMPLIFY_TOL,
            "Simplify tolerance (map units, 0 = no simplification)",
            type=QgsProcessingParameterNumber.Double,
            defaultValue=0.0,
            minValue=0.0
        ))

        self.addParameter(QgsProcessingParameterNumber(
            self.P_DEF_DIST,
            'Geometry "definition" distance (meters). Example: 6 means max ~2 points per 6 m',
            type=QgsProcessingParameterNumber.Double,
            defaultValue=6.0,
            minValue=0.0
        ))

        self.addParameter(QgsProcessingParameterVectorDestination(
            self.P_OUTPUT, "Output polygons"
        ))

    @staticmethod
    def _azimuth_to_unit_vectors(azimuth_deg: float):
        angle_rad = math.radians(90.0 - azimuth_deg)  # azimuth CW from North
        ux = math.cos(angle_rad)  # East
        uy = math.sin(angle_rad)  # North
        return ux, uy

    @staticmethod
    def _thin_ring(points_xy, min_dist):
        # points_xy is a list of QgsPointXY (from geom.asPolygon/ asMultiPolygon)
        if min_dist <= 0 or len(points_xy) < 4:
            return points_xy

        pts = list(points_xy)

        # Ensure closed
        if pts[0] != pts[-1]:
            pts.append(pts[0])

        kept = [pts[0]]
        last = pts[0]

        for p in pts[1:-1]:
            if math.hypot(p.x() - last.x(), p.y() - last.y()) >= min_dist:
                kept.append(p)
                last = p

        if kept[0] != kept[-1]:
            kept.append(kept[0])

        if len(kept) < 4:
            return pts  # fallback

        return kept

    @classmethod
    def _thin_polygon_geom(cls, geom: QgsGeometry, min_dist: float) -> QgsGeometry:
        if min_dist <= 0 or geom is None or geom.isEmpty():
            return geom
        if geom.type() != QgsWkbTypes.PolygonGeometry:
            return geom

        if geom.isMultipart():
            polys = geom.asMultiPolygon()  # list[poly]; poly = [ring...]; ring=list[QgsPointXY]
            new_polys = []
            for poly in polys:
                if not poly:
                    continue
                new_poly = []
                for ring in poly:
                    new_poly.append(cls._thin_ring(ring, min_dist))
                new_polys.append(new_poly)
            return QgsGeometry.fromMultiPolygonXY(new_polys)
        else:
            poly = geom.asPolygon()
            if not poly:
                return geom
            new_poly = []
            for ring in poly:
                new_poly.append(cls._thin_ring(ring, min_dist))
            return QgsGeometry.fromPolygonXY(new_poly)

    def processAlgorithm(self, parameters, context: QgsProcessingContext, feedback: QgsProcessingFeedback):
        dem_layer: QgsRasterLayer = self.parameterAsRasterLayer(parameters, self.P_DEM, context)
        if dem_layer is None:
            raise QgsProcessingException("Invalid DEM layer.")

        azimuth = float(self.parameterAsDouble(parameters, self.P_AZIMUTH, context))
        perp_min = float(self.parameterAsDouble(parameters, self.P_PERP_MIN, context))
        perp_max = float(self.parameterAsDouble(parameters, self.P_PERP_MAX, context))
        bin_half = float(self.parameterAsDouble(parameters, self.P_BIN_HALF, context))
        min_area = float(self.parameterAsDouble(parameters, self.P_MIN_AREA, context))
        simplify_tol = float(self.parameterAsDouble(parameters, self.P_SIMPLIFY_TOL, context))
        def_dist = float(self.parameterAsDouble(parameters, self.P_DEF_DIST, context))
        out_vec = self.parameterAsOutputLayer(parameters, self.P_OUTPUT, context)

        if perp_min > perp_max:
            raise QgsProcessingException("Perpendicular ABS min cannot be greater than max.")

        bin_width = 2.0 * bin_half
        if bin_width <= 0:
            raise QgsProcessingException("Bin width must be > 0.")

        min_vertex_spacing = (def_dist / 2.0) if def_dist > 0 else 0.0

        dem_path = dem_layer.source()
        ds = gdal.Open(dem_path, gdal.GA_ReadOnly)
        if ds is None:
            raise QgsProcessingException(f"Could not open DEM with GDAL: {dem_path}")

        band = ds.GetRasterBand(1)
        nodata = band.GetNoDataValue()
        gt = ds.GetGeoTransform()
        proj = ds.GetProjection()

        dx = float(gt[1])
        dy_raw = float(gt[5])
        dy = abs(dy_raw)

        if dx == 0 or dy == 0:
            raise QgsProcessingException("Invalid pixel size in DEM geotransform.")

        z = band.ReadAsArray().astype(np.float32)
        if z is None:
            raise QgsProcessingException("Failed to read DEM into array.")

        mask = np.isfinite(z)
        if nodata is not None:
            mask &= (z != nodata)

        if not np.any(mask):
            raise QgsProcessingException("DEM contains no valid (non-NoData) pixels.")

        feedback.pushInfo(f"DEM: {dem_path}")
        feedback.pushInfo(f"Pixel size: dx={dx}, dy={dy} (gt[5]={dy_raw})")
        feedback.pushInfo(f"Azimuth: {azimuth} deg CW from North")
        feedback.pushInfo(f"Perp abs limits: [{perp_min}, {perp_max}] deg")
        feedback.pushInfo(f"Slope bin half-width: ±{bin_half} deg (bin width {bin_width} deg)")
        feedback.pushInfo(f"Min area: {min_area} map_units²")
        feedback.pushInfo(f"Simplify tol: {simplify_tol} map_units")
        feedback.pushInfo(f"Definition distance: {def_dist} (min vertex spacing {min_vertex_spacing})")

        feedback.setProgress(10)

        dz_drow, dz_dcol = np.gradient(z, dy, dx)

        # Convert row-gradient to North-positive
        dz_dy_north = -dz_drow if dy_raw < 0 else dz_drow
        dz_dx_east = dz_dcol

        feedback.setProgress(30)

        ux, uy = self._azimuth_to_unit_vectors(azimuth)
        ux_p, uy_p = -uy, ux

        d_dir = dz_dx_east * ux + dz_dy_north * uy
        d_perp = dz_dx_east * ux_p + dz_dy_north * uy_p

        slope_dir = np.degrees(np.arctan(d_dir))
        slope_perp = np.degrees(np.arctan(d_perp))

        feedback.setProgress(50)

        keep = mask & (np.abs(slope_perp) >= perp_min) & (np.abs(slope_perp) <= perp_max)

        class_vals = np.round(slope_dir / bin_width) * bin_width

        nodata_out = -32768
        classified = np.full(z.shape, nodata_out, dtype=np.int16)
        cls = np.clip(class_vals, -32760, 32760).astype(np.int16)
        classified[keep] = cls[keep]

        feedback.setProgress(65)

        tmp_dir = context.temporaryFolder()
        tmp_raster = os.path.join(tmp_dir, "directional_slope_classified.tif")

        driver = gdal.GetDriverByName("GTiff")
        out_ds = driver.Create(
            tmp_raster,
            ds.RasterXSize,
            ds.RasterYSize,
            1,
            gdal.GDT_Int16,
            options=["COMPRESS=LZW", "TILED=YES"]
        )
        if out_ds is None:
            raise QgsProcessingException("Failed to create temporary classified GeoTIFF.")

        out_ds.SetGeoTransform(gt)
        out_ds.SetProjection(proj)
        out_band = out_ds.GetRasterBand(1)
        out_band.SetNoDataValue(nodata_out)
        out_band.WriteArray(classified)
        out_band.FlushCache()
        out_ds.FlushCache()
        out_ds = None

        feedback.setProgress(75)

        # --- Polygonize (must be a real output, not "memory:") ---
        poly_res = processing.run(
            "gdal:polygonize",
            {
                "INPUT": tmp_raster,
                "BAND": 1,
                "FIELD": "slope_deg",
                "EIGHT_CONNECTEDNESS": False,
                "OUTPUT": QgsProcessing.TEMPORARY_OUTPUT
            },
            context=context,
            feedback=feedback
        )

        poly_out = poly_res["OUTPUT"]
        if isinstance(poly_out, QgsVectorLayer):
            poly_layer = poly_out
        else:
            poly_layer = QgsVectorLayer(poly_out, "slope_polygons", "ogr")
            if not poly_layer.isValid():
                raise QgsProcessingException(f"Polygonize created an invalid layer: {poly_out}")

        feedback.setProgress(82)

        # Optional simplify
        if simplify_tol > 0:
            simp_res = processing.run(
                "native:simplifygeometries",
                {
                    "INPUT": poly_layer,
                    "METHOD": 0,
                    "TOLERANCE": simplify_tol,
                    "OUTPUT": QgsProcessing.TEMPORARY_OUTPUT
                },
                context=context,
                feedback=feedback
            )
            poly_layer = simp_res["OUTPUT"]

        feedback.setProgress(88)

        # Optional minimum area filter
        if min_area > 0:
            expr = f"$area >= {min_area}"
            filt_res = processing.run(
                "native:extractbyexpression",
                {
                    "INPUT": poly_layer,
                    "EXPRESSION": expr,
                    "OUTPUT": QgsProcessing.TEMPORARY_OUTPUT
                },
                context=context,
                feedback=feedback
            )
            poly_layer = filt_res["OUTPUT"]

        feedback.setProgress(92)

        # Vertex thinning (definition distance)
        if min_vertex_spacing > 0:
            crs = poly_layer.crs().authid()
            mem = QgsVectorLayer(f"Polygon?crs={crs}", "thinned", "memory")
            prov = mem.dataProvider()
            prov.addAttributes(poly_layer.fields())
            mem.updateFields()

            new_feats = []
            for f in poly_layer.getFeatures():
                g2 = self._thin_polygon_geom(f.geometry(), min_vertex_spacing)
                if g2 is None or g2.isEmpty():
                    continue
                nf = QgsFeature(mem.fields())
                nf.setAttributes(f.attributes())
                nf.setGeometry(g2)
                new_feats.append(nf)

            prov.addFeatures(new_feats)
            mem.updateExtents()
            poly_layer = mem

        feedback.setProgress(96)

        save_res = processing.run(
            "native:savefeatures",
            {
                "INPUT": poly_layer,
                "OUTPUT": out_vec
            },
            context=context,
            feedback=feedback
        )

        feedback.setProgress(100)
        return {self.P_OUTPUT: save_res["OUTPUT"]}

