# -*- coding: utf-8 -*-
from __future__ import annotations

import json
import re
from typing import Optional

import requests
from qgis.PyQt.QtGui import QColor
from qgis.PyQt.QtCore import QVariant
from qgis.core import (
    QgsProcessingException,
    QgsRasterLayer,
    QgsSingleBandPseudoColorRenderer,
    QgsColorRampShader,
    QgsRasterShader,
    QgsVectorLayer,
    QgsFields,
    QgsField,
    QgsFeature,
)
from osgeo import gdal, osr  # QGIS ships GDAL
import numpy as np

from .arraystorage_core import make_session


# ------------------------------------------------------------
# Helpers: color conversions and validation
# ------------------------------------------------------------


def _color_to_rgb_int(color: QColor) -> int:
    return (color.red() << 16) + (color.green() << 8) + color.blue()


def _color_from_rgb_int(value: int, alpha: int) -> QColor:
    r = (int(value) >> 16) & 0xFF
    g = (int(value) >> 8) & 0xFF
    b = int(value) & 0xFF
    return QColor(r, g, b, int(alpha))


def _color_shader_from_layer(layer: QgsRasterLayer) -> QgsColorRampShader:
    if not isinstance(layer, QgsRasterLayer) or not layer.isValid():
        raise QgsProcessingException("Layer must be a valid raster layer")

    renderer = layer.renderer()
    if not isinstance(renderer, QgsSingleBandPseudoColorRenderer):
        raise QgsProcessingException(
            "Raster layer must use a SingleBandPseudoColor renderer"
        )

    shader = renderer.shader()
    if shader is None:
        raise QgsProcessingException("Raster renderer has no shader")

    func = shader.rasterShaderFunction()
    if not isinstance(func, QgsColorRampShader):
        raise QgsProcessingException(
            "Raster renderer must use a color ramp shader (pseudocolor)"
        )
    return func


# ------------------------------------------------------------
# REST wrappers
# ------------------------------------------------------------


def list_classifications(base_url: str, session: Optional[requests.Session] = None):
    base = base_url.rstrip("/")
    url = f"{base}/rest/arrayStorage/classifications"
    session = session or make_session(None, None)

    r = session.get(url, timeout=20)
    if r.status_code >= 400:
        raise QgsProcessingException(
            f"Failed to list classifications: {r.status_code} {r.text}"
        )
    data = r.json()
    if not isinstance(data, list):
        raise QgsProcessingException("Unexpected classification list response")
    return data


def get_classification(
    base_url: str, classification_name: str, session: Optional[requests.Session] = None
) -> dict:
    base = base_url.rstrip("/")
    url = f"{base}/rest/arrayStorage/classifications/{classification_name}"
    session = session or make_session(None, None)

    r = session.get(url, timeout=20)
    if r.status_code >= 400:
        raise QgsProcessingException(
            f"Failed to fetch classification '{classification_name}': {r.status_code} {r.text}"
        )
    data = r.json()
    if not isinstance(data, dict):
        raise QgsProcessingException("Unexpected classification response body")
    return data


def delete_classification(
    base_url: str, classification_name: str, session: Optional[requests.Session] = None
) -> dict:
    base = base_url.rstrip("/")
    url = f"{base}/rest/arrayStorage/classifications/{classification_name}"
    session = session or make_session(None, None)

    r = session.delete(url, timeout=20)
    if r.status_code >= 400:
        raise QgsProcessingException(
            f"Failed to delete classification '{classification_name}': {r.status_code} {r.text}"
        )
    try:
        return r.json()
    except Exception:  # noqa: BLE001
        return {"status_code": r.status_code, "text": r.text}


def upload_classification(
    base_url: str, classification: dict, session: Optional[requests.Session] = None
) -> dict:
    base = base_url.rstrip("/")
    url = f"{base}/rest/arrayStorage/classifications"
    session = session or make_session(None, None)

    r = session.put(url, json=classification, timeout=30)
    if r.status_code >= 400:
        raise QgsProcessingException(
            f"Failed to upload classification '{classification.get('name', '')}': {r.status_code} {r.text}"
        )
    try:
        return r.json()
    except Exception:  # noqa: BLE001
        return {"status_code": r.status_code, "text": r.text}


# ------------------------------------------------------------
# Conversions: QGIS <-> ArrayStorage classification
# ------------------------------------------------------------


def raster_style_to_classification(
    layer: QgsRasterLayer,
    classification_name: str,
    *,
    parameter_name: str = "*",
    description: Optional[str] = None,
) -> dict:
    func = _color_shader_from_layer(layer)
    items = sorted(func.colorRampItemList(), key=lambda i: i.value)
    if not items:
        raise QgsProcessingException("Raster layer has no color ramp items")

    desc = description or classification_name

    colors = []
    ranges = []
    for idx, item in enumerate(items, start=1):
        colors.append(
            {
                "index": idx,
                "alpha": int(item.color.alpha()),
                "rgb": _color_to_rgb_int(item.color),
                "name": idx,
                "description": item.label or str(idx),
            }
        )

        rng: dict[str, object] = {"index": idx, "description": idx}
        if idx > 1:
            prev_val = items[idx - 2].value
            if prev_val is not None:
                rng["from"] = float(prev_val)
        value = float(item.value) if item.value is not None else float(idx)
        if idx < len(items) or len(items) == 1:
            rng["to"] = value
        ranges.append(rng)

    classification = {
        "version": 1,
        "name": classification_name,
        "description": desc,
        "type": "Raster",
        "color": colors,
        "parameter": [
            {
                "name": parameter_name or "*",
                "ascending": "true",
                "classLimits": "lowerIncluded",
                "channel": [
                    {
                        "index": 1,
                        "range": ranges,
                    }
                ],
            }
        ],
    }

    return classification


def classification_to_shader_items(classification: dict):
    colors_raw = classification.get("color") or []
    parameters = classification.get("parameter") or []
    if not colors_raw or not parameters:
        raise QgsProcessingException("Classification is missing 'color' or 'parameter'")

    color_map = {}
    for c in colors_raw:
        try:
            idx = int(c.get("index"))
        except Exception:
            continue
        color_map[idx] = c

    channel = None
    for p in parameters:
        chans = p.get("channel") or []
        if chans:
            channel = chans[0]
            break
    if channel is None:
        raise QgsProcessingException("Classification has no channel information")

    ranges_raw = sorted(channel.get("range") or [], key=lambda r: r.get("index", 0))
    items: list[QgsColorRampShader.ColorRampItem] = []
    min_val: Optional[float] = None
    max_val: Optional[float] = None

    for rng in ranges_raw:
        try:
            idx = int(rng.get("index", len(items) + 1))
        except Exception:
            continue

        color_def = color_map.get(idx)
        if color_def is None:
            continue

        alpha = color_def.get("alpha", 255)
        rgb_val = color_def.get("rgb", 0)
        label = str(rng.get("description") or color_def.get("description") or idx)

        value = rng.get("to")
        if value is None:
            value = rng.get("from")
        if value is None:
            value = idx

        try:
            value_f = float(value)
        except Exception as exc:  # noqa: BLE001
            raise QgsProcessingException(
                "Invalid range value in classification"
            ) from exc

        items.append(
            QgsColorRampShader.ColorRampItem(
                value_f,
                _color_from_rgb_int(rgb_val, alpha),
                label,
            )
        )

        if rng.get("from") is not None:
            try:
                from_val = float(rng["from"])
                min_val = from_val if min_val is None else min(min_val, from_val)
            except Exception:
                pass
        else:
            min_val = value_f if min_val is None else min(min_val, value_f)

        if rng.get("to") is not None:
            try:
                to_val = float(rng["to"])
                max_val = to_val if max_val is None else max(max_val, to_val)
            except Exception:
                pass
        else:
            max_val = value_f if max_val is None else max(max_val, value_f)

    return items, min_val, max_val


def apply_classification_to_layer(layer: QgsRasterLayer, classification: dict) -> None:
    items, min_val, max_val = classification_to_shader_items(classification)
    if not items:
        raise QgsProcessingException("Classification contains no usable ranges")

    func = QgsColorRampShader()
    func.setColorRampType(QgsColorRampShader.Discrete)
    func.setColorRampItemList(items)
    if min_val is not None:
        func.setMinimumValue(min_val)
    if max_val is not None:
        func.setMaximumValue(max_val)

    shader = QgsRasterShader()
    shader.setRasterShaderFunction(func)

    band = 1
    existing = layer.renderer()
    if existing is not None and hasattr(existing, "band"):
        try:
            band = int(existing.band())
        except Exception:
            pass

    renderer = QgsSingleBandPseudoColorRenderer(layer.dataProvider(), band, shader)
    if min_val is not None:
        try:
            renderer.setClassificationMin(min_val)
        except Exception:
            pass
    if max_val is not None:
        try:
            renderer.setClassificationMax(max_val)
        except Exception:
            pass
    layer.setRenderer(renderer)
    layer.triggerRepaint()


def fetch_and_apply_classification(
    base_url: str,
    classification_name: str,
    layer: QgsRasterLayer,
    *,
    session: Optional[requests.Session] = None,
) -> dict:
    session = session or make_session(None, None)
    classification = get_classification(base_url, classification_name, session)
    apply_classification_to_layer(layer, classification)
    return classification


# ------------------------------------------------------------
# Layer builder for table view
# ------------------------------------------------------------


def build_classifications_layer(items: list[dict]) -> QgsVectorLayer:
    layer = QgsVectorLayer("None", "ArrayStorage classifications", "memory")
    prov = layer.dataProvider()

    fields = QgsFields()
    fields.append(QgsField("name", QVariant.String))
    fields.append(QgsField("description", QVariant.String))
    fields.append(QgsField("type", QVariant.String))
    fields.append(QgsField("parameters", QVariant.String))
    fields.append(QgsField("colors", QVariant.Int))
    fields.append(QgsField("version", QVariant.Int))
    fields.append(QgsField("raw_json", QVariant.String))
    prov.addAttributes(fields)
    layer.updateFields()

    feats = []
    for item in items:
        f = QgsFeature()
        f.setFields(fields)

        params = item.get("parameter") or []
        param_names = [p.get("name", "") for p in params if isinstance(p, dict)]

        f["name"] = str(item.get("name", ""))
        f["description"] = str(item.get("description", ""))
        f["type"] = str(item.get("type", ""))
        f["parameters"] = ", ".join([str(p) for p in param_names])
        try:
            f["colors"] = int(len(item.get("color") or []))
        except Exception:
            f["colors"] = None
        try:
            f["version"] = int(item.get("version"))
        except Exception:
            f["version"] = None
        f["raw_json"] = json.dumps(item, ensure_ascii=False)

        feats.append(f)

    if feats:
        prov.addFeatures(feats)
        layer.updateExtents()

    layer.setCustomProperty("arraystorage:kind", "classifications_list")
    return layer


def create_dummy_raster_layer(
    layer_name: str, min_val: float | None = None, max_val: float | None = None
) -> QgsRasterLayer:
    """
    Create a tiny in-memory GeoTIFF and return it as a QgsRasterLayer.
    If min/max are provided, write a 1xN gradient from min -> max so
    classification colors render meaningfully.
    This is used to host symbology for ArrayStorage classifications.
    """

    safe_name = re.sub(r"[^A-Za-z0-9_.-]", "_", layer_name) or "classification"
    path = f"/vsimem/{safe_name}.tif"

    width = 1
    height = 1
    write_data = None
    if min_val is not None and max_val is not None:
        width = 256
        height = 1
        if max_val == min_val:
            write_data = np.full((height, width), float(min_val), dtype=np.float32)
        else:
            write_data = np.linspace(min_val, max_val, num=width, dtype=np.float32)
            write_data = write_data.reshape((height, width))

    drv = gdal.GetDriverByName("GTiff")
    ds = drv.Create(path, width, height, 1, gdal.GDT_Float32)
    if ds is None:
        raise QgsProcessingException("Failed to create dummy raster dataset")

    # 0.1 degree pixel spacing to keep footprint small but visible
    ds.SetGeoTransform([0, 0.1, 0, 0, 0, -0.1])
    srs = osr.SpatialReference()
    srs.ImportFromEPSG(4326)
    ds.SetProjection(srs.ExportToWkt())

    band = ds.GetRasterBand(1)
    if write_data is None:
        band.WriteArray(np.array([[0]], dtype=np.float32))
    else:
        band.WriteArray(write_data)
    band.FlushCache()
    ds.FlushCache()
    ds = None

    rlayer = QgsRasterLayer(path, layer_name, "gdal")
    if not rlayer.isValid():
        raise QgsProcessingException("Dummy raster layer could not be loaded")
    return rlayer
