# -*- coding: utf-8 -*-
"""Processing algorithm to prioritize grid cells based on point data.

This algorithm takes a polygon grid (e.g. hex cells) and a point layer as
input, counts how many points fall within each cell, and writes two
attributes:

- ``<base_name>_cnt``   – raw count of points in the cell
- ``<base_name>_score`` – 0–100 score derived from the counts

The score is scaled so that it can be combined with other metrics which
also use a 0–100 range. Whether more matches increase or decrease the
score is controlled by a boolean parameter.

Useful for example for factoring in crash counts or traffic calming counts.
"""

from qgis.PyQt.QtCore import QCoreApplication, QVariant
from qgis.core import (
    QgsFeature,
    QgsFeatureRequest,
    QgsField,
    QgsFields,
    QgsGeometry,
    QgsProcessing,
    QgsProcessingAlgorithm,
    QgsProcessingException,
    QgsProcessingParameterBoolean,
    QgsProcessingParameterFeatureSink,
    QgsProcessingParameterFeatureSource,
    QgsProcessingParameterString,
    QgsSpatialIndex,
    QgsWkbTypes,
    QgsFeatureSink,
)

from ..utils.scaling import scale_value


class GridPriorityFromPointsAlgorithm(QgsProcessingAlgorithm):
    """Grid Priority (Point Data).

    Generic building block for point-based metrics, such as crashes or
    traffic-calming locations. Each run completely recomputes the count
    and score for every grid cell, so it is safe to re-run on layers
    that already contain previous values.
    """

    # Parameter names
    INPUT_GRID = "INPUT_GRID"
    INPUT_POINTS = "INPUT_POINTS"
    ATTRIBUTE_BASENAME = "ATTRIBUTE_BASENAME"
    # Ex: More crashes could indicate increasing sidewalk priority,
    # but more traffic calming in the area could indicate decreasing sidewalk priority.
    MORE_POINTS_INCREASE_SCORE = "MORE_POINTS_INCREASE_SCORE"
    OUTPUT = "OUTPUT"

    def tr(self, string):
        """Returns a translatable string with the self.tr() function."""
        return QCoreApplication.translate("Processing", string)

    def createInstance(self):
        """Returns a new instance of the algorithm."""
        return GridPriorityFromPointsAlgorithm()

    def name(self):
        """Returns the algorithm id used for identifying the algorithm."""
        return "grid_priority_from_points"

    def displayName(self):
        """Returns the translated algorithm name shown in the toolbox."""
        return self.tr("Grid Priority (Point Data)")

    def group(self):
        """Returns the group name this algorithm belongs to."""
        return ""

    def groupId(self):
        """Returns the unique ID of the group this algorithm belongs to."""
        return ""

    def shortHelpString(self):
        """Returns a localised short help string for the algorithm."""
        return self.tr(
            "Prioritizes grid cells based on the presence of point data."
            ""
            "For each polygon cell in the input grid layer, this algorithm "
            "counts how many points from the supplied point layer fall "
            "within the cell. Two attributes are written:"
            ""
            "- <base_name>_cnt   – raw count of points per cell"
            ""
            "- <base_name>_score – 0–100 score derived from the counts"
            ""
            "Use the 'Metric name' parameter to choose the base name (for "
            "example 'ped_crash' or 'traffic_calmed'). Use the boolean "
            "parameter to control whether more matches increase or decrease "
            "the score."
        )

    def initAlgorithm(self, config=None):
        """Define the inputs and outputs of the algorithm."""
        self.addParameter(
            QgsProcessingParameterFeatureSource(
                self.INPUT_GRID,
                self.tr("Grid Layer"),
                [QgsProcessing.TypeVectorPolygon],
            )
        )

        self.addParameter(
            QgsProcessingParameterFeatureSource(
                self.INPUT_POINTS,
                self.tr("Point Layer"),
                [QgsProcessing.TypeVectorPoint],
            )
        )

        # Base attribute name, e.g. "ped_crash" or "traffic_calmed"
        self.addParameter(
            QgsProcessingParameterString(
                self.ATTRIBUTE_BASENAME,
                self.tr("Metric name (base attribute name)"),
                defaultValue="metric",
            )
        )

        self.addParameter(
            QgsProcessingParameterBoolean(
                self.MORE_POINTS_INCREASE_SCORE,
                self.tr("Higher point count increase the score"),
                defaultValue=True,
            )
        )

        # Output layer
        self.addParameter(
            QgsProcessingParameterFeatureSink(
                self.OUTPUT,
                self.tr("Grid Priority (Points) Output"),
            )
        )

    def processAlgorithm(self, parameters, context, feedback):
        """Execute the algorithm."""
        grid_source = self.parameterAsSource(parameters, self.INPUT_GRID, context)
        if grid_source is None:
            raise QgsProcessingException(
                self.invalidSourceError(parameters, self.INPUT_GRID)
            )

        points_source = self.parameterAsSource(
            parameters, self.INPUT_POINTS, context
        )
        if points_source is None:
            raise QgsProcessingException(
                self.invalidSourceError(parameters, self.INPUT_POINTS)
            )

        base_name = self.parameterAsString(
            parameters, self.ATTRIBUTE_BASENAME, context
        ).strip()
        if not base_name:
            raise QgsProcessingException(
                self.tr("Metric name (base attribute name) must not be empty.")
            )

        more_points_increase_score = self.parameterAsBoolean(
            parameters, self.MORE_POINTS_INCREASE_SCORE, context
        )

        base_cnt_name = f"{base_name}_cnt"
        base_score_name = f"{base_name}_score"

        grid_count = grid_source.featureCount()
        points_count = points_source.featureCount()
        feedback.pushInfo(
            self.tr(
                "Grid features: {grid_count}, point features: {points_count}."
            ).format(
                grid_count=grid_count,
                points_count=points_count,
            )
        )

        # Build spatial index for points and store features by id for geometry lookup
        feedback.pushInfo(self.tr("Building spatial index for point layer..."))
        point_index = QgsSpatialIndex()
        points_by_id = {}

        for p_feat in points_source.getFeatures():
            if feedback.isCanceled():
                raise QgsProcessingException(
                    self.tr("Calculation cancelled by user")
                )
            if not p_feat.hasGeometry():
                continue
            point_index.addFeature(p_feat)
            points_by_id[p_feat.id()] = QgsFeature(p_feat)

        # First pass: count points per grid cell
        feedback.pushInfo(self.tr("Counting points in each grid cell..."))
        counts_by_fid = {}
        count_values = []

        for grid_feat in grid_source.getFeatures():
            if feedback.isCanceled():
                raise QgsProcessingException(
                    self.tr("Calculation cancelled by user")
                )

            geom = grid_feat.geometry()
            if not geom or geom.isEmpty():
                cnt = 0
            else:
                bbox = geom.boundingBox()
                candidate_ids = point_index.intersects(bbox)
                cnt = 0
                for pid in candidate_ids:
                    p_feat = points_by_id.get(pid)
                    if not p_feat:
                        continue
                    p_geom = p_feat.geometry()
                    if p_geom and not p_geom.isEmpty() and geom.contains(p_geom):
                        cnt += 1

            counts_by_fid[grid_feat.id()] = cnt
            count_values.append(cnt)

        if count_values:
            min_cnt = min(count_values)
            max_cnt = max(count_values)
        else:
            min_cnt = 0
            max_cnt = 0

        nonzero_cells = sum(1 for v in count_values if v > 0)
        feedback.pushInfo(
            self.tr(
                "Minimum count: {min_cnt}, maximum count: {max_cnt}. Cells with at least one point: {nz} out of {total}."
            ).format(min_cnt=min_cnt, max_cnt=max_cnt, nz=nonzero_cells, total=len(count_values))
        )

        # Prepare output fields, ensuring our two attributes exist
        orig_fields = grid_source.fields()
        fields = QgsFields(orig_fields)

        cnt_idx = fields.indexOf(base_cnt_name)
        if cnt_idx == -1:
            fields.append(QgsField(base_cnt_name, QVariant.Int))
            cnt_idx = fields.indexOf(base_cnt_name)

        score_idx = fields.indexOf(base_score_name)
        if score_idx == -1:
            # Scores are 0–100 integers; store them as Int for clarity.
            fields.append(QgsField(base_score_name, QVariant.Int))
            score_idx = fields.indexOf(base_score_name)

        feedback.pushInfo(
            self.tr(
                "Writing results to fields '{cnt}' (count) and '{score}' (score)."
            ).format(cnt=base_cnt_name, score=base_score_name)
        )

        # Create output sink
        sink, dest_id = self.parameterAsSink(
            parameters,
            self.OUTPUT,
            context,
            fields,
            grid_source.wkbType(),
            grid_source.sourceCrs(),
        )

        if sink is None:
            raise QgsProcessingException(
                self.invalidSinkError(parameters, self.OUTPUT)
            )

        # Second pass: compute scores and write features
        feedback.pushInfo(self.tr("Calculating scores and writing output..."))

        if max_cnt == 0:
            # No points anywhere. All counts are zero.
            # If higher point count increase sthe score, then everything scores 0.
            # If higher point count decrease the score, then everything scores 100.
            default_score = 0 if more_points_increase_score else 100

            for grid_feat in grid_source.getFeatures():
                if feedback.isCanceled():
                    raise QgsProcessingException(
                        self.tr("Calculation cancelled by user")
                    )
                # Create a new feature with the full field structure
                out_feat = QgsFeature()
                out_feat.setFields(fields)
                out_feat.setGeometry(grid_feat.geometry())
                out_feat.initAttributes(fields.count())
                # Copy original attributes by index (may be fewer than fields)
                orig_attrs = grid_feat.attributes()
                for i, val in enumerate(orig_attrs):
                    if i < fields.count():
                        out_feat.setAttribute(i, val)
                # Set our metric-specific attributes
                out_feat.setAttribute(cnt_idx, 0)
                out_feat.setAttribute(score_idx, int(default_score))
                sink.addFeature(out_feat, QgsFeatureSink.FastInsert)

        else:
            for grid_feat in grid_source.getFeatures():
                if feedback.isCanceled():
                    raise QgsProcessingException(
                        self.tr("Calculation cancelled by user")
                    )

                fid = grid_feat.id()
                cnt = counts_by_fid.get(fid, 0)

                if more_points_increase_score:
                    # Crash-like metric: more points => higher score.
                    if cnt == 0:
                        score = 0
                    else:
                        score = scale_value(cnt, min_cnt, max_cnt, 1, 100)
                else:
                    # Traffic-calming-like metric: more points => lower score.
                    if cnt == 0:
                        score = 100
                    else:
                        score = scale_value(cnt, min_cnt, max_cnt, 100, 0)

                # Create a new feature with the full field structure
                out_feat = QgsFeature()
                out_feat.setFields(fields)
                out_feat.setGeometry(grid_feat.geometry())
                out_feat.initAttributes(fields.count())
                # Copy original attributes by index
                orig_attrs = grid_feat.attributes()
                for i, val in enumerate(orig_attrs):
                    if i < fields.count():
                        out_feat.setAttribute(i, val)
                # Set our metric-specific attributes
                out_feat.setAttribute(cnt_idx, int(cnt))
                out_feat.setAttribute(score_idx, int(score))
                sink.addFeature(out_feat, QgsFeatureSink.FastInsert)

        feedback.pushInfo(
            self.tr(
                "Grid priority calculation from point data complete. Processed {n} grid features."
            ).format(n=len(count_values))
        )

        return {self.OUTPUT: dest_id}
