# -*- coding: utf-8 -*-
"""
/***************************************************************************
 Sidewalk Priority Toolkit - Walk Potential Algorithm
                                 A QGIS plugin
 Processing algorithm for walk potential analysis
                              -------------------
        begin                : 2025-10-13
        copyright            : (C) 2025 by Mark Stosberg
        email                : mark@stosberg.com
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""

import json
import os
from typing import List, Dict, Callable, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed

from qgis.PyQt.QtCore import QCoreApplication, QVariant, Qt
from qgis.PyQt.QtGui import QColor
from qgis.core import (
    QgsProcessing,
    QgsProcessingAlgorithm,
    QgsProcessingParameterFeatureSource,
    QgsProcessingParameterFeatureSink,
    QgsProcessingParameterString,
    QgsProcessingParameterNumber,
    QgsProcessingParameterFile,
    QgsProcessingParameterDefinition,
    QgsProcessingException,
    QgsProcessingUtils,
    QgsFeatureSink,
    QgsProject,
    QgsFeatureRequest,
    QgsGraduatedSymbolRenderer,
    QgsRendererRange,
    QgsSymbol,
    QgsWkbTypes,
    QgsGeometry,
    QgsFeature,
    QgsVectorLayer,
    QgsCoordinateReferenceSystem,
    QgsField,
)

from ..services.overpass_query import OverpassQuery, OverpassQueryError
from ..services.isochrone import IsochroneFetcher, IsochroneError
from ..utils.geometry import (
    geojson_to_qgs_geometry,
    qgs_geometry_to_geojson,
    create_hex_grid,
    buffer_geometry,
    union_geometries,
    get_bbox,
    calculate_area_sqm,
    hex_area_sqm_for_side_km,
)


class WalkPotentialCalculator:
    """Main class for calculating walk potential"""

    # Constants from the original Node.js code
    HEX_SIDE_KM = 0.05  # 50 meters
    HEX_WIDTH_KM = 2 * HEX_SIDE_KM
    AMENITIES_BUFFER_KM = 0.8  # Buffer for amenity bbox

    def __init__(
        self,
        isochrone_url: str,
        concurrency: int = 4,
        amenities_file: str | None = None,
        progress_callback: Callable[[str], None] | None = None,
        overpass_endpoint: str = "https://overpass-api.de/api/interpreter",
    ):
        """Initialize the walk potential calculator."""
        self.isochrone_url = isochrone_url
        self.concurrency = concurrency
        self.progress_callback = progress_callback or (lambda x: None)
        self._cancel_flag = False

        # Load amenities configuration
        if amenities_file is None:
            plugin_dir = os.path.dirname(os.path.dirname(__file__))
            amenities_file = os.path.join(plugin_dir, "amenities.json")

        with open(amenities_file, "r") as f:
            self.amenity_queries = json.load(f)

        # Initialize API clients
        self.overpass = OverpassQuery(endpoint=overpass_endpoint)
        # Don't pass progress_callback to isochrone fetcher - it runs in threads and would cause Qt crashes
        self.isochrone_fetcher = IsochroneFetcher(isochrone_url)

        # Storage for amenity layers
        self.amenity_layers: list[QgsGeometry] = []

    def _cancel_calculation(self) -> None:
        """Signal that calculation should be cancelled."""
        self._cancel_flag = True

    def _do_progress_callback(self, message: str) -> None:
        """Call progress callback only if not cancelled."""
        if not self._cancel_flag:
            self.progress_callback(message)

    def calculate_from_grid(self, grid_layer: QgsVectorLayer) -> QgsVectorLayer:
        """Calculate walk potential for an existing hex grid."""
        # Step 1: Derive boundary geometry from grid
        self._do_progress_callback("Deriving boundary geometry from grid...")
        self._do_progress_callback(
            f"Input grid has {grid_layer.featureCount()} features and CRS {grid_layer.crs().authid()}"
        )

        try:
            boundary_geom = self._get_boundary_geometry(grid_layer)
        except ValueError as e:
            # Fallback: if the grid layer has a non-empty extent but no valid
            # feature geometries (which can happen with some materialized
            # in-memory layers), use the layer extent as a proxy boundary.
            extent = grid_layer.extent()
            if extent is None or extent.isEmpty():
                raise
            self._do_progress_callback(
                "No valid feature geometries found in grid layer; "
                "falling back to layer extent as boundary polygon."
            )
            boundary_geom = QgsGeometry.fromRect(extent)

        # Steps 2-3: Calculate amenities bbox, query amenities, generate isochrones
        total_isochrones_requested = self._prepare_amenities_for_boundary(boundary_geom)

        # Step 4: Use existing hex grid
        if self._cancel_flag:
            raise ValueError("Calculation cancelled by user")
        self._do_progress_callback("Using provided hex grid...")
        hex_features = [feature for feature in grid_layer.getFeatures()]
        self._do_progress_callback(
            f"Using {len(hex_features)} hex cells from input grid"
        )

        # Step 5: Calculate walk potential for each hex
        if self._cancel_flag:
            raise ValueError("Calculation cancelled by user")
        self._do_progress_callback("Calculating walk potential...")
        hex_features = self._calculate_walk_potential(hex_features)

        # Step 6: Create output layer
        result_layer = self._create_result_layer(hex_features)

        # Report summary
        num_categories = len(self.amenity_queries)
        num_hex_cells = len(hex_features)
        self._do_progress_callback(
            "DONE! Walk Potential has been calculated for {num_cells:,} hex cells within the grid, "
            "using {num_categories} categories of amenities and {num_isochrones:,} total travel-time isochrones.".format(
                num_cells=num_hex_cells,
                num_categories=num_categories,
                num_isochrones=total_isochrones_requested,
            )
        )
        return result_layer

    def _prepare_amenities_for_boundary(self, boundary_geom: QgsGeometry) -> int:
        """Calculate amenities bbox and build amenity layers.

        Returns the total number of isochrones requested across all
        amenity categories.
        """
        # Step 2: Calculate amenities bbox with buffer
        self._do_progress_callback("Calculating amenities bounding box...")
        buffered_boundary = buffer_geometry(boundary_geom, self.AMENITIES_BUFFER_KM)
        amenities_bbox = get_bbox(buffered_boundary)
        self._do_progress_callback(f"Amenities bbox: {amenities_bbox}")

        # Step 3: Query amenities and generate isochrone layers
        self.amenity_layers = []
        total_isochrones_requested = 0

        for amenity_name, query in self.amenity_queries.items():
            if self._cancel_flag:
                raise ValueError("Calculation cancelled by user")
            self._do_progress_callback(
                f"{amenity_name}: Querying Overpass API..."
            )

            try:
                features = self.overpass.bbox_query_to_geojson(
                    query,
                    amenities_bbox,
                    self._do_progress_callback if not self._cancel_flag else None,
                )

                if features:
                    self._do_progress_callback(
                        f"{amenity_name}: Found {len(features)} features, generating isochrones..."
                    )
                    layer_geom, isochrone_count = self._create_amenity_layer(
                        features
                    )
                    total_isochrones_requested += isochrone_count
                    if layer_geom:
                        self.amenity_layers.append(layer_geom)
                else:
                    self._do_progress_callback(
                        f"{amenity_name}: No features found"
                    )

            except OverpassQueryError as e:
                self._do_progress_callback(
                    f"{amenity_name}: Error - {str(e)}"
                )

        return total_isochrones_requested

    def _get_boundary_geometry(self, layer: QgsVectorLayer) -> QgsGeometry:
        """Extract and combine all features from the boundary layer."""
        geometries: list[QgsGeometry] = []
        for feature in layer.getFeatures():
            geom = feature.geometry()
            if geom and not geom.isNull():
                geometries.append(geom)

        if not geometries:
            raise ValueError("Boundary layer contains no valid geometries")

        # Union all geometries
        if len(geometries) == 1:
            return geometries[0]
        return union_geometries(geometries)

    def _create_amenity_layer(
        self, features: List[Dict]
    ) -> tuple[Optional[QgsGeometry], int]:
        """Create an amenity layer by generating isochrones for all features.

        :return: (unioned isochrone geometry or None, count of isochrones)
        """
        # Expand large polygons into multiple points
        expanded_features = self._expand_large_features(features)

        # Generate isochrones concurrently
        isochrones: list[QgsGeometry] = []
        total = len(expanded_features)

        with ThreadPoolExecutor(max_workers=self.concurrency) as executor:
            future_to_idx: dict = {}
            for idx, feature in enumerate(expanded_features):
                geom = geojson_to_qgs_geometry(feature["geometry"])
                if geom:
                    center = geom.centroid().asPoint()
                    # Log first URL for debugging
                    if idx == 0:
                        test_url = self.isochrone_fetcher.get_isochrone_url(
                            center.y(), center.x()
                        )
                        self.progress_callback(
                            f"\tSample isochrone URL: {test_url}"
                        )
                    future = executor.submit(
                        self.isochrone_fetcher.fetch_isochrone,
                        center.x(),
                        center.y(),
                    )
                    future_to_idx[future] = idx

            # Collect results
            completed = 0
            failed_count = 0
            critical_errors: list[str] = []

            for future in as_completed(future_to_idx):
                completed += 1
                if completed % 10 == 0:
                    self.progress_callback(
                        f"\tProcessed {completed}/{total} isochrones"
                    )

                try:
                    isochrone_geojson = future.result()
                    if isochrone_geojson:
                        isochrone_geom = geojson_to_qgs_geometry(
                            isochrone_geojson["geometry"]
                        )
                        if isochrone_geom:
                            isochrones.append(isochrone_geom)
                    else:
                        failed_count += 1
                except IsochroneError as e:
                    # Critical error (network/service failure)
                    critical_errors.append(str(e))
                    failed_count += 1
                    # If we get critical errors early, fail fast
                    if len(critical_errors) >= 3:
                        raise IsochroneError(
                            f"Multiple isochrone requests failed: {critical_errors[0]}"
                        )
                except Exception as e:  # noqa: BLE001
                    # Other errors (parsing, etc.) - just count as failed
                    failed_count += 1

        # Check results
        if not isochrones:
            if critical_errors:
                raise IsochroneError(
                    f"All isochrone requests failed. {critical_errors[0]}"
                )
            self.progress_callback(
                f"\tWarning: No valid isochrones generated from {total} features!"
            )
            return None, total

        # Warn if high failure rate
        failure_rate = (failed_count / total) * 100 if total > 0 else 0
        if failure_rate > 2:
            self.progress_callback(
                f"\tWarning: {failure_rate:.1f}% of isochrone requests failed ({failed_count}/{total})"
            )

        # Union all isochrones
        self.progress_callback(
            f"\tSuccess: Generated {len(isochrones)}/{total} isochrones"
        )
        self.progress_callback(
            f"\tMerging {len(isochrones)} isochrones..."
        )
        return union_geometries(isochrones), total

    def _expand_large_features(self, features: List[Dict]) -> List[Dict]:
        """Expand large polygon features into multiple point features at edges."""
        expanded: list[Dict] = []

        # Use the theoretical hexagon area as a threshold for "large" polygons
        hex_area_threshold_sqm = hex_area_sqm_for_side_km(self.HEX_SIDE_KM)

        for feature in features:
            geom = geojson_to_qgs_geometry(feature["geometry"])
            if not geom:
                continue

            # Check if it's a large polygon
            if geom.type() == QgsWkbTypes.PolygonGeometry:
                area = calculate_area_sqm(geom)

                if area > hex_area_threshold_sqm:
                    # Create a grid within the polygon
                    bbox = get_bbox(geom)
                    hex_features = create_hex_grid(geom, self.HEX_SIDE_KM)

                    # If hex grid creation failed (polygon too small), use centroid
                    if not hex_features:
                        expanded.append(feature)
                        continue

                    # Buffer the polygon inward
                    smaller_geom = geom.buffer(
                        -self.HEX_WIDTH_KM / 111.0, 5
                    )

                    # Keep points near the edge
                    edge_points_found = False
                    for hex_feat in hex_features:
                        center = hex_feat.geometry().centroid()
                        # Point is near edge if it's in original but not in smaller
                        if geom.contains(center) and (
                            not smaller_geom
                            or not smaller_geom.contains(center)
                        ):
                            point_feature = {
                                "type": "Feature",
                                "geometry": qgs_geometry_to_geojson(center),
                                "properties": feature.get("properties", {}),
                            }
                            expanded.append(point_feature)
                            edge_points_found = True

                    # Fallback to centroid if no edge points found
                    if not edge_points_found:
                        expanded.append(feature)
                else:
                    expanded.append(feature)
            else:
                expanded.append(feature)

        return expanded if expanded else features

    def _calculate_walk_potential(
        self, hex_features: List[QgsFeature]
    ) -> List[QgsFeature]:
        """Calculate walk potential for each hex feature."""
        total = len(hex_features)
        num_amenity_layers = len(self.amenity_layers)

        if num_amenity_layers == 0:
            self.progress_callback("Warning: No amenity layers found!")
            return hex_features

        for idx, hex_feature in enumerate(hex_features):
            if idx % 100 == 0:
                self.progress_callback(f"\t{idx}/{total} cells processed")

            center = hex_feature.geometry().centroid()
            walk_potential = 0

            # Count how many amenity layers contain this point
            for amenity_layer in self.amenity_layers:
                if amenity_layer.contains(center):
                    walk_potential += 1

            # Scale from 0-100
            scaled_potential = int(
                (walk_potential / num_amenity_layers) * 100
            )

            # Update the feature attributes
            hex_feature.setAttributes([scaled_potential])

        return hex_features

    def _create_result_layer(
        self, hex_features: List[QgsFeature]
    ) -> QgsVectorLayer:
        """Create the result layer with hex grid and walk potential values."""
        crs = QgsCoordinateReferenceSystem("EPSG:4326")
        layer = QgsVectorLayer(
            f"Polygon?crs={crs.authid()}",
            "Walk Potential",
            "memory",
        )

        provider = layer.dataProvider()
        provider.addAttributes([QgsField("walk_potential_score", QVariant.Int)])
        layer.updateFields()

        provider.addFeatures(hex_features)
        layer.updateExtents()

        return layer


class WalkPotentialFromGridAlgorithm(QgsProcessingAlgorithm):
    """Processing algorithm for calculating walk potential on an existing grid."""

    def __init__(self):
        super().__init__()

    # Parameter names
    INPUT_GRID = 'INPUT_GRID'
    ISOCHRONE_URL = 'ISOCHRONE_URL'
    CONCURRENCY = 'CONCURRENCY'
    OVERPASS_API_URL = 'OVERPASS_API_URL'
    AMENITIES_FILE = 'AMENITIES_FILE'
    OUTPUT = 'OUTPUT'

    def tr(self, string):
        """Returns a translatable string with the self.tr() function."""
        return QCoreApplication.translate('Processing', string)

    def createInstance(self):
        """Returns a new instance of the algorithm."""
        return WalkPotentialFromGridAlgorithm()

    def name(self):
        """Returns the algorithm name, used for identifying the algorithm."""
        return 'walkpotential_from_grid'

    def displayName(self):
        """Returns the translated algorithm name."""
        return self.tr('Grid Priority (Walk Potential)')

    def group(self):
        """Returns the name of the group this algorithm belongs to."""
        return ''

    def groupId(self):
        """Returns the unique ID of the group this algorithm belongs to."""
        return ''

    def shortHelpString(self):
        """Returns a localised short helper string for the algorithm."""
        return self.tr(
            "Analyzes pedestrian access to amenities (shops, parks, restaurants, etc.) "
            "using an existing grid layer. "
            "This algorithm expects a polygon grid layer (e.g. hex cells) as input and "
            "calculates a walk potential score (0-100) for each cell based on how many "
            "amenity types are within a 10-minute walk. "
            "Use this when you already have a grid you want to reuse across multiple analyses. "
            "You can apply a default stop-light heatmap style afterwards using the "
            "'Grid Priority (Apply Default Heatmap Style)' algorithm."
        )

    def initAlgorithm(self, config=None):
        """Define the inputs and outputs of the algorithm."""
        # Input grid layer
        self.addParameter(
            QgsProcessingParameterFeatureSource(
                self.INPUT_GRID,
                self.tr('Grid Layer'),
                [QgsProcessing.TypeVectorPolygon]
            )
        )

        # Isochrone service URL (Advanced)
        default_url = 'http://valhalla1.openstreetmap.de/isochrone?json={"polygons":true,"locations":[{"lat":{{lat}}, "lon":{{lon}}}],"costing":"pedestrian","contours":[{"time":10.0}]}'
        isochrone_param = QgsProcessingParameterString(
            self.ISOCHRONE_URL,
            self.tr('Isochrone Service URL (use {{lat}} and {{lon}} as placeholders)'),
            defaultValue=default_url
        )
        isochrone_param.setFlags(isochrone_param.flags() | QgsProcessingParameterDefinition.FlagAdvanced)
        self.addParameter(isochrone_param)

        # Concurrency setting (Advanced)
        concurrency_param = QgsProcessingParameterNumber(
            self.CONCURRENCY,
            self.tr('Concurrency (simultaneous requests)'),
            type=QgsProcessingParameterNumber.Integer,
            defaultValue=4,
            minValue=1,
            maxValue=64
        )
        concurrency_param.setFlags(concurrency_param.flags() | QgsProcessingParameterDefinition.FlagAdvanced)
        self.addParameter(concurrency_param)

        # Overpass API endpoint (Advanced)
        overpass_param = QgsProcessingParameterString(
            self.OVERPASS_API_URL,
            self.tr('Overpass API URL'),
            defaultValue='https://overpass-api.de/api/interpreter'
        )
        overpass_param.setFlags(overpass_param.flags() | QgsProcessingParameterDefinition.FlagAdvanced)
        self.addParameter(overpass_param)

        # Optional custom amenities file
        self.addParameter(
            QgsProcessingParameterFile(
                self.AMENITIES_FILE,
                self.tr('Custom Amenities File'),
                behavior=QgsProcessingParameterFile.File,
                fileFilter='JSON files (*.json)',
                optional=True
            )
        )

        # Output layer
        self.addParameter(
            QgsProcessingParameterFeatureSink(
                self.OUTPUT,
                self.tr('Walk Potential Output Layer')
            )
        )

    def flags(self):
        """Returns the flags for the algorithm."""
        return super().flags() | QgsProcessingAlgorithm.FlagNoThreading

    def processAlgorithm(self, parameters, context, feedback):
        """Execute the algorithm."""
        # Get parameters
        grid_source = self.parameterAsSource(parameters, self.INPUT_GRID, context)
        if grid_source is None:
            raise QgsProcessingException(self.invalidSourceError(parameters, self.INPUT_GRID))

        # Convert source to layer
        grid_layer = grid_source.materialize(
            QgsFeatureRequest(),
            feedback
        )

        if grid_layer is None:
            raise QgsProcessingException('Could not materialize grid layer')

        isochrone_url = self.parameterAsString(parameters, self.ISOCHRONE_URL, context)
        concurrency = self.parameterAsInt(parameters, self.CONCURRENCY, context)
        overpass_endpoint = self.parameterAsString(parameters, self.OVERPASS_API_URL, context)
        amenities_file = self.parameterAsFile(parameters, self.AMENITIES_FILE, context)

        # Progress callback
        def progress_callback(message):
            if feedback.isCanceled():
                raise QgsProcessingException('Calculation cancelled by user')
            feedback.pushInfo(message)

        # Create calculator
        try:
            # Handle empty amenities file - use None to trigger default
            if not amenities_file or not amenities_file.strip():
                amenities_file = None

            calculator = WalkPotentialCalculator(
                isochrone_url=isochrone_url,
                concurrency=concurrency,
                amenities_file=amenities_file,
                progress_callback=progress_callback,
                overpass_endpoint=overpass_endpoint
            )

            # Run calculation
            feedback.pushInfo('Starting walk potential calculation on existing grid...')
            result_layer = calculator.calculate_from_grid(grid_layer)

            if feedback.isCanceled():
                raise QgsProcessingException('Calculation cancelled by user')

            # Add result to output sink
            (sink, dest_id) = self.parameterAsSink(
                parameters,
                self.OUTPUT,
                context,
                result_layer.fields(),
                result_layer.wkbType(),
                result_layer.crs()
            )

            if sink is None:
                raise QgsProcessingException(self.invalidSinkError(parameters, self.OUTPUT))

            # Copy features from result layer to sink
            for feature in result_layer.getFeatures():
                if feedback.isCanceled():
                    break
                sink.addFeature(feature, QgsFeatureSink.FastInsert)

            feedback.pushInfo('Walk potential calculation on existing grid complete!')

            results = {self.OUTPUT: dest_id}
            return results

        except Exception as e:
            raise QgsProcessingException(str(e))
