# SDB_02_Data_Filtering.py
# ---------------------------------------------------------------------------
# MODULE 02: DATA FILTERING (CALCULATE RATIO ON-THE-FLY)
# ---------------------------------------------------------------------------
# Features:
# - Takes Blue & Green Band Indices explicitly.
# - Calculates Ln(Blue)/Ln(Green) for each point internally.
# - RANSAC filters based on this exact physical relationship.
# ---------------------------------------------------------------------------

import os
import numpy as np
import rasterio
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import warnings

warnings.filterwarnings("ignore")

from qgis.core import (
    QgsProcessing, QgsProcessingAlgorithm,
    QgsProcessingParameterRasterLayer, QgsProcessingParameterVectorLayer,
    QgsProcessingParameterField, QgsProcessingParameterFolderDestination,
    QgsProcessingParameterNumber, QgsCoordinateTransform, QgsProject,
    QgsVectorFileWriter, QgsWkbTypes, QgsRasterLayer, QgsProcessingException, QgsVectorLayer,
    QgsProcessingParameterBand
)

from sklearn.linear_model import RANSACRegressor, LinearRegression

class SDBModule02(QgsProcessingAlgorithm):
    INPUT_STACK = 'INPUT_STACK'
    INPUT_POINTS = 'INPUT_POINTS'
    FIELD_DEPTH = 'FIELD_DEPTH'
    
    # New Inputs instead of Ratio Index
    BLUE_BAND = 'BLUE_BAND'
    GREEN_BAND = 'GREEN_BAND'
    
    RESIDUAL_THRESHOLD = 'RESIDUAL_THRESHOLD'
    OUTPUT_FOLDER = 'OUTPUT_FOLDER'

    def initAlgorithm(self, config=None):
        self.addParameter(QgsProcessingParameterRasterLayer(self.INPUT_STACK, 'Input Feature Stack (From Phase 1)'))
        self.addParameter(QgsProcessingParameterVectorLayer(self.INPUT_POINTS, 'Raw ICESat-2 Points'))
        self.addParameter(QgsProcessingParameterField(self.FIELD_DEPTH, 'Depth Field', parentLayerParameterName=self.INPUT_POINTS, type=QgsProcessingParameterField.Numeric))
        
        # Explicit Bands
        self.addParameter(QgsProcessingParameterBand(self.BLUE_BAND, 'Blue Band Number', parentLayerParameterName=self.INPUT_STACK, defaultValue=2))
        self.addParameter(QgsProcessingParameterBand(self.GREEN_BAND, 'Green Band Number', parentLayerParameterName=self.INPUT_STACK, defaultValue=3))
        
        self.addParameter(QgsProcessingParameterNumber(self.RESIDUAL_THRESHOLD, 'RANSAC Threshold (0 = Auto)', type=QgsProcessingParameterNumber.Double, defaultValue=0.0))
        self.addParameter(QgsProcessingParameterFolderDestination(self.OUTPUT_FOLDER, 'Output Folder'))

    def name(self): return 'sdb_02_filtering'
    def displayName(self): return '2. SDB Module 02: Filtering (On-Fly Ratio)'
    def group(self): return 'SDB Research Tools'
    def groupId(self): return 'sdb_tools'
    def createInstance(self): return SDBModule02()

    def processAlgorithm(self, parameters, context, feedback):
        out_dir = self.parameterAsString(parameters, self.OUTPUT_FOLDER, context)
        os.makedirs(out_dir, exist_ok=True)
        
        stack_path = self.parameterAsRasterLayer(parameters, self.INPUT_STACK, context).source()
        points_layer = self.parameterAsVectorLayer(parameters, self.INPUT_POINTS, context)
        depth_fld = self.parameterAsString(parameters, self.FIELD_DEPTH, context)
        
        b_idx = self.parameterAsInt(parameters, self.BLUE_BAND, context)
        g_idx = self.parameterAsInt(parameters, self.GREEN_BAND, context)
        user_threshold = self.parameterAsDouble(parameters, self.RESIDUAL_THRESHOLD, context)

        feedback.pushInfo("\n>>> MODULE 02: STARTING DATA FILTERING...")
        feedback.pushInfo(f"   Using Bands: Blue={b_idx}, Green={g_idx}")
        feedback.pushInfo("   Calculating: Ln(Blue) / Ln(Green) on the fly...")

        # 1. Extract & Calculate Ratio
        feedback.pushInfo("   [1/3] Extracting & Calculating Ratios...")
        X_ratio, y_depth, features_list = self.extract_and_calc_ratio(stack_path, points_layer, depth_fld, b_idx, g_idx, feedback)
        
        if len(y_depth) < 10: 
            raise QgsProcessingException("Not enough valid points found! Check projection or overlap.")

        # 2. Threshold Calculation
        X = X_ratio.reshape(-1, 1)
        y = y_depth
        
        final_thresh = user_threshold
        if user_threshold <= 0:
            feedback.pushInfo("   [Auto] Calculating MAD Threshold...")
            lr = LinearRegression().fit(X, y)
            preds = lr.predict(X)
            residuals = np.abs(y - preds)
            mad = np.median(residuals)
            final_thresh = 2.5 * mad 
            feedback.pushInfo(f"      Calculated Threshold: {final_thresh:.4f}m (MAD={mad:.4f})")
        else:
            feedback.pushInfo(f"   Using Manual Threshold: {final_thresh}m")

        # 3. RANSAC
        feedback.pushInfo("   [2/3] Running RANSAC...")
        ransac = RANSACRegressor(random_state=42, min_samples=0.1, residual_threshold=final_thresh)
        try:
            ransac.fit(X, y)
        except ValueError as e:
            raise QgsProcessingException(f"RANSAC Failed: {e}")
        
        inlier_mask = ransac.inlier_mask_
        outlier_mask = np.logical_not(inlier_mask)
        
        n_in = np.sum(inlier_mask); n_out = np.sum(outlier_mask)
        feedback.pushInfo(f"      Kept: {n_in}, Rejected: {n_out}")

        # 4. Export
        feedback.pushInfo("   [3/3] Exporting...")
        clean_shp = os.path.join(out_dir, '2_Cleaned_Training_Data.shp')
        reject_shp = os.path.join(out_dir, '2_Outliers_Rejected.shp')
        
        self.save_subset(points_layer, features_list, inlier_mask, clean_shp, "Cleaned Data")
        self.save_subset(points_layer, features_list, outlier_mask, reject_shp, "Rejected Data")
        self.save_ransac_plot(X, y, inlier_mask, outlier_mask, ransac, os.path.join(out_dir, '2_RANSAC_Plot.png'), final_thresh)
        
        return {'OUTPUT_CLEAN_VEC': clean_shp}

    # --- Helpers ---
    def extract_and_calc_ratio(self, ras_path, vec_layer, depth_fld, b_idx, g_idx, fb):
        rlayer = QgsRasterLayer(ras_path)
        source_crs = vec_layer.sourceCrs()
        dest_crs = rlayer.crs()
        tr = QgsCoordinateTransform(source_crs, dest_crs, QgsProject.instance())
        
        with rasterio.open(ras_path) as src:
            # Check bands
            if b_idx > src.count or g_idx > src.count:
                raise QgsProcessingException(f"Requested bands (B:{b_idx}, G:{g_idx}) exceed raster band count ({src.count}).")

            # Read Blue and Green Bands
            try:
                band_b = src.read(b_idx)
                band_g = src.read(g_idx)
            except IndexError:
                raise QgsProcessingException("Failed to read Blue or Green band.")
            
            nodata = src.nodata if src.nodata is not None else -9999.0
            h, w = src.height, src.width
            
            ratios_calc = []
            depths_list = []
            feats_list = []
            
            total = vec_layer.featureCount()
            extracted = 0
            
            for f in vec_layer.getFeatures():
                geom = f.geometry()
                try: geom.transform(tr)
                except: continue
                
                pt = geom.asPoint()
                try: r, c = src.index(pt.x(), pt.y())
                except: continue
                
                if 0 <= r < h and 0 <= c < w:
                    val_b = band_b[r, c]
                    val_g = band_g[r, c]
                    
                    # Validity Check (Must be positive for Log)
                    if (val_b > 0 and val_g > 0 and 
                        val_b != nodata and val_g != nodata and 
                        np.isfinite(val_b) and np.isfinite(val_g)):
                        
                        # Calculate Log-Ratio On The Fly
                        # Stumpf: ln(Blue) / ln(Green)
                        # Adding small epsilon isn't needed if we check > 0, but good for safety
                        lb = np.log(val_b)
                        lg = np.log(val_g)
                        
                        if lg != 0:
                            ratio = lb / lg
                            depth = f[depth_fld]
                            
                            if depth is not None:
                                ratios_calc.append(ratio)
                                depths_list.append(depth)
                                feats_list.append(f)
                                extracted += 1
            
            fb.pushInfo(f"      Extracted {extracted}/{total} valid points.")
            return np.array(ratios_calc), np.array(depths_list), feats_list

    def save_subset(self, original_layer, all_features, mask, out_path, layer_name):
        subset = [f for f, m in zip(all_features, mask) if m]
        if not subset: return
        if os.path.exists(out_path):
            try: os.remove(out_path)
            except: pass 
        writer = QgsVectorFileWriter(out_path, "UTF-8", original_layer.fields(), QgsWkbTypes.Point, original_layer.sourceCrs(), "ESRI Shapefile")
        if writer.hasError() == QgsVectorFileWriter.NoError:
            for f in subset: writer.addFeature(f)
            del writer
            QgsProject.instance().addMapLayer(QgsVectorLayer(out_path, layer_name, "ogr"))

    def save_ransac_plot(self, X, y, inlier, outlier, model, path, thresh):
        plt.figure(figsize=(8, 6))
        plt.scatter(X[inlier], y[inlier], c='dodgerblue', marker='.', label='Inliers')
        plt.scatter(X[outlier], y[outlier], c='crimson', marker='x', label='Outliers')
        
        if len(X[inlier]) > 1:
            line_x = np.linspace(X.min(), X.max(), 100).reshape(-1, 1)
            plt.plot(line_x, model.predict(line_x), 'k-', lw=2, label='RANSAC Model')
        
        plt.title(f"RANSAC Filtering (Threshold={thresh:.2f}m)")
        plt.xlabel("Calculated Ratio: Ln(Blue) / Ln(Green)")
        plt.ylabel("Depth (m)")
        plt.legend()
        plt.grid(True, alpha=0.5)
        plt.tight_layout()
        plt.savefig(path)
        plt.close()