# -*- coding: utf-8 -*-
"""Processing algorithm to prioritize grid cells based on polygon coverage.

This algorithm takes a polygon grid
(e.g. hex cells) and an additional polygon layer (e.g. building
footprints), computes how much of each grid cell is covered by the
polygons, and writes two attributes:

- ``<metric_name>_coverage`` – fraction of the cell's area covered (0–1)
- ``<metric_name>_score``    – 0–100 score derived from the coverage

The score is scaled so that it can be combined with other metrics which
also use a 0–100 range. Whether more coverage increases or decreases the
score is controlled by a boolean parameter.
"""

from qgis.PyQt.QtCore import QCoreApplication, QVariant
from qgis.core import (
    QgsFeature,
    QgsField,
    QgsFields,
    QgsProcessing,
    QgsProcessingAlgorithm,
    QgsProcessingException,
    QgsProcessingParameterBoolean,
    QgsProcessingParameterFeatureSink,
    QgsProcessingParameterFeatureSource,
    QgsProcessingParameterString,
    QgsSpatialIndex,
    QgsFeatureSink,
)

from ..utils.scaling import scale_value


class GridPriorityFromPolygonCoverageAlgorithm(QgsProcessingAlgorithm):
    """Grid Priority (Coverage from Polygons).

    For each polygon cell in the input grid layer, this algorithm
    computes the proportion of the cell's area which is covered by
    polygons from a second layer (e.g. building footprints). The raw
    coverage fraction and a 0–100 score suitable for combining with
    other metrics are written to the output grid.
    """

    # Parameter names
    INPUT_GRID = "INPUT_GRID"
    INPUT_POLYGONS = "INPUT_POLYGONS"
    METRIC_NAME = "METRIC_NAME"
    MORE_COVERAGE_INCREASES_SCORE = "MORE_COVERAGE_INCREASES_SCORE"
    OUTPUT = "OUTPUT"

    def tr(self, string):
        """Returns a translatable string with the self.tr() function."""
        return QCoreApplication.translate("Processing", string)

    def createInstance(self):
        return GridPriorityFromPolygonCoverageAlgorithm()

    def name(self):
        return "grid_priority_from_polygon_coverage"

    def displayName(self):
        return self.tr("Grid Priority (Coverage from Polygons)")

    def group(self):
        return ""

    def groupId(self):
        return ""

    def shortHelpString(self):
        return self.tr(
            "Prioritizes grid cells based on the fraction of their area "
            "covered by polygons from another layer, such as building "
            "footprints. This is similar to calculating building density "
            "per cell."
        )

    def initAlgorithm(self, config=None):
        # Input grid layer (polygons)
        self.addParameter(
            QgsProcessingParameterFeatureSource(
                self.INPUT_GRID,
                self.tr("Grid Layer"),
                [QgsProcessing.TypeVectorPolygon],
            )
        )

        # Input polygon layer (e.g. building footprints)
        self.addParameter(
            QgsProcessingParameterFeatureSource(
                self.INPUT_POLYGONS,
                self.tr("Polygon Layer"),
                [QgsProcessing.TypeVectorPolygon],
            )
        )

        # Metric name used as base for output attributes
        self.addParameter(
            QgsProcessingParameterString(
                self.METRIC_NAME,
                self.tr("Metric name (base attribute name)"),
                defaultValue="building_density",
            )
        )

        # Whether more coverage should result in a higher score
        self.addParameter(
            QgsProcessingParameterBoolean(
                self.MORE_COVERAGE_INCREASES_SCORE,
                self.tr("More coverage increases the score"),
                defaultValue=True,
            )
        )

        # Output grid layer
        self.addParameter(
            QgsProcessingParameterFeatureSink(
                self.OUTPUT,
                self.tr("Grid Priority (Polygons) Output"),
            )
        )

    def flags(self):
        return super().flags() | QgsProcessingAlgorithm.FlagNoThreading

    def processAlgorithm(self, parameters, context, feedback):
        grid_source = self.parameterAsSource(parameters, self.INPUT_GRID, context)
        if grid_source is None:
            raise QgsProcessingException(
                self.invalidSourceError(parameters, self.INPUT_GRID)
            )

        poly_source = self.parameterAsSource(
            parameters, self.INPUT_POLYGONS, context
        )
        if poly_source is None:
            raise QgsProcessingException(
                self.invalidSourceError(parameters, self.INPUT_POLYGONS)
            )

        metric_name = self.parameterAsString(parameters, self.METRIC_NAME, context).strip()
        if not metric_name:
            raise QgsProcessingException(
                self.tr("Metric name (base attribute name) must not be empty.")
            )

        more_coverage_increases_score = self.parameterAsBoolean(
            parameters, self.MORE_COVERAGE_INCREASES_SCORE, context
        )

        coverage_field_name = f"{metric_name}_coverage"
        score_field_name = f"{metric_name}_score"

        # Build spatial index for polygons and store features by id
        feedback.pushInfo(self.tr("Building spatial index for polygon layer..."))
        poly_index = QgsSpatialIndex()
        polys_by_id = {}

        for poly_feat in poly_source.getFeatures():
            if feedback.isCanceled():
                raise QgsProcessingException(
                    self.tr("Calculation cancelled by user")
                )
            if not poly_feat.hasGeometry():
                continue
            poly_index.addFeature(poly_feat)
            polys_by_id[poly_feat.id()] = QgsFeature(poly_feat)

        # First pass: compute coverage fraction per grid cell
        feedback.pushInfo(
            self.tr("Calculating polygon coverage for each grid cell...")
        )
        coverage_by_fid = {}
        coverage_values = []

        for grid_feat in grid_source.getFeatures():
            if feedback.isCanceled():
                raise QgsProcessingException(
                    self.tr("Calculation cancelled by user")
                )

            geom = grid_feat.geometry()
            if not geom or geom.isEmpty():
                coverage = 0.0
            else:
                cell_area = geom.area()
                if cell_area <= 0:
                    coverage = 0.0
                else:
                    bbox = geom.boundingBox()
                    candidate_ids = poly_index.intersects(bbox)
                    covered_area = 0.0

                    for pid in candidate_ids:
                        poly_feat = polys_by_id.get(pid)
                        if not poly_feat:
                            continue
                        poly_geom = poly_feat.geometry()
                        if not poly_geom or poly_geom.isEmpty():
                            continue
                        if not geom.intersects(poly_geom):
                            continue

                        inter_geom = poly_geom.intersection(geom)
                        if inter_geom and not inter_geom.isEmpty():
                            covered_area += inter_geom.area()

                    coverage = covered_area / cell_area if cell_area > 0 else 0.0

            # Clamp to [0, 1] to avoid numerical artefacts
            if coverage < 0.0:
                coverage = 0.0
            elif coverage > 1.0:
                coverage = 1.0

            fid = grid_feat.id()
            coverage_by_fid[fid] = coverage
            coverage_values.append(coverage)

        positive_values = [v for v in coverage_values if v > 0]
        if positive_values:
            min_cov = min(positive_values)
            max_cov = max(positive_values)
        else:
            min_cov = 0.0
            max_cov = 0.0

        feedback.pushInfo(
            self.tr("Coverage fractions range from {min_cov} to {max_cov}.").format(
                min_cov=min_cov, max_cov=max_cov
            )
        )

        # Prepare output fields
        orig_fields = grid_source.fields()
        fields = QgsFields(orig_fields)

        coverage_idx = fields.indexOf(coverage_field_name)
        if coverage_idx == -1:
            fields.append(QgsField(coverage_field_name, QVariant.Double))
            coverage_idx = fields.indexOf(coverage_field_name)

        score_idx = fields.indexOf(score_field_name)
        if score_idx == -1:
            # Scores are 0–100 integers; store them as Int for clarity.
            fields.append(QgsField(score_field_name, QVariant.Int))
            score_idx = fields.indexOf(score_field_name)

        # Create output sink
        sink, dest_id = self.parameterAsSink(
            parameters,
            self.OUTPUT,
            context,
            fields,
            grid_source.wkbType(),
            grid_source.sourceCrs(),
        )

        if sink is None:
            raise QgsProcessingException(
                self.invalidSinkError(parameters, self.OUTPUT)
            )

        # Second pass: compute scores and write features
        feedback.pushInfo(self.tr("Calculating scores and writing output..."))

        if not positive_values or max_cov <= min_cov:
            # No positive coverage, or degenerate range
            default_score = 0 if more_coverage_increases_score else 100

            for grid_feat in grid_source.getFeatures():
                if feedback.isCanceled():
                    raise QgsProcessingException(
                        self.tr("Calculation cancelled by user")
                    )

                fid = grid_feat.id()
                coverage = float(coverage_by_fid.get(fid, 0.0))

                # Create a new feature with the full field structure
                out_feat = QgsFeature()
                out_feat.setFields(fields)
                out_feat.setGeometry(grid_feat.geometry())
                out_feat.initAttributes(fields.count())

                # Copy original attributes by index (may be fewer than fields)
                orig_attrs = grid_feat.attributes()
                for i, val in enumerate(orig_attrs):
                    if i < fields.count():
                        out_feat.setAttribute(i, val)

                # Set our metric-specific attributes
                out_feat.setAttribute(coverage_idx, coverage)
                out_feat.setAttribute(score_idx, int(default_score))
                sink.addFeature(out_feat, QgsFeatureSink.FastInsert)
        else:
            for grid_feat in grid_source.getFeatures():
                if feedback.isCanceled():
                    raise QgsProcessingException(
                        self.tr("Calculation cancelled by user")
                    )

                fid = grid_feat.id()
                coverage = float(coverage_by_fid.get(fid, 0.0))

                if coverage <= 0:
                    score = 0 if more_coverage_increases_score else 100
                else:
                    if more_coverage_increases_score:
                        score = scale_value(coverage, min_cov, max_cov, 1, 100)
                    else:
                        score = scale_value(coverage, min_cov, max_cov, 100, 0)

                # Create a new feature with the full field structure
                out_feat = QgsFeature()
                out_feat.setFields(fields)
                out_feat.setGeometry(grid_feat.geometry())
                out_feat.initAttributes(fields.count())

                # Copy original attributes by index
                orig_attrs = grid_feat.attributes()
                for i, val in enumerate(orig_attrs):
                    if i < fields.count():
                        out_feat.setAttribute(i, val)

                # Set our metric-specific attributes
                out_feat.setAttribute(coverage_idx, coverage)
                out_feat.setAttribute(score_idx, int(score))
                sink.addFeature(out_feat, QgsFeatureSink.FastInsert)

        feedback.pushInfo(
            self.tr("Grid priority calculation from polygon coverage complete.")
        )
        return {self.OUTPUT: dest_id}
