# -*- coding: utf-8 -*-
"""Prioritize grid cells based on proximity to points on another layer.

It takes a polygon grid (e.g. hex cells) and a point layer (e.g.
transit stops), calls an isochrone service for each point, unions the
resulting polygons, and tests whether each grid cell lies within the
union.

Two attributes are written:

- ``<metric_name>_within`` – boolean flag, True if the cell is within any
  isochrone polygon
- ``<metric_name>_score``  – numeric score assigned when within the
  isochrone, otherwise 0

The score value is provided directly by the user. For example, use 100 if
this is the only proximity test, or 100 and 50 for two different
isochrone thresholds in a model. The isochrone URL template works like
in the Walk Potential algorithm, with the same default value.
"""

from qgis.PyQt.QtCore import QCoreApplication, QVariant
from qgis.core import (
    QgsFeature,
    QgsField,
    QgsFields,
    QgsProcessing,
    QgsProcessingAlgorithm,
    QgsProcessingException,
    QgsProcessingParameterFeatureSink,
    QgsProcessingParameterFeatureSource,
    QgsProcessingParameterNumber,
    QgsProcessingParameterString,
    QgsSpatialIndex,
    QgsFeatureSink,
)

from ..services.isochrone import IsochroneFetcher, IsochroneError
from ..utils.geometry import geojson_to_qgs_geometry, union_geometries


class GridPriorityProximityAlgorithm(QgsProcessingAlgorithm):
    """Grid Priority (Proximity Test).

    For each polygon cell in the input grid layer, this algorithm checks
    whether the cell lies within an isochrone region generated from a set
    of input points (e.g. transit stops). If the cell is within the
    union of all isochrones, a user-specified score is assigned;
    otherwise the score is zero.
    """

    # Parameter names
    INPUT_GRID = "INPUT_GRID"
    INPUT_POINTS = "INPUT_POINTS"
    METRIC_NAME = "METRIC_NAME"
    SCORE_VALUE = "SCORE_VALUE"
    ISOCHRONE_URL = "ISOCHRONE_URL"
    OUTPUT = "OUTPUT"

    def tr(self, string):
        """Returns a translatable string with the self.tr() function."""
        return QCoreApplication.translate("Processing", string)

    def createInstance(self):
        return GridPriorityProximityAlgorithm()

    def name(self):
        return "grid_priority_proximity"

    def displayName(self):
        return self.tr("Grid Priority (Proximity Test)")

    def group(self):
        return ""

    def groupId(self):
        return ""

    def shortHelpString(self):
        return self.tr(
            "Prioritizes grid cells based on proximity to a set of points "
            "(for example, transit stops). Isochrones are generated for "
            "each point and unioned; cells within the union receive the "
            "specified score, others receive 0. Run this multiple times "
            "with different score values in a model to emulate multi-step "
            "proximity grading (e.g. 5-minute vs 10-minute walks)."
            ""
            "If multiple isochrones are used for different distances, "
            "they could later be combined as a single metric by using an expression. "
            "For example: "
            ""
            "CASE "
            "    WHEN \"bus5_score\" = 100 THEN 100 "
            "    WHEN \"bus10_score\" = 100 THEN 50 "
            "    ELSE 0 "
            "END"
        )

    def initAlgorithm(self, config=None):
        # Input grid layer (polygons)
        self.addParameter(
            QgsProcessingParameterFeatureSource(
                self.INPUT_GRID,
                self.tr("Grid Layer"),
                [QgsProcessing.TypeVectorPolygon],
            )
        )

        # Input point layer
        self.addParameter(
            QgsProcessingParameterFeatureSource(
                self.INPUT_POINTS,
                self.tr("Point Layer"),
                [QgsProcessing.TypeVectorPoint],
            )
        )

        # Metric name used as base for output attributes
        self.addParameter(
            QgsProcessingParameterString(
                self.METRIC_NAME,
                self.tr("Metric name (base attribute name)"),
                defaultValue="proximity",
            )
        )

        # Score value to assign when a cell is within the isochrone union
        self.addParameter(
            QgsProcessingParameterNumber(
                self.SCORE_VALUE,
                self.tr("Score value when within isochrone"),
                type=QgsProcessingParameterNumber.Double,
                defaultValue=100.0,
            )
        )

        # Isochrone service URL template
        default_url = 'http://valhalla1.openstreetmap.de/isochrone?json={"polygons":true,"locations":[{"lat":{{lat}}, "lon":{{lon}}}],"costing":"pedestrian","contours":[{"time":10.0}]}'
        self.addParameter(
            QgsProcessingParameterString(
                self.ISOCHRONE_URL,
                self.tr("Isochrone Service URL (use {{lat}} and {{lon}} as placeholders)"),
                defaultValue=default_url,
            )
        )

        # Output grid layer
        self.addParameter(
            QgsProcessingParameterFeatureSink(
                self.OUTPUT,
                self.tr("Grid Priority (Proximity) Output"),
            )
        )

    def flags(self):
        return super().flags() | QgsProcessingAlgorithm.FlagNoThreading

    def processAlgorithm(self, parameters, context, feedback):
        grid_source = self.parameterAsSource(parameters, self.INPUT_GRID, context)
        if grid_source is None:
            raise QgsProcessingException(
                self.invalidSourceError(parameters, self.INPUT_GRID)
            )

        points_source = self.parameterAsSource(
            parameters, self.INPUT_POINTS, context
        )
        if points_source is None:
            raise QgsProcessingException(
                self.invalidSourceError(parameters, self.INPUT_POINTS)
            )

        metric_name = self.parameterAsString(parameters, self.METRIC_NAME, context).strip()
        if not metric_name:
            raise QgsProcessingException(
                self.tr("Metric name (base attribute name) must not be empty.")
            )

        score_value = self.parameterAsDouble(parameters, self.SCORE_VALUE, context)
        isochrone_url = self.parameterAsString(parameters, self.ISOCHRONE_URL, context)

        within_field_name = f"{metric_name}_within"
        score_field_name = f"{metric_name}_score"

        # Collect all point locations for isochrone generation
        feedback.pushInfo(self.tr("Collecting point locations for isochrone generation..."))
        point_geoms = []
        for p_feat in points_source.getFeatures():
            if feedback.isCanceled():
                raise QgsProcessingException(
                    self.tr("Calculation cancelled by user")
                )
            if not p_feat.hasGeometry():
                continue
            geom = p_feat.geometry()
            if not geom or geom.isEmpty():
                continue
            pt = geom.asPoint()
            point_geoms.append(pt)

        if not point_geoms:
            feedback.pushInfo(self.tr("No valid point geometries found; all cells will score 0."))
            union_geom = None
        else:
            # Generate isochrones for all points and union them
            feedback.pushInfo(self.tr("Requesting isochrones from service..."))
            fetcher = IsochroneFetcher(isochrone_url)
            isochrone_geoms = []

            try:
                for idx, pt in enumerate(point_geoms, start=1):
                    if feedback.isCanceled():
                        raise QgsProcessingException(
                            self.tr("Calculation cancelled by user")
                        )

                    lat = pt.y()
                    lon = pt.x()
                    feedback.pushInfo(
                        self.tr("Fetching isochrone {idx}/{total}...").format(
                            idx=idx, total=len(point_geoms)
                        )
                    )
                    geojson = fetcher.fetch_isochrone(lon, lat)
                    if not geojson:
                        continue
                    geom = geojson_to_qgs_geometry(geojson["geometry"])
                    if geom and not geom.isEmpty():
                        isochrone_geoms.append(geom)

            except IsochroneError as e:
                raise QgsProcessingException(str(e))
            except Exception as e:
                raise QgsProcessingException(str(e))

            if not isochrone_geoms:
                feedback.pushInfo(self.tr("No isochrone geometries were generated; all cells will score 0."))
                union_geom = None
            else:
                feedback.pushInfo(self.tr("Unioning isochrone geometries..."))
                union_geom = union_geometries(isochrone_geoms)

        # Prepare output fields
        orig_fields = grid_source.fields()
        fields = QgsFields(orig_fields)

        within_idx = fields.indexOf(within_field_name)
        if within_idx == -1:
            fields.append(QgsField(within_field_name, QVariant.Bool))
            within_idx = fields.indexOf(within_field_name)

        score_idx = fields.indexOf(score_field_name)
        if score_idx == -1:
            # Scores are typically 0–100 integers; store them as Int for clarity.
            fields.append(QgsField(score_field_name, QVariant.Int))
            score_idx = fields.indexOf(score_field_name)

        # Create output sink
        sink, dest_id = self.parameterAsSink(
            parameters,
            self.OUTPUT,
            context,
            fields,
            grid_source.wkbType(),
            grid_source.sourceCrs(),
        )

        if sink is None:
            raise QgsProcessingException(
                self.invalidSinkError(parameters, self.OUTPUT)
            )

        # Evaluate each grid cell against the isochrone union
        feedback.pushInfo(self.tr("Evaluating grid cells against isochrone union..."))

        for grid_feat in grid_source.getFeatures():
            if feedback.isCanceled():
                raise QgsProcessingException(
                    self.tr("Calculation cancelled by user")
                )

            geom = grid_feat.geometry()
            cell_within = False
            cell_score = 0

            if union_geom and geom and not geom.isEmpty():
                centroid = geom.centroid()
                if union_geom.contains(centroid):
                    cell_within = True
                    cell_score = int(round(score_value))

            # Create a new feature with the full field structure
            out_feat = QgsFeature()
            out_feat.setFields(fields)
            out_feat.setGeometry(geom)
            out_feat.initAttributes(fields.count())

            # Copy original attributes by index
            orig_attrs = grid_feat.attributes()
            for i, val in enumerate(orig_attrs):
                if i < fields.count():
                    out_feat.setAttribute(i, val)

            # Set our metric-specific attributes
            out_feat.setAttribute(within_idx, bool(cell_within))
            out_feat.setAttribute(score_idx, cell_score)
            sink.addFeature(out_feat, QgsFeatureSink.FastInsert)

        feedback.pushInfo(self.tr("Grid priority proximity calculation complete."))
        return {self.OUTPUT: dest_id}
