'''
==========QNetPlanner - Main Plugin Controller===========

Gitlab:
    https://gitlab.com/binoy194/QNetPlanner
email:
    binoy194@gmail.com
    kavyask304@gmail.com

Authors:
    Binoy C
    Kavya S K

'''
import os
import traceback
from qgis.PyQt import QtWidgets
from qgis.PyQt.QtWidgets import QAction, QMessageBox
from qgis.PyQt.QtGui import QIcon
from .QNetPlanner_dialog import QNetPlannerDialog
from .processing_utils import run_viewshed_analysis, reproject_if_needed
from .costs import compute_all_costs_pipeline
from .optimization import generalized_optimizer
from .logger import logger, init_logger
from qgis.core import QgsRasterLayer

from .io_utils import save_results_geojson
import shutil
from qgis.core import QgsRasterLayer, QgsProject
import logging
def export_selected_viewsheds(
    viewshed_paths: dict,
    selected_gateways: list,
    output_dir: str,
    iface=None
):
    
    #Export viewshed rasters for selected gateways.
    vs_dir = os.path.join(output_dir, "viewsheds")
    os.makedirs(vs_dir, exist_ok=True)

    exported = {}

    for gw_id in selected_gateways:
        src = viewshed_paths.get(gw_id)

        if not src or not os.path.exists(src):
            logger.warning("Viewshed missing for gateway %s", gw_id)
            continue

        dst = os.path.join(vs_dir, f"viewshed_gateway_{gw_id}.tif")
        shutil.copy(src, dst)
        exported[gw_id] = dst

        # Load into QGIS
        rl = QgsRasterLayer(dst, f"Viewshed_Gateway_{gw_id}")
        if rl.isValid():
            QgsProject.instance().addMapLayer(rl)

    logger.info("Exported %d viewsheds", len(exported))
    return exported

class QNetPlanner:
    def __init__(self, iface):
        self.iface = iface
        self.plugin_dir = os.path.dirname(__file__)        
        self.action = None
        self.dlg = None        

    def initGui(self):
        try:
            init_logger()

            logger.info("""
\\]]==================================================================[[//
\\]]   Q N E T P L A N N E R
\\]]   QGIS Network Planning & Optimization
\\]]==================================================================[[//
""")
            icon_path = os.path.join(self.plugin_dir, "icon.png")
            if not os.path.exists(icon_path):
                logger.warning("Icon file not found at %s", icon_path)
            self.action = QAction(QIcon(icon_path), "QNetPlanner", self.iface.mainWindow())
            self.action.triggered.connect(self.run)
            self.iface.addToolBarIcon(self.action)
            self.iface.addPluginToMenu("&QNetPlanner", self.action)
            
            logger.info("Added QNetPlanner to QGIS GUI")
        except Exception:
            logger.exception("Failed to initialize GUI")
    
    def unload(self):
        try:
            if self.action:
                self.iface.removeToolBarIcon(self.action)
                self.iface.removePluginMenu("&QNetPlanner", self.action)       
        except Exception:
            logger.exception("Error unloading QNetPlanner plugin")

    def run(self):
        logger.info("""
        ============== QNetPlanner plugin run started ================
        """)
        
        if self.dlg is None:
            self.dlg = QNetPlannerDialog(
                parent=self.iface.mainWindow(),
                plugin=self
            )

        self.dlg.show()
        self.dlg.raise_()
        self.dlg.activateWindow()        
   
    def execute(self, dlg, parent=None):
        # Gather inputs from dialog
        gateway_file = None
        dem_file = None
        gateway_layer = None
        dem_layer = None

        try:
            gateway_layer = dlg.get_selected_gateway_layer()
        except Exception:
            logger.debug("gatewayPath not available in UI")

        try:
            dem_layer = dlg.get_selected_dem_layer()
        except Exception:
            logger.debug("demPath not available in UI")


        sensor_files_dict, select_counts = dlg.get_sensor_files_dict_and_counts()
        
        # Validate essential inputs
        if gateway_layer is None:
            QMessageBox.warning(None, "Missing input", "Please select a gateway layer.")
            return

        if dem_layer is None:
            QMessageBox.warning(None, "Missing input", "Please select a DEM layer.")
            return

        if not sensor_files_dict or all(len(v) == 0 for v in sensor_files_dict.values()):
            QMessageBox.warning(None, "Missing input", "Please add at least one sensor type and attach file(s).")
            logger.warning("No sensor files provided")
            return

        gateway_file = gateway_layer.source()
        dem_file = dem_layer.source()

        logger.info(
            "Dialog inputs: gateway=%s dem=%s sensor_types=%s",
            gateway_file, dem_file, list(sensor_files_dict.keys())
        )
        
        dem = dem_layer
        if not dem.isValid():
            QMessageBox.critical(None, "Invalid DEM", f"Could not load DEM: {dem_file}")
            logger.error("DEM failed to load: %s", dem_file)
            return
         
        if dem.crs().isGeographic():
            QMessageBox.critical(parent,
            "Invalid DEM CRS",
            "The selected DEM uses a geographic CRS (degrees).\n\n"
            "Viewshed analysis requires a projected CRS (meters).\n\n"
            "Please reproject the DEM before proceeding.")
            return
        
        # Run viewshed analysis
        logger.info("Starting viewshed analysis...")
        try:
            (
                coverage_map,
                reverse_sensor_map,
                gateway_positions,
                sensor_positions,
                viewshed_paths,
                gateway_attr_costs,
                sensor_attr_costs
            ) = run_viewshed_analysis(
                    gateway_file,
                    sensor_files_dict,
                    dem_file
            )

        except Exception as e:
            logger.exception("Viewshed Analysis Failed: %s", e)
            QMessageBox.critical(None, "Viewshed Analysis Failed", f"Viewshed analysis failed:\n{e}")
            return

        logger.info("Viewshed analysis completed.")
        logger.debug("Coverage map: %s", coverage_map)

        # Compute costs (gateway and sensors) using costs pipeline
        try:
            costs_out = compute_all_costs_pipeline(
                dem_path=dem_file,
                gateway_positions=gateway_positions,
                sensor_positions=sensor_positions,
                viewshed_paths=viewshed_paths,
                coverage_map=coverage_map,
                gateway_attr_costs=gateway_attr_costs,
                sensor_attr_costs=sensor_attr_costs,
                gateway_weights=None,
                sensor_weights=None,
                gateway_selection_for_sensor_distances=list(gateway_positions.keys()),
                reverse_sensor_map=reverse_sensor_map
            )
        except Exception as e:
            logger.exception("Cost computation failed: %s", e)
            QMessageBox.critical(None, "Cost Computation Failed", f"Cost computation failed:\n{e}")
            return

        gateway_costs = costs_out.get("gateway_costs", {})
        sensor_costs_flat = costs_out.get("sensor_costs", {})

        logger.info("Costs computed. Gateways: %s", gateway_costs.keys())
        logger.debug("Sensor flat costs sample: %s", dict(list(sensor_costs_flat.items())[:10]))

        # Reshape sensor costs into per-type mapping required by optimizer
        sensor_costs_by_type = {}
        for gid, cost in sensor_costs_flat.items():
            stype, orig_fid = reverse_sensor_map.get(gid, (None, None))
            if stype is None:
                stype = "unknown"
            sensor_costs_by_type.setdefault(stype, {})[gid] = cost

        # Prepare select_counts for optimizer (stype -> int or None)
        select_counts_for_optimizer = select_counts or {}

        logger.info("Starting optimization...")
        try:
            result = generalized_optimizer(
                coverage_map=coverage_map,
                gateway_costs=gateway_costs,
                sensor_costs=sensor_costs_by_type,
                select_counts=select_counts_for_optimizer,
                require_at_least_one_per_type=False,
                solver_preference=None,
                solver_msg=False
            )
        except Exception as e:
            logger.exception("Optimization failed: %s", e)
            QMessageBox.critical(None, "Optimization Failed", f"Optimization failed:\n{e}")
            return

        # Present results
        status = result.get("status", "Unknown")
        selected_gateways = result.get("selected_gateways", [])
        selected_sensors = result.get("selected_sensors", {})

        if not selected_gateways:
            QMessageBox.information(
                None,
                "QNetPlanner",
                "Optimization completed, but no gateways were selected."
            )
            return
        
        # Ask user where to save results (GeoJSON single file)
        out_path, _ = QtWidgets.QFileDialog.getSaveFileName(
            None,
            "Save QNetPlanner result (GeoJSON)",
            "",
            "GeoJSON (*.geojson *.json)"
        )
        if not out_path:
            logger.info("User cancelled save dialog")
            return

        # Export viewsheds for SELECTED gateways 
        out_dir = os.path.dirname(out_path)

        try:
            exported_viewsheds = export_selected_viewsheds(
                viewshed_paths=viewshed_paths,
                selected_gateways=selected_gateways,
                output_dir=out_dir,
                iface=self.iface
            )
            QMessageBox.information(
                None,
                "QNetPlanner",
                f"Exported {len(exported_viewsheds)} viewshed rasters to:\n{os.path.join(out_dir, 'viewsheds')}"
            )
        except Exception as e:
            logger.exception("Failed to export viewsheds: %s", e)
            QMessageBox.warning(
                None,
                "Viewshed Export Failed",
                f"Viewsheds could not be exported:\n{e}"
            )

        # Build assigned_map
        assigned_map = {}
        # coverage_map: {gw_id: {stype: [sensor_gid,...], ...}, ...}
        for gw_id, stype_map in (coverage_map or {}).items():
            for stype, gids in stype_map.items():
                for gid in gids:
                    # ensure gid is int-keyed
                    try:
                        gid_int = int(gid)
                    except Exception:
                        gid_int = gid
                    assigned_map.setdefault(gid_int, set()).add(int(gw_id))

        # convert sets to sorted lists (writer expects lists/strings)
        for gid, s in list(assigned_map.items()):
            assigned_map[gid] = sorted(list(s))


        # Normalize final cost dicts
        # sensor_costs_flat: {gid: cost}
        final_sensor_costs = {}
        if sensor_costs_flat:
            for k, v in sensor_costs_flat.items():
                try:
                    final_sensor_costs[int(k)] = float(v) if v is not None else None
                except Exception:
                    final_sensor_costs[k] = v

        # gateway_costs: {gw_id: cost} (or list). Normalize to dict int->float
        final_gateway_costs = {}
        if isinstance(gateway_costs, dict):
            for k, v in gateway_costs.items():
                try:
                    final_gateway_costs[int(k)] = float(v) if v is not None else None
                except Exception:
                    final_gateway_costs[k] = v
        elif isinstance(gateway_costs, (list, tuple)):
            # map list entries to available gateway_positions order
            gw_ids = list(gateway_positions.keys())
            for idx, gw_id in enumerate(gw_ids):
                if idx < len(gateway_costs):
                    try:
                        final_gateway_costs[int(gw_id)] = float(gateway_costs[idx])
                    except Exception:
                        final_gateway_costs[int(gw_id)] = gateway_costs[idx]
                else:
                    final_gateway_costs[int(gw_id)] = None
        else:
            # unknown structure: try scalar
            try:
                scalar = float(gateway_costs)
                for gw_id in gateway_positions.keys():
                    final_gateway_costs[int(gw_id)] = scalar
            except Exception:
                # leave empty mapping
                for gw_id in gateway_positions.keys():
                    final_gateway_costs[int(gw_id)] = None

        # Debugging: inspect keys and missing matches
        logger.debug("DEBUG: sample sensor_positions keys (first 10): %s", list(sensor_positions.keys())[:10])
        logger.debug("DEBUG: sample final_sensor_costs keys (first 20): %s", list(final_sensor_costs.keys())[:20])
        logger.debug("DEBUG: sample final_gateway_costs keys (first 20): %s", list(final_gateway_costs.keys())[:20])

        # check key type mismatches (sensors)
        sp_keys = set()
        try:
            sp_keys = set(int(k) for k in sensor_positions.keys())
        except Exception:
            sp_keys = set(sensor_positions.keys())
        fc_keys = set()
        try:
            fc_keys = set(int(k) for k in final_sensor_costs.keys())
        except Exception:
            fc_keys = set(final_sensor_costs.keys())
        missing = sp_keys - fc_keys
        logger.debug("DEBUG: keys in sensor_positions not in final_sensor_costs (sample 10): %s", list(missing)[:10])

        # Save results to GeoJSON
        try:
            saved = save_results_geojson(
                out_geojson_path=out_path,
                gateway_positions=gateway_positions,
                gateway_attr_costs=gateway_attr_costs,
                selected_gateways=selected_gateways,
                sensor_positions=sensor_positions,
                sensor_attr_costs=sensor_attr_costs,
                selected_sensors=selected_sensors,
                reverse_sensor_map=reverse_sensor_map,
                coverage_map=coverage_map,
                final_gateway_costs=final_gateway_costs,
                final_sensor_costs=final_sensor_costs,
                assigned_map=assigned_map,
                crs=dem.crs(),
                open_in_qgis=True,
                iface=self.iface
            )
            QMessageBox.information(None, "QNetPlanner", f"Results saved to:\n{saved}")
            logger.info("Results saved to %s", saved)
            logger.info("""
            ================ C O M P L E T E D ================
            """)
        except Exception as e:
            logger.exception("Failed to save results geojson: %s", e)
            QMessageBox.critical(None, "Save Failed", f"Failed to save results:\n{e}")
            return
