"""
==========QNetPlanner - Cost calculation module===========

Compute cost scores (0..10) for gateways and sensors.

Functions:
    - sample_dem_elevations(dem_path, points)
    - compute_viewshed_area(vs_path)
    - gateway_base_metrics(...)
    - sensor_base_metrics(...)
    - compute_gateway_costs(...)
    - compute_sensor_costs_for_selection(...)
    - compute_all_costs_pipeline(...)

Notes:
    - This implementation uses GDAL (osgeo.gdal) for raster reading.
    - In QGIS environment you may replace sampling with QgsRasterLayer sampling if preferred.
Gitlab:
    https://gitlab.com/binoy194/QNetPlanner
email:
    binoy194@gmail.com
    kavyask304@gmail.com

Authors:
    Binoy C
    Kavya S K
"""

from osgeo import gdal
import numpy as np
import os
from math import sqrt
from typing import Dict, Tuple, List, Optional
from .logger import logger


# ----------------------------
# Low-level raster helpers
# ----------------------------
def _open_raster(dem_path):
    ds = gdal.Open(dem_path)
    if ds is None:
        raise RuntimeError(f"Could not open raster: {dem_path}")
    gt = ds.GetGeoTransform()
    band = ds.GetRasterBand(1)
    nodata = band.GetNoDataValue()
    return ds, gt, band, nodata

def sample_dem_elevations(dem_path: str, points: List[Tuple[float, float]]):
    """
    Sample DEM at given points (x,y). Returns dict index -> elevation (or None).
    Points must be in same CRS as DEM.
    """
    ds, gt, band, nodata = _open_raster(dem_path)
    xsize = ds.RasterXSize
    ysize = ds.RasterYSize

    def world2pixel(x, y):
        px = int((x - gt[0]) / gt[1])
        py = int((y - gt[3]) / gt[5])
        return px, py

    elevs = {}
    for i, (x, y) in enumerate(points):
        px, py = world2pixel(x, y)
        if px < 0 or py < 0 or px >= xsize or py >= ysize:
            elevs[i] = None
            continue
        arr = band.ReadAsArray(px, py, 1, 1)
        if arr is None:
            elevs[i] = None
            continue
        val = float(arr[0, 0])
        if nodata is not None and val == nodata:
            elevs[i] = None
        else:
            elevs[i] = val
    ds = None

    return elevs

def compute_viewshed_area(vs_path: str):
    """
    Compute visible pixel count and visible area for a viewshed raster.
    Treats any non-zero (and non-nodata) pixel as visible.
    Returns (area_map_units, visible_pixel_count).
    """
    ds = gdal.Open(vs_path)
    if ds is None:
        raise RuntimeError(f"Could not open viewshed raster: {vs_path}")
    gt = ds.GetGeoTransform()
    band = ds.GetRasterBand(1)
    arr = band.ReadAsArray()
    nodata = band.GetNoDataValue()
    mask = (arr != 0)
    if nodata is not None:
        mask = mask & (arr != nodata)
    visible_count = int(np.count_nonzero(mask))
    pixel_area = abs(gt[1] * gt[5])  # note gt[5] is negative for north-up rasters
    area = visible_count * pixel_area
    ds = None
    return area, visible_count

# ----------------------------
# Normalization & scaling
# ----------------------------
def _normalize_dict(values_dict: Dict, eps: float = 1e-9):
    """
    Normalize numeric values in dict to 0..1; None values remain None.
    Min->0, Max->1. If all equal, return 0.5 for those values.
    """
    clean = {k: v for k, v in values_dict.items() if v is not None}
    if not clean:
        return {k: None for k in values_dict.keys()}
    vals = np.array(list(clean.values()), dtype=float)
    vmin = float(vals.min())
    vmax = float(vals.max())
    rng = vmax - vmin
    out = {}
    for k in values_dict.keys():
        v = values_dict[k]
        if v is None:
            out[k] = None
        else:
            if rng < eps:
                out[k] = 0.5
            else:
                out[k] = float((v - vmin) / rng)
    return out

def scale_to_0_10(norm_map: Dict, invert: bool = False, min_cost:float=1e-6):
    """
    Convert normalized 0..1 map to 0..10 scale.
    If invert=True, higher normalized value -> LOWER cost (cost = 10*(1 - norm)).
    """
    out = {}
    for k, v in norm_map.items():
        if v is None:
            out[k] = None
        else:
            val = float(v)
            score = 10.0 * (1.0 - val) if invert else 10.0 * val
            if score <= 0.0:
                score =float(min_cost)
            
            out[k] = float(round(score, 6))
            
    return out

# ----------------------------
# Metric aggregation functions
# ----------------------------
def gateway_base_metrics(
    gateway_positions: Dict[int, Tuple[float, float]],
    dem_path: Optional[str] = None,
    viewshed_paths: Optional[Dict[int, str]] = None,
    coverage_map: Optional[Dict[int, Dict]] = None
):
    """
    Calculate base metrics per gateway.
    gateway_positions: {gw_id: (x,y)}
    dem_path: path to DEM for elevation sampling
    viewshed_paths: optional {gw_id: path_to_viewshed.tif}
    coverage_map: optional coverage_map to compute visible_sensor_count
    Returns: dict gw_id -> metrics dict
    metrics keys: elevation, viewshed_area, visible_pixel_count, visible_sensor_count
    """
    metrics = {}
    gw_ids = sorted(gateway_positions.keys())
    if dem_path:
        pts = [gateway_positions[g] for g in gw_ids]
        elevs = sample_dem_elevations(dem_path, pts)
        elev_map = {gw_ids[i]: elevs[i] for i in range(len(gw_ids))}

    else:
        elev_map = {g: None for g in gw_ids}

    for gw in gw_ids:
        elevation = elev_map.get(gw)
        area = 0.0
        pix_count = 0
        if viewshed_paths and gw in viewshed_paths and viewshed_paths[gw] and os.path.exists(viewshed_paths[gw]):
            try:
                area, pix_count = compute_viewshed_area(viewshed_paths[gw])
            except Exception:
                area, pix_count = 0.0, 0
        vis_count = 0
        if coverage_map and gw in coverage_map:
            for stype, lst in coverage_map[gw].items():
                vis_count += len(lst)
        metrics[gw] = {
            'elevation': elevation,
            'viewshed_area': area,
            'visible_pixel_count': pix_count,
            'visible_sensor_count': vis_count
        }
    
    return metrics

def sensor_base_metrics(
    sensor_positions: Dict[int, Tuple[float, float]],
    dem_path: Optional[str] = None,
    gateway_positions: Optional[Dict[int, Tuple[float, float]]] = None,
    coverage_map: Optional[Dict[int, Dict]] = None
):
    """
    sensor_positions: {global_sensor_id: (x,y)}
    gateway_positions: optional {gw_id: (x,y)} to compute distances
    Returns: dict sensor_id -> { elevation, distances(dict gw->dist), min_distance_all_gws }
    """
    metrics = {}
    sids = sorted(sensor_positions.keys())
    if dem_path:
        pts = [sensor_positions[s] for s in sids]
        elevs = sample_dem_elevations(dem_path, pts)
        elev_map = {sids[i]: elevs[i] for i in range(len(sids))}
    else:
        elev_map = {s: None for s in sids}


    #gateway count   
    sensor_covering_gws = {}
    for gw, stype_map in (coverage_map or {}).items():
        for stype, gids in stype_map.items():
            for gid in gids:
                sensor_covering_gws[gid] = sensor_covering_gws.get(gid, 0) + 1

                
    for s in sids:
        elevation = elev_map.get(s)
        '''
        distances = {}
        min_d = None
        if gateway_positions:
            sx, sy = sensor_positions[s]
            for gw, (gx, gy) in gateway_positions.items():
                d = sqrt((sx - gx) ** 2 + (sy - gy) ** 2)
                distances[gw] = d
            if distances:
                finite_vals = [v for v in distances.values() if v is not None]
                min_d = min(finite_vals) if finite_vals else None
            '''
      
         
        metrics[s] = {
            'elevation': elevation,
            'sensor_covering_gws': sensor_covering_gws.get(s, 0)
        }
    
    
    
    return metrics

# ----------------------------
# Cost calculators
# ----------------------------
def compute_gateway_costs(metrics: Dict[int, Dict], weights: Optional[Dict[str, float]] = None):
    """
    Combine gateway metrics into a 0..10 cost.
    weights defaults: elevation 0.4, viewshed_area 0.4, visible_sensor_count 0.2
    Interpreation:
      - elevation: higher -> lower cost (invert)
      - viewshed_area: larger -> lower cost (invert)
      - visible_sensor_count: larger -> lower cost (invert)
    """
    if weights is None:
        weights = {'elevation': 0.2, 'viewshed_area': 0.4, 'visible_sensor_count': 0.4}

    elev_map = {gw: metrics[gw].get('elevation') for gw in metrics}
    area_map = {gw: metrics[gw].get('viewshed_area') for gw in metrics}
    count_map = {gw: metrics[gw].get('visible_sensor_count') for gw in metrics}

    nelev = _normalize_dict(elev_map)
    narea = _normalize_dict(area_map)
    ncount = _normalize_dict(count_map)

    celev = scale_to_0_10(nelev, invert=True)
    carea = scale_to_0_10(narea, invert=True)
    ccount = scale_to_0_10(ncount, invert=True)

    logger.info(f"gateway normalised elevation:{nelev}")
    logger.info(f"gateway normalised viewshed area:{narea}")
    logger.info(f"gateway normalised sensor_count:{ncount}")

    logger.info(f"gateway scaled to 1-10 elevation:{celev}")
    logger.info(f"gateway scaled to 1-10 viewshed area:{carea}")
    logger.info(f"gateway scaled to 1-10 sensor_count:{ccount}")

    gw_costs = {}
    for gw in metrics.keys():
        parts = []
        wsum = 0.0
        if 'elevation' in weights and celev.get(gw) is not None:
            parts.append((celev[gw], weights['elevation'])); wsum += weights['elevation']
        if 'viewshed_area' in weights and carea.get(gw) is not None:
            parts.append((carea[gw], weights['viewshed_area'])); wsum += weights['viewshed_area']
        if 'visible_sensor_count' in weights and ccount.get(gw) is not None:
            parts.append((ccount[gw], weights['visible_sensor_count'])); wsum += weights['visible_sensor_count']
        if wsum <= 0:
            gw_costs[gw] = None
            continue
        combined = sum(p * (w / wsum) for p, w in parts)
        gw_costs[gw] = float(round(combined, 6))
    
    logger.info(f"final gateway costs:{gw_costs}")
    return gw_costs

def compute_sensor_costs_for_selection(
    sensor_metrics: Dict[int, Dict],
    gateway_selection: Optional[List[int]],
    weights: Optional[Dict[str, float]] = None
):
    """
    Combine sensor factors into 0..10 cost given gateway_selection (list of gateway ids).
    weights default: {'elevation':0.3, 'distance':0.7}
    For distance: nearer -> lower cost (invert=True). We invert after normalization.
    """
    if weights is None:
        weights = {'elevation': 0.3, 'gw_num':0.7}

    # build elevation and min-distance maps
    elev_map = {sid: info.get('elevation') for sid, info in sensor_metrics.items()}
    '''
    dist_map = {}
    for sid, info in sensor_metrics.items():
        if gateway_selection:
            dists = [info.get('distances', {}).get(gw) for gw in gateway_selection if gw in info.get('distances', {})]
            dists = [d for d in dists if d is not None]
            dist_map[sid] = min(dists) if dists else None
        else:
            dist_map[sid] = info.get('min_distance_all_gws')
   '''
    gw_num_map = {sid: info.get('sensor_covering_gws') for sid, info in sensor_metrics.items()}
    
    nelev = _normalize_dict(elev_map)
    #ndist = _normalize_dict(dist_map)
    ngw_num=_normalize_dict(gw_num_map)

    # elevation: higher-> lower cost invert=True
    celev = scale_to_0_10(nelev, invert=True)
    # distance: nearer -> lower cost invert=False
    #cdist = scale_to_0_10(ndist, invert=False)
    # gw_num: higher-> lower cost invert=True
    cgw_num = scale_to_0_10(ngw_num, invert=True)

    logger.info(f"sensor normalised elevation:{nelev}")
    logger.info(f"sensor normalised gateway count:{ngw_num}")

    logger.info(f"sensor scaled to 1-10 elevation:{celev}")
    logger.info(f"sensor scaled to 1-10 gateway count:{cgw_num}")
    

    sensor_costs = {}
    for sid in sensor_metrics.keys():
        parts = []
        wsum = 0.0
        if 'elevation' in weights and celev.get(sid) is not None:
            parts.append((celev[sid], weights['elevation'])); wsum += weights['elevation']
        '''    
        if 'distance' in weights and cdist.get(sid) is not None:
            parts.append((cdist[sid], weights['distance'])); wsum += weights['distance']
        '''
        if 'gw_num' in weights and cgw_num.get(sid) is not None:
            parts.append((cgw_num[sid], weights['gw_num'])); wsum += weights['gw_num']
            
        if wsum <= 0:
            sensor_costs[sid] = None
            continue

        combined = sum(p * (w / wsum) for p, w in parts)
        sensor_costs[sid] = float(round(combined, 6))
    
    logger.info(f"final sensor costs:{sensor_costs}")
    return sensor_costs

# ----------------------------
# Top-level pipeline
# ----------------------------

def compute_all_costs_pipeline(
    dem_path: str,
    gateway_positions: Dict[int, Tuple[float, float]],
    sensor_positions: Dict[int, Tuple[float, float]],
    viewshed_paths: Optional[Dict[int, str]] = None,
    coverage_map: Optional[Dict[int, Dict]] = None,
    gateway_attr_costs: Optional[Dict[int, float]] = None,
    sensor_attr_costs: Optional[Dict[int, float]] = None,
    gateway_weights: Optional[Dict[str, float]] = None,
    sensor_weights: Optional[Dict[str, float]] = None,
    gateway_selection_for_sensor_distances: Optional[List[int]] = None,
    reverse_sensor_map: Optional[Dict[int, Tuple[str, str]]] = None
):
    """
    Same as before but uses gateway_attr_costs and sensor_attr_costs when available.
    Final cost = average of computed metric cost and attribute cost (if attribute exists).
    """
    gateway_attr_costs = gateway_attr_costs or {}
    sensor_attr_costs = sensor_attr_costs or {}



    gw_metrics = gateway_base_metrics(gateway_positions, dem_path=dem_path,
                                      viewshed_paths=viewshed_paths, coverage_map=coverage_map)
    
    logger.info(f"gateway factors:{gw_metrics}")
    logger.info(f"gateway_other_costs:{gateway_attr_costs}")


    gw_comp_costs = compute_gateway_costs(gw_metrics, weights=gateway_weights)

    sensor_metrics = sensor_base_metrics(sensor_positions, dem_path=dem_path, gateway_positions=gateway_positions, coverage_map=coverage_map )
   
    
    logger.info(f"sensor factors:{sensor_metrics}")
    logger.info(f"sensor other costs:{sensor_attr_costs}")
   
    if gateway_selection_for_sensor_distances is None:
        gateway_selection_for_sensor_distances = list(gateway_positions.keys())

    sensor_comp_costs = compute_sensor_costs_for_selection(sensor_metrics, gateway_selection_for_sensor_distances,
                                                           weights=sensor_weights)

    # Combine attribute costs (if present) with computed costs
    gw_final = {}
    for gw, comp in gw_comp_costs.items():
        attr = gateway_attr_costs.get(gw)
        if attr is not None:
            # attr assumed 1..10, leave as is; final = average
            final = (comp + float(attr)) / 2.0 if comp is not None else float(attr)
        else:
            final = (comp + 10) / 2.0
        gw_final[gw] = float(max(0.0, min(10.0, final))) if final is not None else None

    sensor_final = {}
    for sid, comp in sensor_comp_costs.items():
        attr = sensor_attr_costs.get(sid)
        if attr is not None:
            final = (comp + float(attr)) / 2.0 if comp is not None else float(attr)
        else:
            final = (comp + 5) / 2.0
        sensor_final[sid] = float(max(0.0, min(10.0, final))) if final is not None else None


    logger.info(f"gateway final cost with other cost:{gw_final}")
    logger.info(f"sensor final cost with other cost:{sensor_final}")

    return {
        'gateway_metrics': gw_metrics,
        'gateway_costs': gw_final,
        'sensor_metrics': sensor_metrics,
        'sensor_costs': sensor_final
    }
