# SDB_Master_Orchestrator.py
# ---------------------------------------------------------------------------
# SDB MASTER ORCHESTRATOR (FULL LOGGING & AUTO-REPROJECT & OPTIONAL PHASES)
# ---------------------------------------------------------------------------

import os
import time
import datetime
import warnings
from qgis.core import (
    QgsProcessing, QgsProcessingAlgorithm,
    QgsProcessingParameterRasterLayer, QgsProcessingParameterVectorLayer,
    QgsProcessingParameterField, QgsProcessingParameterNumber,
    QgsProcessingParameterBand, QgsProcessingParameterFolderDestination,
    QgsProcessingParameterBoolean, QgsProject,
    QgsProcessingParameterEnum, QgsProcessingException, QgsCoordinateReferenceSystem
)
import processing

warnings.filterwarnings("ignore")

class SDBMasterOrchestrator(QgsProcessingAlgorithm):
    # --- INPUT CONSTANTS ---
    INPUT_RASTER = 'INPUT_RASTER'
    OUTPUT_FOLDER = 'OUTPUT_FOLDER'
    
    # Phase 1
    COASTAL_BAND = 'COASTAL_BAND'
    BLUE_BAND = 'BLUE_BAND'; GREEN_BAND = 'GREEN_BAND'
    RED_BAND = 'RED_BAND'; NIR_BAND = 'NIR_BAND'
    APPLY_SUNGLINT = 'APPLY_SUNGLINT'; NIR_BAND_SUNGLINT = 'NIR_BAND_SUNGLINT'
    DEEP_WATER_POLY = 'DEEP_WATER_POLY'
    APPLY_WATER_MASK = 'APPLY_WATER_MASK'

    # Phase 2 (Optional)
    ENABLE_RANSAC = 'ENABLE_RANSAC'
    RANSAC_THRESHOLD = 'RANSAC_THRESHOLD'

    # Phase 3 (Global Model)
    INPUT_TRAIN = 'INPUT_TRAIN'
    FIELD_DEPTH = 'FIELD_DEPTH'
    FIELD_WEIGHT = 'FIELD_WEIGHT'
    SELECTED_ALGOS = 'SELECTED_ALGOS'
    N_ITERATIONS = 'N_ITERATIONS'
    MEDIAN_SIZE = 'MEDIAN_SIZE'

    # Phase 4 (Adaptive - Optional & Separate Input)
    ENABLE_ADAPTIVE = 'ENABLE_ADAPTIVE'   # <--- New Checkbox
    INPUT_ADAPTIVE_TRAIN = 'INPUT_ADAPTIVE_TRAIN' # <--- New Input Layer
    FIELD_ADAPTIVE_DEPTH = 'FIELD_ADAPTIVE_DEPTH' # <--- New Input Field

    # Phase 5
    INPUT_TEST = 'INPUT_TEST'
    FIELD_TEST_DEPTH = 'FIELD_TEST_DEPTH'

    # Algorithm List
    MODEL_LIST = [
        'Linear Regression',    # 0
        'Random Forest',        # 1
        'Gradient Boosting',    # 2
        'Extra Trees',          # 3
        'Ridge',                # 4
        'Lasso',                # 5
        'ElasticNet',           # 6
        'KNN',                  # 7
        'Decision Tree',        # 8
        'MLP (Neural Net)',     # 9
        'SVR'                   # 10
    ]

    def initAlgorithm(self, config=None):
        # 1. Main
        self.addParameter(QgsProcessingParameterRasterLayer(self.INPUT_RASTER, 'Input Satellite Image (Sentinel-2)'))
        self.addParameter(QgsProcessingParameterFolderDestination(self.OUTPUT_FOLDER, 'Main Output Folder'))
        
        # 2. Bands
        self.addParameter(QgsProcessingParameterBand(self.COASTAL_BAND, 'Coastal Band', parentLayerParameterName=self.INPUT_RASTER, defaultValue=1))
        self.addParameter(QgsProcessingParameterBand(self.BLUE_BAND, 'Blue Band', parentLayerParameterName=self.INPUT_RASTER, defaultValue=2))
        self.addParameter(QgsProcessingParameterBand(self.GREEN_BAND, 'Green Band', parentLayerParameterName=self.INPUT_RASTER, defaultValue=3))
        self.addParameter(QgsProcessingParameterBand(self.RED_BAND, 'Red Band', parentLayerParameterName=self.INPUT_RASTER, defaultValue=4))
        self.addParameter(QgsProcessingParameterBand(self.NIR_BAND, 'NIR Band', parentLayerParameterName=self.INPUT_RASTER, defaultValue=8))

        # 3. Pre-processing
        self.addParameter(QgsProcessingParameterBoolean(self.APPLY_SUNGLINT, 'Apply Sunglint Correction', defaultValue=True))
        self.addParameter(QgsProcessingParameterBand(self.NIR_BAND_SUNGLINT, 'Sunglint NIR Band', parentLayerParameterName=self.INPUT_RASTER, defaultValue=8))
        self.addParameter(QgsProcessingParameterVectorLayer(self.DEEP_WATER_POLY, 'Deep Water ROI (Optional)', optional=True))
        self.addParameter(QgsProcessingParameterBoolean(self.APPLY_WATER_MASK, 'Apply Water Mask', defaultValue=True))
        
        # 4. Phase 2 & 3 Inputs (Main Training)
        self.addParameter(QgsProcessingParameterVectorLayer(self.INPUT_TRAIN, 'Main Training Points (ICESat-2)'))
        self.addParameter(QgsProcessingParameterField(self.FIELD_DEPTH, 'Depth Field (Main)', parentLayerParameterName=self.INPUT_TRAIN, type=QgsProcessingParameterField.Numeric))
        self.addParameter(QgsProcessingParameterField(self.FIELD_WEIGHT, 'Confidence/Weight Field (Optional)', parentLayerParameterName=self.INPUT_TRAIN, type=QgsProcessingParameterField.Numeric, optional=True))
        
        # Phase 2 Controls
        self.addParameter(QgsProcessingParameterBoolean(self.ENABLE_RANSAC, 'Enable RANSAC Filtering (Phase 2)', defaultValue=True))
        self.addParameter(QgsProcessingParameterNumber(self.RANSAC_THRESHOLD, 'RANSAC Outlier Threshold (If Enabled)', type=QgsProcessingParameterNumber.Double, defaultValue=0.0))

        # Phase 3 Controls
        self.addParameter(QgsProcessingParameterEnum(self.SELECTED_ALGOS, 'Select Algorithms (For Global Model)', 
                                                     options=self.MODEL_LIST, allowMultiple=True, defaultValue=[0, 1, 2, 3]))
        self.addParameter(QgsProcessingParameterNumber(self.N_ITERATIONS, 'Optimization Iterations', 
                                                       type=QgsProcessingParameterNumber.Integer, defaultValue=10, minValue=0))
        self.addParameter(QgsProcessingParameterNumber(self.MEDIAN_SIZE, 'Median Filter Size', 
                                                       type=QgsProcessingParameterNumber.Integer, defaultValue=3, minValue=1))
        
        # --- PHASE 4 CONTROLS (NEW) ---
        self.addParameter(QgsProcessingParameterBoolean(self.ENABLE_ADAPTIVE, 'Enable Adaptive Re-training (Phase 4)', defaultValue=True))
        self.addParameter(QgsProcessingParameterVectorLayer(self.INPUT_ADAPTIVE_TRAIN, 'Adaptive Correction Points (Separate Shapefile)', optional=True))
        self.addParameter(QgsProcessingParameterField(self.FIELD_ADAPTIVE_DEPTH, 'Depth Field (Adaptive)', parentLayerParameterName=self.INPUT_ADAPTIVE_TRAIN, type=QgsProcessingParameterField.Numeric, optional=True))

        # 6. Validation
        self.addParameter(QgsProcessingParameterVectorLayer(self.INPUT_TEST, 'Unseen Validation Points'))
        self.addParameter(QgsProcessingParameterField(self.FIELD_TEST_DEPTH, 'Validation Depth Field', parentLayerParameterName=self.INPUT_TEST, type=QgsProcessingParameterField.Numeric))

    def name(self): return 'sdb_master_orchestrator'
    def displayName(self): return 'SDB Master Workflow (Full Pipeline)'
    
    def shortHelpString(self):
        return """
        <div style="color: #333333; font-family: 'Segoe UI', Arial, sans-serif;">
            <h2>🌊 SDB Master Workflow 🚀</h2>
            <p>
                This tool fully automates the Satellite Derived Bathymetry (SDB) process.
                It takes a satellite image and depth points to produce a final bathymetric map.
                <br><b>It now automatically handles CRS differences between inputs!</b>
            </p>
            <h3>✨ The 5-Phase Journey:</h3>
            <ol>
                <li><b>Phase 1: Pre-processing</b> 🛰️</li>
                <li><b>Phase 2: Data Filtering (Optional)</b> 📊</li>
                <li><b>Phase 3: Global Modeling</b> 🧠</li>
                <li><b>Phase 4: Adaptive Re-training (Optional & Separate Input)</b> 📈</li>
                <li><b>Phase 5: Final Reporting & Validation</b> 📝</li>
            </ol>
            <h3>🔬 Scientific Foundation:</h3>
            <p>
                Based on physics-based radiative transfer and ML regression.
            </p>
        </div>
        """
        
    def createInstance(self): return SDBMasterOrchestrator()

    def reproject_layer_if_needed(self, vector_layer, target_crs, temp_output_path, context, feedback):
        if not vector_layer: return None
        source_crs = vector_layer.crs()
        if source_crs == target_crs:
            feedback.pushInfo(f"✔ CRS for '{vector_layer.name()}' matches raster.")
            return vector_layer.source()
        else:
            feedback.pushWarning(f"Reprojecting '{vector_layer.name()}'...")
            reproject_params = {'INPUT': vector_layer, 'TARGET_CRS': target_crs, 'OUTPUT': temp_output_path}
            result = processing.run("native:reprojectlayer", reproject_params, context=context, feedback=feedback, is_child_algorithm=True)
            return result['OUTPUT']

    def processAlgorithm(self, parameters, context, feedback):
        total_start_time = time.time()
        start_dt = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        
        out_dir = self.parameterAsString(parameters, self.OUTPUT_FOLDER, context)
        if not os.path.exists(out_dir): os.makedirs(out_dir)
        
        config_info = self.get_config_info(parameters, context)
        timing_log = {}
        
        # =====================================================================
        # CRS CHECK & PREPARATION
        # =====================================================================
        input_raster = self.parameterAsRasterLayer(parameters, self.INPUT_RASTER, context)
        target_crs = input_raster.crs()
        
        # Load Layers
        train_layer = self.parameterAsVectorLayer(parameters, self.INPUT_TRAIN, context)
        test_layer = self.parameterAsVectorLayer(parameters, self.INPUT_TEST, context)
        adaptive_layer = self.parameterAsVectorLayer(parameters, self.INPUT_ADAPTIVE_TRAIN, context) # New Layer

        # Temp Paths
        temp_train_path = os.path.join(out_dir, 'temp_reprojected_train.gpkg')
        temp_test_path = os.path.join(out_dir, 'temp_reprojected_test.gpkg')
        temp_adaptive_path = os.path.join(out_dir, 'temp_reprojected_adaptive.gpkg')

        # Reproject Main Inputs
        final_train_path = self.reproject_layer_if_needed(train_layer, target_crs, temp_train_path, context, feedback)
        final_test_path = self.reproject_layer_if_needed(test_layer, target_crs, temp_test_path, context, feedback)
        
        # =====================================================================
        # STEP 1: PRE-PROCESSING
        # =====================================================================
        feedback.pushInfo("\n>>> ORCHESTRATOR: [Step 1/5] Pre-processing...")
        t0 = time.time()
        
        p1_params = {
            'INPUT_RASTER': input_raster,
            'COASTAL_BAND': self.parameterAsInt(parameters, self.COASTAL_BAND, context),
            'BLUE_BAND': self.parameterAsInt(parameters, self.BLUE_BAND, context),
            'GREEN_BAND': self.parameterAsInt(parameters, self.GREEN_BAND, context),
            'RED_BAND': self.parameterAsInt(parameters, self.RED_BAND, context),
            'NIR_BAND': self.parameterAsInt(parameters, self.NIR_BAND, context),
            'APPLY_SUNGLINT': self.parameterAsBool(parameters, self.APPLY_SUNGLINT, context),
            'NIR_BAND_SUNGLINT': self.parameterAsInt(parameters, self.NIR_BAND_SUNGLINT, context),
            'DEEP_WATER_POLY': self.parameterAsVectorLayer(parameters, self.DEEP_WATER_POLY, context),
            'APPLY_WATER_MASK': self.parameterAsBool(parameters, self.APPLY_WATER_MASK, context),
            'OUTPUT_FOLDER': out_dir
        }
        p1 = processing.run("sdb_tools:sdb_phase1_preprocessing", p1_params, context=context, feedback=feedback, is_child_algorithm=True)
        timing_log['Phase 1 (Pre-processing)'] = time.time() - t0
        
        path_features = p1['OUTPUT_FEATURES']
        path_mask = p1['OUTPUT_MASK']
        
        # =====================================================================
        # STEP 2: DATA FILTERING (OPTIONAL)
        # =====================================================================
        enable_ransac = self.parameterAsBool(parameters, self.ENABLE_RANSAC, context)
        path_clean_points = None 

        if enable_ransac:
            feedback.pushInfo("\n>>> ORCHESTRATOR: [Step 2/5] Data Filtering (RANSAC) - ENABLED")
            t0 = time.time()
            
            # --- MODIFIED: Pass Blue & Green Bands directly ---
            blue_idx = self.parameterAsInt(parameters, self.BLUE_BAND, context)
            green_idx = self.parameterAsInt(parameters, self.GREEN_BAND, context)
            
            p2_params = {
                'INPUT_STACK': path_features,
                'INPUT_POINTS': final_train_path,
                'FIELD_DEPTH': self.parameterAsString(parameters, self.FIELD_DEPTH, context),
                # Explicitly pass the selected Blue and Green bands
                'BLUE_BAND': blue_idx,
                'GREEN_BAND': green_idx,
                'RESIDUAL_THRESHOLD': self.parameterAsDouble(parameters, self.RANSAC_THRESHOLD, context),
                'OUTPUT_FOLDER': out_dir
            }
            try:
                p2 = processing.run("sdb_tools:sdb_02_filtering", p2_params, context=context, feedback=feedback, is_child_algorithm=True)
                path_clean_points = p2['OUTPUT_CLEAN_VEC']
                timing_log['Phase 2 (Data Filtering)'] = time.time() - t0
            except Exception as e:
                feedback.reportError(f"Phase 2 Failed: {str(e)}. Proceeding with original points.")
                path_clean_points = final_train_path
                timing_log['Phase 2 (Data Filtering)'] = -1
        else:
            feedback.pushInfo("\n>>> ORCHESTRATOR: [Step 2/5] Data Filtering - SKIPPED")
            path_clean_points = final_train_path
            timing_log['Phase 2 (Data Filtering)'] = 0.0

        # =====================================================================
        # STEP 3: INITIAL MODELING
        # =====================================================================
        feedback.pushInfo("\n>>> ORCHESTRATOR: [Step 3/5] Initial Global Modeling...")
        t0 = time.time()
        
        p3_params = {
            'INPUT_STACK': path_features,
            'INPUT_MASK': path_mask,
            'INPUT_POINTS': path_clean_points,
            'FIELD_DEPTH': self.parameterAsString(parameters, self.FIELD_DEPTH, context),
            'FIELD_WEIGHT': self.parameterAsString(parameters, self.FIELD_WEIGHT, context),
            'SELECTED_ALGOS': self.parameterAsEnums(parameters, self.SELECTED_ALGOS, context),
            'N_ITERATIONS': self.parameterAsInt(parameters, self.N_ITERATIONS, context),
            'MEDIAN_SIZE': self.parameterAsInt(parameters, self.MEDIAN_SIZE, context),
            'OUTPUT_FOLDER': out_dir
        }
        p3 = processing.run("sdb_tools:sdb_03_initial_modeling", p3_params, context=context, feedback=feedback, is_child_algorithm=True)
        timing_log['Phase 3 (Global Modeling)'] = time.time() - t0
        
        path_initial_depth = p3['OUTPUT_DEPTH_MAP']

# =====================================================================
        # STEP 4: SPATIAL STACKING (OPTIONAL & NEW INPUTS)
        # =====================================================================
        enable_adaptive = self.parameterAsBool(parameters, self.ENABLE_ADAPTIVE, context)
        path_refined_depth = None # Final output variable

        if enable_adaptive:
            feedback.pushInfo("\n>>> ORCHESTRATOR: [Step 4/5] Adaptive Spatial Refinement - ENABLED")
            t0 = time.time()
            
            # 1. Reproject the specific adaptive training layer
            final_adaptive_path = self.reproject_layer_if_needed(adaptive_layer, target_crs, temp_adaptive_path, context, feedback)
            
            if not final_adaptive_path:
                raise QgsProcessingException("Adaptive Phase is Enabled but no Adaptive Input Layer was provided!")

            p4_params = {
                'INPUT_GLOBAL_RASTER': path_initial_depth,    # From Phase 3
                'INPUT_ORIGINAL_FEAT': path_features,         # From Phase 1
                'INPUT_MASK': path_mask,                      # From Phase 1
                'INPUT_TRAIN': final_adaptive_path,           # Separate Adaptive Input
                'FIELD_TRAIN': self.parameterAsString(parameters, self.FIELD_ADAPTIVE_DEPTH, context),
                'INPUT_VALIDATION': final_test_path,
                'FIELD_VALIDATION': self.parameterAsString(parameters, self.FIELD_TEST_DEPTH, context),
                
                # CRITICAL: Passing the SELECTED_ALGOS list to Phase 4
                'SELECTED_ALGOS': self.parameterAsEnums(parameters, self.SELECTED_ALGOS, context),
                
                'N_ITERATIONS': self.parameterAsInt(parameters, self.N_ITERATIONS, context),
                'MEDIAN_SIZE': self.parameterAsInt(parameters, self.MEDIAN_SIZE, context),
                'OUTPUT_FOLDER': out_dir
            }
            
            # Run Phase 4
            p4 = processing.run("sdb_tools:sdb_phase4_adaptive", p4_params, context=context, feedback=feedback, is_child_algorithm=True)
            path_refined_depth = p4['OUTPUT_FINAL']
            
            timing_log['Phase 4 (Spatial Refinement)'] = time.time() - t0
        else:
            feedback.pushInfo("\n>>> ORCHESTRATOR: [Step 4/5] Adaptive Spatial Refinement - SKIPPED")
            feedback.pushInfo("    > The Final Map will be the Initial Global Model.")
            path_refined_depth = path_initial_depth # Bypass Phase 4
            timing_log['Phase 4 (Spatial Refinement)'] = 0.0

        # =====================================================================
        # STEP 5: FINAL REPORTING
        # =====================================================================
        feedback.pushInfo("\n>>> ORCHESTRATOR: [Step 5/5] Final Comparative Reporting...")
        t0 = time.time()
        
        p5_params = {
            'INPUT_MAP_P3': path_initial_depth,
            'INPUT_MAP_P4': path_refined_depth, # Takes either the Refined or the Initial (if P4 skipped)
            'INPUT_TRAIN': path_clean_points,    # Using original training points for consistency in reporting
            'FIELD_TRAIN': self.parameterAsString(parameters, self.FIELD_DEPTH, context),
            'INPUT_VALIDATION': final_test_path,
            'FIELD_VAL_DEPTH': self.parameterAsString(parameters, self.FIELD_TEST_DEPTH, context),
            'OUTPUT_FOLDER': out_dir
        }
        processing.run("sdb_tools:sdb_05_reporting", p5_params, context=context, feedback=feedback, is_child_algorithm=True)
        timing_log['Phase 5 (Reporting)'] = time.time() - t0

        # =====================================================================
        # FINAL LOG
        # =====================================================================
        total_end_time = time.time()
        self.write_full_project_log(out_dir, start_dt, config_info, timing_log, total_end_time - total_start_time)
        
        feedback.pushInfo("\n>>> ORCHESTRATOR: WORKFLOW COMPLETED SUCCESSFULLY.")
        return {'FINAL_DEPTH_MAP': path_refined_depth}

    def get_config_info(self, params, context):
        indices = self.parameterAsEnums(params, self.SELECTED_ALGOS, context)
        names = [self.MODEL_LIST[i] for i in indices]
        
        return {
            'Algorithms': ", ".join(names),
            'Iterations': self.parameterAsInt(params, self.N_ITERATIONS, context),
            'RANSAC Enabled': self.parameterAsBool(params, self.ENABLE_RANSAC, context),
            'Adaptive Phase': self.parameterAsBool(params, self.ENABLE_ADAPTIVE, context), # Log Phase 4 Status
            'Apply Sunglint': self.parameterAsBool(params, self.APPLY_SUNGLINT, context)
        }

    def write_full_project_log(self, out_dir, start_dt, config, timing, total_duration):
        log_path = os.path.join(out_dir, "FULL_PROJECT_LOG.txt")
        
        p5_summary = ""
        try:
            with open(os.path.join(out_dir, '5_FINAL_SUMMARY.txt'), 'r') as f:
                p5_summary = f.read()
        except:
            p5_summary = "Final summary file not found."

        with open(log_path, 'w') as f:
            f.write("======================================================\n")
            f.write("           SDB PROJECT EXECUTION LOG                 \n")
            f.write("======================================================\n\n")
            f.write(f"Date/Time Started: {start_dt}\n")
            f.write(f"Total Execution Time: {total_duration/60:.2f} minutes\n\n")
            
            f.write("--- CONFIGURATION ---\n")
            for k, v in config.items():
                f.write(f"{k:<20}: {v}\n")
            f.write("\n")
            
            f.write("--- TIMING BREAKDOWN ---\n")
            for k, v in timing.items():
                f.write(f"{k:<30}: {v:.2f} sec\n")
            f.write("\n")
            
            f.write("--- FINAL RESULTS SUMMARY ---\n")
            f.write(p5_summary)
            f.write("\n======================================================\n")