import os
import sys
import traceback
import pandas as pd
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import warnings
import math
import gc
import re
from datetime import datetime, timedelta
warnings.filterwarnings('ignore')

from qgis.core import (
    QgsProcessingAlgorithm,
    QgsProcessingParameterString,
    QgsProcessingParameterEnum,
    QgsProcessingParameterFile,
    QgsProcessingParameterNumber,
    QgsProcessingParameterVectorLayer,
    QgsProcessingParameterRasterLayer,
    QgsProcessingParameterBoolean,
    QgsProcessingParameterFolderDestination,
    QgsProcessingException,
    QgsProcessing,
    QgsProject,
    QgsRasterLayer,
    QgsColorRampShader,
    QgsSingleBandPseudoColorRenderer,
    QgsRasterShader,
    QgsRasterBandStats,
    QgsProcessingOutputFile,
    QgsProcessingOutputFolder,
    QgsProcessingParameterPoint,
    QgsProcessingParameterDateTime,
    QgsCoordinateReferenceSystem,
    QgsPointXY,
    QgsGeometry
)
from qgis.PyQt.QtGui import QColor
from qgis.PyQt.QtCore import QDateTime, QDate, QTime

try:
    from .rainfall_data_sources import download_openmeteo_hourly_data, load_excel_rainfall_data
    from .hydrological_calculations import (
        frequency_analysis, calculate_scs_cn_runoff, calculate_peak_discharge,
        calculate_time_to_peak, generate_scs_unit_hydrograph, generate_flood_hydrograph
    )
    from .flood_mapping import generate_flood_map
except ImportError:
    from rainfall_data_sources import download_openmeteo_hourly_data, load_excel_rainfall_data
    from hydrological_calculations import (
        frequency_analysis, calculate_scs_cn_runoff, calculate_peak_discharge,
        calculate_time_to_peak, generate_scs_unit_hydrograph, generate_flood_hydrograph
    )
    from flood_mapping import generate_flood_map

class FlashFloodAnalysis(QgsProcessingAlgorithm):
    # Parameters for flash flood analysis
    RAINFALL_SOURCE = 'RAINFALL_SOURCE'
    COORDINATES_POINT = 'COORDINATES_POINT'
    COORDINATES_MANUAL = 'COORDINATES_MANUAL'
    START_DATE = 'START_DATE'
    END_DATE = 'END_DATE'
    EXCEL_FILE = 'EXCEL_FILE'
    DATE_COLUMN = 'DATE_COLUMN'
    RAINFALL_COLUMN = 'RAINFALL_COLUMN'
    ANALYSIS_METHOD = 'ANALYSIS_METHOD'
    STORM_DURATION = 'STORM_DURATION'
    RETURN_PERIODS = 'RETURN_PERIODS'
    FREQ_METHOD = 'FREQ_METHOD'
    CATCHMENT_INPUT = 'CATCHMENT_INPUT'
    CATCHMENT = 'CATCHMENT'
    DEM = 'DEM'
    SLOPE_RASTER = 'SLOPE_RASTER'
    FLOW_PATH = 'FLOW_PATH'
    AREA = 'AREA'
    SLOPE = 'SLOPE'
    CN = 'CN'
    RUNOFF_COEFF = 'RUNOFF_COEFF'
    MANNING_N = 'MANNING_N'
    OUTPUT_DIR = 'OUTPUT_DIR'
    GENERATE_MAPS = 'GENERATE_MAPS'
    LOAD_RASTERS = 'LOAD_RASTERS'
    EXPORT_HEC_HMS = 'EXPORT_HEC_HMS'

    # Outputs
    OUTPUT_EXCEL = 'OUTPUT_EXCEL'
    OUTPUT_FLOOD_MAPS = 'OUTPUT_FLOOD_MAPS'
    OUTPUT_FLOOD_RASTERS = 'OUTPUT_FLOOD_RASTERS'
    OUTPUT_HYDROGRAPHS = 'OUTPUT_HYDROGRAPHS'
    OUTPUT_HEC_HMS = 'OUTPUT_HEC_HMS'

    # Analysis methods
    ANALYSIS_METHODS = [
        'SCS Unit Hydrograph',
        'Rational Method (Short Duration)',
        'Time-Area Method'
    ]

    # Storm duration options
    STORM_DURATIONS = [
        '1 hour (Flash Flood)',
        '3 hours (Short Duration)',
        '6 hours (Medium Duration)',
        '12 hours (Long Duration)',
        '24 hours (Daily Storm)'
    ]

    # CN value suggestions
    CN_VALUES = [
        '40 (Woods)', '45 (Farmland)', '50 (Pasture)', 
        '55 (Brush)', '60 (Residential)', '65 (Agricultural)',
        '70 (Agricultural)', '75 (Urban)', '80 (Commercial)',
        '85 (Industrial)', '90 (Paved areas)', '95 (Waterproof)'
    ]
    
    # Runoff coefficient suggestions
    RUNOFF_COEFF_VALUES = [
        '0.10 (Sandy soil)', '0.15 (Forest)', '0.20 (Grassland)',
        '0.30 (Cultivated land)', '0.40 (Residential)', '0.50 (Suburban)',
        '0.60 (Urban)', '0.70 (Commercial)', '0.80 (Industrial)',
        '0.90 (Paved areas)'
    ]
    
    # Manning's n value suggestions
    MANNING_N_VALUES = [
        '0.01 (Smooth concrete)', '0.02 (Finished concrete)', 
        '0.03 (Earth channel)', '0.04 (Gravel channel)', 
        '0.05 (Natural stream)', '0.06 (Weedy stream)', 
        '0.07 (Dense vegetation)', '0.08 (Floodplain)', 
        '0.10 (Heavy brush)', '0.12 (Forest)'
    ]

    def initAlgorithm(self, config=None):
        # Rainfall source
        self.addParameter(QgsProcessingParameterEnum(
            self.RAINFALL_SOURCE, 'Rainfall Source',
            options=['Open-Meteo (Hourly)', 'Excel File (Hourly)'], 
            defaultValue=0))
        
        # Location parameters - Point selection from map
        self.addParameter(QgsProcessingParameterPoint(
            self.COORDINATES_POINT, 'Select Location from Map',
            optional=True,
            defaultValue=''))
            
        # Manual coordinates as fallback
        self.addParameter(QgsProcessingParameterString(
            self.COORDINATES_MANUAL, 'Or Enter Coordinates (Latitude, Longitude)',
            defaultValue='20.0, 77.0',
            optional=True))
            
        # Date and time parameters with calendar widgets
        self.addParameter(QgsProcessingParameterDateTime(
            self.START_DATE, 'Start Date and Time',
            type=QgsProcessingParameterDateTime.DateTime,
            defaultValue=QDateTime(QDate.currentDate().addDays(-30), QTime(0, 0, 0))))
            
        self.addParameter(QgsProcessingParameterDateTime(
            self.END_DATE, 'End Date and Time',
            type=QgsProcessingParameterDateTime.DateTime,
            defaultValue=QDateTime(QDate.currentDate(), QTime(23, 0, 0))))
        
        # Excel file input (conditional)
        self.addParameter(QgsProcessingParameterFile(
            self.EXCEL_FILE, 'Excel File with Hourly Rainfall Data',
            extension='xlsx', optional=True))
            
        self.addParameter(QgsProcessingParameterString(
            self.DATE_COLUMN, 'Date-Time Column Name',
            defaultValue='DateTime', optional=True))
            
        self.addParameter(QgsProcessingParameterString(
            self.RAINFALL_COLUMN, 'Rainfall Column Name',
            defaultValue='Rainfall_mm', optional=True))
        
        # Analysis parameters
        self.addParameter(QgsProcessingParameterEnum(
            self.ANALYSIS_METHOD, 'Flash Flood Analysis Method',
            options=self.ANALYSIS_METHODS, defaultValue=0))
            
        # Storm duration selection
        self.addParameter(QgsProcessingParameterEnum(
            self.STORM_DURATION, 'Storm Duration for Analysis',
            options=self.STORM_DURATIONS, defaultValue=0))
            
        self.addParameter(QgsProcessingParameterString(
            self.RETURN_PERIODS, 'Return Periods (comma separated)',
            defaultValue='10,25,50,100'))
        
        self.addParameter(QgsProcessingParameterEnum(
            self.FREQ_METHOD, 'Frequency Analysis Method',
            options=['Gumbel', 'Log-Pearson III', 'GEV'], defaultValue=0))
        
        # Catchment parameters
        self.addParameter(QgsProcessingParameterEnum(
            self.CATCHMENT_INPUT, 'Catchment Input Method',
            options=['Shapefile & DEM', 'Manual Parameters'], defaultValue=0))
        
        self.addParameter(QgsProcessingParameterVectorLayer(
            self.CATCHMENT, 'Catchment Boundary',
            types=[QgsProcessing.TypeVectorPolygon], optional=True))
        
        self.addParameter(QgsProcessingParameterRasterLayer(
            self.DEM, 'DEM Raster', optional=True))
            
        self.addParameter(QgsProcessingParameterRasterLayer(
            self.SLOPE_RASTER, 'Slope Raster (percent)', optional=True))
        
        self.addParameter(QgsProcessingParameterVectorLayer(
            self.FLOW_PATH, 'Longest Flow Path (Line)',
            types=[QgsProcessing.TypeVectorLine], optional=True))
        
        self.addParameter(QgsProcessingParameterNumber(
            self.AREA, 'Catchment Area (km²)',
            QgsProcessingParameterNumber.Double, optional=True, minValue=0.1))
        
        self.addParameter(QgsProcessingParameterNumber(
            self.SLOPE, 'Mean Slope (decimal)',
            QgsProcessingParameterNumber.Double, optional=True, minValue=0.001))
        
        self.addParameter(QgsProcessingParameterString(
            self.CN, 'Curve Number', 
            defaultValue='65 (Agricultural)'))
            
        self.addParameter(QgsProcessingParameterString(
            self.RUNOFF_COEFF, 'Runoff Coefficient (for Rational Method)',
            defaultValue='0.75 (Urban)', optional=True))
            
        self.addParameter(QgsProcessingParameterString(
            self.MANNING_N, "Manning's Roughness Coefficient",
            defaultValue='0.05 (Natural stream)'))
        
        self.addParameter(QgsProcessingParameterBoolean(
            self.GENERATE_MAPS, 'Generate Flood Maps', defaultValue=True))
        
        self.addParameter(QgsProcessingParameterBoolean(
            self.LOAD_RASTERS, 'Load output flood rasters in QGIS', defaultValue=True))
        
        self.addParameter(QgsProcessingParameterBoolean(
            self.EXPORT_HEC_HMS, 'Export HEC-HMS Input Files', defaultValue=True))
        
        self.addParameter(QgsProcessingParameterFolderDestination(
            self.OUTPUT_DIR, 'Output Directory'))

        # Outputs
        self.addOutput(QgsProcessingOutputFile(
            self.OUTPUT_EXCEL, 'Results Excel File'))
        
        self.addOutput(QgsProcessingOutputFolder(
            self.OUTPUT_FLOOD_MAPS, 'Flood Maps Folder (PNG)'))
            
        self.addOutput(QgsProcessingOutputFolder(
            self.OUTPUT_FLOOD_RASTERS, 'Flood Rasters Folder (GeoTIFF)'))
            
        self.addOutput(QgsProcessingOutputFolder(
            self.OUTPUT_HYDROGRAPHS, 'Hydrographs Folder (PNG)'))
            
        self.addOutput(QgsProcessingOutputFolder(
            self.OUTPUT_HEC_HMS, 'HEC-HMS Input Files Folder'))

    def processAlgorithm(self, parameters, context, feedback):
        try:
            # Import required modules with error handling
            try:
                import geopandas as gpd
                import rasterio
                from rasterio.mask import mask
                from rasterio import features
                from shapely.geometry import mapping
                from scipy.stats import genextreme, gumbel_r, pearson3
                from scipy import ndimage
            except ImportError as e:
                raise QgsProcessingException(f"Required Python packages not available: {str(e)}")

            # Get coordinates - prefer point selection over manual input
            point = self.parameterAsPoint(parameters, self.COORDINATES_POINT, context)
            manual_coords = self.parameterAsString(parameters, self.COORDINATES_MANUAL, context)
            
            if point and not point.isEmpty():
                # Convert point to lat/lon (assuming input is in EPSG:4326)
                lat = point.y()
                lon = point.x()
                feedback.pushInfo(f"Using selected point: {lat:.6f}°N, {lon:.6f}°E")
            elif manual_coords and manual_coords.strip():
                # Use manual coordinates
                try:
                    parts = manual_coords.split(',')
                    lat = float(parts[0].strip())
                    lon = float(parts[1].strip())
                    feedback.pushInfo(f"Using manual coordinates: {lat:.6f}°N, {lon:.6f}°E")
                except:
                    raise QgsProcessingException("Invalid manual coordinate format. Use: 'latitude, longitude'")
            else:
                raise QgsProcessingException("Please provide coordinates either by selecting a point on map or entering manually")

            # Validate coordinates
            if not (-90 <= lat <= 90) or not (-180 <= lon <= 180):
                raise QgsProcessingException("Invalid coordinates. Latitude must be between -90 and 90, Longitude between -180 and 180")

            # Get date/time parameters
            start_dt = self.parameterAsDateTime(parameters, self.START_DATE, context)
            end_dt = self.parameterAsDateTime(parameters, self.END_DATE, context)
            
            # Format dates for API
            start_date_str = start_dt.toString('yyyy-MM-dd HH:mm')
            end_date_str = end_dt.toString('yyyy-MM-dd HH:mm')
            
            feedback.pushInfo(f"Analysis period: {start_date_str} to {end_date_str}")

            # Get other parameters
            return_periods = [int(rp) for rp in self.parameterAsString(parameters, self.RETURN_PERIODS, context).split(',')]
            rainfall_source = self.parameterAsEnum(parameters, self.RAINFALL_SOURCE, context)
            analysis_method = self.parameterAsEnum(parameters, self.ANALYSIS_METHOD, context)
            storm_duration_enum = self.parameterAsEnum(parameters, self.STORM_DURATION, context)
            freq_method = self.parameterAsEnum(parameters, self.FREQ_METHOD, context) + 1
            catchment_input = self.parameterAsEnum(parameters, self.CATCHMENT_INPUT, context)
            load_rasters = self.parameterAsBoolean(parameters, self.LOAD_RASTERS, context)
            export_hec_hms = self.parameterAsBoolean(parameters, self.EXPORT_HEC_HMS, context)
            
            # Map storm duration enum to actual hours
            storm_duration_map = {0: 1, 1: 3, 2: 6, 3: 12, 4: 24}
            storm_duration_hr = storm_duration_map[storm_duration_enum]
            feedback.pushInfo(f"Selected storm duration: {storm_duration_hr} hours")
            
            # Process CN value
            cn_input = self.parameterAsString(parameters, self.CN, context).strip()
            cn_value = self.parse_value_from_input(cn_input, self.CN_VALUES, "CN", 65, feedback)
            feedback.pushInfo(f"Using Curve Number: {cn_value}")
            
            # Process runoff coefficient
            runoff_coeff_input = self.parameterAsString(parameters, self.RUNOFF_COEFF, context)
            if runoff_coeff_input:
                runoff_coeff_input = runoff_coeff_input.strip()
                runoff_coeff_value = self.parse_value_from_input(
                    runoff_coeff_input, self.RUNOFF_COEFF_VALUES, "Runoff Coefficient", 0.75, feedback
                )
                feedback.pushInfo(f"Using Runoff Coefficient: {runoff_coeff_value}")
            else:
                runoff_coeff_value = 0.75
                feedback.pushInfo("Using default Runoff Coefficient: 0.75")
            
            # Process Manning's n
            manning_n_input = self.parameterAsString(parameters, self.MANNING_N, context).strip()
            manning_n_value = self.parse_value_from_input(
                manning_n_input, self.MANNING_N_VALUES, "Manning's n", 0.05, feedback
            )
            feedback.pushInfo(f"Using Manning's n: {manning_n_value}")
            
            # Get layers from project
            catchment_layer = self.parameterAsVectorLayer(parameters, self.CATCHMENT, context)
            dem_layer = self.parameterAsRasterLayer(parameters, self.DEM, context)
            slope_raster_layer = self.parameterAsRasterLayer(parameters, self.SLOPE_RASTER, context)
            flow_path_layer = self.parameterAsVectorLayer(parameters, self.FLOW_PATH, context)
            
            catchment_path = catchment_layer.source() if catchment_layer else None
            dem_path = dem_layer.source() if dem_layer else None
            slope_raster_path = slope_raster_layer.source() if slope_raster_layer else None
            flow_path = flow_path_layer.source() if flow_path_layer else None
            
            area = self.parameterAsDouble(parameters, self.AREA, context)
            slope = self.parameterAsDouble(parameters, self.SLOPE, context)
            generate_maps = self.parameterAsBoolean(parameters, self.GENERATE_MAPS, context)
            output_dir = self.parameterAsString(parameters, self.OUTPUT_DIR, context)

            # Create output directory
            os.makedirs(output_dir, exist_ok=True)
            
            # Run analysis
            feedback.pushInfo("Starting flash flood analysis...")
            feedback.pushInfo(f"Output directory: {output_dir}")

            # ===== CATCHMENT PARAMETERS =====
            area_m2 = None
            slope_percent = None
            longest_flow_path = None
            
            if catchment_input == 0:  # Shapefile & DEM
                if not catchment_path or not dem_path:
                    raise QgsProcessingException("Both catchment boundary and DEM are required")
                
                if not flow_path:
                    raise QgsProcessingException("Longest flow path vector is required for automated method")
                    
                if not slope_raster_path:
                    raise QgsProcessingException("Slope raster is required for automated method")
                
                try:
                    # FIXED: Improved memory management for geopandas read_file
                    feedback.pushInfo("Reading catchment boundary with improved memory handling...")
                    
                    # Force garbage collection before reading
                    gc.collect()
                    
                    # Read with explicit CRS handling to avoid pyproj issues
                    catchment = self.safe_read_file(catchment_path, feedback)
                    
                    if catchment.crs is None:
                        feedback.pushWarning("Catchment CRS is not defined. Assuming WGS84.")
                        catchment.crs = 'EPSG:4326'
                    
                    # Convert to projected CRS for area calculation
                    if catchment.crs.is_geographic:
                        centroid = catchment.geometry.unary_union.centroid
                        utm_zone = int((centroid.x + 180) // 6 + 1)
                        hemisphere = 'north' if centroid.y >= 0 else 'south'
                        crs_utm = f"EPSG:326{utm_zone:02d}" if hemisphere == 'north' else f"EPSG:327{utm_zone:02d}"
                        
                        feedback.pushInfo(f"Reprojecting catchment to UTM zone {utm_zone} ({hemisphere})")
                        catchment = catchment.to_crs(crs_utm)
                    
                    area_m2 = catchment.geometry.area.sum()
                    area = area_m2 / 1e6
                    feedback.pushInfo(f"Computed catchment area: {area:.3f} km²")
                    
                except Exception as e:
                    raise QgsProcessingException(f"Error computing area: {str(e)}")

                # Calculate slope from slope raster
                try:
                    with rasterio.open(slope_raster_path) as src:
                        # Convert catchment to raster CRS
                        catchment_slope = catchment.to_crs(src.crs)
                        geoms = [mapping(geom) for geom in catchment_slope.geometry]
                        
                        out_image, out_transform = mask(src, geoms, crop=True, nodata=np.nan)
                        slope_data = out_image[0]
                        
                        valid_slope = slope_data[(slope_data > 0) & (~np.isnan(slope_data))]
                        if len(valid_slope) == 0:
                            feedback.pushWarning("No valid slope values found in catchment. Using manual slope input.")
                            if not slope or slope <= 0:
                                slope_percent = 0.1
                            else:
                                slope_percent = slope * 100
                        else:
                            slope_percent = np.nanmean(valid_slope)
                            
                        feedback.pushInfo(f"Computed mean slope from raster: {slope_percent:.4f}%")
                        
                        # Clean up
                        del out_image, slope_data, valid_slope
                        gc.collect()
                        
                except Exception as e:
                    feedback.pushWarning(f"Error computing slope from raster: {str(e)}")
                    if not slope or slope <= 0:
                        slope_percent = 0.1
                    else:
                        slope_percent = slope * 100
                    feedback.pushInfo(f"Using manual slope input: {slope_percent:.4f}%")
                
                # Calculate longest flow path from vector
                longest_flow_path = self.calculate_flow_path_length(flow_path, catchment.crs, feedback)
                feedback.pushInfo(f"Computed longest flow path: {longest_flow_path:.2f} m")
            
            elif catchment_input == 1:  # Manual Parameters
                if not area or area <= 0:
                    raise QgsProcessingException("Catchment area must be >0")
                if not slope or slope <= 0:
                    slope = 0.1
                    feedback.pushWarning("Using default slope 0.1")
                slope_percent = slope * 100  # Convert to percent
                
                # Use manual flow path if provided, otherwise estimate
                if flow_path:
                    longest_flow_path = self.calculate_flow_path_length(flow_path, None, feedback)
                else:
                    area_m2 = area * 1e6
                    longest_flow_path = math.sqrt(area_m2) * 2
                    feedback.pushInfo(f"Estimated longest flow path: {longest_flow_path:.2f} m")

            if slope_percent < 0.1:
                slope_percent = 0.1
                feedback.pushWarning(f"Adjusted slope to safe minimum: {slope_percent}%")

            # === HOURLY RAINFALL LOADING ===
            rainfall_df = None
            if rainfall_source == 0:  # Open-Meteo Hourly
                feedback.pushInfo(f"Downloading Open-Meteo hourly data from {start_date_str} to {end_date_str}")
                rainfall_df = download_openmeteo_hourly_data(lat, lon, start_date_str, end_date_str, feedback)
            elif rainfall_source == 1:  # Excel Hourly
                excel_file = self.parameterAsString(parameters, self.EXCEL_FILE, context)
                date_column = self.parameterAsString(parameters, self.DATE_COLUMN, context)
                rainfall_column = self.parameterAsString(parameters, self.RAINFALL_COLUMN, context)
                feedback.pushInfo(f"Loading Excel hourly data from {excel_file}")
                rainfall_df = load_excel_rainfall_data(excel_file, date_column, rainfall_column, feedback)
            
            if rainfall_df is None or rainfall_df.empty:
                raise QgsProcessingException("Failed to load hourly rainfall data")
            
            # Validate rainfall data
            if rainfall_df['Rainfall (mm)'].isnull().all():
                raise QgsProcessingException("All rainfall values are missing")
                
            if (rainfall_df['Rainfall (mm)'] < 0).any():
                feedback.pushWarning("Negative rainfall values found. Setting to zero.")
                rainfall_df['Rainfall (mm)'] = rainfall_df['Rainfall (mm)'].clip(lower=0)
            
            # Save raw data
            raw_data_path = os.path.join(output_dir, "raw_hourly_rainfall_data.csv")
            rainfall_df.to_csv(raw_data_path)
            feedback.pushInfo(f"Raw hourly rainfall data saved to {raw_data_path}")
            
            # === FLASH FLOOD FREQUENCY ANALYSIS ===
            feedback.pushInfo("Performing flash flood frequency analysis...")
            
            # Calculate short-duration maxima (1-hour, 3-hour, 6-hour, 12-hour, 24-hour)
            durations = [1, 3, 6, 12, 24]  # hours for flash flood analysis
            short_duration_max = self.calculate_short_duration_maxima(rainfall_df, durations, feedback)
            
            # Perform frequency analysis for each duration
            design_storms = {}
            for duration in durations:
                if duration in short_duration_max and len(short_duration_max[duration]) > 0:
                    design_rainfalls, freq_method_name, freq_params = frequency_analysis(
                        short_duration_max[duration], return_periods, freq_method, feedback
                    )
                    design_storms[duration] = design_rainfalls
                    feedback.pushInfo(f"Duration {duration}hr design rainfalls: {design_rainfalls}")
            
            # === FLASH FLOOD RUNOFF CALCULATIONS ===
            feedback.pushInfo("Calculating flash flood runoff...")
            
            # Calculate time to peak using SCS Lag Equation
            tp = calculate_time_to_peak(longest_flow_path, slope_percent, cn_value, feedback)
            feedback.pushInfo(f"Time to Peak (Tp): {tp:.2f} hr")
            
            runoff_results = []
            for rp in return_periods:
                # Use selected storm duration for analysis
                duration_hr = storm_duration_hr
                if duration_hr in design_storms:
                    design_rainfall = design_storms[duration_hr][return_periods.index(rp)]
                else:
                    # Fallback: use nearest available duration
                    available_durations = list(design_storms.keys())
                    nearest_duration = min(available_durations, key=lambda x: abs(x - duration_hr))
                    design_rainfall = design_storms[nearest_duration][return_periods.index(rp)]
                    feedback.pushWarning(f"Using {nearest_duration}-hour duration as fallback for {duration_hr}-hour design")
                
                if analysis_method == 0:  # SCS Unit Hydrograph
                    runoff_depth = calculate_scs_cn_runoff([design_rainfall], cn_value)[0]
                    Q = calculate_peak_discharge(runoff_depth, area, tp)
                    method_name = "SCS Unit Hydrograph"
                    col_name = 'Runoff Depth (mm)'
                    col_value = runoff_depth
                    
                elif analysis_method == 1:  # Rational Method
                    intensity = design_rainfall / duration_hr
                    Q = 0.278 * runoff_coeff_value * area * intensity
                    method_name = "Rational Method"
                    col_name = 'Rainfall Intensity (mm/hr)'
                    col_value = intensity
                    
                else:  # Time-Area Method
                    # Simplified time-area method
                    time_concentration = 0.77 * (longest_flow_path**3 / slope_percent)**0.385
                    intensity = design_rainfall / duration_hr
                    Q = 0.278 * runoff_coeff_value * area * intensity * (duration_hr / time_concentration)
                    method_name = "Time-Area Method"
                    col_name = 'Time Concentration (hr)'
                    col_value = time_concentration
                
                # Check discharge
                if Q > 1000:
                    feedback.pushWarning(f"High discharge value ({Q:.1f} m³/s) for {rp}-year return period")
                
                runoff_results.append({
                    'Return Period (yr)': rp,
                    f'{duration_hr}-hr Rainfall (mm)': design_rainfall,
                    col_name: col_value,
                    'Discharge (m³/s)': Q
                })
            
            df_out = pd.DataFrame(runoff_results)
            
            # === IMPROVED HYDROGRAPH GENERATION ===
            hydrograph_paths = []
            hydrograph_data = {}
            if analysis_method == 0:  # Only for SCS Unit Hydrograph method
                feedback.pushInfo("Generating flash flood hydrographs...")
                hydrograph_dir = os.path.join(output_dir, "hydrographs")
                os.makedirs(hydrograph_dir, exist_ok=True)
                
                # Generate unit hydrograph
                time_uh, discharge_uh = generate_scs_unit_hydrograph(area, tp, feedback)
                unit_hydrograph_path = os.path.join(hydrograph_dir, "unit_hydrograph.png")
                self.plot_improved_hydrograph(time_uh, discharge_uh, "SCS Unit Hydrograph", 
                                            "Time (hr)", "Discharge (m³/s)", unit_hydrograph_path,
                                            hydrograph_type='unit')
                hydrograph_paths.append(unit_hydrograph_path)
                
                # Generate flash flood hydrographs for each return period
                for _, row in df_out.iterrows():
                    rp = row['Return Period (yr)']
                    runoff_depth = row['Runoff Depth (mm)']
                    
                    # Generate flood hydrograph using convolution for more realistic shape
                    time_fh, discharge_fh = self.generate_improved_flood_hydrograph(
                        time_uh, discharge_uh, runoff_depth, storm_duration_hr, feedback
                    )
                    
                    hydrograph_data[rp] = {
                        'time': time_fh,
                        'discharge': discharge_fh
                    }
                    
                    # Plot and save with improved visualization
                    fh_path = os.path.join(hydrograph_dir, f"flash_flood_hydrograph_{rp}yr.png")
                    self.plot_improved_hydrograph(
                        time_fh, discharge_fh,
                        f"{rp}-Year Flash Flood Hydrograph\n({storm_duration_hr}-hr storm, Peak: {discharge_fh.max():.1f} m³/s)",
                        "Time (hr)", "Discharge (m³/s)",
                        fh_path,
                        hydrograph_type='flood'
                    )
                    hydrograph_paths.append(fh_path)
            
            # === FLOOD MAPPING ===
            flood_map_paths = []
            flood_raster_paths = []
            if generate_maps and catchment_path and dem_path:
                feedback.pushInfo("Generating flash flood inundation maps...")
                for _, row in df_out.iterrows():
                    rp = row['Return Period (yr)']
                    discharge = row['Discharge (m³/s)']
                    
                    flood_path, flood_raster = generate_flood_map(
                        discharge=discharge,
                        dem_path=dem_path,
                        catchment_path=catchment_path,
                        output_dir=output_dir,
                        return_period=rp,
                        manning_n=manning_n_value,
                        feedback=feedback
                    )
                    
                    if flood_path and flood_raster:
                        flood_map_paths.append(flood_path)
                        flood_raster_paths.append(flood_raster)
            
            # === HEC-HMS EXPORT ===
            hec_hms_paths = []
            if export_hec_hms:
                feedback.pushInfo("Exporting HEC-HMS input files...")
                hec_dir = os.path.join(output_dir, "hec_hms")
                os.makedirs(hec_dir, exist_ok=True)
                
                if analysis_method == 0 and hydrograph_data:  # SCS Unit Hydrograph
                    # Export unit hydrograph
                    uh_path = os.path.join(hec_dir, "unit_hydrograph.csv")
                    pd.DataFrame({
                        'Time (hr)': time_uh,
                        'Discharge (m³/s)': discharge_uh
                    }).to_csv(uh_path, index=False)
                    hec_hms_paths.append(uh_path)
                    
                    # Export flood hydrographs
                    for rp, data in hydrograph_data.items():
                        fh_path = os.path.join(hec_dir, f"flash_flood_hydrograph_{rp}yr.csv")
                        pd.DataFrame({
                            'Time (hr)': data['time'],
                            'Discharge (m³/s)': data['discharge']
                        }).to_csv(fh_path, index=False)
                        hec_hms_paths.append(fh_path)
                
                # Export design storms for all durations
                for duration, storms in design_storms.items():
                    storm_path = os.path.join(hec_dir, f"design_storm_{duration}hr.csv")
                    storm_df = pd.DataFrame({
                        'Return_Period': return_periods,
                        'Rainfall_mm': storms
                    })
                    storm_df.to_csv(storm_path, index=False)
                    hec_hms_paths.append(storm_path)
                
                # Export parameters
                params_path = os.path.join(hec_dir, "flash_flood_parameters.txt")
                with open(params_path, 'w') as f:
                    f.write(f"Area (km²): {area:.2f}\n")
                    f.write(f"Curve Number: {cn_value}\n")
                    f.write(f"Runoff Coefficient: {runoff_coeff_value}\n")
                    f.write(f"Storm Duration (hr): {storm_duration_hr}\n")
                    f.write(f"Time to Peak (hr): {tp:.2f}\n")
                    f.write(f"Longest Flow Path (m): {longest_flow_path:.2f}\n")
                    f.write(f"Mean Slope (%): {slope_percent:.2f}\n")
                    f.write(f"Analysis Method: {self.ANALYSIS_METHODS[analysis_method]}\n")
                    f.write(f"Frequency Method: {['Gumbel', 'Log-Pearson III', 'GEV'][freq_method-1]}\n")
                hec_hms_paths.append(params_path)
            
            # Create summary
            summary_data = {
                'Parameter': [
                    'Catchment Area (km²)', 'Curve Number', 'Runoff Coefficient',
                    "Manning's n", 'Storm Duration (hr)', 'Time to Peak (hr)', 
                    'Longest Flow Path (m)', 'Mean Slope (%)', 'Rainfall Source', 
                    'Analysis Method', 'Frequency Method', 'Data Period', 
                    'Short Durations Analyzed'
                ],
                'Value': [
                    f"{area:.2f}", f"{cn_value}", f"{runoff_coeff_value:.2f}",
                    f"{manning_n_value:.4f}", f"{storm_duration_hr}", f"{tp:.2f}", 
                    f"{longest_flow_path:.2f}", f"{slope_percent:.2f}",
                    'Open-Meteo' if rainfall_source == 0 else 'Excel',
                    self.ANALYSIS_METHODS[analysis_method],
                    ['Gumbel', 'Log-Pearson III', 'GEV'][freq_method-1],
                    f"{rainfall_df.index.min()} to {rainfall_df.index.max()}",
                    ', '.join([f'{d}hr' for d in durations if d in design_storms])
                ]
            }
            
            df_summary = pd.DataFrame(summary_data)
            
            # === SAVE RESULTS ===
            excel_path = os.path.join(output_dir, "Flash_Flood_Analysis_Results.xlsx")
            
            with pd.ExcelWriter(excel_path) as writer:
                df_out.to_excel(writer, sheet_name='Results', index=False)
                df_summary.to_excel(writer, sheet_name='Parameters', index=False)
                
                # Add design storms sheet
                design_storm_data = []
                for duration in durations:
                    if duration in design_storms:
                        for rp, rain in zip(return_periods, design_storms[duration]):
                            design_storm_data.append({
                                'Duration (hr)': duration,
                                'Return Period (yr)': rp,
                                'Design Rainfall (mm)': rain,
                                'Intensity (mm/hr)': rain / duration
                            })
                if design_storm_data:
                    design_storm_df = pd.DataFrame(design_storm_data)
                    design_storm_df.to_excel(writer, sheet_name='Design Storms', index=False)
                
                # Add short duration maxima sheet
                short_duration_data = []
                for duration, values in short_duration_max.items():
                    for year_idx, value in enumerate(values):
                        short_duration_data.append({
                            'Duration (hr)': duration,
                            'Year': year_idx + 1,
                            'Max Rainfall (mm)': value
                        })
                if short_duration_data:
                    short_duration_df = pd.DataFrame(short_duration_data)
                    short_duration_df.to_excel(writer, sheet_name='Short Duration Maxima', index=False)
                
                if flood_map_paths:
                    flood_df = pd.DataFrame({
                        'Return Period': return_periods,
                        'Image Path': flood_map_paths,
                        'Raster Path': flood_raster_paths
                    })
                    flood_df.to_excel(writer, sheet_name='Flood Maps', index=False)
                
                if hydrograph_paths:
                    hydro_df = pd.DataFrame({
                        'Hydrograph Type': ['Unit'] + [f'{rp}-Year' for rp in return_periods],
                        'Image Path': hydrograph_paths
                    })
                    hydro_df.to_excel(writer, sheet_name='Hydrographs', index=False)
            
            # Generate plots
            self.generate_flash_flood_plots(output_dir, df_out, design_storms, durations, return_periods, feedback)
            
            # Load flood rasters in QGIS if requested
            if load_rasters and flood_raster_paths:
                feedback.pushInfo("Loading flash flood rasters in QGIS...")
                for raster_path in flood_raster_paths:
                    if os.path.exists(raster_path):
                        layer_name = f"Flash Flood {storm_duration_hr}hr - " + os.path.basename(raster_path).replace('.tif', '')
                        layer = QgsRasterLayer(raster_path, layer_name)
                        if layer.isValid():
                            # Apply color ramp
                            stats = layer.dataProvider().bandStatistics(1, QgsRasterBandStats.All)
                            min_val = stats.minimumValue
                            max_val = stats.maximumValue
                            
                            # Create color ramp from blue to red
                            color_ramp = QgsColorRampShader()
                            color_ramp.setColorRampType(QgsColorRampShader.Interpolated)
                            
                            # Define color stops
                            items = [
                                QgsColorRampShader.ColorRampItem(0.0, QColor(0, 0, 0, 0), "No Flood"),
                                QgsColorRampShader.ColorRampItem(0.01, QColor(173, 216, 230), "Shallow"),
                                QgsColorRampShader.ColorRampItem(max_val/2, QColor(0, 0, 255), "Medium"),
                                QgsColorRampShader.ColorRampItem(max_val, QColor(255, 0, 0), "Deep")
                            ]
                            color_ramp.setColorRampItemList(items)
                            
                            # Create shader
                            shader = QgsRasterShader()
                            shader.setRasterShaderFunction(color_ramp)
                            
                            # Create renderer
                            renderer = QgsSingleBandPseudoColorRenderer(
                                layer.dataProvider(), 1, shader
                            )
                            renderer.setClassificationMin(min_val)
                            renderer.setClassificationMax(max_val)
                            
                            # Apply renderer
                            layer.setRenderer(renderer)
                            layer.triggerRepaint()
                            
                            QgsProject.instance().addMapLayer(layer)
                            feedback.pushInfo(f"Loaded raster: {layer_name}")
                        else:
                            feedback.pushWarning(f"Failed to load raster: {raster_path}")
            
            # Clean up memory
            gc.collect()
            
            # Return outputs
            flood_maps_dir = os.path.join(output_dir, "flood_maps")
            flood_rasters_dir = os.path.join(output_dir, "flood_rasters")
            hydrographs_dir = os.path.join(output_dir, "hydrographs")
            hec_hms_dir = os.path.join(output_dir, "hec_hms")
            return {
                self.OUTPUT_EXCEL: excel_path,
                self.OUTPUT_FLOOD_MAPS: flood_maps_dir,
                self.OUTPUT_FLOOD_RASTERS: flood_rasters_dir,
                self.OUTPUT_HYDROGRAPHS: hydrographs_dir,
                self.OUTPUT_HEC_HMS: hec_hms_dir
            }
            
        except Exception as e:
            feedback.reportError(f"Flash flood analysis failed: {str(e)}")
            feedback.pushInfo(traceback.format_exc())
            raise
        finally:
            # Force garbage collection
            gc.collect()

    # ===== IMPROVED MEMORY MANAGEMENT FUNCTIONS =====
    def safe_read_file(self, file_path, feedback):
        """Safely read vector files with improved memory management"""
        try:
            import geopandas as gpd
            
            # Force garbage collection before reading
            gc.collect()
            
            # Read file with explicit CRS handling
            gdf = gpd.read_file(file_path)
            
            # Validate CRS
            if gdf.crs is None:
                feedback.pushWarning(f"CRS not defined in {file_path}. Assuming WGS84.")
                gdf.crs = 'EPSG:4326'
            else:
                # Ensure CRS is valid
                try:
                    crs_str = str(gdf.crs)
                    feedback.pushInfo(f"File CRS: {crs_str}")
                except Exception as e:
                    feedback.pushWarning(f"CRS validation warning: {str(e)}")
                    gdf.crs = 'EPSG:4326'  # Fallback to WGS84
            
            return gdf
            
        except Exception as e:
            feedback.reportError(f"Error reading file {file_path}: {str(e)}")
            raise

    def parse_value_from_input(self, input_str, value_list, param_name, default, feedback):
        """
        Parse numeric value from user input (either manual number or dropdown selection)
        """
        try:
            # Try to extract number from descriptive format
            if input_str in value_list:
                # User selected from dropdown
                return float(input_str.split()[0])
            
            # Try direct numeric conversion
            try:
                return float(input_str)
            except:
                pass
                
            # Extract first numeric value from string
            match = re.search(r"[-+]?\d*\.\d+|\d+", input_str)
            if match:
                return float(match.group())
                
            # If no numbers found, use default
            feedback.pushWarning(f"Couldn't parse {param_name} from '{input_str}'. Using default: {default}")
            return default
            
        except Exception as e:
            feedback.pushWarning(f"Error parsing {param_name}: {str(e)}. Using default: {default}")
            return default

    def calculate_flow_path_length(self, flow_path, target_crs, feedback):
        """Calculate length of longest flow path from vector"""
        try:
            import geopandas as gpd
            
            flow_path_gdf = self.safe_read_file(flow_path, feedback)
            
            # Reproject to target CRS if needed
            if target_crs and flow_path_gdf.crs != target_crs:
                flow_path_gdf = flow_path_gdf.to_crs(target_crs)
            
            # Calculate length of all features
            lengths = flow_path_gdf.geometry.length
            longest_length = max(lengths)  # in meters
            
            if longest_length <= 0:
                raise ValueError("Invalid flow path length (<=0)")
                
            feedback.pushInfo(f"Longest flow path: {longest_length:.2f} m")
            
            # Clean up
            del flow_path_gdf, lengths
            gc.collect()
            
            return longest_length
            
        except Exception as e:
            feedback.reportError(f"Error calculating flow path: {str(e)}")
            return 0

    def calculate_short_duration_maxima(self, rainfall_df, durations, feedback):
        """Calculate annual maxima for short durations"""
        short_duration_max = {}
        
        # Ensure we have datetime index
        if not isinstance(rainfall_df.index, pd.DatetimeIndex):
            rainfall_df.index = pd.to_datetime(rainfall_df.index)
        
        for duration in durations:
            # Calculate rolling sum for each year
            yearly_max = []
            years = rainfall_df.index.year.unique()
            
            for year in years:
                year_data = rainfall_df[rainfall_df.index.year == year]
                if len(year_data) >= duration:
                    # Calculate rolling sum for the duration
                    rolling_sum = year_data['Rainfall (mm)'].rolling(
                        window=duration, min_periods=1).sum()
                    max_rainfall = rolling_sum.max()
                    yearly_max.append(max_rainfall)
            
            if yearly_max:
                short_duration_max[duration] = yearly_max
                feedback.pushInfo(f"Calculated {duration}-hour maxima: {len(yearly_max)} years, "
                                f"Max={max(yearly_max):.1f} mm, Mean={np.mean(yearly_max):.1f} mm")
            else:
                feedback.pushWarning(f"No valid data for {duration}-hour duration")
        
        return short_duration_max

    def plot_improved_hydrograph(self, time, discharge, title, xlabel, ylabel, output_path, hydrograph_type='flood'):
        """Plot and save improved hydrograph with better visualization"""
        plt.figure(figsize=(12, 8))
        
        if hydrograph_type == 'unit':
            color = 'blue'
            alpha = 0.7
            line_style = '-'
            fill_alpha = 0.3
        else:  # flood hydrograph
            color = 'red'
            alpha = 0.8
            line_style = '-'
            fill_alpha = 0.4
        
        # Plot main hydrograph
        plt.plot(time, discharge, color=color, linewidth=2.5, linestyle=line_style, alpha=alpha)
        
        # Fill under the curve
        plt.fill_between(time, discharge, alpha=fill_alpha, color=color)
        
        # Add peak discharge point
        peak_idx = np.argmax(discharge)
        peak_time = time[peak_idx]
        peak_discharge = discharge[peak_idx]
        
        plt.plot(peak_time, peak_discharge, 'o', markersize=8, color='darkred', 
                markeredgecolor='white', markeredgewidth=2)
        
        # Add peak annotation
        plt.annotate(f'Peak: {peak_discharge:.1f} m³/s\nTime: {peak_time:.1f} hr', 
                    xy=(peak_time, peak_discharge), 
                    xytext=(peak_time + 0.5, peak_discharge * 0.9),
                    arrowprops=dict(arrowstyle='->', color='black', lw=1),
                    fontsize=10, bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8))
        
        # Customize the plot
        plt.title(title, fontsize=14, fontweight='bold', pad=20)
        plt.xlabel(xlabel, fontsize=12)
        plt.ylabel(ylabel, fontsize=12)
        
        # Improve grid and styling
        plt.grid(True, linestyle='--', alpha=0.5, which='both')
        plt.gca().set_axisbelow(True)
        
        # Set nice axis limits
        plt.xlim(0, max(time) * 1.05)
        plt.ylim(0, max(discharge) * 1.15)
        
        # Add some statistics in a box
        stats_text = f'Total Volume: {np.trapz(discharge, time):.1f} m³\nDuration: {max(time):.1f} hr'
        plt.text(0.02, 0.98, stats_text, transform=plt.gca().transAxes, 
                verticalalignment='top', bbox=dict(boxstyle="round,pad=0.5", facecolor="lightgray", alpha=0.7),
                fontfamily='monospace')
        
        plt.tight_layout()
        plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white')
        plt.close()

    def generate_improved_flood_hydrograph(self, time_uh, discharge_uh, runoff_depth, storm_duration, feedback):
        """Generate improved flood hydrograph using convolution for more realistic shape"""
        try:
            # Scale unit hydrograph by runoff depth
            scaled_uh = discharge_uh * (runoff_depth / 10.0)  # Convert mm to cm
            
            # For flash floods, we might want to consider multiple pulses or a more complex storm pattern
            # Simple approach: use the scaled unit hydrograph directly
            time_fh = time_uh
            discharge_fh = scaled_uh
            
            # Add some smoothing for more realistic shape
            if len(discharge_fh) > 5:
                from scipy.ndimage import gaussian_filter1d
                discharge_fh = gaussian_filter1d(discharge_fh, sigma=1.0)
                # Ensure non-negative
                discharge_fh = np.maximum(discharge_fh, 0)
            
            feedback.pushInfo(f"Generated flood hydrograph with peak {discharge_fh.max():.1f} m³/s")
            return time_fh, discharge_fh
            
        except Exception as e:
            feedback.pushWarning(f"Improved hydrograph generation failed: {str(e)}. Using basic method.")
            # Fallback to basic scaling
            return generate_flood_hydrograph(time_uh, discharge_uh, runoff_depth, feedback)

    def generate_flash_flood_plots(self, output_dir, df_out, design_storms, durations, return_periods, feedback):
        """Generate flash flood specific plots"""
        try:
            # Results plot
            fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(12, 15))
            
            # Plot 1: Design rainfall for different durations
            for duration in durations:
                if duration in design_storms:
                    ax1.plot(return_periods, design_storms[duration], 'o-', 
                            label=f'{duration}-hour', linewidth=2, markersize=6)
            
            ax1.set_title("Flash Flood Design Rainfall", fontsize=14, fontweight='bold')
            ax1.set_xlabel("Return Period (years)")
            ax1.set_ylabel("Rainfall Depth (mm)")
            ax1.legend()
            ax1.grid(True, linestyle='--', alpha=0.7)
            ax1.set_xscale('log')
            
            # Plot 2: Flash flood discharge
            ax2.plot(df_out['Return Period (yr)'], df_out['Discharge (m³/s)'], 's-', 
                    color='crimson', linewidth=2, markersize=6)
            ax2.set_title("Flash Flood Design Discharge", fontsize=14, fontweight='bold')
            ax2.set_xlabel("Return Period (years)")
            ax2.set_ylabel("Peak Discharge (m³/s)")
            ax2.grid(True, linestyle='--', alpha=0.7)
            ax2.set_xscale('log')
            
            # Plot 3: Intensity-Duration-Frequency
            for rp in return_periods:
                intensities = []
                valid_durations = []
                for duration in durations:
                    if duration in design_storms:
                        idx = return_periods.index(rp)
                        intensity = design_storms[duration][idx] / duration
                        intensities.append(intensity)
                        valid_durations.append(duration)
                
                if intensities:
                    ax3.plot(valid_durations, intensities, 'o-', label=f'{rp}-year', 
                            linewidth=2, markersize=6)
            
            ax3.set_title("Intensity-Duration-Frequency (IDF) Relationships", fontsize=14, fontweight='bold')
            ax3.set_xlabel("Duration (hours)")
            ax3.set_ylabel("Rainfall Intensity (mm/hour)")
            ax3.legend()
            ax3.grid(True, linestyle='--', alpha=0.7)
            ax3.set_xticks(durations)
            
            plt.tight_layout()
            results_plot_path = os.path.join(output_dir, "Flash_Flood_Results_Plot.png")
            plt.savefig(results_plot_path, dpi=300, bbox_inches='tight')
            plt.close()
            
            feedback.pushInfo("Flash flood plots generated successfully")
            
        except Exception as e:
            feedback.pushWarning(f"Error generating flash flood plots: {str(e)}")

    def name(self):
        return 'flash_flood_analysis'

    def displayName(self):
        return 'Flash Flood Analysis (Hourly)'

    def group(self):
        return 'Hydrology'

    def groupId(self):
        return 'hydrology'

    def createInstance(self):
        return FlashFloodAnalysis()