from __future__ import annotations

import math
import os
import re
from typing import Dict, List, Tuple

import numpy as np
import processing
from qgis.core import (
    QgsFeature,
    QgsGeometry,
    QgsProcessing,
    QgsProcessingAlgorithm,
    QgsProcessingException,
    QgsProcessingOutputFolder,
    QgsProcessingParameterDistance,
    QgsProcessingParameterEnum,
    QgsProcessingParameterFolderDestination,
    QgsProcessingParameterMultipleLayers,
    QgsProcessingParameterRasterLayer,
    QgsProcessingParameterVectorLayer,
    QgsSpatialIndex,
    QgsVectorLayer,
    QgsRasterLayer,
)

try:
    from scipy import ndimage

    HAS_NDIMAGE = True
except Exception:
    ndimage = None  # type: ignore
    HAS_NDIMAGE = False


def _safe_name(name: str) -> str:
    return re.sub(r"[^a-zA-Z0-9_-]+", "_", name).strip("_") or "layer"


class ConnectivityAnalysisAlgorithm(QgsProcessingAlgorithm):
    """
    Lightweight connectivity summary for raster/vector habitat layers.
    Writes a text report per layer into the chosen output folder.
    """

    INPUT_TYPE = "INPUT_TYPE"
    INPUT_VECTORS = "INPUT_VECTORS"
    INPUT_RASTERS = "INPUT_RASTERS"
    NEIGHBOR_RADIUS = "NEIGHBOR_RADIUS"
    OUTPUT_FOLDER = "OUTPUT_FOLDER"

    def name(self) -> str:  # noqa: D401
        return "linkscape_connectivity_analysis"

    def displayName(self) -> str:  # noqa: D401
        return "Connectivity Analysis"

    def group(self) -> str:
        return "SORUS"

    def groupId(self) -> str:
        return "sorus"

    def shortHelpString(self) -> str:
        return (
            "Quick connectivity summaries for one or more raster/vector habitat layers. "
            "Outputs a text report per layer with patch/network metrics and an overall score."
        )

    def createInstance(self):
        return ConnectivityAnalysisAlgorithm()

    def initAlgorithm(self, config=None):  # noqa: D401
        self.addParameter(
            QgsProcessingParameterEnum(
                self.INPUT_TYPE,
                "Input type",
                options=["Vector polygons", "Raster habitat"],
                defaultValue=0,
            )
        )
        # Separate params to avoid GUI conflicts with multiple-layers + enum switches.
        self.addParameter(
            QgsProcessingParameterVectorLayer(
                self.INPUT_VECTORS,
                "Habitat polygon layer(s)",
                types=[QgsProcessing.TypeVectorPolygon],
                optional=True,
                allowMultiple=True,
            )
        )
        self.addParameter(
            QgsProcessingParameterRasterLayer(
                self.INPUT_RASTERS,
                "Habitat raster layer(s)",
                optional=True,
                allowMultiple=True,
            )
        )
        self.addParameter(
            QgsProcessingParameterDistance(
                self.NEIGHBOR_RADIUS,
                "Neighbor radius (m) for connectivity/degree",
                defaultValue=500.0,
            )
        )
        self.addParameter(
            QgsProcessingParameterFolderDestination(
                self.OUTPUT_FOLDER,
                "Output folder for reports",
            )
        )

    def processAlgorithm(self, parameters, context, feedback):
        input_type = self.parameterAsEnum(parameters, self.INPUT_TYPE, context)
        out_folder = self.parameterAsFile(parameters, self.OUTPUT_FOLDER, context)
        neighbor_radius = float(self.parameterAsDouble(parameters, self.NEIGHBOR_RADIUS, context))

        os.makedirs(out_folder, exist_ok=True)

        layers: List[Tuple[str, QgsVectorLayer]] = []

        if input_type == 0:
            vec_layers = self.parameterAsLayerList(parameters, self.INPUT_VECTORS, context) or []
            for lyr in vec_layers:
                if isinstance(lyr, QgsVectorLayer) and lyr.isValid():
                    layers.append((lyr.name(), lyr))
        else:
            ras_layers = self.parameterAsLayerList(parameters, self.INPUT_RASTERS, context) or []
            for rast in ras_layers:
                if isinstance(rast, QgsRasterLayer) and rast.isValid():
                    layers.append((rast.name(), rast))

        if not layers:
            raise QgsProcessingException("No valid layers to analyze.")

        reports: Dict[str, str] = {}
        for layer_name, layer in layers:
            feedback.pushInfo(f"Analyzing {layer_name}…")
            report_path = os.path.join(out_folder, f"linkscape_connectivity_{_safe_name(layer_name)}.txt")
            if isinstance(layer, QgsVectorLayer):
                metrics = self._analyze_vector(layer, neighbor_radius)
            else:
                metrics = self._analyze_raster(layer, neighbor_radius, feedback)
            self._write_report(layer_name, metrics, report_path)
            reports[layer_name] = report_path
            feedback.pushInfo(f"  ✓ Report: {report_path}")

        return {self.OUTPUT_FOLDER: out_folder, "reports": reports}

    # ------------------------------------------------------------------ #
    # Helpers
    # ------------------------------------------------------------------ #

    def _analyze_vector(self, layer: QgsVectorLayer, neighbor_radius: float) -> Dict[str, float]:
        if not layer.crs().isValid() or layer.crs().isGeographic():
            raise QgsProcessingException("Layer must be in a projected CRS (meters).")

        feats = list(layer.getFeatures())
        if not feats:
            raise QgsProcessingException("No features found.")

        geoms = [f.geometry() for f in feats]
        areas = [g.area() for g in geoms]
        perims = [g.length() for g in geoms]
        total_area = sum(areas)
        largest_area = max(areas) if areas else 0
        mean_area = total_area / len(areas) if areas else 0
        edge_density = (sum(perims) / total_area * 10000.0) if total_area > 0 else 0.0
        cohesion = (total_area / (total_area + sum(perims))) * (1 - (len(areas) - 1) / (len(areas) + 1)) if total_area > 0 else 0.0

        centroids = [g.centroid().asPoint() for g in geoms]
        sindex = QgsSpatialIndex(layer.getFeatures())

        # Build simple graph: connect patches within neighbor_radius
        n = len(feats)
        uf_parents = list(range(n))

        def uf_find(x: int) -> int:
            if uf_parents[x] != x:
                uf_parents[x] = uf_find(uf_parents[x])
            return uf_parents[x]

        def uf_union(a: int, b: int) -> None:
            ra, rb = uf_find(a), uf_find(b)
            if ra != rb:
                uf_parents[rb] = ra

        edges_used = 0
        degrees = [0] * n
        for idx, pt in enumerate(centroids):
            neighbor_ids = sindex.nearestNeighbor(pt, n)
            for ni in neighbor_ids:
                if ni == feats[idx].id():
                    continue
                j = ni
                d = QgsGeometry.fromPointXY(pt).distance(QgsGeometry.fromPointXY(centroids[j]))
                if d <= neighbor_radius:
                    uf_union(idx, j)
                    edges_used += 1
                    degrees[idx] += 1
                    degrees[j] += 1

        components = len({uf_find(i) for i in range(n)})
        redundant_links = max(0, edges_used - (n - components))
        avg_degree = (sum(degrees) / n) if n > 0 else 0.0

        largest_frac = largest_area / total_area if total_area > 0 else 0.0
        component_score = 1.0 / components if components > 0 else 0.0
        cohesion_score = cohesion
        overall_score = (0.4 * largest_frac + 0.3 * component_score + 0.3 * cohesion_score) * 100.0

        return {
            "patches_total": n,
            "total_area": total_area,
            "largest_area": largest_area,
            "mean_area": mean_area,
            "edge_density": edge_density,
            "cohesion": cohesion,
            "components": components,
            "edges_used": edges_used,
            "redundant_links": redundant_links,
            "avg_degree": avg_degree,
            "overall_score": overall_score,
        }

    def _analyze_raster(self, layer: QgsRasterLayer, neighbor_radius: float, feedback) -> Dict[str, float]:
        if not HAS_NDIMAGE:
            raise QgsProcessingException("SciPy (ndimage) is required for raster connectivity analysis.")
        provider = layer.dataProvider()
        extent = layer.extent()
        rows = layer.height()
        cols = layer.width()
        gt = layer.dataProvider().geoTransform()
        px_w = abs(gt[1])
        px_h = abs(gt[5]) if gt[5] != 0 else px_w
        cell_area = px_w * px_h

        # Read band as array
        import gdal  # type: ignore

        ds = gdal.Open(provider.dataSourceUri())
        band = ds.GetRasterBand(1)
        nodata = band.GetNoDataValue()
        arr = band.ReadAsArray().astype(np.float32)
        mask = np.ones_like(arr, dtype=bool)
        if nodata is not None:
            mask &= ~(np.isclose(arr, nodata))
        # Treat zero/negative as non-habitat; assume positive values are habitat
        mask &= arr > 0
        if not mask.any():
            raise QgsProcessingException("Raster contains no habitat pixels (>0 and not nodata).")

        # Label patches (8-connected)
        structure = np.ones((3, 3), dtype=np.uint8)
        labels, n_labels = ndimage.label(mask, structure=structure)
        if n_labels == 0:
            raise QgsProcessingException("No habitat components found in raster.")

        # Compute areas and centroids
        areas = np.zeros(n_labels + 1, dtype=float)
        sum_r = np.zeros(n_labels + 1, dtype=float)
        sum_c = np.zeros(n_labels + 1, dtype=float)
        coords = np.argwhere(labels > 0)
        for r, c in coords:
            lab = labels[r, c]
            areas[lab] += cell_area
            sum_r[lab] += r
            sum_c[lab] += c
        centroids: List[Tuple[float, float]] = []
        for lab in range(1, n_labels + 1):
            if areas[lab] <= 0:
                centroids.append((0.0, 0.0))
                continue
            r_mean = sum_r[lab] / (areas[lab] / cell_area)
            c_mean = sum_c[lab] / (areas[lab] / cell_area)
            x = gt[0] + (c_mean + 0.5) * px_w + (r_mean + 0.5) * gt[2]
            y = gt[3] + (r_mean + 0.5) * gt[5] + (c_mean + 0.5) * gt[4]
            centroids.append((x, y))

        total_area = float(areas.sum())
        largest_area = float(areas.max()) if areas.size else 0.0
        mean_area = total_area / n_labels if n_labels else 0.0

        # Perimeter approximation: count pixel edges bordering non-habitat
        perim_count = 0
        for dr, dc, weight in [(1, 0, px_h), (-1, 0, px_h), (0, 1, px_w), (0, -1, px_w)]:
            shifted = np.zeros_like(labels, dtype=bool)
            if dr >= 0:
                shifted[dr:, :] = labels[: labels.shape[0] - dr, :] > 0
            else:
                shifted[: dr, :] = labels[-dr :, :] > 0
            if dc >= 0:
                shifted[:, dc:] &= labels[:, : labels.shape[1] - dc] > 0
            else:
                shifted[:, : dc] &= labels[:, -dc :] > 0
            edge_mask = (labels > 0) & (~shifted)
            perim_count += edge_mask.sum() * weight
        edge_density = (perim_count / total_area * 10000.0) if total_area > 0 else 0.0
        cohesion = (total_area / (total_area + perim_count)) * (1 - (n_labels - 1) / (n_labels + 1)) if total_area > 0 else 0.0

        # Graph building based on centroid radius
        degrees = [0] * n_labels
        edges_used = 0
        components_par = list(range(n_labels))

        def uf_find(x: int) -> int:
            if components_par[x] != x:
                components_par[x] = uf_find(components_par[x])
            return components_par[x]

        def uf_union(a: int, b: int) -> None:
            ra, rb = uf_find(a), uf_find(b)
            if ra != rb:
                components_par[rb] = ra

        for i in range(n_labels):
            xi, yi = centroids[i]
            for j in range(i + 1, n_labels):
                xj, yj = centroids[j]
                d = math.hypot(xi - xj, yi - yj)
                if d <= neighbor_radius:
                    uf_union(i, j)
                    edges_used += 1
                    degrees[i] += 1
                    degrees[j] += 1

        components = len({uf_find(i) for i in range(n_labels)})
        redundant_links = max(0, edges_used - (n_labels - components))
        avg_degree = (sum(degrees) / n_labels) if n_labels > 0 else 0.0

        largest_frac = largest_area / total_area if total_area > 0 else 0.0
        component_score = 1.0 / components if components > 0 else 0.0
        cohesion_score = cohesion
        overall_score = (0.4 * largest_frac + 0.3 * component_score + 0.3 * cohesion_score) * 100.0

        return {
            "patches_total": n_labels,
            "total_area": total_area,
            "largest_area": largest_area,
            "mean_area": mean_area,
            "edge_density": edge_density,
            "cohesion": cohesion,
            "components": components,
            "edges_used": edges_used,
            "redundant_links": redundant_links,
            "avg_degree": avg_degree,
            "overall_score": overall_score,
        }


    def _write_report(self, layer_name: str, m: Dict[str, float], path: str) -> None:
        with open(path, "w", encoding="utf-8") as fh:
            fh.write("Linkscape Connectivity Analysis\n")
            fh.write("================================\n")
            fh.write(f"Layer:                {layer_name}\n")
            fh.write(f"Overall score:        {m.get('overall_score', 0):.1f}\n\n")

            fh.write("PATCH CONNECTIVITY\n")
            fh.write("------------------\n")
            fh.write(f"Total patches:        {m.get('patches_total', 0)}\n")
            fh.write(f"Components:           {m.get('components', 0)}\n")
            fh.write(f"Largest patch (ha):   {m.get('largest_area', 0)/10000:.2f}\n")
            fh.write(f"Mean patch (ha):      {m.get('mean_area', 0)/10000:.2f}\n")
            fh.write(f"Total habitat (ha):   {m.get('total_area', 0)/10000:.2f}\n")
            fh.write(f"Edge density (m/ha):  {m.get('edge_density', 0):.2f}\n")
            fh.write(f"Cohesion (proxy):     {m.get('cohesion', 0):.3f}\n\n")

            fh.write("CORRIDOR GRAPH (radius-based)\n")
            fh.write("-----------------------------\n")
            fh.write(f"Edges used:           {m.get('edges_used', 0)}\n")
            fh.write(f"Redundant links:      {m.get('redundant_links', 0)}\n")
            fh.write(f"Average degree:       {m.get('avg_degree', 0):.2f}\n")
