# -*- coding: utf-8 -*-
"""Processing algorithm to prioritize grid cells based on polygon attributes at cell centers.

It takes a polygon grid (e.g. hex cells) and an additional polygon layer (e.g. census tracts),
computes the centroid of each grid cell, finds the polygon containing
that centroid, reads a selected numeric attribute from the polygon, and
writes two attributes to the grid:

- ``<metric_name>``       – the raw attribute value from the polygon
- ``<metric_name>_score`` – a 0–100 score derived from the attribute

The score is scaled so that it can be combined with other metrics which
also use a 0–100 range. A boolean parameter controls whether higher
attribute values increase or decrease the score.
"""

from qgis.PyQt.QtCore import QCoreApplication, QVariant
from qgis.core import (
    QgsFeature,
    QgsField,
    QgsFields,
    QgsProcessing,
    QgsProcessingAlgorithm,
    QgsProcessingException,
    QgsProcessingParameterBoolean,
    QgsProcessingParameterFeatureSink,
    QgsProcessingParameterFeatureSource,
    QgsProcessingParameterField,
    QgsProcessingParameterString,
    QgsSpatialIndex,
    QgsFeatureSink,
)

from ..utils.scaling import scale_value


class GridPriorityFromPolygonCenterAlgorithm(QgsProcessingAlgorithm):
    """Grid Priority (Center within Polygon).

    For each polygon cell in the input grid layer, this algorithm
    computes the centroid of the cell and finds the feature from the
    input polygon layer which contains that centroid. It then reads a
    selected numeric attribute from the containing polygon and writes:

    - ``<metric_name>``       – the raw attribute value from the polygon
    - ``<metric_name>_score`` – 0–100 score based on that value

    Use the "Metric name" parameter to choose the base name for the
    output attributes, and the boolean parameter to control whether
    higher attribute values increase or decrease the score.
    """

    # Parameter names
    INPUT_GRID = "INPUT_GRID"
    INPUT_POLYGONS = "INPUT_POLYGONS"
    POLY_FIELD = "POLY_FIELD"
    METRIC_NAME = "METRIC_NAME"
    HIGHER_VALUES_INCREASE_SCORE = "HIGHER_VALUES_INCREASE_SCORE"
    OUTPUT = "OUTPUT"

    def tr(self, string):
        """Returns a translatable string with the self.tr() function."""
        return QCoreApplication.translate("Processing", string)

    def createInstance(self):
        return GridPriorityFromPolygonCenterAlgorithm()

    def name(self):
        """Returns the algorithm name, used for identifying the algorithm."""
        return "grid_priority_from_polygon_center"

    def displayName(self):
        """Returns the translated algorithm name."""
        return self.tr("Grid Priority (Center within Polygon)")

    def group(self):
        """Returns the name of the group this algorithm belongs to."""
        return ""

    def groupId(self):
        """Returns the unique ID of the group this algorithm belongs to."""
        return ""

    def shortHelpString(self):
        """Returns a localised short help string for the algorithm."""
        return self.tr(
            "Assigns a metric value to each grid cell based on the value of a "
            "selected attribute from the polygon feature which contains the "
            "cell's centroid. The raw attribute value and a 0–100 score are "
            "written to the grid, allowing combination with other metrics."
        )

    def initAlgorithm(self, config=None):
        self.addParameter(
            QgsProcessingParameterFeatureSource(
                self.INPUT_GRID,
                self.tr("Grid Layer"),
                [QgsProcessing.TypeVectorPolygon],
            )
        )

        # Input polygon layer (e.g. census tracts)
        self.addParameter(
            QgsProcessingParameterFeatureSource(
                self.INPUT_POLYGONS,
                self.tr("Polygon Layer"),
                [QgsProcessing.TypeVectorPolygon],
            )
        )

        self.addParameter(
            QgsProcessingParameterField(
                self.POLY_FIELD,
                self.tr("Polygon attribute field"),
                parentLayerParameterName=self.INPUT_POLYGONS,
                type=QgsProcessingParameterField.Numeric,
            )
        )

        # Base name for output attributes
        self.addParameter(
            QgsProcessingParameterString(
                self.METRIC_NAME,
                self.tr("Metric name (base attribute name)"),
                defaultValue="metric",
            )
        )

        # Whether higher attribute values should result in a higher score
        self.addParameter(
            QgsProcessingParameterBoolean(
                self.HIGHER_VALUES_INCREASE_SCORE,
                self.tr("Higher values increase the score"),
                defaultValue=True,
            )
        )

        # Output grid layer
        self.addParameter(
            QgsProcessingParameterFeatureSink(
                self.OUTPUT,
                self.tr("Grid Priority (Polygon Center) Output"),
            )
        )

    def flags(self):
        """Returns the flags for the algorithm."""
        return super().flags() | QgsProcessingAlgorithm.FlagNoThreading

    def processAlgorithm(self, parameters, context, feedback):
        """Execute the algorithm."""
        grid_source = self.parameterAsSource(parameters, self.INPUT_GRID, context)
        if grid_source is None:
            raise QgsProcessingException(
                self.invalidSourceError(parameters, self.INPUT_GRID)
            )

        poly_source = self.parameterAsSource(parameters, self.INPUT_POLYGONS, context)
        if poly_source is None:
            raise QgsProcessingException(
                self.invalidSourceError(parameters, self.INPUT_POLYGONS)
            )

        poly_field_name = self.parameterAsString(
            parameters, self.POLY_FIELD, context
        )
        if not poly_field_name:
            raise QgsProcessingException(
                self.tr("Polygon attribute field must be specified.")
            )

        metric_name = self.parameterAsString(
            parameters, self.METRIC_NAME, context
        ).strip()
        if not metric_name:
            raise QgsProcessingException(
                self.tr("Metric name (base attribute name) must not be empty.")
            )

        higher_values_increase_score = self.parameterAsBoolean(
            parameters, self.HIGHER_VALUES_INCREASE_SCORE, context
        )

        # Build spatial index for polygon layer and collect attribute values
        feedback.pushInfo(self.tr("Building spatial index for polygon layer..."))
        poly_index = QgsSpatialIndex()
        poly_geoms = {}
        poly_values = {}
        value_samples = []

        poly_fields = poly_source.fields()
        poly_field_index = poly_fields.indexOf(poly_field_name)
        if poly_field_index < 0:
            raise QgsProcessingException(
                self.tr("Selected polygon attribute field '{name}' was not found.")
                .format(name=poly_field_name)
            )
        value_field_def = poly_fields[poly_field_index]

        for poly_feat in poly_source.getFeatures():
            if feedback.isCanceled():
                raise QgsProcessingException(
                    self.tr("Calculation cancelled by user")
                )
            if not poly_feat.hasGeometry():
                continue
            geom = poly_feat.geometry()
            if not geom or geom.isEmpty():
                continue
            poly_index.addFeature(poly_feat)
            val = poly_feat[poly_field_index]
            try:
                num_val = float(val) if val is not None else None
            except (TypeError, ValueError):
                num_val = None
            poly_geoms[poly_feat.id()] = geom
            if num_val is not None:
                poly_values[poly_feat.id()] = num_val
                value_samples.append(num_val)

        # Determine scaling range from non-null polygon attribute values
        if value_samples:
            min_val = min(value_samples)
            max_val = max(value_samples)
        else:
            min_val = 0.0
            max_val = 0.0

        feedback.pushInfo(
            self.tr(
                "Attribute values for '{field}' range from {min_val} to {max_val}."
            ).format(field=poly_field_name, min_val=min_val, max_val=max_val)
        )

        # Prepare output fields
        orig_fields = grid_source.fields()
        fields = QgsFields(orig_fields)

        value_field_name = metric_name
        score_field_name = f"{metric_name}_score"

        # Add raw value field with same type as source attribute
        value_idx = fields.indexOf(value_field_name)
        if value_idx == -1:
            value_field = QgsField(value_field_name, value_field_def.type())
            fields.append(value_field)
            value_idx = fields.indexOf(value_field_name)

        # Add integer score field
        score_idx = fields.indexOf(score_field_name)
        if score_idx == -1:
            fields.append(QgsField(score_field_name, QVariant.Int))
            score_idx = fields.indexOf(score_field_name)

        # Create output sink
        sink, dest_id = self.parameterAsSink(
            parameters,
            self.OUTPUT,
            context,
            fields,
            grid_source.wkbType(),
            grid_source.sourceCrs(),
        )

        if sink is None:
            raise QgsProcessingException(
                self.invalidSinkError(parameters, self.OUTPUT)
            )

        feedback.pushInfo(self.tr("Evaluating polygon centers for each grid cell..."))

        for grid_feat in grid_source.getFeatures():
            if feedback.isCanceled():
                raise QgsProcessingException(
                    self.tr("Calculation cancelled by user")
                )

            geom = grid_feat.geometry()
            raw_val = None
            score = None

            if geom and not geom.isEmpty():
                center = geom.centroid()
                if not center or center.isEmpty():
                    polygon_value = None
                else:
                    c_geom = center
                    # Use a small bbox around the centroid to query candidate polygons
                    bbox = c_geom.boundingBox()
                    candidate_ids = poly_index.intersects(bbox)
                    polygon_value = None
                    for pid in candidate_ids:
                        p_geom = poly_geoms.get(pid)
                        if p_geom and p_geom.contains(c_geom):
                            polygon_value = poly_values.get(pid)
                            break

                if polygon_value is not None:
                    raw_val = polygon_value
                    if value_samples and min_val != max_val:
                        if higher_values_increase_score:
                            score = scale_value(polygon_value, min_val, max_val, 1, 100)
                        else:
                            score = scale_value(polygon_value, min_val, max_val, 100, 0)
                    else:
                        # Degenerate range or no valid samples: assign neutral/extreme score
                        score = 0 if higher_values_increase_score else 100
                else:
                    # No containing polygon or no attribute value
                    raw_val = None
                    score = 0 if higher_values_increase_score else 100

            # Create a new feature with the full field structure
            out_feat = QgsFeature()
            out_feat.setFields(fields)
            out_feat.setGeometry(geom)
            out_feat.initAttributes(fields.count())

            # Copy original attributes by index
            orig_attrs = grid_feat.attributes()
            for i, val in enumerate(orig_attrs):
                if i < fields.count():
                    out_feat.setAttribute(i, val)

            # Set metric-specific attributes
            out_feat.setAttribute(value_idx, raw_val)
            if score is not None:
                out_feat.setAttribute(score_idx, int(score))

            sink.addFeature(out_feat, QgsFeatureSink.FastInsert)

        feedback.pushInfo(self.tr("Grid priority calculation from polygon centers complete!"))

        return {self.OUTPUT: dest_id}
