"""
Linkscape Corridor Analysis - Raster Workflow (v23.0)
-----------------------------------------------------
Runs the selected raster optimization workflow for corridor analysis.

The logic is adapted from the standalone raster script and packaged so it can
be invoked from the QGIS plugin with user-supplied parameters.
"""

from __future__ import annotations

import heapq
import math
import os
import tempfile
import time
from collections import defaultdict, deque
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Set, Tuple

import numpy as np
try:
    from scipy import ndimage

    HAS_NDIMAGE = True
except ImportError:  # pragma: no cover - scipy ships with QGIS, but guard anyway
    HAS_NDIMAGE = False
try:
    from osgeo import gdal
except ImportError:  # pragma: no cover
    gdal = None  # type: ignore

try:
    from qgis.core import QgsProject, QgsRasterLayer
except ImportError:  # pragma: no cover
    QgsProject = None  # type: ignore
    QgsRasterLayer = None  # type: ignore

# Optional OpenCV import for faster connected component labeling
try:
    import cv2

    HAS_CV2 = True
except ImportError:
    HAS_CV2 = False

from .linkscape_engine import NetworkOptimizer, UnionFind
from .utils import emit_progress, log_error

GTIFF_OPTIONS = ["COMPRESS=LZW", "TILED=YES", "BIGTIFF=IF_SAFER"]

PIXEL_COUNT_WARNING_THRESHOLD = 100_000_000  # warn when raster exceeds ~100 million pixels
PIXEL_SIZE_WARNING_THRESHOLD = 10.0  # warn when pixel size < 10 map units


class RasterAnalysisError(RuntimeError):
    """Raised when the raster analysis cannot be completed."""


@dataclass
class RasterRunParams:
    patch_connectivity: int
    patch_mode: str
    patch_values: List[float]
    range_lower: Optional[float]
    range_upper: Optional[float]
    obstacle_enabled: bool
    obstacle_mode: str
    obstacle_values: List[float]
    obstacle_range_lower: Optional[float]
    obstacle_range_upper: Optional[float]
    value_tolerance: float
    nodata_fallback: float
    min_patch_size: int
    allow_sub_min_corridor: bool
    budget_pixels: int
    max_search_distance: int
    max_corridor_area: Optional[int]
    min_corridor_width: int
    allow_bottlenecks: bool
    stepping_enabled: bool
    hop_distance: int


def _compute_patch_boundaries(labels: np.ndarray) -> Tuple[Dict[int, np.ndarray], Dict[int, Tuple[int, int, int, int]]]:
    """Extract boundary pixels and bounding boxes for each patch label."""
    mask = labels > 0
    rows, cols = labels.shape
    boundary_mask = np.zeros_like(mask, dtype=bool)
    padded = np.pad(labels, 1, mode="constant", constant_values=0)

    for dr, dc in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
        neighbor = padded[1 + dr : 1 + dr + rows, 1 + dc : 1 + dc + cols]
        boundary_mask |= mask & (neighbor != labels)

    boundaries: Dict[int, List[Tuple[int, int]]] = {}
    bboxes: Dict[int, List[int]] = {}
    for r, c in np.argwhere(boundary_mask):
        pid = int(labels[r, c])
        if pid <= 0:
            continue
        boundaries.setdefault(pid, []).append((r, c))
        bbox = bboxes.get(pid)
        if bbox is None:
            bboxes[pid] = [r, r, c, c]
        else:
            bbox[0] = min(bbox[0], r)
            bbox[1] = max(bbox[1], r)
            bbox[2] = min(bbox[2], c)
            bbox[3] = max(bbox[3], c)

    boundary_arrays = {pid: np.asarray(coords, dtype=np.int32) for pid, coords in boundaries.items()}
    bbox_map = {pid: (vals[0], vals[1], vals[2], vals[3]) for pid, vals in bboxes.items()}
    return boundary_arrays, bbox_map


def _min_boundary_distance(coords1: np.ndarray, coords2: np.ndarray, max_distance: int) -> Optional[float]:
    """Return the minimum Euclidean distance between two boundary point sets, short-circuiting when within max_distance."""
    if coords1.size == 0 or coords2.size == 0:
        return None
    max_sq = float(max_distance * max_distance)
    if coords1.shape[0] > coords2.shape[0]:
        coords1, coords2 = coords2, coords1

    for r, c in coords1:
        dr = coords2[:, 0] - r
        dc = coords2[:, 1] - c
        dist_sq = dr * dr + dc * dc
        min_sq = float(np.min(dist_sq))
        if min_sq <= max_sq:
            return math.sqrt(min_sq)
    return None


def _compute_hop_adjacency(labels: np.ndarray, hop_distance: int) -> Dict[int, Set[int]]:
    """Identify patch pairs whose nearest edges lie within hop_distance."""
    if hop_distance <= 0:
        return {}
    boundaries, bboxes = _compute_patch_boundaries(labels)
    patch_ids = sorted(boundaries.keys())
    hop_sq = hop_distance * hop_distance
    adjacency: Dict[int, Set[int]] = defaultdict(set)

    for idx, pid in enumerate(patch_ids):
        rmin1, rmax1, cmin1, cmax1 = bboxes.get(pid, (0, -1, 0, -1))
        for pid2 in patch_ids[idx + 1 :]:
            rmin2, rmax2, cmin2, cmax2 = bboxes.get(pid2, (0, -1, 0, -1))
            row_gap = 0
            if rmax1 < rmin2:
                row_gap = rmin2 - rmax1 - 1
            elif rmax2 < rmin1:
                row_gap = rmin1 - rmax2 - 1

            col_gap = 0
            if cmax1 < cmin2:
                col_gap = cmin2 - cmax1 - 1
            elif cmax2 < cmin1:
                col_gap = cmin1 - cmax2 - 1

            if (row_gap * row_gap + col_gap * col_gap) > hop_sq:
                continue

            dist = _min_boundary_distance(boundaries.get(pid, np.empty((0, 2))), boundaries.get(pid2, np.empty((0, 2))), hop_distance)
            if dist is None or dist > hop_distance:
                continue

            adjacency[pid].add(pid2)
            adjacency[pid2].add(pid)
    return adjacency


def _build_virtual_components(
    patch_sizes: Dict[int, int],
    hop_adjacency: Dict[int, Set[int]],
) -> Tuple[Dict[int, int], Dict[int, int], Dict[int, int]]:
    """Merge patches into virtual components based on hop adjacency."""
    uf = UnionFind()
    for pid, size in patch_sizes.items():
        uf.find(pid)
        uf.size[pid] = size
        uf.count[pid] = 1

    for pid, neighbors in hop_adjacency.items():
        for nbr in neighbors:
            uf.union(pid, nbr)

    component_map: Dict[int, int] = {}
    component_sizes: Dict[int, int] = {}
    component_counts: Dict[int, int] = {}
    for pid, size in patch_sizes.items():
        comp = uf.find(pid)
        component_map[pid] = comp
        component_sizes[comp] = component_sizes.get(comp, 0) + size
        component_counts[comp] = component_counts.get(comp, 0) + 1

    return component_map, component_sizes, component_counts


def label_components_numpy(binary_array: np.ndarray, connectivity: int = 8) -> Tuple[np.ndarray, int]:
    """Label connected components using numpy (no external dependencies)."""
    rows, cols = binary_array.shape
    uf = UnionFind()
    if connectivity == 4:
        neighbors = [(-1, 0), (0, -1)]
    else:
        neighbors = [(-1, -1), (-1, 0), (-1, 1), (0, -1)]

    for i in range(rows):
        for j in range(cols):
            if binary_array[i, j]:
                cur = (i, j)
                uf.find(cur)
                for di, dj in neighbors:
                    ni, nj = i + di, j + dj
                    if 0 <= ni < rows and 0 <= nj < cols and binary_array[ni, nj]:
                        uf.union(cur, (ni, nj))

    root_to_label: Dict[Tuple[int, int], int] = {}
    next_label = 1
    labeled = np.zeros_like(binary_array, dtype=np.int32)
    for i in range(rows):
        for j in range(cols):
            if binary_array[i, j]:
                root = uf.find((i, j))
                if root not in root_to_label:
                    root_to_label[root] = next_label
                    next_label += 1
                labeled[i, j] = root_to_label[root]

    return labeled, next_label - 1


def label_components_opencv(binary_array: np.ndarray, connectivity: int = 8) -> Tuple[np.ndarray, int]:
    """Label connected components using OpenCV when available."""
    cv_conn = 4 if connectivity == 4 else 8
    n_labels, labeled = cv2.connectedComponents(binary_array.astype(np.uint8), connectivity=cv_conn)
    return labeled.astype(np.int32), n_labels - 1


def label_patches(binary_array: np.ndarray, connectivity: int = 8) -> Tuple[np.ndarray, int]:
    """Label connected components using the fastest available approach."""
    if HAS_NDIMAGE:
        print("  Using scipy.ndimage for labeling...")
        structure = np.array(
            [[0, 1, 0], [1, 1, 1], [0, 1, 0]], dtype=np.uint8
        ) if connectivity == 4 else np.ones((3, 3), dtype=np.uint8)
        labeled, n_features = ndimage.label(binary_array, structure=structure)
        return labeled.astype(np.int32), int(n_features)
    if HAS_CV2:
        print("  Using OpenCV for labeling...")
        return label_components_opencv(binary_array, connectivity)
    print("  Using numpy for labeling...")
    return label_components_numpy(binary_array, connectivity)


def read_band(
    band: gdal.Band,
    rows: int,
    cols: int,
    progress_cb: Optional[Callable[[int, Optional[str]], None]] = None,
    progress_start: int = 0,
    progress_end: int = 10,
) -> np.ndarray:
    """Read a raster band as a numpy array with incremental progress updates."""
    data = np.empty((rows, cols), dtype=np.float32)
    chunk_rows = max(1, min(1024, rows // 50 or 1))
    span = max(progress_end - progress_start, 1)

    for start_row in range(0, rows, chunk_rows):
        this_rows = min(chunk_rows, rows - start_row)
        buf = band.ReadRaster(0, start_row, cols, this_rows, cols, this_rows, gdal.GDT_Float32)
        if not buf:
            raise RasterAnalysisError("Failed to read raster data chunk.")
        arr = np.frombuffer(buf, dtype=np.float32, count=cols * this_rows).reshape(this_rows, cols)
        data[start_row : start_row + this_rows] = arr

        if progress_cb is not None:
            ratio = (start_row + this_rows) / max(rows, 1)
            progress_value = progress_start + ratio * span
            emit_progress(progress_cb, progress_value, "Reading raster data…")

    return data


def write_raster(path: str, arr: np.ndarray, gt: Tuple[float, ...], proj: str, nodata: float = 0) -> None:
    """Write a numpy array out to GeoTIFF."""
    rows, cols = arr.shape
    drv = gdal.GetDriverByName("GTiff")
    ds = drv.Create(path, cols, rows, 1, gdal.GDT_Int32, options=GTIFF_OPTIONS)
    if ds is None:
        raise RasterAnalysisError(f"Unable to create output dataset: {path}")
    ds.SetGeoTransform(gt)
    ds.SetProjection(proj)
    band = ds.GetRasterBand(1)
    band.SetNoDataValue(int(nodata))

    if arr.dtype != np.int32:
        arr = arr.astype(np.int32)

    buf = np.ascontiguousarray(arr).tobytes()
    band.WriteRaster(0, 0, cols, rows, buf, cols, rows)
    band.FlushCache()
    ds = None


def define_habitat(data: np.ndarray, nodata_mask: np.ndarray, params: RasterRunParams) -> np.ndarray:
    """Identify patch pixels based on the selected configuration."""
    valid = ~nodata_mask
    patches = np.zeros(data.shape, dtype=np.uint8)
    mode = params.patch_mode.lower()
    tol = params.value_tolerance

    if mode == "value" and params.patch_values:
        for val in params.patch_values:
            patches |= (np.abs(data - val) < tol) & valid
        print(f"  Patch = values {params.patch_values}")
    elif mode == "range" and params.range_lower is not None and params.range_upper is not None:
        patches = ((data >= params.range_lower) & (data <= params.range_upper)) & valid
        print(
            "  Patch = range "
            f"{params.range_lower:.4f} - {params.range_upper:.4f}"
        )
    else:
        raise RasterAnalysisError("Patch configuration did not yield any valid pixels.")

    return patches


def define_obstacles(data: np.ndarray, nodata_mask: np.ndarray, patch_mask: np.ndarray, params: RasterRunParams) -> np.ndarray:
    """Create a boolean mask for obstacle pixels corridors must avoid."""
    if not params.obstacle_enabled:
        return np.zeros(data.shape, dtype=bool)

    mask = np.zeros(data.shape, dtype=bool)
    tol = params.value_tolerance
    mode = params.obstacle_mode.lower()

    if mode == "range" and params.obstacle_range_lower is not None and params.obstacle_range_upper is not None:
        lower = min(params.obstacle_range_lower, params.obstacle_range_upper)
        upper = max(params.obstacle_range_lower, params.obstacle_range_upper)
        mask = (data >= lower) & (data <= upper)
    elif mode == "value" and params.obstacle_values:
        for val in params.obstacle_values:
            mask |= np.abs(data - val) < tol
    else:
        return np.zeros(data.shape, dtype=bool)

    mask &= ~nodata_mask
    return mask


def find_shortest_corridor(
    start_patches: Set[int],
    labels: np.ndarray,
    habitat: np.ndarray,
    max_width: int,
    connectivity: int,
    obstacle_mask: Optional[np.ndarray] = None,
    passable_mask: Optional[np.ndarray] = None,
    hop_distance: int = 0,
) -> List[Tuple[frozenset, int, float]]:
    """
    Dijkstra search to find shortest corridors connecting start_patches to other patches.
    Allows hopping over obstacles up to hop_distance cells.
    Returns a list of (path_pixels, target_patch, length).
    """
    rows, cols = labels.shape
    if connectivity == 4:
        moves = [(-1, 0), (1, 0), (0, -1), (0, 1)]
    else:
        moves = [
            (-1, -1),
            (-1, 0),
            (-1, 1),
            (0, -1),
            (0, 1),
            (1, -1),
            (1, 0),
            (1, 1),
        ]

    start_positions: List[Tuple[int, int]] = []
    for i in range(rows):
        for j in range(cols):
            if not habitat[i, j]:
                for di, dj in moves:
                    ni, nj = i + di, j + dj
                    if 0 <= ni < rows and 0 <= nj < cols and labels[ni, nj] in start_patches:
                        start_positions.append((i, j))
                        break

    if not start_positions:
        return []

    heap: List[Tuple[float, int, int, frozenset, int]] = []
    best_cost: Dict[Tuple[int, int], float] = {}
    results: List[Tuple[frozenset, int, float]] = []
    visited_targets: Set[int] = set()
    for r, c in start_positions:
        obstacle_count = 1 if (obstacle_mask is not None and obstacle_mask[r, c]) else 0
        if obstacle_count > hop_distance:
            continue
        if passable_mask is not None and not passable_mask[r, c]:
            continue
        path = frozenset({(r, c)})
        heapq.heappush(heap, (0.0, r, c, path, obstacle_count))
        best_cost[(r, c)] = 0.0

    while heap:
        cost, r, c, path, obstacle_count = heapq.heappop(heap)
        if cost > max_width:
            continue

        for dr, dc in moves:
            nr, nc = r + dr, c + dc
            if 0 <= nr < rows and 0 <= nc < cols:
                lbl = labels[nr, nc]
                if lbl > 0 and lbl not in start_patches and lbl not in visited_targets:
                    visited_targets.add(lbl)
                    results.append((path, lbl, cost))
                    continue
                new_obstacle_count = obstacle_count
                if obstacle_mask is not None and obstacle_mask[nr, nc]:
                    new_obstacle_count += 1
                    if new_obstacle_count > hop_distance:
                        continue
                if passable_mask is not None and not passable_mask[nr, nc]:
                    continue
                if habitat[nr, nc]:
                    if lbl == 0:
                        move_cost = 0.0
                        new_path = path  # do not draw corridors over stepping-stone habitat
                    else:
                        continue
                else:
                    move_cost = math.sqrt(2) if dr != 0 and dc != 0 else 1.0
                    new_path = path | frozenset({(nr, nc)})
                new_cost = cost + move_cost
                if new_cost > max_width:
                    continue
                prev_best = best_cost.get((nr, nc))
                if prev_best is not None and prev_best <= new_cost:
                    continue
                best_cost[(nr, nc)] = new_cost
                heapq.heappush(heap, (new_cost, nr, nc, new_path, new_obstacle_count))

    return results


def find_all_possible_corridors(
    labels: np.ndarray,
    habitat: np.ndarray,
    patch_sizes: Dict[int, int],
    max_width: int,
    min_corridor_width: int,
    max_area: Optional[int],
    connectivity: int,
    obstacle_mask: Optional[np.ndarray] = None,
    passable_mask: Optional[np.ndarray] = None,
    hop_distance: int = 0,
    hop_adjacency: Optional[Dict[int, Set[int]]] = None,
    stepping_enabled: bool = False,
    obstacles_present: bool = False,
    progress_cb: Optional[Callable[[int, Optional[str]], None]] = None,
    progress_start: int = 45,
    progress_end: int = 75,
) -> List[Dict]:
    """Find all possible corridors between patch pairs."""
    print("  Finding all possible corridors...")
    print(f"  Hop distance for obstacles: {hop_distance}px")
    all_corridors: List[Dict] = []
    processed_pairs: Set[frozenset] = set()
    offsets = _corridor_offsets(min_corridor_width)
    rows, cols = labels.shape

    unique_patches = [p for p in patch_sizes.keys() if p > 0]
    total = len(unique_patches) or 1
    span = max(progress_end - progress_start, 1)
    for idx, patch_id in enumerate(unique_patches):
        if (idx + 1) % 10 == 0:
            print(f"    Analyzing patch {idx + 1}/{len(unique_patches)}...", end="\r")
        if progress_cb is not None:
            pre_value = progress_start + ((idx + 0.25) / total) * span
            emit_progress(
                progress_cb,
                pre_value,
                f"Analyzing patch {idx + 1}/{total}…",
            )

        results = find_shortest_corridor(
            {patch_id},
            labels,
            habitat,
            max_width,
            connectivity,
            obstacle_mask=obstacle_mask,
            passable_mask=passable_mask,
            hop_distance=hop_distance,
        )
        if not results:
            continue

        for path_pixels, target_id, path_len in results:
            pair = frozenset({patch_id, target_id})
            if pair in processed_pairs:
                continue
            processed_pairs.add(pair)

            is_hop_edge = bool(hop_adjacency) and target_id in hop_adjacency.get(patch_id, set())
            buffered = _inflate_corridor_pixels(set(path_pixels), offsets, rows, cols, obstacle_mask=obstacle_mask)
            area_px = len(buffered)

            if max_area is not None and area_px > max_area and not is_hop_edge:
                continue

            # Hop-adjacent patches are treated as near-contiguous when stepping is enabled and no obstacles intervene.
            if stepping_enabled and hop_distance > 0 and is_hop_edge and not obstacles_present:
                area_px = 0

            all_corridors.append(
                {
                    "patch1": patch_id,
                    "patch2": target_id,
                    "centerline_pixels": path_pixels,
                    "pixels": frozenset(buffered),
                    "length": path_len,
                    "area": area_px,
                    "is_hop_edge": is_hop_edge,
                }
            )
        if progress_cb is not None:
            post_value = progress_start + ((idx + 1) / total) * span
            emit_progress(
                progress_cb,
                post_value,
                f"Finished patch {idx + 1}/{total}",
            )

    emit_progress(progress_cb, progress_end, "Corridor candidates ready.")
    print(f"\n  ✓ Found {len(all_corridors)} possible corridors")
    return all_corridors


def _annotate_candidates_with_components(
    candidates: List[Dict],
    component_map: Dict[int, int],
) -> None:
    """Attach component IDs and patch ID sets to corridor candidates."""
    for cand in candidates:
        p1, p2 = cand.get("patch1"), cand.get("patch2")
        cand["patch_ids"] = set(cand.get("patch_ids", set()))
        if not cand["patch_ids"]:
            cand["patch_ids"] = {p1, p2}
        comp1 = component_map.get(p1, p1)
        comp2 = component_map.get(p2, p2)
        cand["comp1"] = comp1
        cand["comp2"] = comp2
        cand["component_ids"] = {component_map.get(pid, pid) for pid in cand["patch_ids"]}
        cand.setdefault("is_stepping_chain", False)


def _build_component_pair_index(candidates: List[Dict]) -> Set[frozenset]:
    """Return the set of component pairs that have a direct corridor candidate."""
    pairs: Set[frozenset] = set()
    for cand in candidates:
        comp1 = cand.get("comp1")
        comp2 = cand.get("comp2")
        if comp1 is None or comp2 is None or comp1 == comp2:
            continue
        pairs.add(frozenset({comp1, comp2}))
    return pairs


def _build_raster_adjacency(candidates: List[Dict]) -> Dict[int, Set[int]]:
    """Build a graph of component connectivity from candidate corridors."""
    adj: Dict[int, Set[int]] = defaultdict(set)
    for cand in candidates:
        c1 = cand.get("comp1")
        c2 = cand.get("comp2")
        if c1 is None or c2 is None or c1 == c2:
            continue
        adj[c1].add(c2)
        adj[c2].add(c1)
    return adj


def _compute_strategic_values_raster(
    component_sizes: Dict[int, int],
    adjacency: Dict[int, Set[int]],
    decay: float = 0.5,
    iterations: int = 2,
) -> Dict[int, float]:
    """Propagate pixel counts so small components inherit value from larger neighbours."""
    values = {k: float(v) for k, v in component_sizes.items()}
    for _ in range(iterations):
        current = values.copy()
        for cid, nbrs in adjacency.items():
            bonus = sum(current.get(n, 0.0) * decay for n in nbrs)
            values[cid] = max(values[cid], float(component_sizes.get(cid, 0)) + bonus)
    return values


def _build_hop_edge_lookup(
    candidates: List[Dict],
    hop_adjacency: Dict[int, Set[int]],
) -> Dict[frozenset, Dict]:
    """Index corridor candidates that connect hop-distance neighbours."""
    if not hop_adjacency:
        for cand in candidates:
            cand["is_hop_edge"] = False
        return {}

    lookup: Dict[frozenset, Dict] = {}
    for cand in candidates:
        p1, p2 = cand.get("patch1"), cand.get("patch2")
        is_hop = p2 in hop_adjacency.get(p1, set())
        cand["is_hop_edge"] = is_hop
        if is_hop:
            lookup[frozenset({p1, p2})] = cand
    return lookup


def create_step_stone_layer(
    labels: np.ndarray,
    component_map: Dict[int, int],
    component_sizes: Dict[int, int],
    gt: Tuple[float, ...],
    proj: str,
    output_dir: str,
    layer_name: str = "Step stone connectivity",
) -> str:
    """Create a raster showing merged patches within hop distance."""
    rows, cols = labels.shape
    step_stone_output = np.zeros_like(labels, dtype=np.int32)

    for r in range(rows):
        for c in range(cols):
            pid = labels[r, c]
            if pid > 0:
                step_stone_output[r, c] = component_map.get(pid, 0)

    os.makedirs(output_dir, exist_ok=True)
    out_path = os.path.join(output_dir, "step_stone_connectivity.tif")
    write_raster(out_path, step_stone_output, gt, proj, nodata=0)

    try:
        result_layer = QgsRasterLayer(out_path, layer_name)
        if result_layer.isValid():
            QgsProject.instance().addMapLayer(result_layer)
    except Exception:
        pass

    return out_path


def _build_stepping_chains(
    hop_lookup: Dict[frozenset, Dict],
    component_map: Dict[int, int],
    component_sizes: Optional[Dict[int, int]] = None,
) -> List[Dict]:
    """Create multi-segment corridors using hop-distance stepping stones."""
    if not hop_lookup:
        return []

    comp_sizes = component_sizes or {}
    graph: Dict[int, Set[int]] = defaultdict(set)
    for pair in hop_lookup.keys():
        if len(pair) != 2:
            continue
        a, b = tuple(pair)
        graph[a].add(b)
        graph[b].add(a)

    chains: List[Dict] = []
    seen_paths: Set[Tuple[int, ...]] = set()

    start_nodes = sorted(graph.keys(), key=lambda pid: -comp_sizes.get(component_map.get(pid, pid), 0))

    for start in start_nodes:
        stack: List[Tuple[int, List[int], int, float, Set[int], Set[Tuple[int, int]], Set[Tuple[int, int]]]] = []
        stack.append((start, [start], 0, 0.0, {start}, set(), set()))
        while stack:
            node, path, cost_so_far, length_so_far, patch_ids, pixels_accum, centerline_accum = stack.pop()
            path_key = tuple(path)
            if path_key in seen_paths:
                continue
            seen_paths.add(path_key)

            for nbr in graph.get(node, set()):
                if nbr in path:
                    continue
                edge = hop_lookup.get(frozenset({node, nbr}))
                if edge is None:
                    continue
                edge_cost = int(edge.get("area", 0))
                edge_length = float(edge.get("length", 0.0))
                new_cost = cost_so_far + edge_cost
                new_length = length_so_far + edge_length
                new_path = path + [nbr]
                new_pids = set(patch_ids)
                new_pids.update(edge.get("patch_ids", set()))

                new_pixels = set(pixels_accum)
                new_pixels |= set(edge.get("pixels", set()))

                new_centerline = set(centerline_accum)
                new_centerline |= set(edge.get("centerline_pixels", set()))

                comp_start = component_map.get(path[0], path[0])
                comp_end = component_map.get(nbr, nbr)
                if comp_start != comp_end and len(new_path) >= 2:
                    chains.append(
                        {
                            "patch1": path[0],
                            "patch2": nbr,
                            "patch_ids": new_pids,
                            "component_ids": {component_map.get(pid, pid) for pid in new_pids},
                            "centerline_pixels": new_centerline,
                            "pixels": new_pixels,
                            "length": new_length,
                            "area": new_cost,
                            "segments": new_path,
                            "comp1": comp_start,
                            "comp2": comp_end,
                            "is_stepping_chain": True,
                        }
                    )
                stack.append((nbr, new_path, new_cost, new_length, new_pids, new_pixels, new_centerline))
    return chains


def optimize_most_connectivity(
    candidates: List[Dict],
    patch_sizes: Dict[int, int],
    budget: int,
    component_sizes: Optional[Dict[int, int]] = None,
    component_patch_counts: Optional[Dict[int, int]] = None,
) -> Tuple[Dict, Dict]:
    """Strategy 1: Most Connectivity - backbone + loop closure via shared optimizer."""
    print("  Strategy: MOST CONNECTIVITY (Backbone + Loops)")

    node_map = component_sizes or patch_sizes
    optimizer = NetworkOptimizer(node_map)

    for idx, cand in enumerate(candidates):
        u = cand.get("comp1") or cand.get("patch1")
        v = cand.get("comp2") or cand.get("patch2")
        cost = cand.get("area") or cand.get("length") or 0
        optimizer.add_candidate(int(u), int(v), idx, float(cost), bool(cand.get("is_stepping_chain")))

    selected_ids, final_sizes, final_counts, budget_used = optimizer.solve(budget, loop_fraction=0.05)

    selected: Dict[int, Dict] = {}
    for sid in selected_ids:
        cand = candidates[sid]
        selected[len(selected) + 1] = {
            "pixels": cand["pixels"],
            "patch_ids": set(cand.get("patch_ids", set())),
            "length": cand.get("length", 0),
            "connected_size": 0,
        }

    for entry in selected.values():
        roots = [optimizer.find(pid) for pid in entry["patch_ids"]]
        if roots:
            entry["connected_size"] = int(final_sizes.get(roots[0], entry["length"]))

    root_sizes: Dict[int, int] = {root: int(size) for root, size in final_sizes.items()}
    root_counts: Dict[int, int] = {root: final_counts.get(root, 1) for root in final_counts}
    n_nodes = len(node_map)
    components = len(set(optimizer.find(n) for n in node_map))
    edges_used = len(selected_ids)
    redundancy = max(0, edges_used - (n_nodes - components))
    avg_degree = (2 * edges_used / n_nodes) if n_nodes > 0 else 0.0

    return selected, {
        "strategy": "resilient_network",
        "corridors_used": len(selected),
        "connections_made": len(selected),
        "budget_used": budget_used,
        "total_connected_size": sum(root_sizes.values()),
        "groups_created": len(root_sizes),
        "largest_group_size": max(root_sizes.values()) if root_sizes else 0,
        "largest_group_patches": max(root_counts.values()) if root_counts else 0,
        "patches_connected": sum(root_counts.values()),
        "patches_total": n_nodes,
        "components_remaining": components,
        "redundant_links": redundancy,
        "avg_degree": avg_degree,
    }


def optimize_largest_patch(
    candidates: List[Dict],
    patch_sizes: Dict[int, int],
    budget: int,
    component_sizes: Optional[Dict[int, int]] = None,
    component_patch_counts: Optional[Dict[int, int]] = None,
) -> Tuple[Dict, Dict]:
    """Strategy 3: Largest Patch - expand the largest component via priority queue."""
    print("  Strategy: LARGEST PATCH (priority growth)")
    if not candidates:
        return {}, {"strategy": "largest_patch", "corridors_used": 0}

    base_component_sizes = component_sizes or patch_sizes
    base_component_counts = component_patch_counts or {pid: 1 for pid in base_component_sizes}

    adjacency: Dict[int, List[Dict]] = {}
    for cand in candidates:
        for comp_id in cand.get("component_ids", {cand.get("comp1"), cand.get("comp2")}) or []:
            adjacency.setdefault(comp_id, []).append(cand)

    sorted_component_ids = sorted(base_component_sizes.keys(), key=lambda k: base_component_sizes[k], reverse=True)
    seed_candidates = sorted_component_ids[:1]

    print("  Testing 1 seed patch (largest)...")
    best_result = {"corridors": {}, "final_size": 0, "seed_id": None, "budget_used": 0}

    for i, seed_id in enumerate(seed_candidates):

        component: Set[int] = {seed_id}
        component_size = base_component_sizes.get(seed_id, 0)
        component_patch_total = base_component_counts.get(seed_id, 1)
        remaining_budget = float(budget)
        sim_corridors: Dict[int, Dict] = {}
        counter = 0
        pq: List[Tuple[float, float, int, Dict]] = []

        def enqueue_neighbors(patch_id: int) -> None:
            nonlocal counter, component_size
            for cand in adjacency.get(patch_id, []):
                target_components = cand.get("component_ids") or {cand.get("comp1"), cand.get("comp2")}
                new_components = set(target_components or set()) - component
                if not new_components:
                    continue
                corridor_area = cand.get("area", cand.get("length", 0))
                potential_size = component_size + sum(base_component_sizes.get(tc, 0) for tc in new_components) + corridor_area
                heapq.heappush(pq, (-potential_size, corridor_area, counter, cand))
                counter += 1

        enqueue_neighbors(seed_id)

        while pq and remaining_budget > 0:
            neg_potential, cost, _, cand = heapq.heappop(pq)
            cand_components = cand.get("component_ids") or {cand.get("comp1"), cand.get("comp2")}
            new_components = set(cand_components or set()) - component
            if not new_components:
                continue
            if cost > remaining_budget:
                continue

            remaining_budget -= cost
            component.update(new_components)
            component_size += sum(base_component_sizes.get(tc, 0) for tc in new_components) + cost
            component_patch_total += sum(base_component_counts.get(tc, 1) for tc in new_components)
            sim_corridors[len(sim_corridors) + 1] = {
                "pixels": cand["pixels"],
                "patch_ids": set(cand.get("patch_ids", set())),
                "length": cand.get("length", 0),
                "connected_size": component_size,
            }
            for comp_id in new_components:
                enqueue_neighbors(comp_id)

        if component_size > best_result["final_size"]:
            best_result = {
                "corridors": sim_corridors,
                "final_size": component_size,
                "seed_id": seed_id,
                "budget_used": budget - remaining_budget,
                "patch_count": component_patch_total,
            }

    print(f"\n  ✓ Best: Patch {best_result['seed_id']} -> {best_result['final_size']:,} px")

    if not best_result["corridors"]:
        return {}, {"strategy": "largest_patch", "corridors_used": 0}

    selected: Dict[int, Dict] = {}
    for i, corr in best_result["corridors"].items():
        selected[i] = {
            "pixels": corr["pixels"],
            "patch_ids": corr["patch_ids"],
            "length": corr["length"],
            "connected_size": best_result["final_size"],
        }

    return selected, {
        "strategy": "largest_patch",
        "seed_id": best_result["seed_id"],
        "final_patch_size": best_result["final_size"],
        "corridors_used": len(selected),
        "budget_used": best_result["budget_used"],
        "groups_created": 1,
        "patches_connected": best_result.get("patch_count", len(selected) + 1),
        "largest_group_size": best_result["final_size"],
    }


def _corridor_offsets(min_corridor_width: int) -> List[Tuple[int, int]]:
    """Precompute offsets used to inflate corridors to the minimum width."""
    width = max(1, int(min_corridor_width))
    if width <= 1:
        return [(0, 0)]

    radius = max(0.0, width / 2.0)
    max_offset = int(math.ceil(radius))
    radius_sq = radius * radius

    offsets: List[Tuple[int, int]] = []
    for dr in range(-max_offset, max_offset + 1):
        for dc in range(-max_offset, max_offset + 1):
            if dr * dr + dc * dc <= radius_sq + 1e-9:
                offsets.append((dr, dc))

    if not offsets:
        offsets.append((0, 0))
    return offsets


def _shift_mask(mask: np.ndarray, dr: int, dc: int) -> np.ndarray:
    """Return a shifted copy of mask aligned so (r, c) reads (r+dr, c+dc) from original."""
    rows, cols = mask.shape
    shifted = np.zeros_like(mask, dtype=bool)

    if abs(dr) >= rows or abs(dc) >= cols:
        return shifted

    if dr >= 0:
        src_r = slice(dr, rows)
        dst_r = slice(0, rows - dr)
    else:
        src_r = slice(0, rows + dr)
        dst_r = slice(-dr, rows)

    if dc >= 0:
        src_c = slice(dc, cols)
        dst_c = slice(0, cols - dc)
    else:
        src_c = slice(0, cols + dc)
        dst_c = slice(-dc, cols)

    shifted[dst_r, dst_c] = mask[src_r, src_c]
    return shifted


def _erode_mask(mask: np.ndarray, offsets: List[Tuple[int, int]]) -> np.ndarray:
    """Morphologically erode a mask using the provided offsets."""
    if not offsets:
        return mask.copy()

    eroded = mask.copy()
    for dr, dc in offsets:
        if dr == 0 and dc == 0:
            continue
        shifted = _shift_mask(mask, dr, dc)
        eroded &= shifted
        if not eroded.any():
            break
    return eroded


def _build_passable_mask(
    habitat: np.ndarray,
    obstacle_mask: np.ndarray,
    min_corridor_width: int,
    allow_bottlenecks: bool,
) -> np.ndarray:
    """Derive which pixels can host corridor centerlines under width constraints."""
    habitat_bool = habitat.astype(bool)
    obstacle_bool = obstacle_mask.astype(bool) if obstacle_mask is not None else np.zeros_like(habitat_bool)
    base_passable = (~habitat_bool) & (~obstacle_bool)

    if allow_bottlenecks or min_corridor_width <= 1:
        return base_passable

    offsets = _corridor_offsets(min_corridor_width)
    # Ignore habitat when evaluating clearance so corridors can still touch patches.
    true_obstacles = obstacle_bool & (~habitat_bool)
    clearance_space = ~true_obstacles
    clearance_ok = _erode_mask(clearance_space, offsets)
    return base_passable & clearance_ok


def _inflate_corridor_pixels(
    pixels: Set[Tuple[int, int]],
    offsets: List[Tuple[int, int]],
    rows: int,
    cols: int,
    obstacle_mask: Optional[np.ndarray] = None,
) -> Set[Tuple[int, int]]:
    """Apply the minimum width offsets to a set of centerline pixels."""
    inflated: Set[Tuple[int, int]] = set()
    use_mask = obstacle_mask is not None

    for r, c in pixels:
        for dr, dc in offsets:
            nr, nc = r + dr, c + dc
            if 0 <= nr < rows and 0 <= nc < cols:
                if use_mask and obstacle_mask[nr, nc]:
                    continue
                inflated.add((nr, nc))
    return inflated


def create_output_raster(
    labels: np.ndarray,
    corridors: Dict[int, Dict],
    min_corridor_width: int,
    obstacle_mask: Optional[np.ndarray] = None,
) -> np.ndarray:
    """Create an output raster with corridors marked by connected size."""
    output = np.zeros_like(labels, dtype=np.int32)
    rows, cols = labels.shape
    use_mask = obstacle_mask is not None

    for corridor_data in corridors.values():
        score = corridor_data["connected_size"]
        for r, c in corridor_data["pixels"]:
            if 0 <= r < rows and 0 <= c < cols:
                if use_mask and obstacle_mask[r, c]:
                    continue
                if score > output[r, c]:
                    output[r, c] = score
    return output


def _to_dataclass(params: Dict) -> RasterRunParams:
    """Convert raw parameter dict into the expected dataclass."""
    return RasterRunParams(
        patch_connectivity=int(params.get("patch_connectivity", 4)),
        patch_mode=str(params.get("patch_mode", "value")).lower(),
        patch_values=list(params.get("patch_values", [])),
        range_lower=params.get("range_lower"),
        range_upper=params.get("range_upper"),
        obstacle_enabled=bool(params.get("obstacle_enabled", False)),
        obstacle_mode=str(params.get("obstacle_mode", "value")).lower(),
        obstacle_values=list(params.get("obstacle_values", [])),
        obstacle_range_lower=params.get("obstacle_range_lower"),
        obstacle_range_upper=params.get("obstacle_range_upper"),
        value_tolerance=float(params.get("value_tolerance", 1e-6)),
        nodata_fallback=float(params.get("nodata_fallback", -9999)),
        min_patch_size=int(params.get("min_patch_size", 0)),
    allow_sub_min_corridor=True,
        budget_pixels=int(params.get("budget_pixels", 0)),
        max_search_distance=int(params.get("max_search_distance", 50)),
        max_corridor_area=(
            int(params["max_corridor_area"]) if params.get("max_corridor_area") is not None else None
        ),
        min_corridor_width=int(params.get("min_corridor_width", 1)),
        allow_bottlenecks=bool(params.get("allow_bottlenecks", True)),
        stepping_enabled=bool(params.get("stepping_enabled", False)),
        hop_distance=int(params.get("hop_distance", 0) or 0),
    )


def run_raster_analysis(
    layer: QgsRasterLayer,
    output_dir: str,
    raw_params: Dict,
    strategy: str = "most_connectivity",
    temporary: bool = False,
    iface=None,
    progress_cb: Optional[Callable[[int, Optional[str]], None]] = None,
) -> List[Dict]:
    """Execute the raster corridor analysis for the provided layer."""
    if not isinstance(layer, QgsRasterLayer) or not layer.isValid():
        raise RasterAnalysisError("Selected layer is not a valid raster layer.")

    params = _to_dataclass(raw_params)
    overall_start = time.time()

    src_path = layer.source()
    ds = gdal.Open(src_path)
    if ds is None:
        raise RasterAnalysisError(f"Cannot open raster source: {src_path}")

    rows, cols = ds.RasterYSize, ds.RasterXSize
    gt = ds.GetGeoTransform()
    proj = ds.GetProjection()
    print("=" * 70)
    print("LINKSCAPE RASTER ANALYSIS v23.0")
    print("=" * 70)
    print("\n1. Loading raster...")
    print(f"  ✓ Using layer: {layer.name()}")
    total_pixels = rows * cols
    print(f"  Size: {rows:,} x {cols:,} = {total_pixels:,} pixels")

    pixel_w = abs(gt[1])
    pixel_h = abs(gt[5]) if gt[5] != 0 else pixel_w
    pixel_size = max(pixel_w, pixel_h)

    warnings: List[str] = []
    if total_pixels >= PIXEL_COUNT_WARNING_THRESHOLD:
        warnings.append(
            "Large raster detected (>{:,} pixels).".format(
                PIXEL_COUNT_WARNING_THRESHOLD
            )
        )
    if 0 < pixel_size < PIXEL_SIZE_WARNING_THRESHOLD:
        warnings.append(
            f"High-resolution data detected (≈{pixel_size:.2f} map units per pixel). "
            "Consider resampling to a coarser resolution for faster corridor modelling."
        )

    if warnings:
        warning_text = " ".join(warnings)
        if iface and hasattr(iface, "messageBar"):
            try:
                iface.messageBar().pushWarning("Linkscape", warning_text)
            except Exception:
                print(f"WARNING: {warning_text}")
        else:
            print(f"WARNING: {warning_text}")
        raise RasterAnalysisError(
            f"Raster is too large/fine for Linkscape to process efficiently. {warning_text} "
            "Please resample to a coarser resolution or process in smaller chunks."
        )
    emit_progress(progress_cb, 5, "Loading raster data…")

    band = ds.GetRasterBand(1)
    nodata = band.GetNoDataValue()
    if nodata is None:
        nodata = params.nodata_fallback

    print("  Reading data...")
    data = read_band(
        band,
        rows,
        cols,
        progress_cb=progress_cb,
        progress_start=5,
        progress_end=18,
    )
    nodata_mask = np.abs(data - nodata) < params.value_tolerance if nodata is not None else np.zeros_like(
        data, dtype=bool
    )
    emit_progress(progress_cb, 20, "Defining habitat patches…")

    print("\n2. Identifying patch pixels...")
    patch_mask = define_habitat(data, nodata_mask, params)
    habitat_mask = patch_mask.astype(np.uint8)
    patch_pixels = int(np.sum(habitat_mask))
    if patch_pixels == 0:
        raise RasterAnalysisError("No patch pixels found with the current configuration.")
    print(f"  ✓ Patch pixels: {patch_pixels:,}")
    emit_progress(progress_cb, 25, "Processing habitat patches…")

    obstacle_mask = define_obstacles(data, nodata_mask, habitat_mask, params)
    obstacles_present = bool(obstacle_mask is not None and np.any(obstacle_mask))
    if params.obstacle_enabled:
        obstacle_pixels = int(np.sum(obstacle_mask))
        if obstacle_pixels:
            print(f"  ✓ Obstacle pixels: {obstacle_pixels:,}")
        else:
            print("  ⚠ Obstacle configuration matched no pixels; proceeding without obstacles.")
    else:
        print("  ✓ Obstacles disabled.")
    emit_progress(progress_cb, 35, "Labeling patches…")

    print("\n3. Labeling patches...")
    t0 = time.time()
    labels, n_patches = label_patches(habitat_mask, params.patch_connectivity)
    print(f"  ✓ Patches: {n_patches:,} in {time.time() - t0:.2f}s")

    if params.min_patch_size > 0:
        unique_labels, counts = np.unique(labels[labels > 0], return_counts=True)
        valid_labels = unique_labels[counts >= params.min_patch_size]
        new_labels = np.zeros_like(labels)
        for new_id, old_id in enumerate(valid_labels, 1):
            new_labels[labels == old_id] = new_id
        labels = new_labels
        patch_mask = (labels > 0).astype(np.uint8)
        print(f"  ✓ After filter: {len(valid_labels):,} patches")
        sub_min_mask = habitat_mask.astype(bool) & (labels == 0)
    else:
        patch_mask = habitat_mask
        sub_min_mask = np.zeros_like(patch_mask, dtype=bool)

    # Keep remaining habitat as obstacles so corridors avoid intact patches, but filtered-out
    # small patches stay passable/stepping stones.
    obstacle_mask = obstacle_mask.astype(bool) | patch_mask.astype(bool)
    # Always keep sub-minimum patches passable for corridors/stepping stones.

    passable_mask = _build_passable_mask(
        patch_mask,
        obstacle_mask,
        params.min_corridor_width,
        params.allow_bottlenecks,
    )

    debug_label_env = os.environ.get("LINKSCAPE_SAVE_PATCH_LABELS")
    if debug_label_env:
        try:
            label_out_dir = output_dir or os.path.dirname(src_path)
            os.makedirs(label_out_dir, exist_ok=True)
            label_path = os.path.join(label_out_dir, "linkscape_patch_labels.tif")
            write_raster(label_path, labels, gt, proj, nodata=0)
            print(f"  ✓ Saved patch ID raster: {label_path}")
        except Exception as label_exc:  # noqa: BLE001
            print(f"  ⚠ Could not save patch ID raster: {label_exc}")

    unique_labels, counts = np.unique(labels[labels > 0], return_counts=True)
    patch_sizes = dict(zip(unique_labels.tolist(), counts.tolist()))
    if not patch_sizes:
        raise RasterAnalysisError("No valid patches remain after filtering.")
    hop_adjacency: Dict[int, Set[int]] = {}
    component_map: Dict[int, int] = {pid: pid for pid in patch_sizes}
    component_sizes: Dict[int, int] = dict(patch_sizes)
    component_counts: Dict[int, int] = {pid: 1 for pid in patch_sizes}

    if params.stepping_enabled and params.hop_distance > 0:
        emit_progress(progress_cb, 42, "Checking stepping-stone proximity…")
        print("\n3b. Stepping-stone preprocessing...")
        hop_adjacency = _compute_hop_adjacency(labels, params.hop_distance)
        if hop_adjacency:
            component_map, component_sizes, component_counts = _build_virtual_components(patch_sizes, hop_adjacency)
            print(
                f"  ✓ Virtual merges: {len(component_sizes):,} starting components "
                f"(hop distance ≤ {params.hop_distance}px)"
            )
            try:
                step_out_dir = output_dir or tempfile.gettempdir()
                step_path = create_step_stone_layer(
                    labels,
                    component_map,
                    component_sizes,
                    gt,
                    proj,
                    step_out_dir,
                    "Step stone connectivity (Below hop distance)",
                )
                original_patches = len(patch_sizes)
                virtual_components = len(component_sizes)
                patches_merged = original_patches - virtual_components
                stats_step = {
                    "original_patches": original_patches,
                    "virtual_components": virtual_components,
                    "patches_merged": patches_merged,
                    "hop_distance_px": params.hop_distance,
                    "step_stone_layer": step_path if not temporary else "temporary",
                }
                print(
                    f"  ✓ Step stone merging: {patches_merged} patches merged into {virtual_components} virtual components"
                )
            except Exception as step_exc:  # noqa: BLE001
                stats_step = {}
                print(f"  ⚠ Could not create step stone layer: {step_exc}")
        else:
            print("  ✓ No patches within hop distance; proceeding without virtual merges.")
            stats_step = {}
    else:
        stats_step = {}

    emit_progress(progress_cb, 45, "Searching for corridors…")

    print("\n4. Finding possible corridors...")
    candidates = find_all_possible_corridors(
        labels,
        habitat_mask,
        patch_sizes,
        params.max_search_distance,
        params.min_corridor_width,
        params.max_corridor_area,
        params.patch_connectivity,
        obstacle_mask=obstacle_mask,
        passable_mask=passable_mask,
        hop_distance=params.hop_distance if params.stepping_enabled else 0,
        hop_adjacency=hop_adjacency if params.stepping_enabled else None,
        stepping_enabled=params.stepping_enabled,
        obstacles_present=obstacles_present,
        progress_cb=progress_cb,
        progress_start=45,
        progress_end=75,
    )

    _annotate_candidates_with_components(candidates, component_map)
    direct_pairs = _build_component_pair_index(candidates)
    hop_lookup = _build_hop_edge_lookup(candidates, hop_adjacency)
    if params.stepping_enabled and params.hop_distance > 0:
        chains = _build_stepping_chains(hop_lookup, component_map, component_sizes)
        if chains:
            print(f"  ✓ Added {len(chains):,} stepping-stone corridor chains")
            candidates.extend(chains)

    if not candidates:
        raise RasterAnalysisError("No feasible corridors found with the current configuration.")
    emit_progress(progress_cb, 78, "Optimizing corridor selection…")

    strategy = (strategy or "most_connectivity").lower()
    strategy_map = {
        "most_connectivity": (
            optimize_most_connectivity,
            "linkscape_most_connectivity.tif",
            "Corridors (Resilient Network)",
        ),
        "largest_patch": (
            optimize_largest_patch,
            "linkscape_largest_patch.tif",
            "Corridors (Largest Patch)",
        ),
    }

    if strategy not in strategy_map:
        raise RasterAnalysisError(f"Unsupported strategy '{strategy}'.")

    optimize_func, default_filename, layer_name = strategy_map[strategy]

    print("\n5. Running optimization...")
    print("=" * 70)
    print(f"--- {strategy.replace('_', ' ').upper()} ---")
    if strategy == "largest_patch":
        print("  Strategy: SINGLE DOMINANT NETWORK (grow one cohesive network from the largest patch)")

    corridors, stats = optimize_func(
        candidates,
        patch_sizes,
        params.budget_pixels,
        component_sizes=component_sizes,
        component_patch_counts=component_counts,
    )
    if not corridors:
        raise RasterAnalysisError("Selected optimization did not produce any corridors.")
    if params.stepping_enabled and params.hop_distance > 0:
        stats["note"] = f"Stepping-stone connectivity enabled (hop distance {params.hop_distance}px)"
        stats["patches_merged"] = len(patch_sizes) - len(component_sizes)
        stats["step_stone_stats"] = stats_step
    emit_progress(progress_cb, 85, "Rendering output raster…")

    print("  Creating output raster...")
    output = create_output_raster(
        labels, corridors, params.min_corridor_width, obstacle_mask=obstacle_mask
    )

    if temporary:
        temp_file = tempfile.NamedTemporaryFile(prefix="linkscape_", suffix=".tif", delete=False)
        out_path = temp_file.name
        temp_file.close()
    else:
        out_dir = output_dir or os.path.dirname(src_path)
        os.makedirs(out_dir, exist_ok=True)
        out_path = os.path.join(out_dir, default_filename)

    write_raster(out_path, output, gt, proj, nodata=0)
    print(f"  ✓ Saved: {out_path}")
    emit_progress(progress_cb, 95, "Finishing up…")

    try:
        result_layer = QgsRasterLayer(out_path, layer_name)
        if result_layer.isValid():
            QgsProject.instance().addMapLayer(result_layer)
            print("  ✓ Added to project")
    except Exception as add_exc:  # noqa: BLE001
        print(f"  ⚠ Could not add layer to project: {add_exc}")

    elapsed = time.time() - overall_start
    stats = dict(stats)
    stats["output_path"] = out_path if not temporary else ""
    stats["layer_name"] = layer_name
    stats["budget_total"] = params.budget_pixels

    print("\n" + "=" * 70)
    print("FINAL SUMMARY")
    print("=" * 70)
    stats_strategy = stats.get("strategy", strategy)
    strategy_label = (
        "Resilient Network"
        if stats_strategy == "resilient_network"
        else "Single Dominant Network"
        if stats_strategy == "largest_patch"
        else stats_strategy.replace("_", " ").title()
    )
    print(f"Strategy:          {strategy_label}")
    print(f"Corridors created: {stats.get('corridors_used', 0)}")
    if stats_strategy in ("most_connectivity", "resilient_network"):
        print(f"Connections:       {stats.get('connections_made', 0)}")
        print(f"Largest patch:     {stats.get('largest_group_size', 0):,} px")
        if "redundant_links" in stats:
            print(f"Redundant links:   {stats.get('redundant_links', 0)}")
        if "avg_degree" in stats:
            print(f"Average degree:    {stats.get('avg_degree', 0):.2f}")
    elif "connections_made" in stats:
        print(f"Connections:       {stats.get('connections_made', 0)}")
    if "seed_id" in stats:
        print(f"Seed patch:        {stats.get('seed_id')}")
    print(f"Final size:        {stats.get('final_patch_size', 0):,} px")
    print(f"Budget used:       {stats.get('budget_used', 0)}/{params.budget_pixels} px")
    print(f"Processing time:   {elapsed:.1f}s")
    if "step_stone_stats" in stats:
        ss = stats["step_stone_stats"]
        print(
            f"Step stone merge:  {ss.get('patches_merged', 0)} patches → {ss.get('virtual_components', 0)} components "
            f"(hop {ss.get('hop_distance_px', ss.get('hop_distance_m', 0))})"
        )
    if temporary:
        print("Output:            Temporary raster layer")
    else:
        print(f"Output GeoTIFF:    {out_path}")
    print("=" * 70)

    ds = None
    emit_progress(progress_cb, 100, "Raster analysis complete.")
    return [{"strategy": strategy, "stats": stats, "output_path": out_path if not temporary else ""}]
