# optimization.py
"""
==========QNetPlanner - optimization module===========

Provides:
 - generalized_optimizer(...)  : main ILP entry (auto-detect solver)
 - greedy_fallback(...)        : fallback greedy solver

 Gitlab:
    https://gitlab.com/binoy194/QNetPlanner
email:
    binoy194@gmail.com
    kavyask304@gmail.com

Authors:
    Binoy C
    Kavya S K
"""

import logging
import shutil
from typing import Dict, List, Optional, Any
from .logger import logger

#logger = logging.getLogger("QNetPlanner.optimization")

# PuLP imports (required)
try:
    import pulp
    from pulp import LpProblem, LpMinimize, LpVariable, lpSum, LpBinary, PULP_CBC_CMD, GLPK_CMD
except Exception as e:
    logger.error("PuLP import failed: %s", e)
    raise

# -------------------------
# Solver selection helpers
# -------------------------
def _auto_solver(msg: bool = False):
    """
    Return a pulp solver instance preferring a system CBC or GLPK if available.
    Falls back to PULP_CBC_CMD() otherwise (may still fail if PuLP packaged cbc broken).
    """
    # Prefer system CBC (coinor-cbc)
    cbc_path = shutil.which("cbc")
    if cbc_path:
        try:
            logger.info("Using system CBC solver at: %s", cbc_path)
            return PULP_CBC_CMD(path=cbc_path, msg=msg)
        except Exception as e:
            logger.warning("Failed to use CBC at %s: %s", cbc_path, e)

    # Try GLPK
    glp_path = shutil.which("glpsol")
    if glp_path:
        try:
            logger.info("Using system GLPK solver at: %s", glp_path)
            return GLPK_CMD(path=glp_path, msg=msg)
        except Exception as e:
            logger.warning("Failed to use GLPK at %s: %s", glp_path, e)

    # Final fallback to default PULP_CBC_CMD (may be packaged binary)
    logger.info("Falling back to default PULP_CBC_CMD()")
    return PULP_CBC_CMD(msg=msg)

def _build_solver(solver_preference: Optional[Any], msg: bool = False):
    """
    Build a pulp solver instance from solver_preference:
      - None -> use _auto_solver()
      - tuple("CBC", path) or string path -> use that path
    """
    if solver_preference:
        # tuple (name, path) or just path string
        if isinstance(solver_preference, tuple) and len(solver_preference) >= 2:
            _, path = solver_preference[0], solver_preference[1]
            try:
                return PULP_CBC_CMD(path=path, msg=msg)
            except Exception as e:
                logger.warning("Failed to use solver at %s: %s", path, e)
        elif isinstance(solver_preference, str):
            try:
                return PULP_CBC_CMD(path=solver_preference, msg=msg)
            except Exception as e:
                logger.warning("Failed to use solver at %s: %s", solver_preference, e)

    return _auto_solver(msg=msg)

# -------------------------
# Greedy fallback solver
# -------------------------
def greedy_fallback(
    coverage_map: Dict[int, Dict[str, List[int]]],
    gateway_costs: Dict[int, float],
    sensor_costs: Dict[str, Dict[int, float]],
    select_counts: Optional[Dict[str, Optional[int]]] = None
) -> Dict:
    """
    Simple greedy set-cover style fallback.
    - coverage_map: {gw: {stype: [sensor_gid,...], ...}, ...}
    - gateway_costs: {gw: cost}
    - sensor_costs: {stype: {gid: cost}}
    - select_counts: {stype: required_int or None}
    Returns a simple result dict similar to generalized_optimizer.
    """
    logger.warning("Entering greedy fallback optimizer")


    select_counts = select_counts or {}
    logger.info(f"count_sensor:{select_counts}")

    
    # Build sensors_by_type and initialize selected sets
    sensors_by_type = {}
    for stype, mapping in sensor_costs.items():
        sensors_by_type[stype] = sorted(mapping.keys(), key=lambda g: mapping[g])

    selected_gateways = set()
    selected_sensors = {stype: set() for stype in sensors_by_type.keys()}

    # Build gw -> covered sensor ids (all types combined)
    gw_covers_all = {}
    for gw, stmap in coverage_map.items():
        covers = set()
        for stype, gids in stmap.items():
            covers.update(gids)
        gw_covers_all[gw] = covers

    # Helper: which types still unmet (numeric requirements)
    def unmet_types():
        res = []
        for stype, required in select_counts.items():
            if required is None:
                continue
            if len(selected_sensors.get(stype, set())) < required:
                res.append(stype)
        return res

    # Greedy loop: pick gateway maximizing new_required_covered / cost
    max_iters = max(1, len(gw_covers_all) * 2)
    it = 0
    while True:
        it += 1
        unmet = unmet_types()
        if not unmet or it > max_iters:
            break

        best_gw = None
        best_score = 0.0
        best_new_by_type = {}

        for gw, covers in gw_covers_all.items():
            if gw in selected_gateways:
                continue
            # compute how many *still needed* sensors this gateway covers per type
            new_by_type = {}
            new_total_needed = 0
            for stype in unmet:
                # sensors of this type that gw covers and not yet selected
                gw_type_gids = set(coverage_map.get(gw, {}).get(stype, []))
                still_needed = set(gw_type_gids) - selected_sensors.get(stype, set())
                # but also only count up to required
                required = select_counts.get(stype, 0)
                already = len(selected_sensors.get(stype, set()))
                allowed = max(0, required - already)
                count_add = min(len(still_needed), allowed)
                if count_add > 0:
                    new_by_type[stype] = still_needed
                    new_total_needed += count_add

            if new_total_needed <= 0:
                continue

            cost = float(gateway_costs.get(gw, 1.0) or 1.0)
            score = new_total_needed / (cost + 1e-6)
            if score > best_score:
                best_score = score
                best_gw = gw
                best_new_by_type = new_by_type

        if best_gw is None:
            # can't cover remaining requirements
            logger.warning("Greedy fallback cannot cover remaining requirements, exiting loop")
            break

        # select gateway
        selected_gateways.add(best_gw)

        # choose sensors from best_new_by_type prioritizing cheapest
        for stype, gidset in best_new_by_type.items():
            required = select_counts.get(stype, 0)
            already = len(selected_sensors.get(stype, set()))
            need = max(0, required - already)
            if need <= 0:
                continue
            # sort candidates by sensor cost
            candidates = sorted(list(gidset), key=lambda g: sensor_costs.get(stype, {}).get(g, 999.0))
            for gid in candidates[:need]:
                selected_sensors[stype].add(gid)

    # Final minimal post-processing: if some types still unmet, try to fulfill by selecting additional gateways
    for stype, required in select_counts.items():
        if required is None:
            continue
        while len(selected_sensors.get(stype, set())) < required:
            # find gateway that covers most remaining sensors of this type (cheapest gateway priority)
            best_gw = None
            best_cover = 0
            for gw, stmap in coverage_map.items():
                if gw in selected_gateways:
                    continue
                covers = set(stmap.get(stype, [])) - selected_sensors.get(stype, set())
                cover_count = len(covers)
                if cover_count <= 0:
                    continue
                # prefer cheaper gateway if equal coverage
                cost = float(gateway_costs.get(gw, 1.0) or 1.0)
                score = cover_count / (cost + 1e-6)
                if score > best_cover:
                    best_cover = score
                    best_gw = gw
            if best_gw is None:
                break
            selected_gateways.add(best_gw)
            # add sensors
            covers = set(coverage_map[best_gw].get(stype, [])) - selected_sensors.get(stype, set())
            candidates = sorted(list(covers), key=lambda g: sensor_costs.get(stype, {}).get(g, 999.0))
            need = required - len(selected_sensors.get(stype, set()))
            for gid in candidates[:need]:
                selected_sensors[stype].add(gid)

    # convert sets to sorted lists for stable output
    selected_gateways_list = sorted(list(selected_gateways))
    selected_sensors_dict = {stype: sorted(list(gset)) for stype, gset in selected_sensors.items()}

    logger.info("Greedy fallback result: gateways=%s sensors=%s", selected_gateways_list, selected_sensors_dict)

    return {
        "status": "Greedy",
        "pulp_status": None,
        "objective": None,
        "selected_gateways": selected_gateways_list,
        "selected_sensors": selected_sensors_dict
    }

# -------------------------
# Main ILP optimizer
# -------------------------
def generalized_optimizer(
    coverage_map: Dict[int, Dict[str, List[int]]],
    gateway_costs: Dict[int, float],
    sensor_costs: Dict[str, Dict[int, float]],
    select_counts: Optional[Dict[str, Optional[int]]] = None,
    require_at_least_one_per_type: bool = False,
    solver_preference: Optional[Any] = None,
    solver_msg: bool = False
) -> Dict:
    """
    Generalized ILP optimizer (tries system solver then greedy fallback on failure).

    Returns dict with keys:
      - status (string: "Optimal", "Infeasible", "Greedy", "SolverFailure", etc.)
      - pulp_status: raw pulp status if available
      - objective: objective value (if available)
      - selected_gateways: [gw_id,...]
      - selected_sensors: {stype: [gid,...], ...}
    """
    logger.info("Starting generalized_optimizer")


    # Prepare flattened sensor list and sensor->type map
    sensor_global_ids = set()
    for stype, mapping in sensor_costs.items():
        sensor_global_ids.update(mapping.keys())
    sensor_global_ids = sorted(sensor_global_ids)
    logger.info(f"sensor global ids:{sensor_global_ids}")

    sensor_gid_to_type = {}
    for stype, mapping in sensor_costs.items():
        for gid in mapping.keys():
            sensor_gid_to_type[gid] = stype
    logger.info(f"sensor type: {sensor_gid_to_type}")

    
    # Build sensor -> covering gateways
    sensor_to_gateways: Dict[int, List[int]] = {gid: [] for gid in sensor_global_ids}
    for gw, stmap in coverage_map.items():
        for stype, gid_list in stmap.items():
            for gid in gid_list:
                if gid in sensor_to_gateways:
                    sensor_to_gateways[gid].append(gw)

    # Build ILP model
    model = LpProblem("QNetPlanner_Generalized", LpMinimize)

    # Gateway vars
    x_vars = {gw: LpVariable(f"x_gw_{gw}", cat=LpBinary) for gw in gateway_costs.keys()}

    # Sensor vars
    y_vars = {gid: LpVariable(f"y_s_{gid}", cat=LpBinary) for gid in sensor_global_ids}

    # Objective: minimize gateway costs + sensor costs
    obj_terms = []
    for gw, var in x_vars.items():
        cost = float(gateway_costs.get(gw, 0.0) or 0.0)
        obj_terms.append(cost * var)
    for stype, mapping in sensor_costs.items():
        for gid, cost in mapping.items():
            if gid in y_vars:
                obj_terms.append(float(cost or 0.0) * y_vars[gid])
    model += lpSum(obj_terms), "TotalCost"

    # Constraints: sensor selection only if a covering gateway is selected
    for gid, y in y_vars.items():
        coverings = sensor_to_gateways.get(gid, [])
        if coverings:
            model += y <= lpSum([x_vars[gw] for gw in coverings]), f"cover_sensor_{gid}"
        else:
            # no gateway can cover this sensor -> cannot select it
            model += y == 0, f"nocover_sensor_{gid}"

    # Per-type selection constraints
    select_counts = select_counts or {}
    logger.info(f"counts dict:{select_counts}")
    
    sensors_by_type: Dict[str, List[int]] = {}
    for gid, stype in sensor_gid_to_type.items():
        sensors_by_type.setdefault(stype, []).append(gid)

    for stype, gids in sensors_by_type.items():
        required = select_counts.get(stype, None)
        if required is not None:
            model += lpSum([y_vars[gid] for gid in gids if gid in y_vars]) == int(required), f"select_count_{stype}"
        else:
            if require_at_least_one_per_type:
                model += lpSum([y_vars[gid] for gid in gids if gid in y_vars]) >= 1, f"at_least_one_{stype}"

    # If user specified numeric requirements, require at least one gateway (avoid trivial zero selection)
    if any((v is not None and int(v) > 0) for v in select_counts.values()):
        model += lpSum([x_vars[gw] for gw in x_vars.keys()]) >= 1, "at_least_one_gateway"

    # Solve with auto-detected solver (or provided preference)
    try:
        solver = _build_solver(solver_preference, msg=solver_msg)
        status_code = model.solve(solver)
    except Exception as e:
        logger.exception("Solver invocation failed: %s", e)
        # Attempt one more time with _auto_solver
        try:
            solver = _auto_solver(msg=solver_msg)
            status_code = model.solve(solver)
        except Exception as e2:
            logger.exception("Auto solver invocation failed: %s", e2)
            # fallback to greedy
            return greedy_fallback(coverage_map, gateway_costs, sensor_costs, select_counts)

    # Interpret pulp status
    pulp_status = getattr(pulp, "LpStatus", None)
    status_str = pulp_status[model.status] if pulp_status and model.status in pulp_status else str(model.status)
    logger.info("Solver finished with status: %s", status_str)

    # If solver didn't produce optimal/feasible results, use greedy fallback
    try:
        if status_str not in ("Optimal", "Feasible"):
            logger.warning("Solver status not optimal/feasible (%s) — using greedy fallback", status_str)
            return greedy_fallback(coverage_map, gateway_costs, sensor_costs, select_counts)
    except Exception:
        # defensive: if anything odd, fallback
        logger.exception("Error interpreting solver status; using greedy fallback")
        return greedy_fallback(coverage_map, gateway_costs, sensor_costs, select_counts)

    # Collect solution
    selected_gateways = [gw for gw, var in x_vars.items() if var.value() == 1]
    selected_sensors_by_type: Dict[str, List[int]] = {stype: [] for stype in sensors_by_type.keys()}
    for gid, var in y_vars.items():
        if var.value() == 1:
            stype = sensor_gid_to_type.get(gid, "unknown")
            selected_sensors_by_type.setdefault(stype, []).append(gid)

    # Objective value
    try:
        objective_val = float(pulp.value(model.objective))
    except Exception:
        objective_val = None

    result = {
        "status": status_str,
        "pulp_status": model.status,
        "objective": objective_val,
        "selected_gateways": sorted(selected_gateways),
        "selected_sensors": {k: sorted(v) for k, v in selected_sensors_by_type.items()}
    }

    logger.info("Optimization result: %s", result)
    return result
