# -*- coding: utf-8 -*-
"""
Dialog to load a raster from ArrayStorage with forecast-aware UX.
"""

import json
import os
from typing import Any, Optional, cast

from qgis.PyQt import uic, QtWidgets
from qgis.PyQt.QtCore import QDateTime, Qt, QRegularExpression
from qgis.PyQt.QtGui import QColor, QFont, QSyntaxHighlighter, QTextCharFormat
from qgis.PyQt.QtWidgets import QFileDialog, QMessageBox, QDialogButtonBox
from qgis.core import (
    QgsApplication,
    QgsAuthMethodConfig,
    QgsProcessingException,
    QgsProcessingUtils,
    QgsProject,
    QgsLayerTreeGroup,
)

from .arraystorage_core import (
    ArrayStorageRasterRequest,
    list_raster_timeseries,
    list_forecast_t0_options,
    list_timestamps,
    load_arraystorage_raster,
    make_session,
)
from .arraystorage_settings import ArrayStorageSettingsStore

FORM_CLASS_ARRAYSTORE_LOAD, _ = uic.loadUiType(
    os.path.join(os.path.dirname(__file__), "arraystorage_dialog_load_raster.ui")
)


class _JsonHighlighter(QSyntaxHighlighter):
    """
    Lightweight syntax highlighter for JSON to improve readability of metadata.
    """

    def __init__(self, parent=None):
        super().__init__(parent)
        self._rules: list[tuple[QRegularExpression, QTextCharFormat]] = []

        def _fmt(color: str, bold: bool = False) -> QTextCharFormat:
            f = QTextCharFormat()
            f.setForeground(QColor(color))
            if bold:
                f.setFontWeight(QFont.Bold)
            return f

        # Keys: "name":
        self._rules.append(
            (
                QRegularExpression(r'"[^"\\]*(?:\\.[^"\\]*)*"\s*(?=:)'),
                _fmt("#1e6dd6", bold=True),
            )
        )
        # Strings (values)
        self._rules.append(
            (
                QRegularExpression(r':\s*"(?:[^"\\]|\\.)*"'),
                _fmt("#c23b22"),
            )
        )
        # Numbers
        self._rules.append(
            (
                QRegularExpression(r"\b-?(?:0|[1-9]\d*)(?:\.\d+)?\b"),
                _fmt("#9c27b0"),
            )
        )
        # Booleans / null
        self._rules.append(
            (
                QRegularExpression(r"\btrue\b|\bfalse\b|\bnull\b"),
                _fmt("#00897b", bold=True),
            )
        )
        # Braces / brackets
        self._rules.append((QRegularExpression(r"[{}\[\]]"), _fmt("#555", bold=True)))

    def highlightBlock(self, text: str) -> None:  # noqa: N802
        for pattern, fmt in self._rules:
            it = pattern.globalMatch(text)
            while it.hasNext():
                m = it.next()
                self.setFormat(m.capturedStart(), m.capturedLength(), fmt)


class ArrayStorageDialogLoadRaster(QtWidgets.QDialog, FORM_CLASS_ARRAYSTORE_LOAD):
    def __init__(self, parent=None):
        super().__init__(parent)
        self.setupUi(self)

        self._timeseries: list[dict[str, Any]] = []

        # Wire buttons / signals
        self.btnRefreshTimeseries.clicked.connect(self._refresh_timeseries)
        self.comboTimeseries.currentIndexChanged.connect(self._update_metadata_view)
        self.btnBrowseTarget.clicked.connect(self._browse_target_file)
        self.btnRefreshTime.clicked.connect(self._refresh_raster_tree)
        self.treeRasters.itemChanged.connect(self._on_tree_item_changed)
        self.comboConfig.currentIndexChanged.connect(self._apply_selected_config)

        try:
            self.buttonBox.accepted.disconnect()
        except TypeError:
            pass
        self.buttonBox.accepted.connect(self._on_load_clicked)
        self.buttonBox.rejected.connect(self.reject)
        ok_btn = self.buttonBox.button(QDialogButtonBox.Ok)
        if ok_btn is not None:
            ok_btn.setText(self.tr("Load raster"))

        # Default values
        self.lineBaseUrl.setText("http://localhost:8013")
        for w in (self.dtT0From, self.dtT0To):
            w.setDisplayFormat("yyyy-MM-dd HH:mm:ss")
        now = QDateTime.currentDateTimeUtc()
        self.dtT0To.setDateTime(now)
        self.dtT0From.setDateTime(now.addDays(-2))
        self._reload_configs()

        # Tree / progress defaults
        self._tree_signal_block = False
        self.treeRasters.setUniformRowHeights(True)
        self.treeRasters.setSelectionMode(QtWidgets.QAbstractItemView.ExtendedSelection)
        self.treeRasters.setHeaderHidden(False)
        self.progressBarLoad.setVisible(False)
        self._json_highlighter = _JsonHighlighter(self.textMetadata.document())

    # -------------------------
    # UI helpers
    # -------------------------

    def _show_error(self, title: str, msg: str) -> None:
        QMessageBox.critical(self, title, msg)

    def _show_info(self, title: str, msg: str) -> None:
        QMessageBox.information(self, title, msg)

    def _browse_target_file(self) -> None:
        path, _ = QFileDialog.getSaveFileName(
            self,
            self.tr("Select target GeoTIFF"),
            self.lineTargetFile.text().strip() or "",
            self.tr("GeoTIFF (*.tif *.tiff)"),
        )
        if path:
            self.lineTargetFile.setText(path)

    def _resolve_credentials_from_authcfg(
        self, authcfg: Optional[str]
    ) -> tuple[Optional[str], Optional[str]]:
        if not authcfg:
            return None, None
        try:
            cfg = QgsAuthMethodConfig()
            if QgsApplication.authManager().loadAuthenticationConfig(
                authcfg, cfg, True
            ):
                return cfg.config("username") or None, cfg.config("password") or None
        except Exception:
            pass
        return None, None

    def _reload_configs(self) -> None:
        store = ArrayStorageSettingsStore()
        configs = store.list()
        self.comboConfig.blockSignals(True)
        self.comboConfig.clear()
        for cfg in configs:
            label = f"{cfg.name} ({cfg.base_url})"
            self.comboConfig.addItem(label, cfg)
        self.comboConfig.blockSignals(False)
        if self.comboConfig.count() > 0:
            self.comboConfig.setCurrentIndex(0)
            self._apply_selected_config()

    def _apply_selected_config(self) -> None:
        idx = self.comboConfig.currentIndex()
        cfg = self.comboConfig.itemData(idx)
        if not cfg:
            return
        self.lineBaseUrl.setText(str(cfg.base_url or "").strip())
        user_from_auth, pass_from_auth = self._resolve_credentials_from_authcfg(
            getattr(cfg, "authcfg", None)
        )
        if user_from_auth:
            self.lineUser.setText(user_from_auth)
        if pass_from_auth:
            self.linePass.setText(pass_from_auth)

    def _selected_timeseries(self) -> Optional[dict[str, Any]]:
        idx = self.comboTimeseries.currentIndex()
        if idx < 0:
            return None
        data = self.comboTimeseries.itemData(idx)
        return data if isinstance(data, dict) else None

    def _is_forecast(self, item: dict[str, Any]) -> bool:
        if not item:
            return False
        if "is_forecast" in item:
            return bool(item.get("is_forecast"))
        selection = item.get("selection") or {}
        return "t0" in selection

    def _is_multiband(self, item: dict[str, Any]) -> bool:
        """
        Detect multiband rasters by presence of a 'z' dimension in metadata.
        """
        dims = item.get("dimensions") if isinstance(item, dict) else []
        for dim in dims or []:
            if not isinstance(dim, dict):
                continue
            name = str(dim.get("name") or "").lower()
            if name == "z":
                return True
        return False

    def _update_metadata_view(self, idx: int) -> None:
        item = self._selected_timeseries()
        if not item:
            self.textMetadata.clear()
            self._clear_raster_tree()
            return
        try:
            pretty = json.dumps(item, indent=2, sort_keys=True)
        except Exception:
            pretty = str(item)
        self.textMetadata.setPlainText(pretty)
        self._refresh_raster_tree(auto=True)

    # -------------------------
    # Data fetches
    # -------------------------

    def _refresh_timeseries(self) -> None:
        base_url = self.lineBaseUrl.text().strip()
        as_user = self.lineUser.text().strip() or None
        as_pass = self.linePass.text().strip() or None

        if not base_url:
            self._show_error(
                self.tr("Missing Base URL"), self.tr("Enter the Base URL first.")
            )
            return

        try:
            session = make_session(as_user, as_pass)
            items = list_raster_timeseries(base_url, session, store_id=None)

            self._timeseries = cast(list[dict[str, Any]], items)
            self.comboTimeseries.blockSignals(True)
            self.comboTimeseries.clear()

            for ts in self._timeseries:
                path = str(ts.get("path") or "").strip()
                ts_id = str(
                    ts.get("timeseriesId")
                    or ts.get("timeSeriesId")
                    or ts.get("id")
                    or ""
                ).strip()
                fcst_obs = "fcst" if self._is_forecast(ts) else "obs"
                band_kind = "mb" if self._is_multiband(ts) else "sb"
                label_kind = f"{fcst_obs}/{band_kind}"
                display = f"{path} ({label_kind})" if path else ts_id or "<unknown>"
                self.comboTimeseries.addItem(display, ts)

            self.comboTimeseries.blockSignals(False)

            if self.comboTimeseries.count() > 0:
                self.comboTimeseries.setCurrentIndex(0)
                self._update_metadata_view(0)
                self._refresh_raster_tree(auto=True)

            self._show_info(
                self.tr("Timeseries refreshed"),
                self.tr(f"Loaded {len(self._timeseries)} raster time series."),
            )
        except QgsProcessingException as e:
            self._show_error(self.tr("Refresh failed"), str(e))
        except Exception as e:
            self._show_error(
                self.tr("Refresh failed"),
                self.tr(f"Unexpected error:\n{e}"),
            )

    # -------------------------
    # Tree-based selection
    # -------------------------

    def _clear_raster_tree(self) -> None:
        self._tree_signal_block = True
        try:
            self.treeRasters.clear()
        finally:
            self._tree_signal_block = False

    def _format_dt_label(
        self, value: Optional[str], *, none_label: str = "None"
    ) -> str:
        if not value:
            return none_label
        qdt = QDateTime.fromString(value, Qt.ISODate)
        if qdt.isValid():
            qdt = qdt.toUTC()
            return qdt.toString("yyyy-MM-dd HH:mm:ss 'UTC'")
        return value

    def _set_children_check_state(
        self, item: QtWidgets.QTreeWidgetItem, state: Qt.CheckState
    ) -> None:
        if state == Qt.PartiallyChecked:
            return
        for i in range(item.childCount()):
            child = item.child(i)
            child.setCheckState(0, state)
            self._set_children_check_state(child, state)

    def _update_parent_state(self, item: QtWidgets.QTreeWidgetItem) -> None:
        parent = item.parent()
        if parent is None:
            return
        states = {parent.child(i).checkState(0) for i in range(parent.childCount())}
        if not states:
            return
        if states == {Qt.Checked}:
            parent.setCheckState(0, Qt.Checked)
        elif states == {Qt.Unchecked}:
            parent.setCheckState(0, Qt.Unchecked)
        else:
            parent.setCheckState(0, Qt.PartiallyChecked)
        self._update_parent_state(parent)

    def _on_tree_item_changed(
        self, item: QtWidgets.QTreeWidgetItem, column: int
    ) -> None:
        if self._tree_signal_block or column != 0:
            return
        self._tree_signal_block = True
        try:
            # Do not propagate partial state downwards; parents turn partial when
            # children differ, but leaves stay two-state.
            state = item.checkState(column)
            if state == Qt.PartiallyChecked:
                self._update_parent_state(item)
                return
            state = item.checkState(column)
            self._set_children_check_state(item, state)
            self._update_parent_state(item)
        finally:
            self._tree_signal_block = False

    def _collect_checked_leaf_items(self) -> list[dict[str, Any]]:
        checked: list[dict[str, Any]] = []
        root = self.treeRasters.invisibleRootItem()
        for i in range(root.childCount()):
            t0_item = root.child(i)
            for j in range(t0_item.childCount()):
                group_item = t0_item.child(j)
                for k in range(group_item.childCount()):
                    leaf = group_item.child(k)
                    if leaf.checkState(0) == Qt.Checked:
                        data = leaf.data(0, Qt.UserRole)
                        if isinstance(data, dict):
                            checked.append(data)
        return checked

    def _build_tree_from_runs(self, runs: list[dict[str, Any]]) -> None:
        self._tree_signal_block = True
        try:
            self.treeRasters.clear()
            grouped: dict[Any, list[dict[str, Any]]] = {}
            for run in runs:
                key = run.get("t0")
                grouped.setdefault(key, []).append(run)

            for t0_raw, t0_runs in grouped.items():
                t0_label = self._format_dt_label(t0_raw)
                t0_item = QtWidgets.QTreeWidgetItem([t0_label])
                t0_item.setData(0, Qt.UserRole, {"t0": t0_raw})
                t0_item.setFlags(
                    t0_item.flags() | Qt.ItemIsUserCheckable | Qt.ItemIsTristate
                )
                t0_item.setCheckState(0, Qt.Unchecked)
                self.treeRasters.addTopLevelItem(t0_item)

                for run in t0_runs:
                    dispatch_info = run.get("dispatch_info")
                    member = run.get("member")
                    times: list[dict[str, Any]] = run.get("times") or []

                    group_label = f"{dispatch_info or 'None'}, {member or 'None'}"
                    group_item = QtWidgets.QTreeWidgetItem([group_label])
                    group_item.setData(
                        0,
                        Qt.UserRole,
                        {
                            "t0": t0_raw,
                            "dispatch_info": dispatch_info,
                            "member": member,
                        },
                    )
                    group_item.setFlags(
                        group_item.flags() | Qt.ItemIsUserCheckable | Qt.ItemIsTristate
                    )
                    group_item.setCheckState(0, Qt.Unchecked)
                    t0_item.addChild(group_item)

                    for t_entry in times:
                        time_iso = str(t_entry.get("time") or "").strip()
                        if not time_iso:
                            continue
                        time_label = self._format_dt_label(
                            time_iso, none_label=time_iso
                        )
                        leaf = QtWidgets.QTreeWidgetItem([time_label])
                        leaf.setData(
                            0,
                            Qt.UserRole,
                            {
                                "t0": t0_raw,
                                "dispatch_info": dispatch_info,
                                "member": member,
                                "time": time_iso,
                            },
                        )
                        leaf.setFlags(
                            (leaf.flags() | Qt.ItemIsUserCheckable) & ~Qt.ItemIsTristate
                        )
                        leaf.setCheckState(0, Qt.Unchecked)
                        group_item.addChild(leaf)

                    group_item.setExpanded(True)
                t0_item.setExpanded(True)

            self.treeRasters.expandToDepth(2)
            self.treeRasters.resizeColumnToContents(0)
        finally:
            self._tree_signal_block = False

    def _refresh_raster_tree(self, auto: bool = False) -> None:
        self._clear_raster_tree()
        ts = self._selected_timeseries()
        if not ts:
            return

        base_url = self.lineBaseUrl.text().strip()
        as_user = self.lineUser.text().strip() or None
        as_pass = self.linePass.text().strip() or None

        ts_id = str(
            ts.get("timeseriesId") or ts.get("timeSeriesId") or ts.get("id") or ""
        ).strip()
        if not base_url or not ts_id:
            if not auto:
                self._show_error(
                    self.tr("Missing information"),
                    self.tr("Base URL and timeseries are required to list rasters."),
                )
            return

        is_fcst = self._is_forecast(ts)
        path = str(ts.get("path") or "").strip()
        if is_fcst and not path:
            if not auto:
                self._show_error(
                    self.tr("Missing path"),
                    self.tr("Forecast rasters need a path to fetch t0/ensemble runs."),
                )
            return

        try:
            session = make_session(as_user, as_pass)
            runs: list[dict[str, Any]] = []

            from_dt = self.dtT0From.dateTime()
            to_dt = self.dtT0To.dateTime()
            if from_dt.isValid() and to_dt.isValid() and from_dt > to_dt:
                raise QgsProcessingException("t0 filter start must be before end.")

            if is_fcst:
                t0_entries = list_forecast_t0_options(
                    base_url, path, session, store_id=None
                )
            else:
                t0_entries = [{"t0": None, "dispatch_info": None, "member": None}]

            def _passes_filter(t0_iso: Optional[str]) -> bool:
                if not t0_iso:
                    return True
                qdt = QDateTime.fromString(t0_iso, Qt.ISODate)
                if not qdt.isValid():
                    return False
                if from_dt.isValid() and qdt < from_dt:
                    return False
                if to_dt.isValid() and qdt > to_dt:
                    return False
                return True

            filtered_entries = [
                entry
                for entry in t0_entries
                if _passes_filter(str(entry.get("t0") or "").strip() or None)
            ]

            for entry in filtered_entries:
                t0_iso = str(entry.get("t0") or "").strip() or None
                dispatch_info = entry.get("dispatch_info")
                member = entry.get("member")

                selection: dict[str, str] = {}
                if t0_iso:
                    selection["t0"] = t0_iso
                if dispatch_info:
                    selection["dispatch_info"] = str(dispatch_info)
                if member:
                    selection["member"] = str(member)

                times_raw = list_timestamps(
                    base_url,
                    ts_id,
                    session,
                    selection=selection or None,
                    period="complete",
                    by_field="timeseriesId",
                )
                times_norm: list[dict[str, Any]] = []
                for t in times_raw:
                    if isinstance(t, list) and t:
                        ts_str = str(t[0])
                        completeness = t[1] if len(t) > 1 else None
                    else:
                        ts_str = str(t)
                        completeness = None
                    times_norm.append({"time": ts_str, "complete": completeness})

                runs.append(
                    {
                        "t0": t0_iso,
                        "dispatch_info": dispatch_info,
                        "member": member,
                        "times": times_norm,
                    }
                )

            self._build_tree_from_runs(runs)
            total_times = sum(len(r.get("times") or []) for r in runs)
            if total_times == 0 and not auto:
                self._show_info(
                    self.tr("No rasters in filter"),
                    self.tr("No timestamps matched the current t0 filter."),
                )
            if not auto:
                self._show_info(
                    self.tr("Rasters listed"),
                    self.tr(f"Found {total_times} rasters grouped by t0/ensemble."),
                )
        except QgsProcessingException as e:
            if not auto:
                self._show_error(self.tr("Refresh raster list failed"), str(e))
        except Exception as e:  # noqa: BLE001
            if not auto:
                self._show_error(
                    self.tr("Refresh raster list failed"),
                    self.tr(f"Unexpected error:\n{e}"),
                )

    def _ensure_group(self, parent: QgsLayerTreeGroup, name: str) -> QgsLayerTreeGroup:
        for child in parent.children():
            if isinstance(child, QgsLayerTreeGroup) and child.name() == name:
                return child
        return parent.addGroup(name)

    def _safe_slug(self, text: str) -> str:
        cleaned = "".join(ch if ch.isalnum() else "-" for ch in str(text))
        cleaned = cleaned.strip("-")
        return cleaned[:60] or "value"

    def _target_file_for_selection(
        self,
        base_target: str,
        idx: int,
        total: int,
        time_iso: str,
        t0_iso: Optional[str],
    ) -> str:
        root, ext = os.path.splitext(base_target)
        ext = ext or ".tif"

        if base_target and total == 1:
            return root + ext if not base_target.endswith(ext) else base_target

        if base_target:
            suffix_parts = []
            if t0_iso:
                suffix_parts.append(self._safe_slug(t0_iso))
            suffix_parts.append(self._safe_slug(time_iso))
            suffix = "_".join(suffix_parts) or str(idx)
            return f"{root}_{suffix}{ext}"

        name = f"arraystorage_raster_{self._safe_slug(time_iso) or idx}{ext}"
        return QgsProcessingUtils.generateTempFilename(name)

    def _add_layer_with_grouping(
        self, layer, ts_group_name: str, selection: dict[str, Any]
    ) -> None:
        project = QgsProject.instance()
        time_iso = selection.get("time") or ""
        if time_iso:
            layer.setName(self._format_dt_label(time_iso, none_label=time_iso))

        project.addMapLayer(layer, False)
        try:
            root = project.layerTreeRoot()
            ts_group = self._ensure_group(root, ts_group_name)
            t0_group = self._ensure_group(
                ts_group, self._format_dt_label(selection.get("t0"))
            )
            member_group = self._ensure_group(
                t0_group,
                f"{selection.get('dispatch_info') or 'None'}, {selection.get('member') or 'None'}",
            )
            member_group.addLayer(layer)
        except Exception:
            project.addMapLayer(layer)

    # -------------------------
    # Load / raster load
    # -------------------------

    def _on_load_clicked(self) -> None:
        if self.load_raster():
            self.accept()

    def load_raster(self) -> bool:
        selections = self._collect_checked_leaf_items()
        if not selections:
            self._show_error(
                self.tr("Nothing selected"),
                self.tr("Select at least one raster time in the tree."),
            )
            return False

        base_url = self.lineBaseUrl.text().strip()
        as_user = self.lineUser.text().strip() or None
        as_pass = self.linePass.text().strip() or None

        selected = self._selected_timeseries()
        if not selected:
            self._show_error(
                self.tr("Load raster failed"), self.tr("Select a timeseries first.")
            )
            return False

        ts_id = str(
            selected.get("timeseriesId")
            or selected.get("timeSeriesId")
            or selected.get("id")
            or ""
        ).strip()

        if not base_url or not ts_id:
            self._show_error(
                self.tr("Load raster failed"),
                self.tr("Base URL and timeseries are required."),
            )
            return False

        is_fcst = self._is_forecast(selected)
        base_group = str(selected.get("path") or ts_id or "ArrayStorage")

        if len(selections) > 10:
            res = QMessageBox.warning(
                self,
                self.tr("Large selection"),
                self.tr(
                    f"You selected {len(selections)} rasters. Loading many rasters may take time. Continue?"
                ),
                QMessageBox.Ok | QMessageBox.Cancel,
            )
            if res != QMessageBox.Ok:
                return False

        target_base = self.lineTargetFile.text().strip()
        total = len(selections)
        self.progressBarLoad.setMinimum(0)
        self.progressBarLoad.setMaximum(total)
        self.progressBarLoad.setValue(0)
        self.progressBarLoad.setVisible(True)
        QtWidgets.QApplication.processEvents()

        try:
            for idx, sel in enumerate(selections, start=1):
                time_iso = str(sel.get("time") or "").strip()
                if not time_iso:
                    raise QgsProcessingException("Selected item has no time value.")
                time_dt = QDateTime.fromString(time_iso, Qt.ISODate)
                if not time_dt.isValid():
                    raise QgsProcessingException(f"Invalid raster time: {time_iso}")

                t0_iso = sel.get("t0")
                t0_dt = None
                if t0_iso:
                    t0_dt = QDateTime.fromString(str(t0_iso), Qt.ISODate)
                    if not t0_dt.isValid():
                        raise QgsProcessingException(f"Invalid t0 value: {t0_iso}")

                dispatch_info = sel.get("dispatch_info")
                member = sel.get("member")

                target_file = self._target_file_for_selection(
                    target_base, idx, total, time_iso, t0_iso
                )

                req = ArrayStorageRasterRequest(
                    base_url=base_url,
                    ident_type=cast(Any, "ts_id"),
                    identifier=ts_id,
                    time_dt=time_dt,
                    t0_dt=t0_dt if (is_fcst and t0_dt and t0_dt.isValid()) else None,
                    dispatch_info=dispatch_info if is_fcst else None,
                    member=member if is_fcst else None,
                    as_user=as_user,
                    as_pass=as_pass,
                    target_file=target_file,
                    store_id=None,
                    format="geotiff",
                    activate_temporal=self.chkActivateTemporal.isChecked(),
                )

                result = load_arraystorage_raster(
                    req, feedback=None, add_to_project=False
                )
                self._add_layer_with_grouping(
                    result.layer,
                    base_group,
                    {
                        "t0": t0_iso,
                        "dispatch_info": dispatch_info,
                        "member": member,
                        "time": time_iso,
                    },
                )
                self.progressBarLoad.setValue(idx)
                QtWidgets.QApplication.processEvents()
            return True

        except QgsProcessingException as e:
            self._show_error(self.tr("Load raster failed"), str(e))
            return False
        except Exception as e:
            self._show_error(
                self.tr("Load raster failed"),
                self.tr(f"Unexpected error:\n{e}"),
            )
            return False
        finally:
            self.progressBarLoad.setValue(0)
            self.progressBarLoad.setVisible(False)
