from qgis.core import (
    QgsVectorLayer, QgsFeature, QgsGeometry, QgsPointXY,
    QgsFields, QgsField
)
from PyQt5.QtCore import QVariant


def run_allocation(polygons_layer, pipes_layer, threshold, feedback, output_crs):
    
    # ─────────────────────────────────────────
    # Generate centroids
    # ─────────────────────────────────────────
    centroids = []

    feedback.pushInfo(f"Input house polygons feature count: {polygons_layer.featureCount()}")

    for feat in polygons_layer.getFeatures():
        geom = feat.geometry()
        if geom and not geom.isEmpty():
            c = geom.centroid()
            if c and not c.isEmpty():
                centroids.append(c.asPoint())

    if not centroids:
        feedback.reportError("No valid centroids found → stopping here")
        raise Exception("No valid centroids found")

    # ─────────────────────────────────────────
    # Snap centroids to nearest pipe
    # ─────────────────────────────────────────
    snapped = []
    max_snap_dist = 1000.0

    for pt in centroids:
        pt_geom = QgsGeometry.fromPointXY(pt)
        min_dist = float("inf")
        nearest = None

        for pipe in pipes_layer.getFeatures():
            g = pipe.geometry()
            if not g:
                continue
            d = g.distance(pt_geom)
            if d < min_dist:
                min_dist = d
                nearest = g.nearestPoint(pt_geom).asPoint()

        if nearest and min_dist <= max_snap_dist:
            snapped.append(nearest)

    if not snapped:
        feedback.reportError("No centroids could be snapped to any pipe → no output will be created")
        raise Exception("No centroids could be snapped to pipes")

    # ─────────────────────────────────────────
    # Cluster snapped points
    # ─────────────────────────────────────────
    clusters = []
    for p in snapped:
        added = False
        for c in clusters:
            if QgsGeometry.fromPointXY(p).distance(
                QgsGeometry.fromPointXY(c[0])
            ) <= threshold:
                c.append(p)
                added = True
                break
        if not added:
            clusters.append([p])

    # ─────────────────────────────────────────
    # Create POINT output (memory) - House Load Points
    # ─────────────────────────────────────────
    point_fields = QgsFields()
    point_fields.append(QgsField("NUM", QVariant.Int))

    point_layer = QgsVectorLayer(
        f"Point?crs={output_crs.authid()}",
        "house_load_points",
        "memory"
    )
    point_layer.dataProvider().addAttributes(point_fields)
    point_layer.updateFields()

    added_count = 0
    for cluster in clusters:
        if len(cluster) == 0:
            continue
        x = sum(p.x() for p in cluster) / len(cluster)
        y = sum(p.y() for p in cluster) / len(cluster)

        f = QgsFeature()
        f.setGeometry(QgsGeometry.fromPointXY(QgsPointXY(x, y)))
        f.setAttributes([len(cluster)])
        point_layer.dataProvider().addFeature(f)
        added_count += 1

    # ─────────────────────────────────────────
    # Create REFERENCE LINES: centroid → cluster average
    # ─────────────────────────────────────────
    line_fields = QgsFields()
    line_fields.append(QgsField("NUM_HOUSES", QVariant.Int))
    line_fields.append(QgsField("DISTANCE_M", QVariant.Double))

    line_layer = QgsVectorLayer(
        f"LineString?crs={output_crs.authid()}",
        "centroid_to_demand_points",
        "memory"
    )
    line_layer.dataProvider().addAttributes(line_fields)
    line_layer.updateFields()

    added_lines = 0

    # Map snapped points to their cluster index
    snapped_to_cluster = {}
    for cluster_idx, cluster in enumerate(clusters):
        for pt in cluster:
            try:
                idx = snapped.index(pt)
                snapped_to_cluster[idx] = cluster_idx
            except ValueError:
                continue 

    # Create one line per original centroid
    for i, centroid in enumerate(centroids):
        if i not in snapped_to_cluster:
            continue

        cluster_idx = snapped_to_cluster[i]
        cluster = clusters[cluster_idx]

        if not cluster:
            continue

        cx = sum(p.x() for p in cluster) / len(cluster)
        cy = sum(p.y() for p in cluster) / len(cluster)
        avg_point = QgsPointXY(cx, cy)

        line_geom = QgsGeometry.fromPolylineXY([centroid, avg_point])

        f = QgsFeature()
        f.setGeometry(line_geom)
        f.setAttributes([
            len(cluster),
            line_geom.length()
        ])
        line_layer.dataProvider().addFeature(f)
        added_lines += 1

    feedback.pushInfo(f"Created {point_layer.featureCount()} house load points")
    feedback.pushInfo(f"Created {line_layer.featureCount()} reference lines")
    
    return point_layer, line_layer