'''
==========QNetPlanner - Viewshed module===========
This module performs viewshed analysis using the GDAL viewshed tool and generates a sensor coverage map. 
Prior to analysis, it verifies that the coordinate reference systems (CRS) of the input layers match the DEM. 
If a mismatch is detected, the layers are automatically reprojected to ensure consistency. The module also 
extracts cost values from the attribute table for use in subsequent processing.
    Returns:
      coverage_map: {gw_id: {stype: [global_sensor_id,...], ...}, ...}
      reverse_sensor_map: {global_id: (stype, orig_feature_id)}
      gateway_positions: {gw_id: (x,y)}
      sensor_positions: {global_sensor_id: (x,y)}
      viewshed_paths: {gw_id: path_to_vs.tif}
      gateway_attr_costs: {gw_id: attr_cost (1..10) or None}
      sensor_attr_costs: {global_sensor_id: attr_cost (1..10) or None}

Gitlab:
    https://gitlab.com/binoy194/QNetPlanner
email:
    binoy194@gmail.com
    kavyask304@gmail.com

Authors:
    Binoy C
    Kavya S K

'''
import os
import tempfile
from qgis.core import (
    QgsRasterLayer,
    QgsVectorLayer,
    QgsProcessingFeedback,
    QgsProcessingFeatureSourceDefinition,
    QgsPointXY,
    QgsRaster,
    QgsCoordinateTransform,
    QgsProject,
    QgsWkbTypes,
    QgsFeature
)
from qgis.PyQt.QtCore import QVariant
import processing
from .logger import logger
from pprint import pformat


def reproject_if_needed(layer, target_crs):
    # Check CRS of vector files with DEM and reindex the id to start from 1
    if layer.crs() == target_crs:
        logger.info(f"CRS of {layer.name()} matches DEM — reindexing feature IDs from 1")

        # Create new memory layer
        final_layer = QgsVectorLayer(
            f"{QgsWkbTypes.displayString(layer.wkbType())}?crs={layer.crs().authid()}",
            f"{layer.name()}_reindexed",
            "memory"
        )

        provider = final_layer.dataProvider()

        # Copy fields
        provider.addAttributes(layer.fields())
        final_layer.updateFields()

        # Copy features with new IDs
        new_features = []
        for i, feat in enumerate(layer.getFeatures(), start=1):
            new_feat = QgsFeature(final_layer.fields())
            new_feat.setGeometry(feat.geometry())
            new_feat.setAttributes(feat.attributes())
            new_feat.setId(i)
            new_features.append(new_feat)

        provider.addFeatures(new_features)
        final_layer.updateExtents()    
        
    else:
        logger.warning(
            f"CRS mismatch detected — reprojecting layer "
            f"{layer.name() if hasattr(layer, 'name') else layer}"
        )

        result = processing.run(
            "native:reprojectlayer",
            {
                "INPUT": layer,
                "TARGET_CRS": target_crs,
                "OUTPUT": "memory:"
            }
        )

        final_layer = result["OUTPUT"]

    # ---- Log attribute table ----
    field_names = [field.name() for field in final_layer.fields()]
    logger.info(f"Attribute table fields: {field_names}")

    for feature in final_layer.getFeatures():
        attrs = {field: feature[field] for field in field_names}
        logger.info(attrs)

    return final_layer


def safe_point_from_geometry(geom):
    """Extract a point safely even if geometry is multipoint or has Z/M."""
    try:
        if geom.type() == QgsWkbTypes.PointGeometry:
            # asPoint works for Point, returns QgsPointXY or QgsPoint
            return geom.asPoint()
        else:
            # fallback: take centroid for polygons/multipoints
            return geom.centroid().asPoint()
    except Exception:
        # final fallback: use bounding box center
        bbox = geom.boundingBox()
        return QgsPointXY((bbox.xMinimum() + bbox.xMaximum()) / 2.0, (bbox.yMinimum() + bbox.yMaximum()) / 2.0)


def run_viewshed_analysis(gateway_file, sensor_files_dict, dem_file):
    
    logger.info("=== Viewshed Analysis Started ===")
    logger.info(f"Gateway file: {gateway_file}")
    logger.info(f"Sensor files dict: {pformat(sensor_files_dict)}")
    logger.info(f"DEM file: {dem_file}")

    # load gateway layer
    gateways = QgsVectorLayer(gateway_file, "gateways", "ogr")
    if not gateways.isValid():
        raise Exception("Invalid gateway layer")

    dem = QgsRasterLayer(dem_file, "DEM")
    if not dem.isValid():
        raise Exception("Invalid DEM")

    # load all sensor layers and combine features
    sensors_by_type = {}  # stype -> list of QgsVectorLayer
    for stype, files in sensor_files_dict.items():
        sensors_by_type[stype] = []
        for f in files:
            vl = QgsVectorLayer(f, f"{stype}", "ogr")
            if not vl.isValid():
                logger.warning("Sensor layer invalid: %s", f)
            else:
                sensors_by_type[stype].append(vl)
                logger.info("Loaded sensor layer %s for type %s (features=%d)", f, stype, vl.featureCount())

    # ensure CRS match DEM: reproject layers if needed
    target_crs = dem.crs()
    gateways = reproject_if_needed(gateways, target_crs)
    for stype, layers in list(sensors_by_type.items()):
        reprojected = []
        for l in layers:
            rl = reproject_if_needed(l, target_crs)
            reprojected.append(rl)
        sensors_by_type[stype] = reprojected

    coverage_map = {}
    reverse_sensor_map = {}
    gateway_positions = {}
    sensor_positions = {}
    viewshed_paths = {}
    gateway_attr_costs = {}
    sensor_attr_costs = {}

    # Build an aggregate sensors list with global ids
    next_sensor_id = 1
    for stype, layers in sensors_by_type.items():
        for layer in layers:
            if not layer.isValid():
                logger.warning("Skipping invalid sensor layer for type %s", stype)
                continue
            logger.info("Processing sensor layer: %s (type=%s) features=%d", layer.name(), stype, layer.featureCount())
            for feat in layer.getFeatures():
                geom = feat.geometry()
                if geom is None:
                    logger.warning("Sensor feature %s has no geometry, skipping", feat.id())
                    continue
                # Accept point-type geometries
                # some layers might have multipoint -- safe_point_from_geometry handles it
                pt = safe_point_from_geometry(geom)
                if pt is None:
                    logger.warning("Could not get point for sensor feature %s, skipping", feat.id())
                    continue
                sensor_positions[next_sensor_id] = (float(pt.x()), float(pt.y()))
                reverse_sensor_map[next_sensor_id] = (stype, feat.id())
               

                # read attribute 'cost' if present (case-insensitive)
                attr_cost = None
                for fn in feat.fields().names():
                    if fn.lower() == "cost":
                        try:
                            v = feat[fn]
                            attr_cost = float(v) if v is not None else None
                        except Exception:
                            attr_cost = None
                        break
                sensor_attr_costs[next_sensor_id] = attr_cost
                next_sensor_id += 1

    logger.info("Total sensors discovered: %d", len(sensor_positions))
    logger.info(f"processing_utils:sensor_positions:{sensor_positions}")
    logger.info(f"processing_utils:reverse_sensor map:{reverse_sensor_map}")

    # Process gateways, compute viewsheds and find visible sensors
    for gw in gateways.getFeatures():
        gw_id = gw.id()
        geom = gw.geometry()
        if geom is None:
            logger.warning("Gateway feature %s has no geometry, skipping", gw_id)
            continue
        if geom.type() != QgsWkbTypes.PointGeometry:
            raise Exception("Gateways must be POINT layer")
        pt = safe_point_from_geometry(geom)
        x = float(pt.x()); y = float(pt.y())
        gateway_positions[gw_id] = (x, y)

        # read gateway attribute cost (case-insensitive)
        gw_attr_cost = None
        for fn in gw.fields().names():
            if fn.lower() == "cost":
                try:
                    v = gw[fn]
                    gw_attr_cost = float(v) if v is not None else None
                except Exception:
                    gw_attr_cost = None
                break
        gateway_attr_costs[gw_id] = gw_attr_cost

        # observer height: try Height field (case-insensitive) else default to 10
        obs_height = None
        for fn in gw.fields().names():
            if fn.lower() == "height":
                try:
                    obs_height = float(gw[fn]) if gw[fn] is not None else 10.0
                except Exception:
                    obs_height = 10.0
                break
        if obs_height is None:
            obs_height = 10.0

        #maximum range of gateway: default 5000km
        gw_range=None
        for rg in gw.fields().names():
            if rg.lower() == "range":
                try:
                    gw_range = float(gw[rg]) if gw[rg] is not None else 5000.0
                except Exception:
                    gw_range =5000.0
                break
        if gw_range is None:
            gw_range = 5000.0


        logger.info("Processing gateway %s at (%.6f, %.6f) obs_height=%s", gw_id, x, y, obs_height)

        out_path = os.path.join(tempfile.gettempdir(), f"viewshed_{gw_id}.tif")
        params = {
            'INPUT': dem,
            'BAND': 1,
            'OBSERVER': QgsPointXY(x, y),
            'OBSERVER_HEIGHT': float(obs_height),
            'TARGET_HEIGHT': 1.0,
            'MAX_DISTANCE': float(gw_range),
            'OUTPUT': out_path
        }

        logger.info(f"viewshed parameters:{params}")
        try:
            processing.run("gdal:viewshed", params)
        except Exception as e:
            logger.error("Viewshed failed for gw %s: %s", gw_id, e)
            coverage_map[gw_id] = {stype: [] for stype in sensors_by_type.keys()}
            viewshed_paths[gw_id] = None
            continue

        # load the viewshed raster once
        vs_raster = QgsRasterLayer(out_path, f"vs_{gw_id}")
        if not vs_raster.isValid():
            logger.error("Viewshed raster invalid for gw %s (path: %s)", gw_id, out_path)
            coverage_map[gw_id] = {stype: [] for stype in sensors_by_type.keys()}
            viewshed_paths[gw_id] = None
            continue

        viewshed_paths[gw_id] = out_path

        visible_by_type = {stype: [] for stype in sensors_by_type.keys()}

        # Pre-get data provider for identify
        provider = vs_raster.dataProvider()

        # Diagnostic logging: DEM extent
        dem_extent = dem.extent()
        logger.debug("DEM extent: %s", dem_extent.toString())

        # Iterate all known sensor positions and sample
        for sid, (sx, sy) in sensor_positions.items():
            try:
                # Check if sensor inside DEM extent (fast rejection)
                if not dem_extent.contains(QgsPointXY(sx, sy)):
                    logger.debug("Sensor %s at (%.6f,%.6f) outside DEM extent, skipping", sid, sx, sy)
                    continue

                ident = provider.identify(QgsPointXY(sx, sy), QgsRaster.IdentifyFormatValue)
                if ident.isValid():
                    results = ident.results()  # dict: band_index -> value
                    # GDAL viewshed uses 255 = visible, 0 = not visible
                    # Accept any positive non-zero value as visible, but prefer 255
                    val = None
                    if isinstance(results, dict) and len(results) > 0:
                        # pick first band's value
                        first_key = sorted(results.keys())[0]
                        val = results[first_key]
                    else:
                        val = None

                    if val is None:
                        logger.debug("Sensor %s sampling returned None at viewshed for gw %s", sid, gw_id)
                        continue

                    # treat visible if val == 255 or val > 0
                    if (isinstance(val, (int, float)) and (int(val) == 255 or float(val) > 0.0)):
                        stype = reverse_sensor_map.get(sid, (None,))[0]
                        if stype:
                            visible_by_type[stype].append(sid)
                            logger.debug("Sensor %s (type=%s) visible from gateway %s (sample=%s)", sid, stype, gw_id, val)
                        else:
                            logger.debug("Sensor %s visible but type unknown, sample=%s", sid, val)
                else:
                    # identify invalid (point outside viewshed extent?), log low-level debug
                    logger.debug("Viewshed identify invalid for sensor %s at (%.6f,%.6f) gw %s", sid, sx, sy, gw_id)

            except Exception as ex:
                logger.exception("Error sampling viewshed at sensor %s for gw %s: %s", sid, gw_id, ex)
                continue

        coverage_map[gw_id] = visible_by_type
        logger.info("Gateway %s visible:", gw_id)
        for stype, lst in visible_by_type.items():
            logger.info("  %s: %s", stype, lst)

    logger.info("=== Viewshed Analysis Completed Successfully ===")
    logger.info("Coverage Map:\n%s", pformat(coverage_map))
    return coverage_map, reverse_sensor_map, gateway_positions, sensor_positions, viewshed_paths, gateway_attr_costs, sensor_attr_costs
