# -*- coding: utf-8 -*-
"""
Main walk potential calculation module
"""

import json
import os
from typing import List, Dict, Callable, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed

from qgis.core import (
    QgsGeometry, QgsFeature, QgsVectorLayer, QgsField,
    QgsCoordinateReferenceSystem, QgsProject, QgsWkbTypes,
    QgsGraduatedSymbolRenderer, QgsRendererRange, QgsSymbol, QgsStyle
)
from qgis.PyQt.QtCore import QVariant, Qt
from qgis.PyQt.QtGui import QColor

try:
    from .overpass_query import OverpassQuery, OverpassQueryError
    from .isochrone import IsochroneFetcher, IsochroneError
    from .geometry_utils import (
        geojson_to_qgs_geometry, qgs_geometry_to_geojson,
        create_hex_grid, buffer_geometry, union_geometries, get_bbox,
        calculate_area_sqm
    )
except ImportError:
    # Fallback for direct execution
    from overpass_query import OverpassQuery, OverpassQueryError
    from isochrone import IsochroneFetcher, IsochroneError
    from geometry_utils import (
        geojson_to_qgs_geometry, qgs_geometry_to_geojson,
        create_hex_grid, buffer_geometry, union_geometries, get_bbox,
        calculate_area_sqm
    )


class WalkPotentialCalculator:
    """Main class for calculating walk potential"""
    
    # Constants from the original Node.js code
    HEX_SIDE_KM = 0.05  # 50 meters
    HEX_WIDTH_KM = 2 * HEX_SIDE_KM
    HEX_AREA_SQM = 6495  # square meters
    AMENITIES_BUFFER_KM = 0.8  # Buffer for amenity bbox
    
    def __init__(self, isochrone_url: str, concurrency: int = 4,
                 amenities_file: str = None,
                 progress_callback: Callable[[str], None] = None):
        """
        Initialize the walk potential calculator
        
        :param isochrone_url: URL template for isochrone service
        :param concurrency: Number of concurrent isochrone requests
        :param amenities_file: Path to amenities.json file
        :param progress_callback: Callback function for progress updates
        """
        self.isochrone_url = isochrone_url
        self.concurrency = concurrency
        self.progress_callback = progress_callback or (lambda x: None)
        self._cancel_flag = False
        
        # Load amenities configuration
        if amenities_file is None:
            plugin_dir = os.path.dirname(__file__)
            amenities_file = os.path.join(plugin_dir, 'amenities.json')
        
        with open(amenities_file, 'r') as f:
            self.amenity_queries = json.load(f)
        
        # Initialize API clients
        self.overpass = OverpassQuery()
        # Don't pass progress_callback to isochrone fetcher - it runs in threads and would cause Qt crashes
        self.isochrone_fetcher = IsochroneFetcher(isochrone_url)
        
        # Storage for amenity layers
        self.amenity_layers = []
    
    def _cancel_calculation(self):
        """Signal that calculation should be cancelled"""
        self._cancel_flag = True
    
    def _do_progress_callback(self, message):
        """Call progress callback only if not cancelled"""
        if not self._cancel_flag:
            self.progress_callback(message)
    
    def calculate(self, boundary_layer: QgsVectorLayer) -> QgsVectorLayer:
        """
        Calculate walk potential for the given boundary
        
        :param boundary_layer: QgsVectorLayer containing the city boundary
        :return: QgsVectorLayer containing the hex grid with walk potential values
        """
        # Step 1: Get the boundary geometry
        self._do_progress_callback("Getting boundary geometry...")
        boundary_geom = self._get_boundary_geometry(boundary_layer)
        
        # Step 2: Calculate amenities bbox with buffer
        self._do_progress_callback("Calculating amenities bounding box...")
        buffered_boundary = buffer_geometry(boundary_geom, self.AMENITIES_BUFFER_KM)
        amenities_bbox = get_bbox(buffered_boundary)
        self._do_progress_callback(f"Amenities bbox: {amenities_bbox}")
        
        # Step 3: Query amenities and generate isochrone layers
        self.amenity_layers = []
        total_isochrones_requested = 0
        
        for amenity_name, query in self.amenity_queries.items():
            if self._cancel_flag:
                raise ValueError("Calculation cancelled by user")
            self._do_progress_callback(f"{amenity_name}: Querying Overpass API...")
            
            try:
                features = self.overpass.bbox_query_to_geojson(
                    query, amenities_bbox, self._do_progress_callback if not self._cancel_flag else None
                )
                
                if features:
                    self._do_progress_callback(
                        f"{amenity_name}: Found {len(features)} features, generating isochrones..."
                    )
                    layer_geom, isochrone_count = self._create_amenity_layer(features)
                    total_isochrones_requested += isochrone_count
                    if layer_geom:
                        self.amenity_layers.append(layer_geom)
                else:
                    self._do_progress_callback(f"{amenity_name}: No features found")
            
            except OverpassQueryError as e:
                self._do_progress_callback(f"{amenity_name}: Error - {str(e)}")
        
        # Step 4: Generate or load hex grid
        if self._cancel_flag:
            raise ValueError("Calculation cancelled by user")
        self._do_progress_callback("Generating hex grid...")
        hex_features = create_hex_grid(boundary_geom, self.HEX_SIDE_KM)
        self._do_progress_callback(f"Generated {len(hex_features)} hex cells")
        
        # Step 5: Calculate walk potential for each hex
        if self._cancel_flag:
            raise ValueError("Calculation cancelled by user")
        self._do_progress_callback("Calculating walk potential...")
        hex_features = self._calculate_walk_potential(hex_features)
        
        # Step 6: Create output layer
        result_layer = self._create_result_layer(hex_features)
        
        # Report summary
        num_categories = len(self.amenity_queries)
        num_hex_cells = len(hex_features)
        self._do_progress_callback(
            f"DONE! Walk Potential has been calculated for {num_hex_cells:,} hex cells within the boundary, "
            f"using {num_categories} categories of amenities and {total_isochrones_requested:,} total travel-time isochrones."
        )
        return result_layer
    
    def _get_boundary_geometry(self, layer: QgsVectorLayer) -> QgsGeometry:
        """
        Extract and combine all features from the boundary layer
        
        :param layer: Boundary layer
        :return: Combined boundary geometry
        """
        geometries = []
        for feature in layer.getFeatures():
            geom = feature.geometry()
            if geom and not geom.isNull():
                geometries.append(geom)
        
        if not geometries:
            raise ValueError("Boundary layer contains no valid geometries")
        
        # Union all geometries
        if len(geometries) == 1:
            return geometries[0]
        else:
            return union_geometries(geometries)
    
    def _create_amenity_layer(self, features: List[Dict]) -> tuple[Optional[QgsGeometry], int]:
        """
        Create an amenity layer by generating isochrones for all features
        
        :param features: List of GeoJSON feature dictionaries
        :return: Tuple of (unioned geometry of all isochrones or None, count of isochrones requested)
        """
        # Expand large polygons into multiple points
        expanded_features = self._expand_large_features(features)
        
        # Generate isochrones concurrently
        isochrones = []
        total = len(expanded_features)
        
        # Log a sample URL before starting
        # if total > 0:
        #     self.progress_callback(f"\tDEBUG: Expanded features count: {total}")
        #     self.progress_callback(f"\tDEBUG: First feature type: {type(expanded_features[0])}")
        #     self.progress_callback(f"\tDEBUG: First feature keys: {list(expanded_features[0].keys()) if isinstance(expanded_features[0], dict) else 'N/A'}")
            
        #     if isinstance(expanded_features[0], dict) and 'geometry' in expanded_features[0]:
        #         self.progress_callback(f"\tDEBUG: First feature geometry: {expanded_features[0]['geometry']}")
        #         first_geom = geojson_to_qgs_geometry(expanded_features[0]['geometry'])
        #         if first_geom:
        #             center = first_geom.centroid().asPoint()
        #             sample_url = self.isochrone_fetcher.get_isochrone_url(center.y(), center.x())
        #             self.progress_callback(f"\tDEBUG: : {sample_url}")
        #         else:
        #             self.progress_callback(f"\tDEBUG: geojson_to_qgs_geometry returned None")
        #     else:
        #         self.progress_callback(f"\tDEBUG: No 'geometry' key in first feature")
        
        with ThreadPoolExecutor(max_workers=self.concurrency) as executor:
            # Submit all tasks
            future_to_idx = {}
            for idx, feature in enumerate(expanded_features):
                geom = geojson_to_qgs_geometry(feature['geometry'])
                if geom:
                    center = geom.centroid().asPoint()
                    # Log first URL for debugging
                    if idx == 0:
                        test_url = self.isochrone_fetcher.get_isochrone_url(center.y(), center.x())
                        self.progress_callback(f"\tSample isochrone URL: {test_url}")
                    future = executor.submit(
                        self.isochrone_fetcher.fetch_isochrone,
                        center.x(), center.y()
                    )
                    future_to_idx[future] = idx
            
            # Collect results
            completed = 0
            failed_count = 0
            critical_errors = []
            
            for future in as_completed(future_to_idx):
                completed += 1
                if completed % 10 == 0:
                    self.progress_callback(f"\tProcessed {completed}/{total} isochrones")
                
                try:
                    isochrone_geojson = future.result()
                    if isochrone_geojson:
                        isochrone_geom = geojson_to_qgs_geometry(
                            isochrone_geojson['geometry']
                        )
                        if isochrone_geom:
                            isochrones.append(isochrone_geom)
                    else:
                        failed_count += 1
                except IsochroneError as e:
                    # Critical error (network/service failure)
                    critical_errors.append(str(e))
                    failed_count += 1
                    # If we get critical errors early, fail fast
                    if len(critical_errors) >= 3:
                        raise IsochroneError(
                            f"Multiple isochrone requests failed: {critical_errors[0]}"
                        )
                except Exception as e:
                    # Other errors (parsing, etc) - just count as failed
                    failed_count += 1
        
        # Check results
        if not isochrones:
            if critical_errors:
                raise IsochroneError(
                    f"All isochrone requests failed. {critical_errors[0]}"
                )
            else:
                self.progress_callback(f"\tWarning: No valid isochrones generated from {total} features!")
                return None, total
        
        # Warn if high failure rate
        failure_rate = (failed_count / total) * 100 if total > 0 else 0
        if failure_rate > 2:
            self.progress_callback(
                f"\tWarning: {failure_rate:.1f}% of isochrone requests failed ({failed_count}/{total})"
            )
        
        # Union all isochrones
        self.progress_callback(f"\tSuccess: Generated {len(isochrones)}/{total} isochrones")
        self.progress_callback(f"\tMerging {len(isochrones)} isochrones...")
        return union_geometries(isochrones), total
    
    def _expand_large_features(self, features: List[Dict]) -> List[Dict]:
        """
        Expand large polygon features into multiple point features at their edges
        
        :param features: List of GeoJSON features
        :return: Expanded list of features
        """
        expanded = []
        
        for feature in features:
            geom = geojson_to_qgs_geometry(feature['geometry'])
            if not geom:
                continue
            
            # Check if it's a large polygon
            if geom.type() == QgsWkbTypes.PolygonGeometry:
                area = calculate_area_sqm(geom)
                
                if area > self.HEX_AREA_SQM:
                    # Generate points along the edge
                    # Create a grid within the polygon
                    bbox = get_bbox(geom)
                    hex_features = create_hex_grid(geom, self.HEX_SIDE_KM)
                    
                    # If hex grid creation failed (polygon too small), use centroid
                    if not hex_features:
                        expanded.append(feature)
                        continue
                    
                    # Buffer the polygon inward
                    smaller_geom = geom.buffer(-self.HEX_WIDTH_KM / 111.0, 5)
                    
                    # Keep points near the edge
                    edge_points_found = False
                    for hex_feat in hex_features:
                        center = hex_feat.geometry().centroid()
                        # Point is near edge if it's in original but not in smaller
                        if geom.contains(center):
                            if not smaller_geom or not smaller_geom.contains(center):
                                # Create a point feature
                                point_feature = {
                                    'type': 'Feature',
                                    'geometry': qgs_geometry_to_geojson(center),
                                    'properties': feature.get('properties', {})
                                }
                                expanded.append(point_feature)
                                edge_points_found = True
                    
                    # Fallback to centroid if no edge points found
                    if not edge_points_found:
                        expanded.append(feature)
                else:
                    expanded.append(feature)
            else:
                expanded.append(feature)
        
        return expanded if expanded else features
    
    def _calculate_walk_potential(self, hex_features: List[QgsFeature]) -> List[QgsFeature]:
        """
        Calculate walk potential for each hex feature
        
        :param hex_features: List of hex grid features
        :return: Updated list with walk potential values
        """
        total = len(hex_features)
        num_amenity_layers = len(self.amenity_layers)
        
        if num_amenity_layers == 0:
            self.progress_callback("Warning: No amenity layers found!")
            return hex_features
        
        for idx, hex_feature in enumerate(hex_features):
            if idx % 100 == 0:
                self.progress_callback(f"\t{idx}/{total} cells processed")
            
            center = hex_feature.geometry().centroid()
            walk_potential = 0
            
            # Count how many amenity layers contain this point
            for amenity_layer in self.amenity_layers:
                if amenity_layer.contains(center):
                    walk_potential += 1
            
            # Scale from 0-100
            scaled_potential = int((walk_potential / num_amenity_layers) * 100)
            
            # Update the feature attributes
            hex_feature.setAttributes([scaled_potential])
        
        return hex_features
    
    def _create_result_layer(self, hex_features: List[QgsFeature]) -> QgsVectorLayer:
        """
        Create the result layer with hex grid and walk potential values
        
        :param hex_features: List of hex features with walk potential values
        :return: QgsVectorLayer
        """
        # Create memory layer
        crs = QgsCoordinateReferenceSystem("EPSG:4326")
        layer = QgsVectorLayer(
            f"Polygon?crs={crs.authid()}",
            "Walk Potential",
            "memory"
        )
        
        # Add walk potential field
        provider = layer.dataProvider()
        provider.addAttributes([
            QgsField("walkPotential", QVariant.Int)
        ])
        layer.updateFields()
        
        # Add features
        provider.addFeatures(hex_features)
        layer.updateExtents()
        
        # Apply graduated styling
        self._apply_graduated_style(layer)
        
        return layer
    
    def _apply_graduated_style(self, layer: QgsVectorLayer):
        """
        Apply graduated color styling to the walk potential layer
        
        :param layer: The layer to style
        """
        # Define 5 ranges with colors from red (low) to green (high)
        ranges = [
            (0, 20, '#d73027', 'Very Low'),      # Red
            (20, 40, '#fc8d59', 'Low'),          # Orange
            (40, 60, '#fee08b', 'Medium'),       # Yellow
            (60, 80, '#91cf60', 'High'),         # Light green
            (80, 100, '#1a9850', 'Very High')    # Dark green
        ]
        
        # Create renderer ranges
        renderer_ranges = []
        for lower, upper, color, label in ranges:
            symbol = QgsSymbol.defaultSymbol(layer.geometryType())
            symbol.setColor(QColor(color))
            symbol.setOpacity(0.7)
            
            # Remove outline for cleaner look
            symbol.symbolLayer(0).setStrokeStyle(Qt.NoPen)
            
            renderer_range = QgsRendererRange(
                lower, upper, symbol, label
            )
            renderer_ranges.append(renderer_range)
        
        # Create and apply graduated renderer
        renderer = QgsGraduatedSymbolRenderer('walkPotential', renderer_ranges)
        renderer.setMode(QgsGraduatedSymbolRenderer.Custom)
        layer.setRenderer(renderer)
        layer.triggerRepaint()
