# -*- coding: utf-8 -*-

"""
/***************************************************************************
 kisters_processing
                                 A QGIS plugin
 This is the processing toolkit for the network store
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2022-05-28
        copyright            : (C) 2022 by Attila Bibok / KISTERS North America
        email                : attila.bibok@kisters.net
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""

__author__ = "Attila Bibok / KISTERS North America"
__date__ = "2022-05-28"
__copyright__ = "(C) 2022 by Attila Bibok / KISTERS North America"

# This will get replaced with a git SHA1 when you do a git archive

__revision__ = "$Format:%H$"

import os
import json
from typing import Optional
import concurrent.futures

import requests


from qgis.PyQt.QtGui import QIcon
from qgis.PyQt.QtCore import (
    QCoreApplication,
    # Qt, QDateTime, QTime,
    QSettings,
)
from qgis.core import (
    QgsProcessingAlgorithm,
    QgsProcessingException,
    QgsProcessingParameterString,
    QgsProcessingParameterEnum,
    QgsProcessingParameterFileDestination,
    QgsProcessingParameterDateTime,
    QgsProcessingParameterCrs,
    QgsProcessing,
    # QgsProject,
    # QgsRasterLayer,
    # QgsMessageLog,
    # Qgis,
)
from qgis.core import (
    QgsProcessingParameterFeatureSink,
    QgsFeatureSink,
    QgsFields,
    QgsField,
    QgsFeature,
    QgsGeometry,
    QgsPointXY,
    QgsWkbTypes,
    QgsCoordinateReferenceSystem,
)
from qgis.PyQt.QtCore import QVariant

from .datasphere_core import (
    RasterRequest,
    load_datasphere_raster,
    list_raster_timeseries,
    make_session,
    floor_to_past_hour_utc,
    qdt_to_iso_with_tz,
    get_datasphere_classification,
    list_locations,
)
from .arraystore_classifications import (
    apply_classification_to_layer,
    create_dummy_raster_layer,
    classification_to_shader_items,
)
from qgis.core import QgsProject


class _BaseDatasphereAlgorithm(QgsProcessingAlgorithm):
    """
    Shared helpers for the network algorithms.
    """

    PARAM_BASE_URL = "BASE_URL"
    PARAM_DS_USER = "DS_USER"
    PARAM_DS_PASS = "DS_PASS"

    def createInstance(self):
        # QGIS calls this to clone the algorithm when running it
        return self.__class__()

    def tr(self, string: str) -> str:
        return QCoreApplication.translate("Processing", string)

    # All four will appear under provider "Kisters" / group "Networks"
    def groupId(self) -> str:
        return "Datasphere"

    def group(self) -> str:
        return self.tr("Datasphere")

    def icon(self):
        """
        Icon shown for Datasphere algorithms in the Processing toolbox.
        """
        plugin_path = os.path.dirname(__file__)
        return QIcon(os.path.join(plugin_path, "icon_datasphere.png"))

    def base_url_param(self):
        return QgsProcessingParameterString(
            self.PARAM_BASE_URL,
            self.tr("Base URL"),
            defaultValue="http://na.datasphere.online/external",
        )

    def ds_user_param(self):
        return QgsProcessingParameterString(
            self.PARAM_DS_USER,
            self.tr("Username"),
            # defaultValue="http://na.datasphere.online/",
        )

    def ds_pass_param(self):
        return QgsProcessingParameterString(
            self.PARAM_DS_PASS,
            self.tr("Password"),
            # defaultValue="http://na.datasphere.online/",
        )

    def flags(self):
        # Run on the main thread so layer tree changes are safe
        return super().flags() | QgsProcessingAlgorithm.FlagNoThreading


class DataSphereLoadRasterAlgorithm(_BaseDatasphereAlgorithm):
    """
    Load a network from the store into the current QGIS project.
    """

    PARAM_DSID = "DS_ID"
    PARAM_IDENT_TYPE = "IDENT_TYPE"
    PARAM_TS_PATH_ENUM = "TS_PATH_ENUM"
    PARAM_IDENTIFIER = "IDENTIFIER"
    PARAM_T0 = "RASTER_T0"
    PARAM_TIME = "RASTER_TIME"
    PARAM_TARGET_CRS = "TARGET_CRS"
    PARAM_FORMAT = "RASTER_FORMAT"
    PARAM_TARGET_FILE = "TARGET_FILE"

    IDENT_TYPE_OPTIONS = ["ts_id", "path"]
    FORMAT_OPTIONS = [
        "geotiff"
    ]  # png is pointless, but let's leave this here for future extensions

    # Static cache for raster paths shown in the dropdown.
    # Processing UIs are static, so we populate once at init.
    _cached_paths: list[str] = []

    def name(self) -> str:
        # internal id, no spaces
        return "load_raster_from_datasphere"

    def displayName(self) -> str:
        return self.tr("Load Raster Product")

    def shortHelpString(self) -> str:
        return self.tr(
            "Downloads a raster product (GeoTIFF) from Datasphere by ts_id or path "
            "and loads it into the current project."
            "Give t0, to retrieve ensemble raster timeseries."
            "Omitting t0 indicates that it's an observation, not a forecast."
            "Currently, only geotiff output is supported."
        )

    def icon(self):
        plugin_path = os.path.dirname(__file__)
        return QIcon(os.path.join(plugin_path, "icon_datasphere_with_raster.png"))

    def dsid_param(self):
        return QgsProcessingParameterString(
            self.PARAM_DSID,
            self.tr("Datasource ID (dsId)"),
            defaultValue="0",
        )

    def ident_type_param(self):
        return QgsProcessingParameterEnum(
            self.PARAM_IDENT_TYPE,
            self.tr("Identifier type"),
            options=self.IDENT_TYPE_OPTIONS,
            defaultValue=1,  # path
        )

    def ts_path_enum_param(self):
        # Populated from cached list. If empty, user can still type manual id/path.
        return QgsProcessingParameterEnum(
            self.PARAM_TS_PATH_ENUM,
            self.tr("Raster product (path)"),
            options=self._cached_paths or [""],
            defaultValue=0,
            optional=True,
        )

    def identifier_param(self):
        return QgsProcessingParameterString(
            self.PARAM_IDENTIFIER,
            self.tr("Timeseries identifier (ts_id or path)"),
            optional=True,
        )

    def t0_param(self):
        return QgsProcessingParameterDateTime(
            self.PARAM_T0,
            self.tr("Ensemble reference time (t0) in LOCAL TIMEZONE"),
            optional=True,
        )

    def time_param(self):
        return QgsProcessingParameterDateTime(
            self.PARAM_TIME,
            self.tr("Raster layer time in LOCAL TIMEZONE"),
        )

    def crs_param(self):
        return QgsProcessingParameterCrs(
            self.PARAM_TARGET_CRS,
            self.tr("Target CRS"),
            optional=True,
        )

    def format_param(self):
        return QgsProcessingParameterEnum(
            self.PARAM_FORMAT,
            self.tr("Format"),
            options=self.FORMAT_OPTIONS,
            defaultValue=0,
        )

    def target_file_param(self):
        return QgsProcessingParameterFileDestination(
            self.PARAM_TARGET_FILE,
            self.tr("Target GeoTIFF file (or temporary)"),
            fileFilter="GeoTIFF (*.tif *.tiff)",
            defaultValue=QgsProcessing.TEMPORARY_OUTPUT,
        )

    def initAlgorithm(self, config):
        # try to populate cached paths once
        if not self.__class__._cached_paths:
            try:
                settings = QSettings()
                # use the same keys your plugin uses for DS creds
                # adjust if your real keys differ
                ds_user = settings.value("network_store/ds_user", "", type=str) or None
                ds_pass = settings.value("network_store/ds_pass", "", type=str) or None
                base_url = settings.value(
                    "network_store/ds_base_url",
                    "http://na.datasphere.online/",
                    type=str,
                )
                ds_id = settings.value("network_store/ds_id", "0", type=str)

                if ds_user and ds_pass:
                    session = make_session(ds_user, ds_pass)
                    items = list_raster_timeseries(base_url, session, ds_id=int(ds_id))
                    paths = sorted(
                        {
                            str(it.get("path", "")).strip()
                            for it in items
                            if it.get("path")
                        }
                    )
                    self.__class__._cached_paths = paths
            except Exception:
                # leave list empty; manual entry still works
                self.__class__._cached_paths = []

        self.addParameter(self.base_url_param())
        self.addParameter(self.ds_user_param())
        self.addParameter(self.ds_pass_param())
        self.addParameter(self.dsid_param())
        self.addParameter(self.ident_type_param())
        # self.addParameter(self.ts_path_enum_param())
        self.addParameter(self.identifier_param())
        self.addParameter(self.format_param())
        self.addParameter(self.t0_param())
        self.addParameter(self.time_param())
        self.addParameter(self.crs_param())
        self.addParameter(self.target_file_param())

    def resolve_path_to_ts_id(
        self,
        base_url: str,
        ds_id: int,
        path_identifier: str,
        session: requests.Session,
    ) -> str:
        """
        Resolve readable path -> timeseriesId by filtering the raster timeSeries list.
        """
        items = list_raster_timeseries(base_url, session, ds_id=ds_id)
        for it in items:
            if str(it.get("path", "")).strip() == path_identifier.strip():
                tsid = it.get("timeseriesId") or it.get("timeSeriesId") or it.get("id")
                if tsid:
                    return str(tsid)

        raise QgsProcessingException(
            f"Path '{path_identifier}' not found in raster products for dsId={ds_id}."
        )

    def processAlgorithm(self, parameters, context, feedback):
        base_url = self.parameterAsString(parameters, self.PARAM_BASE_URL, context)
        ds_user = self.parameterAsString(parameters, self.PARAM_DS_USER, context)
        ds_pass = self.parameterAsString(parameters, self.PARAM_DS_PASS, context)
        ds_id = self.parameterAsInt(parameters, self.PARAM_DSID, context)

        ident_type_idx = self.parameterAsEnum(
            parameters, self.PARAM_IDENT_TYPE, context
        )
        ident_type = self.IDENT_TYPE_OPTIONS[ident_type_idx]

        # Prefer dropdown path if chosen
        # enum_idx = self.parameterAsEnum(parameters, self.PARAM_TS_PATH_ENUM, context)
        # enum_path = ""
        # if self.__class__._cached_paths and 0 <= enum_idx < len(
        #     self.__class__._cached_paths
        # ):
        #     enum_path = self.__class__._cached_paths[enum_idx].strip()
        enum_path = None

        # Manual identifier
        manual_identifier = self.parameterAsString(
            parameters, self.PARAM_IDENTIFIER, context
        ).strip()

        identifier = enum_path or manual_identifier
        if not identifier:
            raise QgsProcessingException(
                "You must select a raster product path or provide an identifier manually."
            )

        fmt_idx = self.parameterAsEnum(parameters, self.PARAM_FORMAT, context)
        fmt = self.FORMAT_OPTIONS[fmt_idx]
        if fmt != "geotiff":
            raise QgsProcessingException("Only GeoTIFF export is supported currently.")

        t0_dt = self.parameterAsDateTime(parameters, self.PARAM_T0, context)
        time_dt = self.parameterAsDateTime(parameters, self.PARAM_TIME, context)
        time_dt = floor_to_past_hour_utc(time_dt)

        if not time_dt.isValid():
            raise QgsProcessingException("Raster time is required.")

        t0_iso = None
        if t0_dt.isValid():
            t0_dt = floor_to_past_hour_utc(t0_dt)
            t0_iso = qdt_to_iso_with_tz(t0_dt)
            if feedback:
                feedback.pushInfo(f"Using floored t0 (UTC): {t0_iso}")
        else:
            if feedback:
                feedback.pushInfo("No t0 provided (non-ensemble raster).")
        time_iso = qdt_to_iso_with_tz(time_dt)

        if feedback:
            feedback.pushInfo(f"Using floored t0 (UTC): {t0_iso}")
            feedback.pushInfo(f"Using floored time (UTC): {time_iso}")

        target_file = self.parameterAsFileOutput(
            parameters, self.PARAM_TARGET_FILE, context
        )
        target_file = target_file.strip()

        # Let core do flooring, ts_id resolution, download, layer creation & props
        req = RasterRequest(
            base_url=base_url,
            ds_id=ds_id,
            ident_type=ident_type,
            identifier=identifier,
            time_dt=time_dt,
            t0_dt=t0_dt if t0_dt.isValid() else None,
            ds_user=ds_user,
            ds_pass=ds_pass,
            target_file=target_file,
            format="geotiff",
        )

        load_datasphere_raster(req, feedback=feedback, add_to_project=True)
        return {}


# ---------------------------------------------------------------------------
# Build classification dummy layers for Datasphere
# ---------------------------------------------------------------------------


class DatasphereBuildClassificationLayersAlgorithm(_BaseDatasphereAlgorithm):
    """
    Creates dummy raster layers for Datasphere classifications by scanning
    raster time series metadata for referenced classifications.
    """

    PARAM_IDENT_TYPE = "IDENT_TYPE"  # defined for parity; unused here
    PARAM_DSID = "DS_ID"
    PARAM_NAMES_FILTER = "CLASSIFICATION_NAMES"

    def name(self) -> str:
        return "datasphere_build_classification_layers"

    def displayName(self) -> str:
        return self.tr("Load Datasphere Classifications as Dummy Layers")

    def shortHelpString(self) -> str:
        return self.tr(
            "Fetches raster metadata to discover classifications, downloads each "
            "classification definition, and creates 1x1 dummy raster layers with "
            "the corresponding symbology applied. Layers are placed into the group "
            "'Datasphere - Classifications' so they can be edited."
        )

    def icon(self):
        plugin_path = os.path.dirname(__file__)
        return QIcon(os.path.join(plugin_path, "icon_datasphere.png"))

    def dsid_param(self):
        return QgsProcessingParameterString(
            self.PARAM_DSID,
            self.tr("Datasource ID (dsId)"),
            defaultValue="0",
        )

    def names_filter_param(self):
        return QgsProcessingParameterString(
            self.PARAM_NAMES_FILTER,
            self.tr("Classification names (comma/newline separated, optional)"),
            optional=True,
            multiLine=True,
        )

    def createInstance(self):
        return DatasphereBuildClassificationLayersAlgorithm()

    def _split_names(self, raw: str) -> set[str]:
        names: set[str] = set()
        for part in raw.replace(",", "\n").splitlines():
            p = part.strip()
            if p:
                names.add(p)
        return names

    def initAlgorithm(self, config):
        self.addParameter(self.base_url_param())
        self.addParameter(self.ds_user_param())
        self.addParameter(self.ds_pass_param())
        self.addParameter(self.dsid_param())
        self.addParameter(self.names_filter_param())

    def processAlgorithm(self, parameters, context, feedback):
        base_url = self.parameterAsString(parameters, self.PARAM_BASE_URL, context)
        ds_user = (
            self.parameterAsString(parameters, self.PARAM_DS_USER, context) or None
        )
        ds_pass = (
            self.parameterAsString(parameters, self.PARAM_DS_PASS, context) or None
        )
        ds_id = (
            self.parameterAsString(parameters, self.PARAM_DSID, context).strip() or "0"
        )

        names_raw = self.parameterAsString(parameters, self.PARAM_NAMES_FILTER, context)
        name_filter = self._split_names(names_raw) if names_raw else None

        session = make_session(ds_user, ds_pass)
        items = list_raster_timeseries(base_url, session, ds_id=int(ds_id))

        def _extract_classification(meta: dict) -> Optional[str]:
            if not isinstance(meta, dict):
                return None
            cls = (
                meta.get("classification")
                or meta.get("classification_id")
                or meta.get("classificationId")
            )
            if cls:
                return str(cls)
            attrs = meta.get("attributes")
            if isinstance(attrs, dict):
                cls = (
                    attrs.get("classification")
                    or attrs.get("classification_id")
                    or attrs.get("classificationId")
                )
                if cls:
                    return str(cls)
            return None

        classifications: set[str] = set()
        for item in items:
            cls_name = _extract_classification(item)
            if cls_name:
                classifications.add(cls_name)

        if name_filter:
            classifications = {n for n in classifications if n in name_filter}

        if not classifications:
            raise QgsProcessingException("No classifications found with given filter")

        project = QgsProject.instance()
        root = project.layerTreeRoot()
        group_name = "Datasphere - Classifications"
        group = root.findGroup(group_name)
        if group is None:
            group = root.addGroup(group_name)

        created = 0
        # Fetch classifications concurrently (HTTP only), build layers on main thread
        fetched: dict[str, dict] = {}
        max_workers = min(8, len(classifications))

        def _fetch_one(name: str) -> tuple[str, Optional[dict], Optional[Exception]]:
            try:
                sess = make_session(ds_user, ds_pass)
                data = get_datasphere_classification(
                    base_url, int(ds_id), name, session=sess
                )
                return (name, data, None)
            except Exception as e:  # noqa: BLE001
                return (name, None, e)

        with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
            future_map = {ex.submit(_fetch_one, n): n for n in classifications}
            for fut in concurrent.futures.as_completed(future_map):
                name, data, err = fut.result()
                if err:
                    if feedback:
                        feedback.reportError(
                            f"Skipping classification '{name}': {err}",
                            fatalError=False,
                        )
                    continue
                fetched[name] = data

        for cls_name, classification in sorted(fetched.items()):
            # remove existing layer with same name in the target group
            for child in list(group.children()):
                if child.name() == cls_name and hasattr(child, "layerId"):
                    project.removeMapLayer(child.layerId())

            try:
                _, min_val, max_val = classification_to_shader_items(classification)
                layer = create_dummy_raster_layer(
                    cls_name, min_val=min_val, max_val=max_val
                )
                layer.setCustomProperty("datasphere:classification_id", cls_name)
                layer.setCustomProperty("datasphere:base_url", base_url)
                layer.setCustomProperty("datasphere:kind", "classification_dummy")
                layer.setCustomProperty(
                    "datasphere:metadata", json.dumps(classification)
                )

                apply_classification_to_layer(layer, classification)

                project.addMapLayer(layer, False)
                group.insertLayer(0, layer)
                created += 1
                if feedback:
                    feedback.pushInfo(f"Loaded classification '{cls_name}'")
            except Exception as e:  # noqa: BLE001
                if feedback:
                    feedback.reportError(
                        f"Failed to build layer for classification '{cls_name}': {e}",
                        fatalError=False,
                    )
                continue

        return {"created": created}


# ---------------------------------------------------------------------------
# Load locations as point layer
# ---------------------------------------------------------------------------


class DatasphereLoadLocationsAlgorithm(_BaseDatasphereAlgorithm):
    """
    Loads Datasphere locations as a point layer.
    """

    PARAM_OUTPUT = "OUTPUT"

    def name(self):
        return "datasphere_load_locations"

    def displayName(self):
        return self.tr("Load Datasphere Locations")

    def shortHelpString(self):
        return self.tr(
            "Fetches /datasphere/api/locations and creates a point layer with attributes."
        )

    def initAlgorithm(self, config):
        self.addParameter(self.base_url_param())
        self.addParameter(self.ds_user_param())
        self.addParameter(self.ds_pass_param())
        self.addParameter(
            QgsProcessingParameterFeatureSink(
                self.PARAM_OUTPUT,
                self.tr("Output locations"),
                type=QgsProcessing.TypeVectorPoint,
                defaultValue=None,
                optional=True,
            )
        )

    def processAlgorithm(self, parameters, context, feedback):
        base_url = self.parameterAsString(parameters, self.PARAM_BASE_URL, context)
        ds_user = (
            self.parameterAsString(parameters, self.PARAM_DS_USER, context) or None
        )
        ds_pass = (
            self.parameterAsString(parameters, self.PARAM_DS_PASS, context) or None
        )

        session = make_session(ds_user, ds_pass)
        items = list_locations(base_url, session=session)

        fields = QgsFields()
        fields.append(QgsField("id", QVariant.String))
        fields.append(QgsField("name", QVariant.String))
        fields.append(QgsField("name2", QVariant.String))
        fields.append(QgsField("key", QVariant.String))
        fields.append(QgsField("org", QVariant.String))
        fields.append(QgsField("creator", QVariant.String))
        fields.append(QgsField("editor", QVariant.String))
        fields.append(QgsField("creationTime", QVariant.String))
        fields.append(QgsField("editTime", QVariant.String))
        fields.append(QgsField("geometryType", QVariant.String))
        fields.append(QgsField("raw_json", QVariant.String))

        (sink, dest_id) = self.parameterAsSink(
            parameters,
            self.PARAM_OUTPUT,
            context,
            fields,
            QgsWkbTypes.Point,
            QgsCoordinateReferenceSystem("EPSG:4326"),
        )

        for item in items:
            lat = item.get("latitude")
            lon = item.get("longitude")
            if lat is None or lon is None:
                continue
            feat = QgsFeature(fields)
            feat.setGeometry(
                QgsGeometry.fromPointXY(QgsPointXY(float(lon), float(lat)))
            )
            feat["id"] = str(item.get("id") or "")
            feat["name"] = str(item.get("name") or "")
            feat["name2"] = str(item.get("name2") or "")
            feat["key"] = str(item.get("key") or "")
            feat["org"] = str(item.get("organization") or "")
            feat["creator"] = str(item.get("creator") or "")
            feat["editor"] = str(item.get("editor") or "")
            feat["creationTime"] = str(item.get("creationTime") or "")
            feat["editTime"] = str(item.get("editTime") or "")
            feat["geometryType"] = str(item.get("geometryType") or "")
            try:
                import json

                feat["raw_json"] = json.dumps(item, ensure_ascii=False)
            except Exception:
                feat["raw_json"] = ""
            sink.addFeature(feat, QgsFeatureSink.FastInsert)

        return {self.PARAM_OUTPUT: dest_id}

    # Algorithm specific input parameters
