"""
==========QNetPlanner - Cost calculation module===========
QNetPlanner I/O utilities
Provides:
 - save_results_geojson(...) : write one GeoJSON containing gateways + sensors
 - save_results_gpkg(...)    : write a GeoPackage (optional)

This file's save_results_geojson() signature matches the call used in QNetPlanner.run().


Gitlab:
    https://gitlab.com/binoy194/QNetPlanner
email:
    binoy194@gmail.com
    kavyask304@gmail.com

Authors:
    Binoy C
    Kavya S K
"""

import os
import logging
from qgis.core import (
    QgsVectorLayer, QgsFields, QgsField, QgsFeature, QgsGeometry,
    QgsPointXY, QgsProject, QgsVectorFileWriter, QgsWkbTypes,
    QgsCoordinateReferenceSystem
)
from qgis.PyQt.QtCore import QVariant

logger = logging.getLogger("QNetPlanner.io_utils")


def _safe_remove(path):
    try:
        if path and os.path.exists(path):
            os.remove(path)
    except Exception as e:
        logger.warning("Could not remove file %s: %s", path, e)


def save_results_geojson(
    out_geojson_path: str,
    gateway_positions: dict,
    gateway_attr_costs: dict,
    selected_gateways: list,
    sensor_positions: dict,
    sensor_attr_costs: dict,
    selected_sensors: dict,
    reverse_sensor_map: dict,
    coverage_map: dict,
    final_gateway_costs: dict,
    final_sensor_costs: dict,
    assigned_map: dict = None,
    crs=None,
    open_in_qgis=True,
    iface=None
):
    """
    Save both gateways and sensors into a single GeoJSON file.

    Parameters (matching QNetPlanner.run() call):
      out_geojson_path: str
      gateway_positions: {gw_id: (x,y)}
      gateway_attr_costs: {gw_id: attr_cost}
      selected_gateways: [gw_id,...]
      sensor_positions: {sensor_gid: (x,y)}
      sensor_attr_costs: {sensor_gid: attr_cost}
      selected_sensors: {stype: [gid,...]} or {gid_set...} depending on optimizer
      reverse_sensor_map: {gid: (stype, orig_fid)}
      coverage_map: {gw_id: {stype: [gid,...], ...}, ...}
      final_gateway_costs: {gw_id: final_cost}
      final_sensor_costs: {gid: final_cost}
      assigned_map: optional mapping gid -> list_of_gw_ids (best-effort)
      crs: QgsCoordinateReferenceSystem (optional)
      open_in_qgis: bool
      iface: QGIS iface for adding layer
    """
    logger.info("Saving QNetPlanner results to GeoJSON: %s", out_geojson_path)

    if crs is None:
        target_crs = QgsCoordinateReferenceSystem("EPSG:4326")
    else:
        target_crs = crs

    if not out_geojson_path.lower().endswith(".geojson") and not out_geojson_path.lower().endswith(".json"):
        out_geojson_path = out_geojson_path + ".geojson"

    out_dir = os.path.dirname(out_geojson_path)
    if out_dir and not os.path.exists(out_dir):
        os.makedirs(out_dir, exist_ok=True)

    # create memory layer in target CRS
    uri = f"Point?crs={target_crs.authid()}"
    mem_layer = QgsVectorLayer(uri, "QNetPlanner_results_temp", "memory")
    prov = mem_layer.dataProvider()

    # define fields
    fields = QgsFields()
    fields.append(QgsField("gid", QVariant.Int))          # gateway id or sensor global id
    fields.append(QgsField("orig_fid", QVariant.Int))     # original layer fid (sensors) or None for gateways
    fields.append(QgsField("feature_type", QVariant.String))  # "gateway" or "sensor"
    fields.append(QgsField("stype", QVariant.String))     # sensor type like sensor_1 (empty for gateway)
    fields.append(QgsField("orig_cost", QVariant.Double)) # final cost value for this feature
    fields.append(QgsField("selected", QVariant.String))     # Yes Selected; no not selected
    fields.append(QgsField("gateways covering", QVariant.String))  # for sensors: comma-separated gateway ids that cover it

    prov.addAttributes(fields)
    mem_layer.updateFields()

    feats = []

    # Add gateway features
    for gw_id, (gx, gy) in gateway_positions.items():
        f = QgsFeature()
        f.setFields(fields)
        f.setAttribute("gid", int(gw_id))
        f.setAttribute("orig_fid", None)
        f.setAttribute("feature_type", "gateway")
        f.setAttribute("stype", "")
        # final gateway cost preferred, otherwise attribute cost
        final_g_cost = None
        if final_gateway_costs is not None and gw_id in final_gateway_costs:
            final_g_cost = final_gateway_costs.get(gw_id)
        else:
            final_g_cost = gateway_attr_costs.get(gw_id)
        f.setAttribute("orig_cost", float(final_g_cost) if final_g_cost is not None else None)
        f.setAttribute("selected", "yes" if gw_id in selected_gateways else "no")
        f.setAttribute("gateways covering", "")
        f.setGeometry(QgsGeometry.fromPointXY(QgsPointXY(float(gx), float(gy))))
        feats.append(f)

    # Flatten selected_sensors (support both dict-of-lists or dict-of-ids)
    selected_sensor_set = set()
    if isinstance(selected_sensors, dict):
        # may be stype -> [gids]
        for val in selected_sensors.values():
            if isinstance(val, (list, set, tuple)):
                selected_sensor_set.update(val)
            else:
                # if dict maps gid->something, include keys
                try:
                    selected_sensor_set.add(int(val))
                except Exception:
                    pass
    elif isinstance(selected_sensors, (list, set, tuple)):
        selected_sensor_set.update(selected_sensors)

    # Build mapping sensor_gid -> covering gateways from coverage_map
    sensor_covering_gws = {}
    for gw, stype_map in (coverage_map or {}).items():
        for stype, gids in stype_map.items():
            for gid in gids:
                sensor_covering_gws.setdefault(gid, set()).add(gw)

    # Add sensor features
    for gid, (sx, sy) in sensor_positions.items():
        stype, orig_fid = reverse_sensor_map.get(gid, (None, None))
        f = QgsFeature()
        f.setFields(fields)
        f.setAttribute("gid", int(gid))
        f.setAttribute("orig_fid", int(orig_fid) if orig_fid is not None else None)
        f.setAttribute("feature_type", "sensor")
        f.setAttribute("stype", str(stype) if stype is not None else "")
        # final sensor cost preferred, otherwise attribute cost
        final_s_cost = None
        if final_sensor_costs is not None and gid in final_sensor_costs:
            final_s_cost = final_sensor_costs.get(gid)
        else:
            final_s_cost = sensor_attr_costs.get(gid)
        f.setAttribute("orig_cost", float(final_s_cost) if final_s_cost is not None else None)
        f.setAttribute("selected", "yes" if gid in selected_sensor_set else "no")
        covering = sensor_covering_gws.get(gid, set())
        cov_str = ",".join([str(int(x)) for x in sorted(list(covering))]) if covering else ""
        # if assigned_map provided and has mapping for this gid, append gateways from it
        if assigned_map and gid in assigned_map:
            try:
                assigned_vals = assigned_map[gid]
                # assigned_vals may be list or single
                if isinstance(assigned_vals, (list, tuple, set)):
                    for v in assigned_vals:
                        if str(int(v)) not in cov_str.split(","):
                            cov_str = cov_str + ("," if cov_str else "") + str(int(v))
                else:
                    if str(int(assigned_vals)) not in cov_str.split(","):
                        cov_str = cov_str + ("," if cov_str else "") + str(int(assigned_vals))
            except Exception:
                pass
        f.setAttribute("gateways covering", cov_str)
        f.setGeometry(QgsGeometry.fromPointXY(QgsPointXY(float(sx), float(sy))))
        feats.append(f)

    if feats:
        prov.addFeatures(feats)
    mem_layer.updateExtents()

    # remove existing file
    _safe_remove(out_geojson_path)

    # write memory layer to GeoJSON
    dest_crs = QgsCoordinateReferenceSystem("EPSG:4326")
    writer_result = QgsVectorFileWriter.writeAsVectorFormat(mem_layer, out_geojson_path, "utf-8", dest_crs, "GeoJSON")

    if isinstance(writer_result, tuple):
        err_code, err_msg = writer_result
    else:
        err_code = writer_result
        err_msg = ""

    if err_code != QgsVectorFileWriter.NoError:
        logger.error("Failed to write GeoJSON (code %s, msg='%s')", err_code, err_msg)
        raise Exception(f"Failed to write GeoJSON (code {err_code}) - {err_msg}")

    logger.info("Saved QNetPlanner results to %s", out_geojson_path)

    # optionally add to QGIS project
    if open_in_qgis and iface is not None:
        try:
            vlayer = QgsVectorLayer(out_geojson_path, f"QNetPlanner Results - {os.path.basename(out_geojson_path)}", "ogr")
            if vlayer.isValid():
                QgsProject.instance().addMapLayer(vlayer)
            else:
                logger.warning("Saved GeoJSON but failed to load into QGIS")
        except Exception as e:
            logger.exception("Error adding result layer to QGIS: %s", e)

    return out_geojson_path


def save_results_gpkg(
    out_gpkg_path: str,
    gateway_positions: dict,
    gateway_attr_costs: dict,
    selected_gateways: list,
    sensor_positions: dict,
    sensor_attr_costs: dict,
    selected_sensors: dict,
    reverse_sensor_map: dict,
    coverage_map: dict,
    final_gateway_costs: dict,
    final_sensor_costs: dict,
    crs=None,
    open_in_qgis=True,
    iface=None
):
    
    logger.info("Saving QNetPlanner results to GeoPackage: %s", out_gpkg_path)

    if crs is None:
        crs = QgsCoordinateReferenceSystem("EPSG:4326")

    out_dir = os.path.dirname(out_gpkg_path)
    if out_dir and not os.path.exists(out_dir):
        os.makedirs(out_dir, exist_ok=True)

    # gateways layer
    gw_uri = f"Point?crs={crs.authid()}"
    gw_layer = QgsVectorLayer(gw_uri, "selected_gateways_temp", "memory")
    gw_fields = QgsFields()
    gw_fields.append(QgsField("gw_id", QVariant.Int))
    gw_fields.append(QgsField("orig_cost", QVariant.Double))
    gw_fields.append(QgsField("selected", QVariant.String))
    gw_layer.dataProvider().addAttributes(gw_fields)
    gw_layer.updateFields()

    gw_feats = []
    for gw_id, (x, y) in gateway_positions.items():
        f = QgsFeature()
        f.setFields(gw_layer.fields())
        f.setAttribute("gw_id", int(gw_id))
        final_g_cost = None
        if final_gateway_costs is not None and gw_id in final_gateway_costs:
            final_g_cost = final_gateway_costs.get(gw_id)
        else:
            final_g_cost = gateway_attr_costs.get(gw_id)
        f.setAttribute("orig_cost", float(final_g_cost) if final_g_cost is not None else None)
        f.setAttribute("selected", "yes" if gw_id in selected_gateways else "no")
        f.setGeometry(QgsGeometry.fromPointXY(QgsPointXY(x, y)))
        gw_feats.append(f)
    if gw_feats:
        gw_layer.dataProvider().addFeatures(gw_feats)
    gw_layer.updateExtents()

    # sensors layer
    s_uri = f"Point?crs={crs.authid()}"
    s_layer = QgsVectorLayer(s_uri, "selected_sensors_temp", "memory")
    s_fields = QgsFields()
    s_fields.append(QgsField("sensor_gid", QVariant.Int))
    s_fields.append(QgsField("stype", QVariant.String))
    s_fields.append(QgsField("orig_fid", QVariant.Int))
    s_fields.append(QgsField("orig_cost", QVariant.Double))
    s_fields.append(QgsField("selected", QVariant.String))
    s_fields.append(QgsField("gateways covering", QVariant.String))
    s_layer.dataProvider().addAttributes(s_fields)
    s_layer.updateFields()

    s_feats = []

    # precompute sensor->covering gateways
    sensor_covering_gws = {}
    for gw, stype_map in (coverage_map or {}).items():
        for stype, gids in stype_map.items():
            for gid in gids:
                sensor_covering_gws.setdefault(gid, set()).add(gw)

    for gid, (x, y) in sensor_positions.items():
        f = QgsFeature()
        f.setFields(s_layer.fields())
        stype, orig_fid = reverse_sensor_map.get(gid, (None, None))
        f.setAttribute("sensor_gid", int(gid))
        f.setAttribute("stype", str(stype) if stype is not None else "")
        f.setAttribute("orig_fid", int(orig_fid) if orig_fid is not None else None)
        final_s_cost = None
        if final_sensor_costs is not None and gid in final_sensor_costs:
            final_s_cost = final_sensor_costs.get(gid)
        else:
            final_s_cost = sensor_attr_costs.get(gid)
        f.setAttribute("orig_cost", float(final_s_cost) if final_s_cost is not None else None)
        sel = 0
        for stype_key, gids in selected_sensors.items():
            if gid in gids:
                sel = 1
                break
        f.setAttribute("selected", sel)
        cov = sensor_covering_gws.get(gid, set())
        cov_str = ",".join([str(int(x)) for x in sorted(list(cov))]) if cov else ""
        f.setAttribute("gateways covering", cov_str)
        f.setGeometry(QgsGeometry.fromPointXY(QgsPointXY(x, y)))
        s_feats.append(f)
    if s_feats:
        s_layer.dataProvider().addFeatures(s_feats)
    s_layer.updateExtents()

    _safe_remove(out_gpkg_path)

    gw_layer_name = "selected_gateways"
    s_layer_name = "selected_sensors"

    res1 = QgsVectorFileWriter.writeAsVectorFormat(
        gw_layer,
        out_gpkg_path,
        "utf-8",
        crs,
        "GPKG",
        layerOptions=[f"layername={gw_layer_name}"]
    )
    if isinstance(res1, tuple):
        err_code, err_msg = res1
    else:
        err_code, err_msg = res1, ""
    if err_code != QgsVectorFileWriter.NoError:
        logger.error("Failed writing gateways to gpkg: code %s msg='%s'", err_code, err_msg)
        raise Exception(f"Failed writing gateways to geopackage (code {err_code}) - {err_msg}")

    res2 = QgsVectorFileWriter.writeAsVectorFormat(
        s_layer,
        out_gpkg_path,
        "utf-8",
        crs,
        "GPKG",
        layerOptions=[f"layername={s_layer_name}", "append=yes"]
    )
    if isinstance(res2, tuple):
        err_code2, err_msg2 = res2
    else:
        err_code2, err_msg2 = res2, ""
    if err_code2 != QgsVectorFileWriter.NoError:
        logger.error("Failed writing sensors to gpkg: code %s msg='%s'", err_code2, err_msg2)
        raise Exception(f"Failed writing sensors to geopackage (code {err_code2}) - {err_msg2}")

    logger.info("Results saved to %s (layers: %s, %s)", out_gpkg_path, gw_layer_name, s_layer_name)

    if open_in_qgis and iface is not None:
        try:
            ds_gw = f"{out_gpkg_path}|layername={gw_layer_name}"
            ds_s = f"{out_gpkg_path}|layername={s_layer_name}"
            gw_vl = QgsVectorLayer(ds_gw, "QNetPlanner - Selected Gateways", "ogr")
            s_vl = QgsVectorLayer(ds_s, "QNetPlanner - Selected Sensors", "ogr")
            if gw_vl.isValid():
                QgsProject.instance().addMapLayer(gw_vl)
            if s_vl.isValid():
                QgsProject.instance().addMapLayer(s_vl)
        except Exception as e:
            logger.exception("Failed to open result layers in QGIS: %s", e)

    return out_gpkg_path
