"""
/***************************************************************************
 Segmenter
                     A QGIS plugin
 This plugin segments the map into discrete buckets
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                  -------------------
     begin                : 2023-05-26
     git sha              : $Format:%H$
     copyright            : (C) 2023 by Quant Civil
     email                : joshua.herrera@quantcivil.ai
 ***************************************************************************/
"""
from collections import deque
import math
import os
import re
import threading
import weakref
from typing import Optional
from qgis.PyQt.QtCore import (
    QSettings,
    QTranslator,
    QCoreApplication,
    QThread,
    QObject,
    pyqtSignal,
    QUrl,
    Qt,
    QEvent,
)
from qgis.PyQt.QtGui import QIcon, QPixmap, QDesktopServices, QColor, QCursor
from qgis.PyQt.QtWidgets import (
    QAction,
    QMessageBox,
    QGraphicsDropShadowEffect,
)
from qgis.core import (
    QgsTask,
    QgsApplication,
    QgsMessageLog,
    QgsProject,
    QgsRasterLayer,
    QgsMapLayer,
    Qgis,
)

# Initialize Qt resources from file resources.py
from .resources import *  # noqa: F401,F403 - Qt resource imports

# Import the code for the dialog
from .segmenter_dialog import SegmenterDialog

from datetime import datetime
from .dependency_manager import ensure_dependencies

ensure_dependencies()

import torch  # noqa: E402 - must import after ensure_dependencies
import numpy as np  # noqa: E402
from .funcs import (  # noqa: E402
    execute_kmeans_segmentation,
    SegmentationCanceled,
    _apply_optional_blur,
)
from .qgis_funcs import render_raster  # noqa: E402
from .map_to_raster import (  # noqa: E402
    is_file_backed_gdal_raster,
    is_renderable_non_file_layer,
    build_convert_map_to_raster_params,
    extract_layer_metadata,
    get_canvas_extent_tuple,
    open_convert_map_to_raster_dialog,
)

SUPPORTED_RASTER_EXTENSIONS = {".tif", ".tiff"}
GITHUB_ISSUES_URL = "https://github.com/sirebellum/qgis-segmentation/issues/new/choose"
PROFILE_MODEL_SIZES = {"high": 4, "medium": 8, "low": 16}
BUY_ME_A_COFFEE_URL = "https://buymeacoffee.com/sirebellum"
RESOLUTION_CHOICES = (
    ("low", 16),
    ("medium", 8),
    ("high", 4),
)
DEFAULT_RESOLUTION_LABEL = "high"
RESOLUTION_VALUE_MAP = {label: value for label, value in RESOLUTION_CHOICES}

# Discrete slider levels: 0=low, 1=medium, 2=high
SLIDER_LEVEL_LOW = 0
SLIDER_LEVEL_MEDIUM = 1
SLIDER_LEVEL_HIGH = 2
SLIDER_LEVEL_NAMES = {SLIDER_LEVEL_LOW: "low", SLIDER_LEVEL_MEDIUM: "medium", SLIDER_LEVEL_HIGH: "high"}
# Base smoothing kernel sizes at resolution=4 (high). Scales proportionally with resolution.
# At resolution 4: low=5px, medium=11px, high=21px
# At resolution 16: low=20px, medium=44px, high=84px
SMOOTHING_BASE_KERNELS = {
    SLIDER_LEVEL_LOW: {"base_kernel": 5, "iterations": 1},
    SLIDER_LEVEL_MEDIUM: {"base_kernel": 11, "iterations": 1},
    SLIDER_LEVEL_HIGH: {"base_kernel": 21, "iterations": 2},
}
SMOOTHING_BASE_RESOLUTION = 4  # Reference resolution for base kernel sizes

PROGRESS_PERCENT_PATTERN = re.compile(r"(?P<percent>\d{1,3})\s*%")
PROGRESS_STEP_PATTERN = re.compile(r"step\s+(?P<current>\d+)\s*/\s*(?P<total>\d+)", re.IGNORECASE)
PROGRESS_TEXT_LIMIT = 80
PROGRESS_STAGE_MAP = {
    "prepare": (0.0, 20.0, "Preparing input..."),
    "chunk_plan": (20.0, 35.0, "Estimating chunk plan..."),
    "queue": (35.0, 50.0, "Starting segmentation task..."),
    "inference": (50.0, 80.0, "Running segmentation..."),
    "latent": (80.0, 90.0, "Refining latent features..."),
    "smooth": (90.0, 97.0, "Smoothing segmentation map..."),
    "render": (97.0, 100.0, "Rendering output..."),
}


class StatusEmitter(QObject):
    message = pyqtSignal(object)


class CancellationToken:
    def __init__(self):
        self._event = threading.Event()
        self._task_ref: Optional[weakref.ReferenceType] = None

    def bind_task(self, task: QgsTask) -> None:
        try:
            self._task_ref = weakref.ref(task)
        except TypeError:
            self._task_ref = None

    def cancel(self) -> None:
        self._event.set()

    def is_cancelled(self) -> bool:
        if self._event.is_set():
            return True
        if self._task_ref is None:
            return False
        task = self._task_ref()
        if task is not None and task.isCanceled():
            self._event.set()
            return True
        return self._event.is_set()

    def raise_if_cancelled(self) -> None:
        if self.is_cancelled():
            raise SegmentationCanceled()


class _LogoHoverController(QObject):
    def __init__(self, label, click_callback):
        super().__init__(label)
        self._label = label
        self._click_callback = click_callback
        self._hover_radius = max(60.0, min(label.width(), label.height()) * 0.6)
        self._glow = QGraphicsDropShadowEffect(label)
        self._glow.setOffset(0, 0)
        self._glow.setBlurRadius(0)
        self._glow.setColor(QColor(255, 215, 0, 0))
        label.setGraphicsEffect(self._glow)
        label.setCursor(Qt.ArrowCursor)

    def eventFilter(self, obj, event):
        if obj is not self._label:
            return False
        if event.type() == QEvent.MouseMove:
            self._update_glow(event.pos())
        elif event.type() == QEvent.Enter:
            self._update_glow(self._label.mapFromGlobal(QCursor.pos()))
        elif event.type() == QEvent.Leave:
            self._reset_glow()
        elif event.type() == QEvent.MouseButtonRelease and getattr(event, "button", lambda: None)() == Qt.LeftButton:
            if callable(self._click_callback):
                self._click_callback()
            return True
        return False

    def _update_glow(self, pos):
        center = self._label.rect().center()
        dx = pos.x() - center.x()
        dy = pos.y() - center.y()
        distance = math.hypot(dx, dy)
        radius = max(self._hover_radius, 1.0)
        intensity = max(0.0, 1.0 - (distance / radius))
        blur = 8 + (28 * intensity)
        alpha = int(80 + 120 * intensity)
        color = QColor(255, 215, 0, min(255, max(0, alpha)))
        self._glow.setBlurRadius(blur if intensity > 0 else 0)
        self._glow.setColor(color)
        self._label.setCursor(Qt.PointingHandCursor if intensity > 0.1 else Qt.ArrowCursor)

    def _reset_glow(self):
        self._glow.setBlurRadius(0)
        self._glow.setColor(QColor(255, 215, 0, 0))
        self._label.setCursor(Qt.ArrowCursor)


class _ComboRefreshController(QObject):
    def __init__(self, combo_box, refresh_callback):
        super().__init__(combo_box)
        self._combo = combo_box
        self._refresh_callback = refresh_callback

    def eventFilter(self, obj, event):
        if obj is not self._combo:
            return False
        if event.type() == QEvent.MouseButtonPress:
            if callable(self._refresh_callback):
                self._refresh_callback()
        return False


# Multithreading stuff
class Task(QgsTask):
    def __init__(self, function, *args, **kwargs):
        super().__init__()
        self.function = function
        self.args = args
        cancel_token = kwargs.pop("cancel_token", None)
        self.cancel_token = cancel_token or CancellationToken()
        self.cancel_token.bind_task(self)
        self.kwargs = kwargs
        self.result = None
        QgsMessageLog.logMessage("Task initialized", "Segmenter", level=Qgis.Info)
        self._status("Task queued")

    def run(self):
        QgsMessageLog.logMessage("Running task", "Segmenter", level=Qgis.Info)
        self._status("Processing started")
        if self.isCanceled():
            self._status("Task canceled before execution")
            return False
        try:
            self.result = self.function(*self.args, cancel_token=self.cancel_token)
            if self.isCanceled():
                self._status("Processing canceled")
                return False
            self._status("Processing completed successfully")
            return True
        except SegmentationCanceled:
            self._status("Processing canceled")
            return False
        except Exception as e:
            QgsMessageLog.logMessage(
                f"Exception in task: {e}", "Segmenter", level=Qgis.Critical
            )
            self._status(f"Processing failed: {e}")
            return False

    def cancel(self):
        self.cancel_token.cancel()
        return super().cancel()

    def finished(self, result):
        QgsMessageLog.logMessage("Task finished", "Segmenter", level=Qgis.Info)
        segmenter = self.kwargs.get("segmenter")
        if result and not self.isCanceled():
            # Apply optional post-smoothing
            blur_config = self.kwargs.get("blur_config")
            status_callback = self.kwargs.get("status_callback")
            output = self.result
            if blur_config is not None:
                if segmenter:
                    segmenter._update_overall_progress("smooth", 0, "Smoothing segmentation map...")
                output = _apply_optional_blur(
                    output,
                    blur_config,
                    status_callback,
                    cancel_token=self.cancel_token,
                )
                if segmenter:
                    segmenter._update_overall_progress("smooth", 100, "Smoothing complete.")
            # render raster
            if segmenter:
                segmenter._update_overall_progress("render", 20, "Rendering output...")
            render_raster(
                output,
                self.kwargs["layer"].extent(),
                f"{self.kwargs['layer'].name()}_kmeans_{self.kwargs['num_segments']}_{self.kwargs['resolution']}",
                self.kwargs["canvas"].layer(0).crs().postgisSrid(),
            )
            self._status("Segmentation layer rendered")
        elif self.isCanceled():
            self._status("Segmentation task canceled")
        else:
            self._status("Segmentation task failed")
        if segmenter:
            segmenter.task = None
            segmenter._set_stop_enabled(False)
            if result and not self.isCanceled():
                segmenter._finalize_progress("success")
            elif self.isCanceled():
                segmenter._finalize_progress("canceled")
            else:
                segmenter._finalize_progress("error")

    def _status(self, message):
        callback = self.kwargs.get("status_callback")
        if not callback:
            return
        try:
            callback(message)
        except Exception:  # pragma: no cover  # nosec B110 - best effort status callback
            pass


def run_task(function, *args, **kwargs):
    task = Task(function, *args, **kwargs)
    QgsApplication.taskManager().addTask(task)
    return task


class Segmenter:
    """QGIS Plugin Implementation."""

    def __init__(self, iface):
        """Constructor.

        :param iface: An interface instance that will be passed to this class
            which provides the hook by which you can manipulate the QGIS
            application at run time.
        :type iface: QgsInterface
        """
        # Save reference to the QGIS interface
        self.iface = iface
        # initialize plugin directory
        self.plugin_dir = os.path.dirname(__file__)
        # initialize locale
        locale = QSettings().value("locale/userLocale")[0:2]
        locale_path = os.path.join(
            self.plugin_dir, "i18n", "Segmenter_{}.qm".format(locale)
        )

        if os.path.exists(locale_path):
            self.translator = QTranslator()
            self.translator.load(locale_path)
            QCoreApplication.installTranslator(self.translator)

        # Declare instance attributes
        self.actions = []
        self.menu = self.tr("&Map Segmenter")

        # Check if plugin was started the first time in current QGIS session
        # Must be set in initGui() to survive plugin reloads
        self.first_start = None

        QSettings().setValue("/qgis/parallel_rendering", True)
        threadcount = QThread.idealThreadCount()
        QgsApplication.setMaxThreads(threadcount)

        self.task = None
        self._status_buffer = []
        self._log_history = deque(maxlen=50)
        self.status_emitter = StatusEmitter()
        self.status_emitter.message.connect(self._handle_status_message)
        self._logged_missing_layers = False
        self._logo_hover = None
        self._layer_refresh_controller = None
        self._progress_last_value = 0.0
        self._progress_active = False
        self._progress_stage = "idle"
        # Map-to-raster assist: track last triggered layer to avoid dialog spam
        self._last_map_assist_layer_id = None
        self._layer_selection_controller = None
        # Flag to suppress map-to-raster assist during programmatic dropdown changes
        self._suppress_layer_assist = False

    # noinspection PyMethodMayBeStatic
    def tr(self, message):
        """Get the translation for a string using Qt translation API.

        We implement this ourselves since we do not inherit QObject.

        :param message: String for translation.
        :type message: str, QString

        :returns: Translated version of message.
        :rtype: QString
        """
        # noinspection PyTypeChecker,PyArgumentList,PyCallByClass
        return QCoreApplication.translate("Segmenter", message)

    def log_status(self, message):
        self._emit_status_message(message, category="general")

    def worker_status(self, message):
        self._emit_status_message(message, category="worker")

    def _emit_status_message(self, message, category):
        self.status_emitter.message.emit((category, message))

    def _handle_status_message(self, payload):
        category, message = payload
        timestamp = datetime.now().strftime("%H:%M:%S")
        entry = f"[{timestamp}] {message}"
        if getattr(self, "dlg", None):
            if self._should_display_log(category, message):
                self._append_log_entry(entry)
            self._maybe_update_progress_from_message(category, message)
        else:
            self._status_buffer.append((category, message, timestamp))

    def _flush_status_buffer(self):
        if not getattr(self, "dlg", None):
            return
        while self._status_buffer:
            category, message, timestamp = self._status_buffer.pop(0)
            entry = f"[{timestamp}] {message}"
            if self._should_display_log(category, message):
                self._append_log_entry(entry)

    def _should_display_log(self, category, message):
        if category == "general":
            return True
        if category == "worker":
            return bool(re.search(r"\b\d+\s*/\s*\d+\b", message))
        return False

    def _append_log_entry(self, entry):
        self._log_history.append(entry)
        if getattr(self, "dlg", None):
            lines = list(self._log_history)[::-1]
            self.dlg.inputBox.setPlainText("\n".join(lines))

    def _progress_widget(self):
        dlg = getattr(self, "dlg", None)
        if not dlg:
            return None
        return getattr(dlg, "jobProgress", None)

    def _reset_progress_bar(self, text="Idle"):
        self._progress_last_value = 0.0
        self._progress_active = False
        self._progress_stage = "idle"
        bar = self._progress_widget()
        if not bar:
            return
        bar.setRange(0, 100)
        bar.setValue(0)
        bar.setFormat(text)

    def _start_progress_cycle(self, message="Preparing segmentation..."):
        self._progress_active = True
        self._update_overall_progress("prepare", 0, message)

    def _update_overall_progress(self, stage: str, local_percent: float = 0.0, message: Optional[str] = None):
        info = PROGRESS_STAGE_MAP.get(stage)
        if not info:
            return
        start, end, default_text = info
        span = max(end - start, 1e-6)
        normalized = float(np.clip(local_percent, 0.0, 100.0)) / 100.0
        value = start + span * normalized
        self._progress_stage = stage
        self._apply_progress_update(value, message or default_text)

    def _apply_progress_update(self, percent, message=None):
        bar = self._progress_widget()
        if not bar:
            return
        self._progress_active = True
        bar.setRange(0, 100)
        value = float(np.clip(percent, 0.0, 100.0))
        if value < self._progress_last_value:
            value = self._progress_last_value
        self._progress_last_value = value
        bar.setValue(int(round(value)))
        if message:
            bar.setFormat(self._format_progress_text(message))

    def _format_progress_text(self, message):
        cleaned = (message or "").strip()
        if not cleaned:
            return "Working..."
        if len(cleaned) > PROGRESS_TEXT_LIMIT:
            return cleaned[: PROGRESS_TEXT_LIMIT - 3] + "..."
        return cleaned

    def _set_progress_message(self, message, indeterminate=False):
        bar = self._progress_widget()
        if not bar or not message:
            return
        if indeterminate:
            bar.setRange(0, 0)
        elif bar.maximum() == 0:
            bar.setRange(0, 100)
        bar.setFormat(self._format_progress_text(message))

    def _finalize_progress(self, status="idle"):
        bar = self._progress_widget()
        if not bar:
            return
        self._progress_active = False
        bar.setRange(0, 100)
        if status == "success":
            self._update_overall_progress("render", 100, "Segmentation complete")
        elif status == "canceled":
            bar.setValue(int(round(self._progress_last_value)))
            bar.setFormat("Canceled")
        elif status == "error":
            value = self._progress_last_value if self._progress_last_value > 0 else 0
            bar.setValue(int(round(value)))
            bar.setFormat("Failed")
        else:
            self._reset_progress_bar()
            return
        if status in {"success", "canceled", "error"}:
            self._progress_stage = "idle"

    def _maybe_update_progress_from_message(self, category, message):
        if category not in {"worker", "general"}:
            return
        stage, percent = self._extract_progress_hint(message)
        if stage is None or percent is None:
            return
        self._update_overall_progress(stage, percent, message)

    def _extract_progress_hint(self, message):
        if not message:
            return None, None
        normalized = message.lower()
        percent = self._extract_percent_token(message)
        if percent is None:
            percent = self._extract_step_percent(message)
        if percent is None:
            return None, None
        stage = None
        if "latent" in normalized or "knn" in normalized:
            stage = "latent"
        elif "smooth" in normalized or "blur" in normalized:
            stage = "smooth"
        elif "render" in normalized:
            stage = "render"
        elif "prepare" in normalized:
            stage = "prepare"
        elif "chunk" in normalized and "plan" in normalized:
            stage = "chunk_plan"
        elif "queue" in normalized:
            stage = "queue"
        else:
            stage = "inference"
        return stage, percent

    def _extract_percent_token(self, message):
        match = PROGRESS_PERCENT_PATTERN.search(message)
        if not match:
            return None
        try:
            value = int(match.group("percent"))
        except (TypeError, ValueError):
            return None
        return int(np.clip(value, 0, 100))

    def _extract_step_percent(self, message):
        match = PROGRESS_STEP_PATTERN.search(message)
        if not match:
            return None
        try:
            current = int(match.group("current"))
            total = int(match.group("total"))
        except (TypeError, ValueError):
            return None
        if total <= 0:
            return None
        ratio = max(0.0, min(1.0, current / total))
        return int(round(ratio * 100))

    def _set_stop_enabled(self, enabled):
        dlg = getattr(self, "dlg", None)
        if not dlg:
            return
        button = getattr(dlg, "buttonStop", None)
        if button is not None:
            button.setEnabled(bool(enabled))

    def _is_smoothing_enabled(self) -> bool:
        """Check if the smoothing checkbox is checked. Default is False (disabled)."""
        dlg = getattr(self, "dlg", None)
        if not dlg:
            return False
        checkbox = getattr(dlg, "checkSmoothing", None)
        if not checkbox:
            return False
        return bool(checkbox.isChecked())

    def _blur_config(self, resolution_label: Optional[str] = None) -> dict:
        """Build blur config based on smoothing slider and resolution. Returns None if smoothing is disabled.
        
        Kernel size scales proportionally with resolution:
        - At high resolution (4px blocks): base kernel sizes
        - At low resolution (16px blocks): 4x larger kernels
        """
        if not self._is_smoothing_enabled():
            return None
        dlg = getattr(self, "dlg", None)
        if not dlg:
            level = SLIDER_LEVEL_MEDIUM
        else:
            slider = getattr(dlg, "sliderSmoothness", None)
            level = int(np.clip(slider.value(), SLIDER_LEVEL_LOW, SLIDER_LEVEL_HIGH)) if slider else SLIDER_LEVEL_MEDIUM
        
        base = SMOOTHING_BASE_KERNELS.get(level, SMOOTHING_BASE_KERNELS[SLIDER_LEVEL_MEDIUM])
        base_kernel = base["base_kernel"]
        iterations = base["iterations"]
        
        # Scale kernel proportionally with resolution
        resolution = RESOLUTION_VALUE_MAP.get(resolution_label, SMOOTHING_BASE_RESOLUTION) if resolution_label else SMOOTHING_BASE_RESOLUTION
        scale = resolution / SMOOTHING_BASE_RESOLUTION
        kernel = int(round(base_kernel * scale))
        kernel = kernel | 1  # Ensure odd
        kernel = max(3, kernel)  # Minimum 3px
        
        return {"kernel_size": kernel, "iterations": iterations}

    def open_feedback_link(self):
        if not getattr(self, "dlg", None):
            return
        opened = QDesktopServices.openUrl(QUrl(GITHUB_ISSUES_URL))
        if opened:
            self.log_status(f"Opening feedback page: {GITHUB_ISSUES_URL}")
        else:
            QMessageBox.warning(
                self.dlg,
                "Unable to open link",
                "Could not launch the browser. Please visit the issues page manually.",
            )

    def _open_support_link(self):
        if not getattr(self, "dlg", None):
            return
        opened = QDesktopServices.openUrl(QUrl(BUY_ME_A_COFFEE_URL))
        if opened:
            self.log_status("Thanks for considering supporting development!")
        else:
            QMessageBox.warning(
                self.dlg,
                "Unable to open link",
                "Could not launch the browser. Please try again later.",
            )

    def stop_current_task(self):
        if not self.task:
            self.log_status("No active segmentation task to cancel.")
            return
        if self.task.isCanceled():
            self.log_status("Cancellation already requested.")
            return
        try:
            task_id = self.task.taskId()
        except AttributeError:
            task_id = None
        if task_id is not None:
            QgsApplication.taskManager().cancelTask(task_id)
        self.task.cancel()
        self._set_stop_enabled(False)
        self._set_progress_message("Cancelling task...")
        self.log_status("Cancellation requested; attempting to stop the worker immediately.")

    def _init_logo_interactions(self):
        logo = getattr(self.dlg, "imageLarge", None)
        if not logo:
            return
        logo.setAttribute(Qt.WA_Hover, True)
        logo.setMouseTracking(True)
        if self._logo_hover is None:
            self._logo_hover = _LogoHoverController(logo, self._open_support_link)
            logo.installEventFilter(self._logo_hover)

    def add_action(
        self,
        icon_path,
        text,
        callback,
        enabled_flag=True,
        add_to_menu=True,
        add_to_toolbar=True,
        status_tip=None,
        whats_this=None,
        parent=None,
    ):
        """Add a toolbar icon to the toolbar.

        :param icon_path: Path to the icon for this action. Can be a resource
            path (e.g. ':/plugins/foo/bar.png') or a normal file system path.
        :type icon_path: str

        :param text: Text that should be shown in menu items for this action.
        :type text: str

        :param callback: Function to be called when the action is triggered.
        :type callback: function

        :param enabled_flag: A flag indicating if the action should be enabled
            by default. Defaults to True.
        :type enabled_flag: bool

        :param add_to_menu: Flag indicating whether the action should also
            be added to the menu. Defaults to True.
        :type add_to_menu: bool

        :param add_to_toolbar: Flag indicating whether the action should also
            be added to the toolbar. Defaults to True.
        :type add_to_toolbar: bool

        :param status_tip: Optional text to show in a popup when mouse pointer
            hovers over the action.
        :type status_tip: str

        :param parent: Parent widget for the new action. Defaults None.
        :type parent: QWidget

        :param whats_this: Optional text to show in the status bar when the
            mouse pointer hovers over the action.

        :returns: The action that was created. Note that the action is also
            added to self.actions list.
        :rtype: QAction
        """

        icon = QIcon(icon_path)
        action = QAction(icon, text, parent)
        action.triggered.connect(callback)
        action.setEnabled(enabled_flag)

        if status_tip is not None:
            action.setStatusTip(status_tip)

        if whats_this is not None:
            action.setWhatsThis(whats_this)

        if add_to_toolbar:
            # Adds plugin icon to Plugins toolbar
            self.iface.addToolBarIcon(action)

        if add_to_menu:
            self.iface.addPluginToMenu(self.menu, action)

        self.actions.append(action)

        return action

    def initGui(self):
        """Create the menu entries and toolbar icons inside the QGIS GUI."""

        icon_path = ":/plugins/segmenter/icon.png"
        self.add_action(
            icon_path,
            text=self.tr("Segment the map"),
            callback=self.run,
            parent=self.iface.mainWindow(),
        )

        # will be set False in run()
        self.first_start = True

    def unload(self):
        """Removes the plugin menu item and icon from QGIS GUI."""
        for action in self.actions:
            self.iface.removePluginMenu(self.tr("&Map Segmenter"), action)
            self.iface.removeToolBarIcon(action)

    # Predict coverage map
    def predict(self):
        self._start_progress_cycle("Preparing segmentation...")
        self._update_overall_progress("prepare", 10, "Validating layer selection...")

        layer_name = self.dlg.inputLayer.currentText()
        layers = QgsProject.instance().mapLayersByName(layer_name)
        if not layers:
            self.log_status("Selected layer is no longer available. Please choose another layer.")
            self._reset_progress_bar()
            return
        layer = layers[0]
        if not self._is_supported_raster_layer(layer):
            # Check if this is a renderable layer that needs conversion
            meta = extract_layer_metadata(layer)
            if is_renderable_non_file_layer(
                meta["layer_type"],
                meta["provider_name"],
                meta["source_path"],
            ):
                self.log_status(
                    "Selected layer is a map service. Please use 'Convert map to raster' first "
                    "(dialog should have opened when you selected this layer), then select the "
                    "generated GeoTIFF as input."
                )
            else:
                self.log_status("Selected layer is not a supported 3-band GeoTIFF raster.")
            self._reset_progress_bar()
            return
        if not layer.isValid():
            raise ValueError(f"Invalid raster layer! \n{layer_name}")
        raster_source = layer.source().split("|")[0]
        self._update_overall_progress("prepare", 60, "Raster scheduled for background loading.")
        self.log_status("Raster IO deferred to the worker thread to keep QGIS responsive.")

        segments_raw = (self.dlg.inputSegments.text() or "").strip()
        if not segments_raw:
            self.log_status("Please enter the desired number of segments and try again.")
            self.dlg.inputSegments.setFocus()
            self._reset_progress_bar()
            return
        try:
            num_segments = int(segments_raw)
        except ValueError:
            self.log_status("Number of segments must be an integer value.")
            self.dlg.inputSegments.setFocus()
            self._reset_progress_bar()
            return
        if num_segments <= 0:
            self.log_status("Number of segments must be a positive integer.")
            self.dlg.inputSegments.setFocus()
            self._reset_progress_bar()
            return
        self._update_overall_progress("prepare", 80, "Segmentation parameters verified.")

        resolution_label = self.dlg.inputRes.currentData()
        if not resolution_label:
            resolution_label = self.dlg.inputRes.currentText() or DEFAULT_RESOLUTION_LABEL
        resolution_label = str(resolution_label).strip().lower()
        resolution = RESOLUTION_VALUE_MAP.get(resolution_label, RESOLUTION_VALUE_MAP[DEFAULT_RESOLUTION_LABEL])

        blur_config = self._blur_config(resolution_label)
        if blur_config:
            self.log_status(
                f"Post-smoothing configured: {blur_config['kernel_size']}px kernel, {blur_config['iterations']} pass(es)."
            )
        else:
            self.log_status("Post-smoothing disabled (checkbox unchecked).")

        if not hasattr(self, "device"):
            raise AttributeError("Segmenter instance must have a 'device' attribute set before calling predict().")

        kwargs = {
            "layer": layer,
            "canvas": self.canvas,
            "dlg": self.dlg,
            "num_segments": num_segments,
            "resolution": resolution,
            "status_callback": self.log_status,
            "segmenter": self,
            "blur_config": blur_config,
        }

        func = execute_kmeans_segmentation
        args = (
            raster_source,
            num_segments,
            resolution,
            None,  # chunk_plan - let pipeline compute it
            self.worker_status,
            1.0,  # sample_scale - fixed at 1.0 (speed/accuracy sliders removed)
            self.device,
        )

        self._update_overall_progress("queue", 90, "Dispatching segmentation task...")
        self.log_status(
            f"Queued K-Means segmentation with {num_segments} segments at {self.dlg.inputRes.currentText()} resolution."
        )
        self._update_overall_progress("queue", 100, "Task queued; starting worker...")
        self._update_overall_progress("inference", 0, "Running segmentation...")
        self.task = run_task(func, *args, **kwargs)
        self._set_stop_enabled(True)

        if self.task.waitForFinished(1):
            self.log_status("An error occurred. Please try again.")

    # Process user input box
    def submit(self):
        return

    # Display layers in dropdown
    def render_layers(self):
        # Suppress map-to-raster assist during programmatic population
        self._suppress_layer_assist = True
        try:
            project_layers = QgsProject.instance().mapLayers().values()
            # Include all renderable layers: rasters (file-backed and web services) plus vectors
            all_layers = []
            for layer in project_layers:
                if not isinstance(layer, QgsMapLayer):
                    continue
                # Include raster layers (GDAL-backed and WMS/XYZ/etc.)
                if isinstance(layer, QgsRasterLayer):
                    all_layers.append(layer)
                # Include vector layers (can be rendered to raster)
                elif layer.type() == QgsMapLayer.VectorLayer:
                    all_layers.append(layer)
            all_layers.sort(key=lambda lyr: lyr.name().lower())

            current = self.dlg.inputLayer.currentText()
            self.dlg.inputLayer.clear()
            if not all_layers:
                if not self._logged_missing_layers:
                    self.log_status("No layers detected in the project.")
                    self._logged_missing_layers = True
                return

            self._logged_missing_layers = False
            for layer in all_layers:
                # Store layer ID as item data for reliable lookup
                self.dlg.inputLayer.addItem(layer.name(), layer.id())

            if current:
                index = self.dlg.inputLayer.findText(current)
                if index >= 0:
                    self.dlg.inputLayer.setCurrentIndex(index)
        finally:
            self._suppress_layer_assist = False

    # Display resolutions in dropdown
    def render_resolutions(self):
        self.dlg.inputRes.clear()
        for label, value in RESOLUTION_CHOICES:
            display = f"{label.title()} ({value})"
            self.dlg.inputRes.addItem(display, label)
        index = self.dlg.inputRes.findData(DEFAULT_RESOLUTION_LABEL)
        if index < 0:
            index = self.dlg.inputRes.findText(DEFAULT_RESOLUTION_LABEL, Qt.MatchFixedString)
        if index >= 0:
            self.dlg.inputRes.setCurrentIndex(index)

    def _is_supported_raster_layer(self, layer):
        if not isinstance(layer, QgsRasterLayer):
            return False
        try:
            band_count = layer.bandCount()
        except Exception:
            return False
        if band_count != 3:
            return False
        provider = layer.dataProvider()
        if not provider:
            return False
        if provider.name().lower() != "gdal":
            return False
        source = layer.source().split("|")[0]
        _, ext = os.path.splitext(source)
        if ext.lower() not in SUPPORTED_RASTER_EXTENSIONS:
            return False
        return True

    def run(self):
        """Run method that performs all the real work"""

        # Create the dialog with elements (after translation) and keep reference
        # Only create GUI ONCE in callback, so that it will only load when the plugin is started
        if self.first_start:
            self.first_start = False
            self.dlg = SegmenterDialog()
            self.canvas = self.iface.mapCanvas()
            self._reset_progress_bar()

            # Set device (CUDA, CPU)
            if torch.cuda.is_available():  # Cuda
                self.device = torch.device("cuda")
            elif torch.backends.mps.is_available():  # Multi-Process Service
                self.device = torch.device("mps")
            else:  # CPU
                self.device = torch.device("cpu")

            # Populate drop down menus
            self.render_layers()
            self.render_resolutions()
            if not (self.dlg.inputSegments.text() or "").strip():
                self.dlg.inputSegments.setText("8")

            # Set gpu message
            gpu_msg = "GPU available."
            if self.device == torch.device("cpu"):
                gpu_msg = "GPU not available. Using CPU instead."

            self._log_history.clear()
            self.dlg.inputBox.clear()
            self._flush_status_buffer()
            self.log_status(gpu_msg)
            self.log_status("K-Means segmentation runtime ready.")

            # Attach inputs
            self.dlg.inputBox.textChanged.connect(self.submit)
            self.dlg.buttonPredict.clicked.connect(self.predict)
            self.dlg.buttonFeedback.clicked.connect(self.open_feedback_link)
            self.dlg.buttonStop.clicked.connect(self.stop_current_task)
            self._set_stop_enabled(False)

            # Wire layer selection change for map-to-raster assist
            self.dlg.inputLayer.currentIndexChanged.connect(self._on_layer_selection_changed)

            # Render logo
            img_path = os.path.join(self.plugin_dir, "logo.png")
            pix = QPixmap(img_path)
            self.dlg.imageLarge.setPixmap(pix)
            self._init_logo_interactions()
            self._init_layer_refresh()

        # show the dialog
        self.render_layers()
        if not self.task:
            self._reset_progress_bar()
        self.dlg.show()
        # Check if we should auto-trigger map-to-raster assist (no compatible rasters available)
        self._check_auto_map_assist()

    def _check_auto_map_assist(self) -> None:
        """Auto-trigger map-to-raster assist if no compatible rasters exist but map layers do."""
        dlg = getattr(self, "dlg", None)
        if not dlg:
            return

        # Check if there are any layers at all
        if dlg.inputLayer.count() == 0:
            return

        # Check if any layer is a supported raster
        has_supported_raster = False
        first_map_layer = None
        first_map_layer_id = None

        for i in range(dlg.inputLayer.count()):
            layer_id = dlg.inputLayer.itemData(i)
            if not layer_id:
                continue
            layer = QgsProject.instance().mapLayer(layer_id)
            if not layer:
                continue

            if self._is_supported_raster_layer(layer):
                has_supported_raster = True
                break

            # Track first map/web service layer
            if first_map_layer is None:
                meta = extract_layer_metadata(layer)
                if is_renderable_non_file_layer(
                    meta["layer_type"],
                    meta["provider_name"],
                    meta["source_path"],
                ):
                    first_map_layer = layer
                    first_map_layer_id = layer_id

        # If no supported rasters but we have a map layer, auto-trigger assist
        if not has_supported_raster and first_map_layer is not None:
            self._last_map_assist_layer_id = first_map_layer_id
            self._open_convert_map_to_raster_assist(first_map_layer)

    def _init_layer_refresh(self):
        dlg = getattr(self, "dlg", None)
        if not dlg:
            return
        combo = getattr(dlg, "inputLayer", None)
        if not combo:
            return
        if self._layer_refresh_controller is None:
            self._layer_refresh_controller = _ComboRefreshController(combo, self.render_layers)
            combo.installEventFilter(self._layer_refresh_controller)
    def _on_layer_selection_changed(self, index: int) -> None:
        """Handle layer dropdown selection change for map-to-raster assist.

        Opens the Convert map to raster dialog only if:
        1. The selection was made by the user (not programmatic), AND
        2. The selected layer is a web service/vector (not a file-backed GDAL raster)
        """
        if index < 0:
            return

        # Skip if this is a programmatic change (e.g., during render_layers)
        if getattr(self, "_suppress_layer_assist", False):
            return

        dlg = getattr(self, "dlg", None)
        if not dlg:
            return

        # Get layer ID from item data (set in render_layers)
        layer_id = dlg.inputLayer.itemData(index)
        if not layer_id:
            return

        # Avoid dialog spam: don't trigger if we already triggered for this layer
        if layer_id == self._last_map_assist_layer_id:
            return

        # Lookup layer by ID
        layer = QgsProject.instance().mapLayer(layer_id)
        if not layer:
            return

        # Check if this is a supported raster (no assist needed)
        if self._is_supported_raster_layer(layer):
            # Clear the assist state for raster layers
            self._last_map_assist_layer_id = None
            return

        # Extract metadata for detection
        meta = extract_layer_metadata(layer)

        # Check if this is a renderable non-file layer
        if not is_renderable_non_file_layer(
            meta["layer_type"],
            meta["provider_name"],
            meta["source_path"],
        ):
            # Not a recognized web service or vector - may be unsupported
            self._last_map_assist_layer_id = None
            return

        # User explicitly selected a non-file layer - trigger map-to-raster assist
        self._last_map_assist_layer_id = layer_id
        self._open_convert_map_to_raster_assist(layer)

    def _open_convert_map_to_raster_assist(self, layer) -> None:
        """Open the Convert map to raster dialog prefilled for the given layer."""
        canvas = getattr(self, "canvas", None)
        if not canvas:
            self.log_status("Map canvas not available. Cannot determine extent.")
            return

        # Build parameters
        extent_tuple = get_canvas_extent_tuple(canvas)
        layer_id = layer.id() if hasattr(layer, "id") else layer.name()
        params = build_convert_map_to_raster_params(extent_tuple, layer_id)

        # Log the assist action
        self.log_status(
            "Selected layer is a map service. Opening Convert map to raster dialog "
            "(prefilled: current extent, 1 map unit/pixel). Adjust settings and run "
            "to create a GeoTIFF, then select it as input."
        )

        # Open the dialog
        dlg = getattr(self, "dlg", None)
        parent = dlg if dlg else None
        opened = open_convert_map_to_raster_dialog(params, parent=parent)

        if not opened:
            self.log_status(
                "Could not open Convert map to raster dialog. "
                "You can find it in Processing > Toolbox > Rasterize (raster saving)."
            )