# -*- coding: utf-8 -*-
from __future__ import annotations

import json
import os
from dataclasses import dataclass
from typing import Literal, Optional, Protocol
from urllib.parse import quote
import re
from datetime import datetime, timezone

import requests

from qgis.PyQt.QtCore import QDateTime, QTime, Qt
from qgis.core import (
    QgsMessageLog,
    QgsProcessingException,
    QgsProject,
    QgsRasterLayer,
    QgsDateTimeRange,
    QgsRasterLayerTemporalProperties,
    QgsCoordinateReferenceSystem,
    Qgis,
)
from .arraystore_classifications import apply_classification_to_layer


# -------------------------
# Typed interfaces / models
# -------------------------


class Feedback(Protocol):
    """Minimal feedback interface compatible with QgsProcessingFeedback."""

    def pushInfo(self, msg: str) -> None: ...


IdentType = Literal["ts_id", "path"]
RasterFormat = Literal["geotiff"]


@dataclass(frozen=True)
class RasterRequest:
    base_url: str
    ds_id: int
    ident_type: IdentType
    identifier: str  # ts_id or path depending on ident_type
    time_dt: QDateTime  # local tz ok; will be floored to past hour UTC
    t0_dt: Optional[QDateTime] = None  # optional; local tz ok
    ds_user: Optional[str] = None
    ds_pass: Optional[str] = None
    target_file: str = ""  # may be temp output without extension
    format: RasterFormat = "geotiff"
    activate_temporal: bool = True


@dataclass(frozen=True)
class RasterLoadResult:
    tif_path: str
    ts_id: str
    time_iso: str
    t0_iso: Optional[str]
    layer: QgsRasterLayer


# -------------------------
# Low-level REST utilities
# -------------------------


def make_session(ds_user: Optional[str], ds_pass: Optional[str]) -> requests.Session:
    s = requests.Session()
    if ds_user and ds_pass:
        s.auth = (ds_user, ds_pass)  # basic auth
    s.headers.update({"accept": "application/json"})
    return s


def list_raster_timeseries(
    base_url: str,
    session: requests.Session,
    *,
    ds_id: int = 0,
) -> list[dict]:
    """
    GET <base_url>/raster/datasources/{dsId}/timeSeries
    """
    base = base_url.rstrip("/")
    url = f"{base}/raster/datasources/{ds_id}/timeSeries"
    r = session.get(url, timeout=20)
    if r.status_code >= 400:
        raise QgsProcessingException(
            f"Failed to list raster products: {r.status_code} {r.text}"
        )
    data = r.json()
    if not isinstance(data, list):
        raise QgsProcessingException("Unexpected response for raster timeSeries list.")
    return data


def list_dimension_values(
    base_url: str,
    ds_id: int,
    ts_id: str,
    dim_name: str,
    session: requests.Session,
    *,
    selection: Optional[dict[str, object]] = None,
) -> list:
    """
    GET <base_url>/raster/datasources/{dsId}/timeSeries/{tsId}/dimensions/{dimName}/data
    Returns the raw JSON list for the given dimension (e.g. t0).
    """
    base = base_url.rstrip("/")
    url = (
        f"{base}/raster/datasources/{ds_id}/timeSeries/{ts_id}/dimensions/"
        f"{quote(dim_name, safe='')}/data"
    )
    params: dict[str, str] = {}
    if selection:
        params["selection"] = json.dumps(selection)

    r = session.get(url, params=params, timeout=20)
    if r.status_code >= 400:
        raise QgsProcessingException(
            f"Failed to list dimension '{dim_name}' values: {r.status_code} {r.text}"
        )
    data = r.json()
    if not isinstance(data, list):
        raise QgsProcessingException(
            f"Unexpected response for dimension '{dim_name}' values."
        )
    return data


def list_timestamps_datasphere(
    base_url: str,
    ds_id: int,
    ts_id: str,
    session: requests.Session,
    *,
    selection: Optional[dict] = None,
    period: Optional[str] = None,
) -> list:
    """
    GET <base_url>/raster/datasources/{dsId}/timeSeries/{tsId}/timeStamps
    Optionally passes selection (e.g., {"t0": "<iso>"})
    """
    base = base_url.rstrip("/")
    url = f"{base}/raster/datasources/{ds_id}/timeSeries/{ts_id}/timeStamps"
    params: dict[str, str] = {}
    if selection:
        params["selection"] = json.dumps(selection)
    if period:
        params["period"] = period

    r = session.get(url, params=params, timeout=20)
    if r.status_code >= 400:
        raise QgsProcessingException(
            f"Failed to list timestamps: {r.status_code} {r.text}"
        )
    data = r.json()
    if not isinstance(data, list):
        raise QgsProcessingException("Unexpected response for timestamps list.")
    return data


def list_locations(
    base_url: str,
    session: requests.Session,
) -> list[dict]:
    """
    GET <base_url>/locations
    """
    base = base_url.rstrip("/")
    url = f"{base}/locations"
    r = session.get(url, timeout=30)
    if r.status_code >= 400:
        raise QgsProcessingException(
            f"Failed to list locations: {r.status_code} {r.text}"
        )
    data = r.json()
    if not isinstance(data, list):
        raise QgsProcessingException("Unexpected response for locations list.")
    return data


def get_datasphere_classification(
    base_url: str,
    ds_id: int,
    classification_name: str,
    session: requests.Session,
) -> dict:
    """
    GET <base_url>/raster/datasources/{dsId}/classifications/{classificationName}
    """
    base = base_url.rstrip("/")
    url = f"{base}/raster/datasources/{ds_id}/classifications/{quote(classification_name, safe='')}"
    r = session.get(url, timeout=20)
    if r.status_code >= 400:
        raise QgsProcessingException(
            f"Failed to fetch classification '{classification_name}': {r.status_code} {r.text}"
        )
    data = r.json()
    if not isinstance(data, dict):
        raise QgsProcessingException(
            f"Unexpected response for classification '{classification_name}'."
        )
    return data


def resolve_path_to_ts_id(
    base_url: str,
    ds_id: int,
    path_identifier: str,
    session: requests.Session,
) -> str:
    """
    Resolve readable path -> timeseriesId by filtering the raster timeSeries list.
    """
    items = list_raster_timeseries(base_url, session, ds_id=ds_id)
    for it in items:
        if str(it.get("path", "")).strip() == path_identifier.strip():
            tsid = it.get("timeseriesId") or it.get("timeSeriesId") or it.get("id")
            if tsid:
                return str(tsid)
    raise QgsProcessingException(
        f"Path '{path_identifier}' not found in raster products for dsId={ds_id}."
    )


def download_raster_geotiff(
    base_url: str,
    ts_id: str,
    raster_time_iso: str,
    t0_iso: Optional[str],
    target_file: str,
    session: requests.Session,
    *,
    ds_id: int = 0,
    extra_selection: Optional[dict[str, object]] = None,
    feedback: Optional[Feedback] = None,
) -> str:
    """
    Download a GeoTIFF raster for given timeseries id and time.
    Returns the final file path written.
    """
    base = base_url.rstrip("/")

    # Encode the time segment similar to example (keep T, -, +; encode :)
    time_seg = quote(raster_time_iso, safe="T-+")
    url = f"{base}/raster/datasources/{ds_id}/timeSeries/{ts_id}/data/{time_seg}"

    params = {
        "format": "geotiff",
        "extractMode": "fill_with_nan",
    }
    selection: dict[str, object] = {}
    if t0_iso is not None:
        selection["t0"] = t0_iso
    if extra_selection:
        selection.update(extra_selection)
    if selection:
        params["selection"] = json.dumps(selection)

    if feedback:
        feedback.pushInfo(f"Downloading raster: {url}")
        feedback.pushInfo(f"Params: {params}")

    r = session.get(url, params=params, stream=True, timeout=60)
    if r.status_code >= 400:
        raise QgsProcessingException(
            f"Raster download failed: {r.status_code} {r.text}"
        )

    os.makedirs(os.path.dirname(target_file), exist_ok=True)

    with open(target_file, "wb") as f:
        for chunk in r.iter_content(chunk_size=1024 * 1024):
            if chunk:
                f.write(chunk)

    return target_file


# -------------------------
# Time helpers
# -------------------------


def qdt_to_iso_with_tz(qdt: QDateTime) -> str:
    """
    Convert QDateTime to ISO8601 with milliseconds AND explicit timezone.
    Datasphere rejects naive timestamps.
    """
    qdt_utc = qdt.toUTC()
    s = qdt_utc.toString(Qt.ISODateWithMs)  # usually ends with 'Z'
    if s.endswith("Z"):
        s = s[:-1] + "+00:00"
    return s


def floor_to_past_hour_utc(qdt: QDateTime) -> QDateTime:
    """
    Floor a QDateTime to the past square hour in UTC.
    """
    qdt_utc = qdt.toUTC()
    d = qdt_utc.date()
    t = qdt_utc.time()
    return QDateTime(d, QTime(t.hour(), 0, 0, 0), Qt.UTC)


def ensure_tif_extension(path: str) -> str:
    if not os.path.splitext(path)[1]:
        return path + ".tif"
    return path


def _parse_iso_duration_seconds(value: str) -> Optional[int]:
    """
    Minimal ISO8601 duration parser: supports PnDTnHnMnS or PTnHnMnS or PTnM, etc.
    """
    if not value or not isinstance(value, str):
        return None

    m = re.match(
        r"^P(?:(?P<days>\d+)D)?(?:T(?:(?P<hours>\d+)H)?(?:(?P<mins>\d+)M)?(?:(?P<secs>\d+(?:\.\d+)?)S)?)?$",
        value,
    )
    if not m:
        return None

    try:
        days = int(m.group("days") or 0)
        hours = int(m.group("hours") or 0)
        mins = int(m.group("mins") or 0)
        secs = float(m.group("secs") or 0)
        total = days * 86400 + hours * 3600 + mins * 60 + secs
        return int(total)
    except Exception:
        return None


def _apply_temporal_range(
    layer: QgsRasterLayer,
    time_iso: str,
    value_distance: Optional[str],
    *,
    activate: bool = True,
) -> None:
    """
    Attach a fixed temporal range to the raster layer so Temporal Controller picks it up.
    """
    tp = layer.temporalProperties()
    tp.setIsActive(False)
    start_dt = _parse_iso_qdatetime(time_iso)
    if not start_dt.isValid():
        return

    duration_secs = _parse_iso_duration_seconds(str(value_distance or ""))
    if duration_secs and duration_secs > 0:
        end_dt = QDateTime(start_dt)
        # Avoid overlap with the next slice: stop 1 second before the next valueDistance
        end_dt = end_dt.addSecs(max(duration_secs - 1, 1))
    else:
        end_dt = QDateTime(start_dt)
        end_dt = end_dt.addSecs(1)

    tp.setMode(QgsRasterLayerTemporalProperties.ModeFixedTemporalRange)
    tp.setFixedTemporalRange(QgsDateTimeRange(start_dt, end_dt))
    tp.setReferenceTime(start_dt)
    tp.setIsActive(bool(activate))
    try:
        layer.setTemporalProperties(tp)
    except Exception:
        pass


def _parse_iso_qdatetime(value: str) -> QDateTime:
    """
    Best-effort ISO parser that tolerates offsets and milliseconds.
    """
    if not value:
        return QDateTime()
    # Try Python's ISO parser first for robustness
    try:
        dt = datetime.fromisoformat(value.replace("Z", "+00:00"))
        dt = dt.astimezone(timezone.utc)
        qdt = QDateTime(dt)
        qdt.setTimeSpec(Qt.UTC)
        return qdt
    except Exception:
        pass
    attempts = [value, value.replace("Z", "+00:00"), value.replace("Z", "")]
    for candidate in attempts:
        for fmt in (Qt.ISODateWithMs, Qt.ISODate):
            qdt = QDateTime.fromString(candidate, fmt)
            if qdt.isValid():
                qdt.setTimeSpec(Qt.UTC)
                return qdt
        if "+" in candidate:
            base, _, offset = candidate.partition("+")
            offset = offset.replace(":", "")
            v_try = f"{base}+{offset}"
            for fmt in (Qt.ISODateWithMs, Qt.ISODate):
                qdt = QDateTime.fromString(v_try, fmt)
                if qdt.isValid():
                    qdt.setTimeSpec(Qt.UTC)
                    return qdt
    return QDateTime()


def _z_level_count(meta: Optional[dict]) -> Optional[int]:
    dims = meta.get("dimensions") if isinstance(meta, dict) else []
    for dim in dims or []:
        if not isinstance(dim, dict):
            continue
        name = str(dim.get("name") or "").lower()
        if name != "z":
            continue
        try:
            size_val = dim.get("size")
            if size_val is None:
                continue
            size_int = int(size_val)
            return size_int
        except Exception:
            continue
    return None


def _crs_from_meta(meta: Optional[dict]) -> Optional[QgsCoordinateReferenceSystem]:
    if not isinstance(meta, dict):
        return None
    ident = str(meta.get("projectionIdentifier") or "").strip()
    proj4 = str(meta.get("projectionProj4") or "").strip()
    if ident:
        crs = QgsCoordinateReferenceSystem(ident)
        if crs.isValid():
            return crs
    if proj4:
        crs = QgsCoordinateReferenceSystem.fromProj(proj4)
        if crs.isValid():
            return crs
    return None


# -------------------------
# High-level reusable entrypoint
# -------------------------


def load_datasphere_raster(
    req: RasterRequest,
    *,
    feedback: Optional[Feedback] = None,
    add_to_project: bool = True,
) -> RasterLoadResult:
    """
    End-to-end core:
      - floor times to UTC
      - resolve path -> ts_id if needed
      - download GeoTIFF
      - build QgsRasterLayer
      - attach tracking custom properties
      - optionally add to project

    Safe to call from both Processing and custom dialogs.
    """
    if req.format != "geotiff":
        raise QgsProcessingException("Only GeoTIFF export is supported currently.")

    if not req.identifier.strip():
        raise QgsProcessingException(
            "Timeseries identifier (ts_id or path) is required."
        )

    time_dt = floor_to_past_hour_utc(req.time_dt)
    if not time_dt.isValid():
        raise QgsProcessingException("Raster time is required.")
    time_iso = qdt_to_iso_with_tz(time_dt)

    t0_iso: Optional[str] = None
    if req.t0_dt is not None and req.t0_dt.isValid():
        t0_dt = floor_to_past_hour_utc(req.t0_dt)
        t0_iso = qdt_to_iso_with_tz(t0_dt)
        if feedback:
            feedback.pushInfo(f"Using floored t0 (UTC): {t0_iso}")
    else:
        if feedback:
            feedback.pushInfo("No t0 provided (non-ensemble raster).")

    if feedback:
        feedback.pushInfo(f"Using floored time (UTC): {time_iso}")

    target_file = ensure_tif_extension(req.target_file.strip())
    session = make_session(req.ds_user, req.ds_pass)

    # Resolve ts_id if path is provided
    if req.ident_type == "path":
        ts_id = resolve_path_to_ts_id(req.base_url, req.ds_id, req.identifier, session)
    else:
        ts_id = req.identifier.strip()

    meta: Optional[dict] = None
    try:
        base = req.base_url.rstrip("/")
        url = f"{base}/raster/datasources/{req.ds_id}/timeSeries/{ts_id}"
        r = session.get(url, timeout=20)
        if r.status_code < 400:
            meta = r.json()
    except Exception as e:  # noqa: BLE001
        if feedback:
            feedback.pushInfo(f"Datasphere metadata fetch failed: {e}")

    z_size = _z_level_count(meta)
    has_z_dim = bool(z_size is not None and z_size >= 2)
    if feedback and has_z_dim:
        feedback.pushInfo(f"Detected multiband z dimension with size={z_size}")

    mb_tag = "(MB)"
    mb_tag_lower = mb_tag.lower()

    base_name = ""
    if req.ident_type == "path":
        base_name = req.identifier.strip()
    elif isinstance(meta, dict):
        base_name = str(meta.get("path") or "").strip()
    if not base_name:
        base_name = ts_id

    name_base = f"{base_name} {mb_tag}" if has_z_dim else base_name
    layer_name = f"{name_base}_{time_iso}"

    tif_path = download_raster_geotiff(
        req.base_url,
        ts_id,
        time_iso,
        t0_iso,
        target_file,
        session,
        ds_id=req.ds_id,
        feedback=feedback,
    )

    rlayer = QgsRasterLayer(tif_path, layer_name, "gdal")
    if not rlayer.isValid():
        raise QgsProcessingException(
            f"Downloaded GeoTIFF could not be loaded as a raster layer: {tif_path}"
        )

    # Fix missing/unknown CRS from Datasphere metadata when possible.
    if isinstance(meta, dict):
        try:
            crs = rlayer.crs()
            crs_desc = str(crs.description() or "").strip().lower()
            needs_fix = (not crs.isValid()) or crs_desc == "unknown"
            if needs_fix:
                ds_crs = _crs_from_meta(meta)
                if ds_crs and ds_crs.isValid():
                    rlayer.setCrs(ds_crs)
                    if feedback:
                        feedback.pushInfo(
                            f"Applied CRS from metadata: {ds_crs.authid() or ds_crs.description()}"
                        )
        except Exception as e:  # noqa: BLE001
            if feedback:
                feedback.pushInfo(f"CRS fixup skipped: {e}")

    is_multiband = has_z_dim or rlayer.bandCount() > 1
    if is_multiband and mb_tag_lower not in rlayer.name().lower():
        rlayer.setName(f"{rlayer.name()} {mb_tag}")

    # Datasphere tracking metadata (NO password)
    rlayer.setCustomProperty("datasphere:time", time_iso)
    rlayer.setCustomProperty("datasphere:t0", t0_iso or "")
    rlayer.setCustomProperty("datasphere:ts_id", ts_id)
    path_value = ""
    if req.ident_type == "path":
        path_value = req.identifier.strip()
    elif isinstance(meta, dict):
        path_value = str(meta.get("path") or "").strip()
    rlayer.setCustomProperty("datasphere:path", path_value)
    rlayer.setCustomProperty("datasphere:base_url", req.base_url)
    rlayer.setCustomProperty("datasphere:user", req.ds_user or "")
    # Typo-safe alias you asked for earlier
    rlayer.setCustomProperty("dataspehre:user", req.ds_user or "")

    # Try to fetch metadata and apply classification (similar to ArrayStorage)
    if isinstance(meta, dict):
        try:
            rlayer.setCustomProperty("datasphere:metadata", json.dumps(meta))
            classification_id = (
                meta.get("classification_id")
                or meta.get("classificationId")
                or meta.get("classification")
            )
            if not classification_id:
                attrs = meta.get("attributes")
                if isinstance(attrs, dict):
                    classification_id = (
                        attrs.get("classification")
                        or attrs.get("classification_id")
                        or attrs.get("classificationId")
                    )
            if classification_id:
                rlayer.setCustomProperty(
                    "datasphere:classification_id", classification_id
                )
                try:
                    classification = get_datasphere_classification(
                        req.base_url, int(req.ds_id), classification_id, session=session
                    )
                    apply_classification_to_layer(rlayer, classification)
                    if feedback:
                        feedback.pushInfo(
                            f"Applied Datasphere classification '{classification_id}' to raster layer"
                        )
                except Exception as inner_e:  # noqa: BLE001
                    if feedback:
                        feedback.pushInfo(
                            f"Classification '{classification_id}' could not be applied: {inner_e}"
                        )
            # Enable temporal range using valueDistance when available
            _apply_temporal_range(
                rlayer,
                time_iso,
                meta.get("valueDistance") if isinstance(meta, dict) else None,
                activate=req.activate_temporal,
            )
        except Exception as e:  # noqa: BLE001
            if feedback:
                feedback.pushInfo(f"Datasphere metadata handling failed: {e}")

    if add_to_project:
        QgsProject.instance().addMapLayer(rlayer)

    QgsMessageLog.logMessage(
        f"Loaded Datasphere raster ts_id={ts_id} time={time_iso} -> {tif_path}",
        "Datasphere",
        Qgis.Info,
    )

    return RasterLoadResult(
        tif_path=tif_path,
        ts_id=ts_id,
        time_iso=time_iso,
        t0_iso=t0_iso,
        layer=rlayer,
    )
