# -*- coding: utf-8 -*-
"""
/***************************************************************************
 RSWaterQualityMapper
                                 A QGIS plugin
 RS Water Quality Mapper
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2024-09-25
        git sha              : $Format:%H$
        copyright            : (C) 2024 by Haibin Su
        email                : haibin.su@tamuk.edu
 ***************************************************************************/

/***************************************************************************
 * *
 * This program is free software; you can redistribute it and/or modify  *
 * it under the terms of the GNU General Public License as published by  *
 * the Free Software Foundation; either version 2 of the License, or     *
 * (at your option) any later version.                                   *
 * *
 ***************************************************************************/
"""

import os
import subprocess
import json
import pickle
import graphviz
import rasterio
import rasterio.mask
import rasterio.features
import numpy as np
import pandas as pd
import processing
import tempfile
from rasterio.enums import Resampling
from sklearn.linear_model import LinearRegression
from sklearn.ensemble import RandomForestRegressor
from sklearn.mixture import GaussianMixture
from sklearn import tree
from scipy.ndimage import median_filter

from sklearn.tree import DecisionTreeClassifier
from sklearn.svm import SVR
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import matplotlib.pyplot as plt
from shapely.geometry import shape
import rioxarray

from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from qgis.PyQt import QtWidgets
from qgis.PyQt.QtCore import QSettings, QTranslator, QCoreApplication, Qt, QThread, pyqtSignal
from qgis.PyQt.QtGui import QColor, QPixmap
from qgis.PyQt.QtWidgets import (QDialog, QVBoxLayout, QHBoxLayout, QLabel,  
                                 QPushButton, QMessageBox, QTableWidgetItem, QListWidgetItem, QPlainTextEdit, 
                                 QGridLayout, QComboBox, QCheckBox, QDoubleSpinBox, QSpinBox, QDialogButtonBox, QRadioButton)
from qgis.core import (
    QgsProject, QgsRasterLayer, QgsColorRampShader, QgsRasterShader, QgsSingleBandPseudoColorRenderer,
    QgsMapLayerProxyModel, QgsPalettedRasterRenderer, QgsField, QgsWkbTypes, 
    Qgis, QgsCoordinateTransform, QgsRectangle, QgsGeometry, 
)

# Import the missing QGIS GUI classes
from qgis.gui import (QgsMapLayerComboBox, QgsMapCanvas, QgsRubberBand, QgsMapTool, QgsFileWidget)

# Import gdal and enable exceptions to handle the FutureWarning
try:
    from osgeo import gdal
    gdal.UseExceptions()
except ImportError:
    pass # GDAL may not be available in all environments

def get_image_data_and_profile(image_path, satellite_type):
    """Reads raster data from TIFF or NetCDF and returns a numpy array and profile."""
    if image_path.lower().endswith('.nc'):
        return read_netcdf_bands(image_path, satellite_type)
    else:
        with rasterio.open(image_path) as src:
            profile = src.profile
            numpy_array = src.read().astype(np.float32)
            return numpy_array, profile

def find_band_in_range(target_wl, available_bands, tolerance=5):
    """Finds a band name in a list that matches a target wavelength within a tolerance."""
    for band_name in available_bands:
        if band_name.startswith('rhos_'):
            try:
                wl_str = band_name.split('_')[1]
                wl = int(wl_str)
                if (target_wl - tolerance) <= wl <= (target_wl + tolerance):
                    return band_name
            except (IndexError, ValueError):
                continue
    return None
        
def read_netcdf_bands(nc_path, satellite):
    """Reads and stacks specific rhos bands from an ACOLITE NetCDF file with wavelength tolerance."""
    try:
        import rioxarray
    except ImportError:
        raise ImportError("The 'rioxarray' and 'xarray' packages are required to read NetCDF files. Please install them in your QGIS Python environment (e.g., via the OSGeo4W Shell).")

    # Define target wavelengths for each satellite
    target_wavelengths = []
    if "Landsat 8/9" in satellite:
        target_wavelengths = [443, 482, 561, 655, 865]
    elif "Sentinel-2" in satellite:
        target_wavelengths = [443, 492, 560, 665, 704, 740, 783, 842, 865]
    elif "Landsat 5/7" in satellite:
        target_wavelengths = [485, 560, 660, 830]
    else:
        raise ValueError(f"NetCDF reading for satellite '{satellite}' is not supported.")

    try:
        rds = rioxarray.open_rasterio(nc_path, masked=True)
        
        available_bands = list(rds.data_vars)
        found_band_names = []
        missing_wavelengths = []

        for target_wl in target_wavelengths:
            found_band = find_band_in_range(target_wl, available_bands, tolerance=10)  # Increased tolerance for robustness
            if found_band:
                found_band_names.append(found_band)
            else:
                missing_wavelengths.append(str(target_wl))
        
        if not found_band_names:
            raise ValueError(f"No required rhos bands found in NetCDF file for {satellite}. Expected bands near: {', '.join(map(str, target_wavelengths))} nm.")
        
        if missing_wavelengths:
            QMessageBox.warning(None, "Missing Bands", f"Some expected rhos bands for {satellite} not found: {', '.join(missing_wavelengths)} nm. Proceeding with available bands.")

        # Stack only the found bands
        stacked_bands = rds[found_band_names].to_array(dim='band').to_numpy()
        
        # Get nodata from the first band in the dataset, with a fallback
        first_band_name = found_band_names[0]
        nodata_val = getattr(rds[first_band_name].rio, 'nodata', -9999.0)

        profile = {
            'driver': 'GTiff', 'dtype': rasterio.float32, 'nodata': nodata_val,
            'width': rds.rio.width, 'height': rds.rio.height, 'count': len(found_band_names),
            'crs': rds.rio.crs, 'transform': rds.rio.transform(),
        }
        return stacked_bands.astype(np.float32), profile

    except Exception as e:
        raise IOError(f"Failed to read or process NetCDF file {nc_path}: {e}")
    

def single_band_raster_to_numpy(raster_path):
    """Read a single-band raster into a NumPy array."""
    try:
        with rasterio.open(raster_path) as dataset:
            if dataset.count < 1:
                raise ValueError(f"Invalid single-band raster: {raster_path}")
            return dataset.read(1).astype(np.float32)
    except rasterio.errors.RasterioIOError as e:
        raise FileNotFoundError(f"Cannot open the raster file: {raster_path}") from e


def numpy_to_raster(numpy_array, output_raster_path, profile, nodata_value=None):
    """Write a NumPy array to a raster file using a given profile."""
    out_profile = profile.copy()
    out_profile.update(
        dtype=rasterio.float32,
        count=1,
        driver='GTiff',
        nodata=nodata_value
    )
    try:
        with rasterio.open(output_raster_path, 'w', **out_profile) as out_ds:
            out_ds.write(numpy_array.astype(rasterio.float32), 1)
    except rasterio.errors.RasterioIOError as e:
        raise Exception(f"Could not create output raster: {output_raster_path}") from e


def apply_model_to_raster(parent, model, model_x_field, image_path, band_num, output_path, use_mask=False, mask_path=None):
    """
    A generic helper function to apply any scikit-learn regression model to a raster band.
    """
    try:
        with rasterio.open(image_path) as src:
            if band_num > src.count:
                QMessageBox.critical(parent, "Error", f"Invalid band number. Raster has only {src.count} bands.")
                return
            
            index_band_data = src.read(band_num, masked=True).astype(np.float32)
            profile = src.profile
            nodata_val = src.nodata or -9999.0

        # Apply the model
        original_shape = index_band_data.shape
        pixels_to_predict = index_band_data.flatten().reshape(-1, 1)
        valid_data_mask = ~np.ma.getmaskarray(pixels_to_predict).flatten()
        
        predicted_values = model.predict(pixels_to_predict[valid_data_mask])

        prediction_data_flat = np.full(index_band_data.size, nodata_val, dtype=np.float32)
        prediction_data_flat[valid_data_mask] = predicted_values
        
        prediction_data = prediction_data_flat.reshape(original_shape)

        # Apply the optional water mask
        if use_mask and mask_path:
            with rasterio.open(mask_path) as mask_src:
                if mask_src.shape != prediction_data.shape:
                    raise ValueError("The Water Mask layer does not align with the input image.")
                
                mask_data = mask_src.read(1)
                prediction_data[mask_data != 1] = nodata_val

        # Update profile and save the output
        profile.update(dtype=rasterio.float32, count=1, nodata=nodata_val)
        with rasterio.open(output_path, 'w', **profile) as dst:
            dst.write(prediction_data, 1)

        # Add the new layer to QGIS
        layer_name = os.path.basename(output_path)
        new_layer = QgsRasterLayer(output_path, layer_name)
        if new_layer.isValid():
            QgsProject.instance().addMapLayer(new_layer)
            QMessageBox.information(parent, "Success", "Prediction raster created and added to the project.")
        else:
            QMessageBox.critical(parent, "Error", "Failed to create or load the prediction raster layer.")

    except Exception as e:
        QMessageBox.critical(parent, "Processing Error", f"An error occurred: {str(e)}")


def calculate_all_indices(df, satellite):
    try:
        if "Landsat 5/7" in satellite: 
            B1,B2,B3,B4 = (df[b] for b in ['B1','B2','B3','B4'])
            df['NDTI']=(B3-B2)/(B3+B2); df['RB_Ratio']=B3/B1; df['GB_Ratio']=B2/B1; df['RedPlusNIR']=B3+B4
            df['NDCI']=(B4-B3)/(B4+B3); df['TWOBDA']=B4/B3; df['THREEBDA']=(B1-B3)/B2; df['SABI']=(B4-B3)/(B1+B2)
            df['BG_Ratio']=B1/B2; df['RG_Ratio']=B3/B2; df['RMinusG']=B3-B2; df['GR_Ratio']=B2/B3
        elif "Landsat 8/9" in satellite: 
            B2,B3,B4,B5 = (df[b] for b in ['B2','B3','B4','B5'])
            df['NDTI']=(B4-B3)/(B4+B3); df['RB_Ratio']=B4/B2; df['GB_Ratio']=B3/B2; df['RedPlusNIR']=B4+B5
            df['NDCI']=(B5-B4)/(B5+B4); df['TWOBDA']=B5/B4; df['THREEBDA']=(B2-B4)/B3; df['SABI']=(B5-B4)/(B2+B3)
            df['BG_Ratio']=B2/B3; df['RG_Ratio']=B4/B3; df['RMinusG']=B4-B3; df['GR_Ratio']=B3/B4
        elif "Sentinel-2" in satellite: 
            B1,B2,B3,B4,B5,B6,B8,B8A = (df[b] for b in ['B1','B2','B3','B4','B5','B6','B8','B8A'])
            df['NDTI']=(B4-B3)/(B4+B3); df['RB_Ratio']=B4/B2; df['GB_Ratio']=B3/B2; df['RedPlusNIR']=B4+B8
            df['NDCI']=(B5-B4)/(B5+B4); df['TWOBDA']=B5/B4; df['THREEBDA']=(1/B4-1/B5)*B6; df['SABI']=(B8A-B4)/(B1+B3)
            df['BG_Ratio']=B2/B3; df['RG_Ratio']=B4/B3; df['RMinusG']=B4-B3; df['GR_Ratio']=B3/B4
    except KeyError as e: raise Exception(f"Input layer is missing a required band field: {e}")
    return df


def run_ensemble_model_application(classifier, models_para_list, field_satellite, selected_columns, input_raster_path, watermask_path, output_raster_path):
    """Standalone function to apply a full ensemble model to a raster image."""
    try:
        slopes, intercepts = {m["model_name"]: m["coefficients"] for m in models_para_list}, {m["model_name"]: m["intercept"] for m in models_para_list}
        
        band_arrays, profile = get_image_data_and_profile(input_raster_path, field_satellite)
        
        # Validate band count
        expected_bands = len(selected_columns)
        actual_bands = band_arrays.shape[0]
        if actual_bands < expected_bands:
            raise ValueError(f"Input raster has {actual_bands} bands, but {expected_bands} are required for {field_satellite} ({selected_columns}).")

        with rasterio.open(watermask_path) as mask_src:
            # Check for alignment; if not aligned, reproject mask in memory
            if (mask_src.height, mask_src.width) != (profile['height'], profile['width']) or mask_src.transform != profile['transform']:
                watermask_array = np.empty((profile['height'], profile['width']), dtype=mask_src.dtypes[0])
                rasterio.warp.reproject(
                    source=rasterio.band(mask_src, 1),
                    destination=watermask_array,
                    dst_transform=profile['transform'],
                    dst_crs=profile['crs'],
                    resampling=Resampling.nearest
                )
            else:
                watermask_array = mask_src.read(1)

        nodata_val, epsilon = -9999.0, 1e-9
        output_numpy_array = np.full(watermask_array.shape, nodata_val, dtype=np.float32)

        # Define band mapping based on satellite
        if "Landsat 8/9" in field_satellite:
            band_map = {'B1': 0, 'B2': 1, 'B3': 2, 'B4': 3, 'B5': 4}  # Adjusted to match read_netcdf_bands
        elif "Sentinel-2" in field_satellite:
            band_map = {'B1': 0, 'B2': 1, 'B3': 2, 'B4': 3, 'B5': 4, 'B6': 5, 'B7': 6,'B8': 7, 'B8A': 8}
        elif "Landsat 5/7" in field_satellite:
            band_map = {'B1': 0, 'B2': 1, 'B3': 2, 'B4': 3}
        else:
            raise ValueError(f"Unsupported satellite type: {field_satellite}")

        # Validate selected columns
        feature_indices = []
        for col in selected_columns:
            if col not in band_map:
                raise ValueError(f"Band {col} not found in band mapping for {field_satellite}")
            feature_indices.append(band_map[col])

        pixels_to_process = watermask_array == 1
        if not np.any(pixels_to_process): 
            QMessageBox.warning(None, "No Data", "The water mask does not contain any water pixels (value of 1).")
            return

        # Select bands and apply mask
        selected_bands = band_arrays[feature_indices, :, :]  # Shape: (n_bands, height, width)
        # Reshape to (n_pixels, n_bands) for classifier
        pixels_to_process_flat = pixels_to_process.flatten()
        features = selected_bands.reshape(len(feature_indices), -1)[:, pixels_to_process_flat].T  # Shape: (n_valid_pixels, n_bands)

        best_model_per_pixel = classifier.predict(features)
        result_pixels = np.full(features.shape[0], nodata_val, dtype=np.float32)

        for model_name in set(best_model_per_pixel):
            if model_name not in slopes: 
                continue
            mask = best_model_per_pixel == model_name

            band_data = {b_name: band_arrays[b_idx].reshape(-1)[pixels_to_process_flat][mask] for b_name, b_idx in band_map.items() if b_idx < len(band_arrays)}            

            with np.errstate(divide='ignore', invalid='ignore'):
                if "Landsat 8/9" in field_satellite:
                    if model_name == 'RB_Ratio': index_val = band_data['B4'] / (band_data['B2'] + epsilon)
                    elif model_name == 'NDTI': index_val = (band_data['B4'] - band_data['B3']) / (band_data['B4'] + band_data['B3'] + epsilon)
                    elif model_name == 'GB_Ratio': index_val = band_data['B3'] / (band_data['B2'] + epsilon)
                    elif model_name == 'RedPlusNIR': index_val = band_data['B4'] + band_data['B5']
                    elif model_name == 'NDCI': index_val = (band_data['B5'] - band_data['B4']) / (band_data['B5'] + band_data['B4'] + epsilon)
                    elif model_name == 'TWOBDA': index_val = band_data['B5'] / (band_data['B4'] + epsilon)
                    elif model_name == 'THREEBDA': index_val = (band_data['B2'] - band_data['B4']) / (band_data['B3'] + epsilon)
                    elif model_name == 'SABI': index_val = (band_data['B5'] - band_data['B4']) / (band_data['B2'] + band_data['B3'] + epsilon)
                    elif model_name == 'BG_Ratio': index_val = band_data['B2'] / (band_data['B3'] + epsilon)
                    elif model_name == 'RG_Ratio': index_val = band_data['B4'] / (band_data['B3'] + epsilon)
                    elif model_name == 'RMinusG': index_val = band_data['B4'] - band_data['B3']
                    elif model_name == 'GR_Ratio': index_val = band_data['B3'] / (band_data['B4'] + epsilon)
                    else: index_val = 0
                elif "Sentinel-2" in field_satellite:
                    if model_name == 'RB_Ratio': index_val = band_data['B4'] / (band_data['B2'] + epsilon)
                    elif model_name == 'NDTI': index_val = (band_data['B4'] - band_data['B3']) / (band_data['B4'] + band_data['B3'] + epsilon)
                    elif model_name == 'GB_Ratio': index_val = band_data['B3'] / (band_data['B2'] + epsilon)
                    elif model_name == 'RedPlusNIR': index_val = band_data['B4'] + band_data['B8']
                    elif model_name == 'NDCI': index_val = (band_data['B5'] - band_data['B4']) / (band_data['B5'] + band_data['B4'] + epsilon)
                    elif model_name == 'TWOBDA': index_val = band_data['B5'] / (band_data['B4'] + epsilon)
                    elif model_name == 'THREEBDA': index_val = (1/(band_data['B4']+epsilon) - 1/(band_data['B5']+epsilon)) * band_data['B6']
                    elif model_name == 'SABI': index_val = (band_data['B8A'] - band_data['B4']) / (band_data['B1'] + band_data['B3'] + epsilon)
                    elif model_name == 'BG_Ratio': index_val = band_data['B2'] / (band_data['B3'] + epsilon)
                    elif model_name == 'RG_Ratio': index_val = band_data['B4'] / (band_data['B3'] + epsilon)
                    elif model_name == 'RMinusG': index_val = band_data['B4'] - band_data['B3']
                    elif model_name == 'GR_Ratio': index_val = band_data['B3'] / (band_data['B4'] + epsilon)
                    else: index_val = 0
                else:  # Landsat 5/7
                    if model_name == 'RB_Ratio': index_val = band_data['B3'] / (band_data['B1'] + epsilon)
                    elif model_name == 'NDTI': index_val = (band_data['B3'] - band_data['B2']) / (band_data['B3'] + band_data['B2'] + epsilon)
                    elif model_name == 'GB_Ratio': index_val = band_data['B2'] / (band_data['B1'] + epsilon)
                    elif model_name == 'RedPlusNIR': index_val = band_data['B3'] + band_data['B4']
                    elif model_name == 'NDCI': index_val = (band_data['B4'] - band_data['B3']) / (band_data['B4'] + band_data['B3'] + epsilon)
                    elif model_name == 'TWOBDA': index_val = band_data['B4'] / (band_data['B3'] + epsilon)
                    elif model_name == 'THREEBDA': index_val = (band_data['B1'] - band_data['B3']) / (band_data['B2'] + epsilon)
                    elif model_name == 'SABI': index_val = (band_data['B4'] - band_data['B3']) / (band_data['B1'] + band_data['B2'] + epsilon)
                    elif model_name == 'BG_Ratio': index_val = band_data['B1'] / (band_data['B2'] + epsilon)
                    elif model_name == 'RG_Ratio': index_val = band_data['B3'] / (band_data['B2'] + epsilon)
                    elif model_name == 'RMinusG': index_val = band_data['B3'] - band_data['B2']
                    elif model_name == 'GR_Ratio': index_val = band_data['B2'] / (band_data['B3'] + epsilon)
                    else: index_val = 0
            
            result_pixels[mask] = (np.nan_to_num(index_val) * slopes[model_name]) + intercepts[model_name]

        output_numpy_array[pixels_to_process] = result_pixels
        data_mask = output_numpy_array != nodata_val
        np.clip(output_numpy_array, 0, 1000, out=output_numpy_array, where=data_mask)
        
        # Apply 3x3 median filter to smooth the result ---
        # Create a mask of the valid data pixels to ensure NoData areas are preserved.
        # Apply the filter only on the valid data region to avoid edge effects with NoData.
        filtered_data = median_filter(output_numpy_array, size=3, mode='reflect')
        output_numpy_array[data_mask] = filtered_data[data_mask]

        numpy_to_raster(output_numpy_array, output_raster_path, profile, nodata_value=nodata_val)
        QgsProject.instance().addMapLayer(QgsRasterLayer(output_raster_path, os.path.basename(output_raster_path)))

    except Exception as e:
        import traceback
        QMessageBox.critical(None, "Model Application Error", f"An error occurred: {str(e)}\n{traceback.format_exc()}")


class ISODATA:
    def __init__(self, data, k, min_members, max_iterations=1000, min_distance=1e-4):
        self.data = np.array(data)
        self.k = k
        self.min_members = min_members
        self.max_iterations = max_iterations
        self.min_distance = min_distance

    def initialize_centers(self):
        n = len(self.data)
        indices = np.linspace(0, n - 1, self.k, dtype=int)
        return self.data[indices]

    def euclidean_distance(self, x1, x2):
        return np.sqrt(np.sum((x1 - x2) ** 2))

    def assign_clusters(self, centers):
        clusters = [[] for _ in range(len(centers))]
        assignments = []
        for point in self.data:
            distances = [self.euclidean_distance(point, center) for center in centers]
            closest_center_index = np.argmin(distances)
            clusters[closest_center_index].append(point)
            assignments.append(closest_center_index)
        return clusters, np.array(assignments)

    def update_centers(self, clusters):
        new_centers = []
        valid_clusters = []
        for cluster in clusters:
            if len(cluster) >= self.min_members:
                new_centers.append(np.mean(cluster, axis=0))
                valid_clusters.append(cluster)
        return np.array(new_centers), valid_clusters

    def fit(self):
        centers = self.initialize_centers()
        for _ in range(self.max_iterations):
            old_centers = centers.copy()
            clusters, assignments = self.assign_clusters(centers)
            centers, valid_clusters = self.update_centers(clusters)
            if len(centers) == 0:
                centers = old_centers
                break
            if len(old_centers) == len(centers) and np.mean([self.euclidean_distance(old, new) for old, new in zip(old_centers, centers)]) < self.min_distance:
                break
        # Final assignment with final centers
        final_clusters, final_assignments = self.assign_clusters(centers)
        return centers, final_clusters, final_assignments
    

class RectangleMapTool(QgsMapTool):
    def __init__(self, canvas):
        super().__init__(canvas)
        self.canvas = canvas
        self.rubber_band = QgsRubberBand(self.canvas, QgsWkbTypes.PolygonGeometry)
        self.start_point = None
        self.is_drawing = False
        self.geometry = None

    def canvasPressEvent(self, event):
        if event.button() == Qt.LeftButton:
            self.start_point = self.toMapCoordinates(event.pos())
            self.is_drawing = True
            self.rubber_band.reset(QgsWkbTypes.PolygonGeometry)

    def canvasMoveEvent(self, event):
        if self.is_drawing:
            end_point = self.toMapCoordinates(event.pos())
            rect = QgsRectangle(self.start_point, end_point)
            self.rubber_band.reset(QgsWkbTypes.PolygonGeometry)
            self.rubber_band.setToGeometry(QgsGeometry.fromRect(rect), None)
            self.rubber_band.show()

    def canvasReleaseEvent(self, event):
        if event.button() == Qt.LeftButton and self.is_drawing:
            self.is_drawing = False
            end_point = self.toMapCoordinates(event.pos())
            rect = QgsRectangle(self.start_point, end_point)
            self.geometry = QgsGeometry.fromRect(rect)
            self.canvas.unsetMapTool(self)
            self.rubber_band.reset(QgsWkbTypes.PolygonGeometry)

    def deactivate(self):
        self.rubber_band.reset(QgsWkbTypes.PolygonGeometry)
        super().deactivate()


class PolygonMapTool(QgsMapTool):
    def __init__(self, canvas):
        super().__init__(canvas)
        self.canvas = canvas
        self.rubber_band = QgsRubberBand(self.canvas, QgsWkbTypes.PolygonGeometry)
        self.points = []
        self.is_drawing = False
        self.geometry = None

    def canvasPressEvent(self, event):
        if event.button() == Qt.LeftButton:
            self.is_drawing = True
            point = self.toMapCoordinates(event.pos())
            self.points.append(point)
            self.rubber_band.addPoint(point, True) # Add the point and update the rubber band
            self.rubber_band.show()

    def canvasDoubleClickEvent(self, event):
        """Finalizes the polygon on a double-click."""
        if self.is_drawing and len(self.points) > 2:
            self.is_drawing = False
            # Create the final polygon geometry from the collected points
            self.geometry = QgsGeometry.fromPolygonXY([self.points])
            # Deactivate this map tool
            self.canvas.unsetMapTool(self)
            # Reset the rubber band to clear the temporary drawing
            self.rubber_band.reset(QgsWkbTypes.PolygonGeometry)

    def deactivate(self):
        """Called when the tool is deactivated."""
        self.rubber_band.reset(QgsWkbTypes.PolygonGeometry)
        super().deactivate()


class ApplyModelDialog(QDialog):
    """A dialog to get inputs for applying a regression model to an image."""
    def __init__(self, model_x_field, parent=None):
        super(ApplyModelDialog, self).__init__(parent)
        self.setWindowTitle("Apply Regression Model")
        layout = QGridLayout(self)

        # --- Widgets ---
        layout.addWidget(QLabel("Image to Apply Model To:"), 0, 0)
        self.image_path_widget = QgsFileWidget()
        self.image_path_widget.setFilter("*.tif *.tiff *.jp2")
        layout.addWidget(self.image_path_widget, 0, 1)

        # Input for band number, with helpful text
        band_label_text = f"Band number in image corresponding to '{model_x_field}':"
        layout.addWidget(QLabel(band_label_text), 1, 0)
        self.band_spinbox = QSpinBox()
        self.band_spinbox.setRange(1, 100) # A reasonable max
        layout.addWidget(self.band_spinbox, 1, 1)

        # Optional Water Mask
        self.use_mask_checkbox = QCheckBox("Use Water Mask (pixels where mask is not 1 will be ignored)")
        self.use_mask_checkbox.stateChanged.connect(self.toggle_mask_input)
        layout.addWidget(self.use_mask_checkbox, 2, 0, 1, 2)

        self.mask_path_widget = QgsFileWidget()
        self.mask_path_widget.setFilter("*.tif *.tiff *.jp2")
        layout.addWidget(self.mask_path_widget, 3, 0, 1, 2)

        # Output Path
        layout.addWidget(QLabel("Output Prediction Raster:"), 4, 0)
        self.output_path_widget = QgsFileWidget()
        self.output_path_widget.setStorageMode(QgsFileWidget.SaveFile)
        self.output_path_widget.setFilter("GeoTIFF (*.tif)")
        layout.addWidget(self.output_path_widget, 4, 1)

        # OK and Cancel buttons
        button_box = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel)
        button_box.accepted.connect(self.accept)
        button_box.rejected.connect(self.reject)
        layout.addWidget(button_box, 5, 0, 1, 2)
        
        self.toggle_mask_input() # Set initial state

    def toggle_mask_input(self):
        self.mask_path_widget.setEnabled(self.use_mask_checkbox.isChecked())

    # --- Methods to retrieve user selections ---
    def imagePath(self):
        return self.image_path_widget.filePath()
    
    def bandNumber(self):
        return self.band_spinbox.value()
        
    def useMask(self):
        return self.use_mask_checkbox.isChecked()

    def maskPath(self):
        return self.mask_path_widget.filePath()

    def outputPath(self):
        return self.output_path_widget.filePath()
    

class ApplyEnsembleModelDialog(QDialog):
    def __init__(self, parent=None):
        super(ApplyEnsembleModelDialog, self).__init__(parent)
        self.setWindowTitle("Apply Ensemble Model to Image")
        layout = QGridLayout(self)
        self.widgets = {}

        ui_setup = [
            ('json_model', "Ensemble Component Model File (.json):", QgsFileWidget(), {'filter': "*.json"}),
            ('pickle_model', "Decision Tree Model File (.pickle):", QgsFileWidget(), {'filter': "*.pickle"}),
            ('satellite', "Satellite Type:", QComboBox(), {'items': ["Landsat 8/9", "Sentinel-2", "Landsat 5/7"]}),
            ('parameter', "Water Quality Parameter:", QComboBox(), {}),
            ('input_image', "Input Satellite Image (SR):", QgsFileWidget(), {'filter': "Image Files (*.tif *.tiff *.nc)"}),
            ('water_mask', "Input Water Mask Raster:", QgsFileWidget(), {'filter': "*.tif *.tiff"}),
            ('output_raster', "Output Prediction Raster (.tif):", QgsFileWidget(storageMode=QgsFileWidget.SaveFile), {'filter': "GeoTIFF (*.tif)"}),
        ]
        
        for i, (key, label_text, widget, setup) in enumerate(ui_setup):
            label = QLabel(label_text)
            if isinstance(widget, QComboBox) and 'items' in setup:
                widget.addItems(setup['items'])
            elif isinstance(widget, QgsFileWidget):
                widget.setFilter(setup['filter'])
            
            layout.addWidget(label, i, 0)
            layout.addWidget(widget, i, 1)
            self.widgets[key] = widget
        
        self.widgets['json_model'].fileChanged.connect(self.update_parameter_list)
        self.widgets['parameter'].setEnabled(False)

        buttons_layout = QHBoxLayout()
        self.run_button = QPushButton("Run Model")
        self.run_button.clicked.connect(self.process)
        self.close_button = QPushButton("Close")
        self.close_button.clicked.connect(self.accept)
        
        buttons_layout.addStretch()
        buttons_layout.addWidget(self.run_button)
        buttons_layout.addWidget(self.close_button)
        layout.addLayout(buttons_layout, len(ui_setup), 0, 1, 2)
        self.setLayout(layout)

    def update_parameter_list(self, json_path):
        """Reads the selected JSON file and populates the parameter dropdown."""
        param_combo = self.widgets['parameter']
        param_combo.clear()
        
        if not json_path or not os.path.exists(json_path):
            param_combo.setEnabled(False)
            return

        try:
            with open(json_path, "r") as f:
                models = json.load(f)
            
            parameters = sorted(list(set(m.get("para") for m in models if "para" in m)))
            
            if parameters:
                param_combo.addItems(parameters)
                param_combo.setEnabled(True)
            else:
                param_combo.addItem("No parameters found in file")
                param_combo.setEnabled(False)

        except (json.JSONDecodeError, IOError):
            param_combo.addItem("Invalid JSON file")
            param_combo.setEnabled(False)

    def process(self):
        """Gathers inputs and calls the shared helper function."""
        try:
            inputs = {k: v.currentText() if isinstance(v, QComboBox) else v.filePath() for k, v in self.widgets.items()}
            if not all(inputs.values()):
                QMessageBox.critical(self, "Input Error", "Please specify all input and output files."); return
            
            # Validation for NetCDF input file ---
            input_image_path = inputs['input_image']
            if input_image_path.lower().endswith('.nc'):
                base_name = os.path.basename(input_image_path)
                if not base_name.lower().endswith('_l2r.nc'):
                    QMessageBox.critical(self, 
                                         "Incorrect NetCDF File", 
                                         "For ACOLITE NetCDF files, please select the file ending in '_L2R.nc', which contains the surface reflectance data.")
                    return # Stop processing
                
            with open(inputs['json_model'], "r") as f:
                all_models = json.load(f)
            with open(inputs['pickle_model'], "rb") as f:
                classifier = pickle.load(f)

            selected_parameter = inputs['parameter']
            filtered_models_list = [model for model in all_models if model.get("para") == selected_parameter]

            if not filtered_models_list:
                QMessageBox.critical(self, "Model Not Found", f"No models for '{selected_parameter}' were found in the selected JSON file.")
                return
                
            sat_map = {"Landsat 5/7": ['B1','B2','B3','B4'], "Landsat 8/9": ['B2','B3','B4','B5'], "Sentinel-2": ['B2','B3','B4','B8']}
            selected_columns = sat_map[inputs['satellite']]

            run_ensemble_model_application(
                classifier, 
                filtered_models_list, 
                inputs['satellite'], 
                selected_columns, 
                inputs['input_image'], 
                inputs['water_mask'], 
                inputs['output_raster']
            )
            QMessageBox.information(self, "Success", "Ensemble model applied successfully.")
            self.accept()
            
        except Exception as e:
            import traceback
            QMessageBox.critical(self, "Processing Error", f"An error occurred:\n{e}\n\n{traceback.format_exc()}")


# --- INTEGRATED ENSEMBLE MODEL DIALOG CLASS
class WaterQualityEnsembleModelDialog(QDialog):
    def __init__(self, parent=None):
        super().__init__(parent)
        self.setWindowTitle("Complete Ensemble Model Workflow (Training & Application)")
        
        layout = QGridLayout(self)
        self.widgets = {}

        # Define the UI elements, now including an R-squared threshold and explicit pickle output
        ui_setup = [
            ('input_layer', "Input Shapefile with Sample Data:", QgsMapLayerComboBox(), {'filters': QgsMapLayerProxyModel.PointLayer}),
            ('satellite', "Satellite Type:", QComboBox(), {'items': ["Landsat 5/7", "Landsat 8/9", "Sentinel-2"]}),
            ('parameter', "Water Quality Parameter:", QComboBox(), {}),
            ('input_image', "Input Satellite Image (SR):", QgsFileWidget(), {'filter': "Image Files (*.tif *.tiff *.nc)"}),
            ('water_mask', "Input Water Mask Raster:", QgsFileWidget(), {'filter': "*.tif *.tiff"}),
            ('r2', "Minimum R-squared to Stop:", QDoubleSpinBox(), {'range': (0.0, 1.0), 'value': 0.95, 'step': 0.05}),
            ('max_iterations', "Max Iterations:", QSpinBox(), {'range': (1, 10000), 'value': 500}),
            ('min_samples', "Min Samples per Node:", QSpinBox(), {'range': (1, 1000), 'value': 2}),
            ('max_depth', "Max Tree Depth:", QSpinBox(), {'range': (1, 100), 'value': 10}),
            ('output_raster', "Output Prediction Raster (.tif):", QgsFileWidget(storageMode=QgsFileWidget.SaveFile), {'filter': "GeoTIFF (*.tif)"}),
            ('output_json', "Output Model Parameters File (.json):", QgsFileWidget(storageMode=QgsFileWidget.SaveFile), {'filter': "JSON files (*.json)"}),
            ('output_pickle_model', "Output Decision Tree Model (.pickle):", QgsFileWidget(storageMode=QgsFileWidget.SaveFile), {'filter': "Pickle files (*.pickle)"}),
            ('output_dt_image', "Output Decision Tree Image (.png):", QgsFileWidget(storageMode=QgsFileWidget.SaveFile), {'filter': "PNG image (*.png)"})
        ]
        
        for i, (key, label_text, widget, setup) in enumerate(ui_setup):
            label = QLabel(label_text)
            if isinstance(widget, QgsMapLayerComboBox): widget.setFilters(setup['filters'])
            elif isinstance(widget, QComboBox) and 'items' in setup: widget.addItems(setup['items'])
            elif isinstance(widget, QgsFileWidget): widget.setFilter(setup['filter'])
            elif isinstance(widget, QSpinBox): widget.setRange(*setup['range']); widget.setValue(setup['value'])
            elif isinstance(widget, QDoubleSpinBox): widget.setRange(*setup['range']); widget.setValue(setup['value']); widget.setSingleStep(setup['step'])
            
            layout.addWidget(label, i, 0); layout.addWidget(widget, i, 1)
            self.widgets[key] = widget
        
        buttons_layout = QHBoxLayout()
        self.run_button = QPushButton("Run Workflow"); self.close_button = QPushButton("Close")
        self.run_button.clicked.connect(self.process); self.close_button.clicked.connect(self.accept)
        buttons_layout.addStretch(); buttons_layout.addWidget(self.run_button); buttons_layout.addWidget(self.close_button)
        layout.addLayout(buttons_layout, len(ui_setup), 0, 1, 2)
        
        self.widgets['input_layer'].layerChanged.connect(self.update_parameters)
        self.update_parameters(self.widgets['input_layer'].currentLayer())

    def update_parameters(self, layer):
        """Populates the parameter dropdown based on fields in the selected layer."""
        param_combo = self.widgets['parameter']
        param_combo.clear()
        
        if not layer:
            param_combo.setEnabled(False)
            return

        potential_params = ["Turbidity", "Chlorophyl", "CDOM"]
        layer_fields = [field.name() for field in layer.fields()]
        available_params = [p for p in potential_params if p in layer_fields]
        
        if available_params:
            param_combo.addItems(available_params)
            param_combo.setEnabled(True)
        else:
            param_combo.addItem("No valid parameter fields found")
            param_combo.setEnabled(False)

    def process(self):
        """
        Executes the entire ensemble modeling workflow, now with R-squared convergence check.
        """
        try:
            # Gather all inputs from the UI
            inputs = {k: v.currentLayer() if isinstance(v, QgsMapLayerComboBox) else v.currentText() if isinstance(v, QComboBox) else v.value() if isinstance(v, (QSpinBox, QDoubleSpinBox)) else v.filePath() for k, v in self.widgets.items()}
            
            if not all(inputs.values()):
                QMessageBox.critical(self, "Input Error", "Please specify all inputs and outputs."); return

            if not self.widgets['parameter'].isEnabled() or not self.widgets['parameter'].currentText():
                QMessageBox.critical(self, "Input Error", "No valid water quality parameter selected or found in the layer."); return
            
            input_image_path = inputs['input_image']
            if input_image_path.lower().endswith('.nc'):
                base_name = os.path.basename(input_image_path)
                if not base_name.lower().endswith('_l2r.nc'):
                    QMessageBox.critical(self, 
                                         "Incorrect NetCDF File", 
                                         "For ACOLITE NetCDF files, please select the file ending in '_L2R.nc', which contains the surface reflectance data.")
                    return # Stop processing
                
            # --- Model Training ---
            df = pd.DataFrame([f.attributes() for f in inputs['input_layer'].getFeatures()], columns=[field.name() for field in inputs['input_layer'].fields()]).dropna()
            df = calculate_all_indices(df, inputs['satellite'])

            y_series = df[inputs['parameter']]
            pred_col_name = f"{inputs['parameter'][:5]}_Pred"
            model_map = {"Turbidity": ['RB_Ratio','NDTI','GB_Ratio','RedPlusNIR'], "Chlorophyl": ['NDCI','TWOBDA','THREEBDA','BG_Ratio','RG_Ratio','SABI'], "CDOM": ['GB_Ratio','RB_Ratio','RMinusG','GR_Ratio']}
            model_names = model_map[inputs['parameter']]
            
            sat_map = {"Landsat 5/7": ['B1','B2','B3','B4'], "Landsat 8/9": ['B2','B3','B4','B5'], "Sentinel-2": ['B2','B3','B4','B8']}
            selected_columns = sat_map[inputs['satellite']]
            
            # Initial model assignments using ISODATA
            isodata = ISODATA(np.array(df[selected_columns]), k=len(model_names), min_members=2)
            _, _, labels = isodata.fit()
            df['BestModel'] = [model_names[label % len(model_names)] for label in labels]
            df['LastBestModel'] = df['BestModel']

            count = 0
            converged_by_model = False
            converged_by_r2 = False
            current_r2 = 0.0

            for i in range(inputs['max_iterations']):
                count = i + 1
                
                frames = []
                for _, group_data in df.groupby('LastBestModel'):
                    y_group = group_data[inputs['parameter']]
                    predicted_group = [LinearRegression().fit(np.array(group_data[name]).reshape(-1,1), y_group).predict(np.array(group_data[name]).reshape(-1,1)) for name in model_names]
                    min_indices = np.argmin(np.abs(np.array(predicted_group).T - y_group.values[:,None]), axis=1)
                    
                    group_data = group_data.assign(BestModel=[model_names[i] for i in min_indices])
                    final_predictions_for_group = [predicted_group[model_idx][row_idx] for row_idx, model_idx in enumerate(min_indices)]
                    group_data[pred_col_name] = final_predictions_for_group
                    
                    frames.append(group_data)

                if frames: 
                    df = pd.concat(frames)

                # Check stopping conditions
                y_actual_iter = df[inputs['parameter']]
                y_predicted_iter = df[pred_col_name]
                
                if len(y_actual_iter) > 1:
                    current_r2 = r2_score(y_actual_iter, y_predicted_iter)

                converged_by_model = df['BestModel'].equals(df['LastBestModel'])
                converged_by_r2 = current_r2 >= inputs['r2']
                
                if converged_by_model or converged_by_r2:
                    break # Exit the loop early if any condition is met

                df['LastBestModel'] = df['BestModel']

            max_iter_reached = count >= inputs['max_iterations']
            # Save model parameters, classifier, and decision tree image
            models_para_list = [{"para": inputs['parameter'], "model_name": name, "coefficients": (lr := LinearRegression().fit(np.array(group[name]).reshape(-1,1), group[y_series.name])).coef_[0], "intercept": lr.intercept_} for name, group in df.groupby('BestModel')]
            with open(inputs['output_json'], "w") as f: json.dump(models_para_list, f, indent=4)

            classifier = DecisionTreeClassifier(criterion="gini", max_depth=inputs['max_depth'], min_samples_leaf=inputs['min_samples'])
            classifier.fit(df[selected_columns], df['BestModel'])
            
            with open(inputs['output_pickle_model'], "wb") as pkl_file:
                pickle.dump(classifier, pkl_file)
            
            dot_data = tree.export_graphviz(classifier, out_file=None, feature_names=selected_columns, class_names=sorted(df['BestModel'].unique()), filled=True, rounded=True)
            graphviz.Source(dot_data).render(os.path.splitext(inputs['output_dt_image'])[0], format='png', view=False, cleanup=True)
            
            # --- Model Application ---
            run_ensemble_model_application(
                classifier=classifier, 
                models_para_list=models_para_list, 
                field_satellite=inputs['satellite'], 
                selected_columns=selected_columns, 
                input_raster_path=inputs['input_image'], 
                watermask_path=inputs['water_mask'], 
                output_raster_path=inputs['output_raster']
            )

            message = "Complete workflow finished successfully!\n\n"
            if converged_by_r2:
                message += f"Calibration process converged after {count} iterations because R-squared met the threshold."
            elif max_iter_reached:
                message += f"Calibration process stopped after reaching the maximum of {count} iterations."
            elif df['BestModel'].equals(df['LastBestModel']):
                 message += f"Model converged in {count} iterations (model assignments stabilized)."
            QMessageBox.information(self, "Success", message)

        except Exception as e:
            import traceback
            QMessageBox.critical(self, "Processing Error", f"An error occurred during the workflow:\n{e}\n\n{traceback.format_exc()}")


# --- DECISION TREE DIALOG CLASS
class DecisonTreeDialog(QDialog):
    def __init__(self, parent=None):
        """Constructor."""
        super(DecisonTreeDialog, self).__init__(parent)
        self.setWindowTitle("Spectral-Space Partition (Decision Tree)")
        self.setMinimumWidth(800)

        layout = QGridLayout(self)
        self.widgets = {}

        # --- UI Setup ---
        ui_setup = [
            ('input_layer', "Input Calibrated Shapefile:", QgsMapLayerComboBox(), {'filters': QgsMapLayerProxyModel.PointLayer}),
            ('satellite', "Satellite Type:", QComboBox(), {'items': ["Landsat 5/7", "Landsat 8/9", "Sentinel-2"]}),
            ('parameter', "Water Quality Parameter:", QComboBox(), {}),
            ('min_samples', "Min Samples per Node:", QSpinBox(), {'range': (1, 1000), 'value': 2}),
            ('max_depth', "Max Tree Depth:", QSpinBox(), {'range': (1, 100), 'value': 10}),
            ('output_dt_image', "Output Decision Tree Image (.png):", QgsFileWidget(storageMode=QgsFileWidget.SaveFile), {'filter': "PNG image (*.png)"}),
            ('output_pickle', "Output Decision Tree Model (.pickle):", QgsFileWidget(storageMode=QgsFileWidget.SaveFile), {'filter': "Pickle file (*.pickle)"})
        ]

        for i, (key, label_text, widget, setup) in enumerate(ui_setup):
            label = QLabel(label_text)
            if isinstance(widget, QgsMapLayerComboBox): widget.setFilters(setup['filters'])
            elif isinstance(widget, QComboBox) and 'items' in setup: widget.addItems(setup['items'])
            elif isinstance(widget, QgsFileWidget): widget.setFilter(setup['filter'])
            elif isinstance(widget, QSpinBox): widget.setRange(*setup['range']); widget.setValue(setup['value'])
            
            layout.addWidget(label, i, 0); layout.addWidget(widget, i, 1)
            self.widgets[key] = widget
        
        # --- ADDED: Image display label ---
        self.image_label = QLabel()
        self.image_label.setAlignment(Qt.AlignCenter)
        self.image_label.setMinimumSize(600, 400)
        layout.addWidget(self.image_label, len(ui_setup), 0, 1, 2)
        self.image_label.setVisible(False)

        buttons_layout = QHBoxLayout()
        self.run_button = QPushButton("Run Training & View Tree")
        self.close_button = QPushButton("OK")
        self.run_button.clicked.connect(self.process)
        self.close_button.clicked.connect(self.accept)
        buttons_layout.addStretch()
        buttons_layout.addWidget(self.run_button)
        buttons_layout.addWidget(self.close_button)
        layout.addLayout(buttons_layout, len(ui_setup) + 1, 0, 1, 2)
        
        self.widgets['input_layer'].layerChanged.connect(self.update_parameters)
        self.update_parameters(self.widgets['input_layer'].currentLayer())

    def update_parameters(self, layer):
        """Populates the parameter dropdown based on fields in the selected layer."""
        param_combo = self.widgets['parameter']
        param_combo.clear()
        
        if not layer:
            param_combo.setEnabled(False)
            return

        potential_params = ["Turbidity", "Chlorophyl", "CDOM"]
        layer_fields = [field.name() for field in layer.fields()]
        available_params = [p for p in potential_params if p in layer_fields]
        
        if available_params:
            param_combo.addItems(available_params)
            param_combo.setEnabled(True)
        else:
            param_combo.addItem("No valid parameter fields found")
            param_combo.setEnabled(False)

    def process(self):
        """Main processing logic for decision tree training."""
        inputs = {k: v.currentLayer() if isinstance(v, QgsMapLayerComboBox) else v.currentText() if isinstance(v, QComboBox) else v.value() if isinstance(v, QSpinBox) else v.filePath() for k, v in self.widgets.items()}
        
        if not all(w.filePath() for w in [self.widgets['output_dt_image'], self.widgets['output_pickle']]) or not inputs['input_layer']:
            QMessageBox.critical(self, "Input Error", "Please specify an input layer and both output file paths."); return
        
        if not self.widgets['parameter'].isEnabled() or not self.widgets['parameter'].currentText():
            QMessageBox.critical(self, "Input Error", "No valid water quality parameter selected or found in the layer."); return

        df = pd.DataFrame([f.attributes() for f in inputs['input_layer'].getFeatures()], columns=[field.name() for field in inputs['input_layer'].fields()])

        try:
            sat_map = {"Landsat 5/7": ['B1','B2','B3','B4'], "Landsat 8/9": ['B2','B3','B4','B5'], "Sentinel-2": ['B2','B3','B4','B8']}
            selected_columns = sat_map[inputs['satellite']]

            model_map = {"Turbidity": ['RB_Ratio','NDTI','GB_Ratio','RedPlusNIR'], "Chlorophyl": ['NDCI','TWOBDA','THREEBDA','BG_Ratio','RG_Ratio','SABI'], "CDOM": ['GB_Ratio','RB_Ratio','RMinusG','GR_Ratio']}
            model_names = model_map[inputs['parameter']]

            required_cols = selected_columns + ['BestModel']
            if not all(col in df.columns for col in required_cols):
                missing = [col for col in required_cols if col not in df.columns]
                QMessageBox.critical(self, "Missing Columns", f"The input layer is missing required columns: {', '.join(missing)}. Please run the calibration tool first.")
                return

            classifier = DecisionTreeClassifier(criterion="gini", max_depth=inputs['max_depth'], min_samples_leaf=inputs['min_samples'])
            classifier.fit(df[selected_columns], df['BestModel'])

            with open(inputs['output_pickle'], "wb") as outclassifierfile:
                pickle.dump(classifier, outclassifierfile)

            dot_data = tree.export_graphviz(classifier, out_file=None, 
                                            feature_names=selected_columns, 
                                            class_names=sorted(df['BestModel'].unique()),
                                            filled=True, rounded=True, 
                                            special_characters=True)
            graph = graphviz.Source(dot_data)
            output_image_base, _ = os.path.splitext(inputs['output_dt_image'])
            png_path = graph.render(output_image_base, format='png', view=False, cleanup=True)
            
            # --- ADDED: Load and display the image ---
            if os.path.exists(png_path):
                pixmap = QPixmap(png_path)
                if not pixmap.isNull():
                    # Scale pixmap to fit the label's width, maintaining aspect ratio
                    scaled_pixmap = pixmap.scaledToWidth(self.image_label.width(), Qt.SmoothTransformation)
                    self.image_label.setPixmap(scaled_pixmap)
                    self.image_label.setVisible(True)
                    self.resize(self.sizeHint())
                else:
                    self.image_label.setText("Could not load generated image.")
                    self.image_label.setVisible(True)
            else:
                self.image_label.setText("Generated image file not found.")
                self.image_label.setVisible(True)

            QMessageBox.information(self, "Success", "Decision tree training complete. Model and image files have been saved.")

        except Exception as e:
            import traceback
            QMessageBox.critical(self, "Processing Error", f"An error occurred:\n{e}\n\n{traceback.format_exc()}")


# --- ENSEMBLE CALIBRATION DIALOG CLASS
class EnsembleCalibrationDialog(QDialog):
    def __init__(self, parent=None):
        """Constructor."""
        super(EnsembleCalibrationDialog, self).__init__(parent)
        self.setWindowTitle("Ensemble Component Models Calibration")
        self.setMinimumSize(600, 400)

        layout = QGridLayout(self)
        self.widgets = {}

        ui_setup = [
            ('input_layer', "Select Layer to Calibrate/Update:", QgsMapLayerComboBox(), {'filters': QgsMapLayerProxyModel.PointLayer}),
            ('satellite', "Satellite Type:", QComboBox(), {'items': ["Landsat 5/7", "Landsat 8/9", "Sentinel-2"]}),
            ('parameter', "Water Quality Parameter:", QComboBox(), {}),
            ('model_list', "Select Component Models:", QtWidgets.QListWidget(), {'selection_mode': QtWidgets.QAbstractItemView.ExtendedSelection}),
            ('r2', "Minimum R-squared:", QDoubleSpinBox(), {'range': (0.0, 1.0), 'value': 0.90, 'step': 0.05}),
            ('max_iterations', "Max Iterations:", QSpinBox(), {'range': (1, 10000), 'value': 500}),
            ('output_json', "Output Calibrated Model File (.json):", QgsFileWidget(storageMode=QgsFileWidget.SaveFile), {'filter': "JSON files (*.json)"})
        ]

        for i, (key, label_text, widget, setup) in enumerate(ui_setup):
            label = QLabel(label_text)
            if isinstance(widget, QgsMapLayerComboBox): widget.setFilters(setup['filters'])
            elif isinstance(widget, QComboBox) and 'items' in setup: widget.addItems(setup['items'])
            elif isinstance(widget, QgsFileWidget): widget.setFilter(setup['filter'])
            elif isinstance(widget, QSpinBox): widget.setRange(*setup['range']); widget.setValue(setup['value'])
            elif isinstance(widget, QDoubleSpinBox): widget.setRange(*setup['range']); widget.setValue(setup['value']); widget.setSingleStep(setup['step'])
            elif isinstance(widget, QtWidgets.QListWidget): widget.setSelectionMode(setup['selection_mode'])
            
            layout.addWidget(label, i, 0); layout.addWidget(widget, i, 1)
            self.widgets[key] = widget

        self.figure = plt.figure(dpi=300, constrained_layout=True) 
        self.plot_canvas = FigureCanvas(self.figure)
        self.plot_canvas.setVisible(False)
        layout.addWidget(self.plot_canvas, len(ui_setup), 0, 1, 2)

        buttons_layout = QHBoxLayout()
        self.run_button = QPushButton("Run Calibration"); self.close_button = QPushButton("Close")
        self.run_button.clicked.connect(self.process); self.close_button.clicked.connect(self.accept)
        buttons_layout.addStretch(); buttons_layout.addWidget(self.run_button); buttons_layout.addWidget(self.close_button)
        layout.addLayout(buttons_layout, len(ui_setup) + 1, 0, 1, 2)
        
        self.widgets['input_layer'].layerChanged.connect(self.update_parameters)
        self.widgets['parameter'].currentIndexChanged.connect(self.update_model_list)
        
        self.update_parameters(self.widgets['input_layer'].currentLayer())

    def update_parameters(self, layer):
        """Populates the parameter dropdown and triggers the model list update."""
        param_combo = self.widgets['parameter']
        param_combo.clear()
        
        if not layer:
            param_combo.setEnabled(False)
            self.update_model_list() # Clear model list if no layer
            return

        potential_params = ["Turbidity", "Chlorophyl", "CDOM"]
        layer_fields = [field.name() for field in layer.fields()]
        available_params = [p for p in potential_params if p in layer_fields]
        
        if available_params:
            param_combo.addItems(available_params)
            param_combo.setEnabled(True)
        else:
            param_combo.addItem("No valid parameter fields found")
            param_combo.setEnabled(False)
        
        self.update_model_list()

    def update_model_list(self):
        """Populates the model list based on the selected parameter."""
        model_list_widget = self.widgets['model_list']
        model_list_widget.clear()
        
        parameter = self.widgets['parameter'].currentText()
        if not parameter or not self.widgets['parameter'].isEnabled():
            return
            
        model_map = {
            "Turbidity": ['RB_Ratio', 'NDTI', 'GB_Ratio', 'RedPlusNIR'],
            "Chlorophyl": ['NDCI', 'TWOBDA', 'THREEBDA', 'BG_Ratio', 'RG_Ratio', 'SABI'],
            "CDOM": ['GB_Ratio', 'RB_Ratio', 'RMinusG', 'GR_Ratio']
        }
        
        available_models = model_map.get(parameter, [])
        model_list_widget.addItems(available_models)
        for i in range(model_list_widget.count()):
            model_list_widget.item(i).setSelected(True)

    def process(self):
        """Main processing logic for model calibration."""
        self.plot_canvas.setVisible(False)
        self.resize(self.minimumSizeHint())

        inputs = {k: v.currentLayer() if isinstance(v, QgsMapLayerComboBox) else v.selectedItems() if isinstance(v, QtWidgets.QListWidget) else v.currentText() if isinstance(v, QComboBox) else v.value() if isinstance(v, (QSpinBox, QDoubleSpinBox)) else v.filePath() for k, v in self.widgets.items()}
        
        if not inputs['input_layer'] or not inputs['output_json']:
            QMessageBox.critical(self, "Input Error", "Please specify an input layer and a JSON output path."); return

        if not self.widgets['parameter'].isEnabled() or not self.widgets['parameter'].currentText():
            QMessageBox.critical(self, "Input Error", "No valid water quality parameter selected or found in the layer."); return

        model_names = [item.text() for item in inputs['model_list']]
        if not model_names:
            QMessageBox.critical(self, "Input Error", "Please select at least one component model to calibrate.")
            return

        features = list(inputs['input_layer'].getFeatures())
        if not features: QMessageBox.warning(self, "Warning", "Input layer has no features."); return

        feature_data = [{'fid': f.id(), **dict(zip([field.name() for field in inputs['input_layer'].fields()], f.attributes()))} for f in features]
        df = pd.DataFrame(feature_data).dropna()

        try:
            calculate_all_indices(df, inputs['satellite'])
            
            y_series = df[inputs['parameter']]
            pred_col_name = f"{inputs['parameter'][:5]}_Pred"
            
            df[pred_col_name] = 0.0
            df['LastBestModel'], df['BestModel'] = '', ''
            
            predicted = [LinearRegression().fit(np.array(df[name]).reshape(-1, 1), y_series).predict(np.array(df[name]).reshape(-1, 1)) for name in model_names]
            min_indices_initial = np.argmin(np.abs(np.array(predicted).T - y_series.values[:, None]), axis=1)
            df['BestModel'] = [model_names[i] for i in min_indices_initial]
            df[pred_col_name] = [predicted[model_idx][row_idx] for row_idx, model_idx in enumerate(min_indices_initial)]
            df['LastBestModel'] = df['BestModel']

            count = 0
            converged_by_model = False
            converged_by_r2 = False
            current_r2 = 0.0

            for i in range(inputs['max_iterations']):
                count = i + 1
                
                frames = []
                for _, group_data in df.groupby('LastBestModel'):
                    y_group = group_data[inputs['parameter']]
                    predicted_group = [LinearRegression().fit(np.array(group_data[name]).reshape(-1,1), y_group).predict(np.array(group_data[name]).reshape(-1,1)) for name in model_names]
                    min_indices = np.argmin(np.abs(np.array(predicted_group).T - y_group.values[:,None]), axis=1)
                    
                    group_data = group_data.assign(BestModel=[model_names[i] for i in min_indices])
                    final_predictions_for_group = [predicted_group[model_idx][row_idx] for row_idx, model_idx in enumerate(min_indices)]
                    group_data[pred_col_name] = final_predictions_for_group
                    
                    frames.append(group_data)

                if frames: 
                    df = pd.concat(frames)

                # Check stopping conditions
                y_actual_iter = df[inputs['parameter']]
                y_predicted_iter = df[pred_col_name]
                
                if len(y_actual_iter) > 1:
                    current_r2 = r2_score(y_actual_iter, y_predicted_iter)

                converged_by_model = df['BestModel'].equals(df['LastBestModel'])
                converged_by_r2 = current_r2 >= inputs['r2']
                
                if converged_by_model or converged_by_r2:
                    break # Exit the loop early if any condition is met

                df['LastBestModel'] = df['BestModel']

            models_para_list = [{"para": inputs['parameter'], "model_name": name, "coefficients": (lr := LinearRegression().fit(np.array(group[name]).reshape(-1,1), group[y_series.name])).coef_[0], "intercept": lr.intercept_} for name, group in df.groupby('BestModel')]
            with open(inputs['output_json'], "w") as f: json.dump(models_para_list, f, indent=4)
        except Exception as e:
            QMessageBox.critical(self, "Calculation Error", f"An error occurred during calibration:\n{e}"); return

        self.figure.clear()
        ax = self.figure.add_subplot(111)
        y_actual, y_predicted = df[inputs['parameter']], df[pred_col_name]
        ax.scatter(y_actual, y_predicted, alpha=0.6, s=20, edgecolors='w', linewidth=0.5, label="Predicted vs. Actual")
        
        X_plot = y_actual.values.reshape(-1, 1)
        plot_model = LinearRegression().fit(X_plot, y_predicted)
        slope, intercept, r2 = plot_model.coef_[0], plot_model.intercept_, plot_model.score(X_plot, y_predicted)
        
        all_plot_values = np.concatenate([y_actual, y_predicted])
        lims = [all_plot_values.min(), all_plot_values.max()]
        range_buffer = (lims[1] - lims[0]) * 0.05
        lims = [lims[0] - range_buffer, lims[1] + range_buffer]
        
        ax.plot(lims, lims, 'r--', alpha=0.75, zorder=0, label="1:1 Line")
        
        r2_text = f'$R^2 = {r2:.3f}$'
        ax.text(0.05, 0.95, r2_text, transform=ax.transAxes, fontsize=10, 
                verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

        ax.set_aspect('equal', adjustable='box')
        ax.set_xlim(lims)
        ax.set_ylim(lims)
        ax.set_xlabel("Actual Values")
        ax.set_ylabel("Predicted Values")
        ax.set_title(f"Ensemble Model Performance for {inputs['parameter']}")
        ax.legend(loc='lower right')
        ax.grid(True, linestyle=':', alpha=0.7) # Cleaner grid for high res
        
        self.plot_canvas.draw(); self.plot_canvas.setVisible(True); self.resize(self.sizeHint())

        provider = inputs['input_layer'].dataProvider()
        fields_to_add = []
        if 'BestModel' not in provider.fieldNameMap(): fields_to_add.append(QgsField('BestModel', 10, "text", 20))
        if pred_col_name not in provider.fieldNameMap(): fields_to_add.append(QgsField(pred_col_name, 6, "double", 20, 10))
        
        if fields_to_add: provider.addAttributes(fields_to_add); inputs['input_layer'].updateFields()

        best_model_idx, pred_col_idx = inputs['input_layer'].fields().indexOf('BestModel'), inputs['input_layer'].fields().indexOf(pred_col_name)

        inputs['input_layer'].startEditing()
        for _, row in df.iterrows():
            inputs['input_layer'].changeAttributeValue(row['fid'], best_model_idx, row['BestModel'])
            inputs['input_layer'].changeAttributeValue(row['fid'], pred_col_idx, row[pred_col_name])
        
        if not inputs['input_layer'].commitChanges():
            QMessageBox.critical(self, "Error", "Could not commit changes to the layer."); inputs['input_layer'].rollBack()
        else:
            # --- MODIFIED: Create a more informative success message ---
            max_iter_reached = count >= inputs['max_iterations']
            
            message = "Calibration successful!\n"
            message += "Layer updated, plot generated, and model JSON file saved.\n\n"
            
            if converged_by_r2:
                message += f"Process converged after {count} iterations because R-squared met the threshold."
            elif max_iter_reached:
                message += f"Process stopped after reaching the maximum of {count} iterations."
            elif df['BestModel'].equals(df['LastBestModel']):
                 message += f"Model converged in {count} iterations (model assignments stabilized)."


            QMessageBox.information(self, "Success", message)


# --- SPECTRAL INDEX DIALOG CLASS
class SpectralIndexDialog(QDialog):
    def __init__(self, parent=None):
        """Constructor."""
        super(SpectralIndexDialog, self).__init__(parent)
        self.setWindowTitle("Calculate Water Quality Spectral Indices")
        
        # --- UI Setup ---
        layout = QGridLayout(self)

        # Input Layer
        self.layer_label = QLabel("Select Point Layer to Update:")
        self.layer_combo = QgsMapLayerComboBox()
        self.layer_combo.setFilters(QgsMapLayerProxyModel.PointLayer)
        layout.addWidget(self.layer_label, 0, 0)
        layout.addWidget(self.layer_combo, 0, 1, 1, 2)

        # Satellite Type
        self.satellite_label = QLabel("Satellite Type:")
        self.satellite_combo = QComboBox()
        self.satellite_combo.addItems(["Landsat 5", "Landsat 7", "Landsat 8", "Landsat 9", "Sentinel-2"])
        layout.addWidget(self.satellite_label, 1, 0)
        layout.addWidget(self.satellite_combo, 1, 1, 1, 2)

        # Water Quality Parameter
        self.param_label = QLabel("Water Quality Parameter:")
        self.param_combo = QComboBox()
        self.param_combo.addItems(["Turbidity", "Chlorophyl", "CDOM"])
        layout.addWidget(self.param_label, 2, 0)
        layout.addWidget(self.param_combo, 2, 1, 1, 2)

        # Buttons
        self.button_box = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel)
        self.button_box.accepted.connect(self.process)
        self.button_box.rejected.connect(self.reject)
        layout.addWidget(self.button_box, 4, 1, 1, 2)

    def process(self):
        """Main processing logic to update the selected layer in-place."""
        input_layer = self.layer_combo.currentLayer()
        field_satellite = self.satellite_combo.currentText()
        field_qual_para = self.param_combo.currentText()

        # --- Input Validation ---
        if not input_layer:
            QMessageBox.critical(self, "Input Error", "Please select an input layer.")
            return

        # --- Read features and attributes into a pandas DataFrame ---
        features = list(input_layer.getFeatures())
        if not features:
            QMessageBox.warning(self, "Warning", "Input layer contains no features.")
            return

        feature_data = [{'fid': f.id(), **dict(zip([field.name() for field in f.fields()], f.attributes()))} for f in features]
        df = pd.DataFrame(feature_data)

        # --- Spectral Index Calculation (same logic as before) ---
        model_names = [] # To hold the names of the new fields
        try:
            if field_qual_para == "Turbidity":
                model_names = ['RB_Ratio', 'NDTI', 'GB_Ratio', 'RedPlusNIR']
                if "Landsat 5" in field_satellite or "Landsat 7" in field_satellite:
                    Blue, Green, Red, NIR = df['B1'], df['B2'], df['B3'], df['B4']
                elif "Landsat 8" in field_satellite or "Landsat 9" in field_satellite:
                    Blue, Green, Red, NIR = df['B2'], df['B3'], df['B4'], df['B5']
                elif field_satellite == "Sentinel-2":
                    Blue, Green, Red, NIR = df['B2'], df['B3'], df['B4'], df['B8']
                df['NDTI'] = (Red - Green) / (Red + Green)
                df['RB_Ratio'] = Red / Blue
                df['GB_Ratio'] = Green / Blue
                df['RedPlusNIR'] = Red + NIR
            elif field_qual_para == "Chlorophyl":
                model_names = ['NDCI', 'TWOBDA', 'THREEBDA', 'BG_Ratio', 'RG_Ratio', 'SABI']
                if "Landsat 5" in field_satellite or "Landsat 7" in field_satellite:
                    Blue, Green, Red, NIR = df['B1'], df['B2'], df['B3'], df['B4']
                    df['NDCI'] = (NIR - Red) / (NIR + Red)
                    df['TWOBDA'] = NIR / Red
                    df['THREEBDA'] = (Blue - Red) / Green
                    df['SABI'] = (NIR - Red) / (Blue + Green)
                elif "Landsat 8" in field_satellite or "Landsat 9" in field_satellite:
                    Blue, Green, Red, NIR = df['B2'], df['B3'], df['B4'], df['B5']
                    df['NDCI'] = (NIR - Red) / (NIR + Red)
                    df['TWOBDA'] = NIR / Red
                    df['THREEBDA'] = (Blue - Red) / Green
                    df['SABI'] = (NIR - Red) / (Blue + Green)
                elif field_satellite == "Sentinel-2":
                    B1, Blue, Green, Red, B5, B6, NIR, B8A = df['B1'], df['B2'], df['B3'], df['B4'], df['B5'], df['B6'], df['B8'], df['B8A']
                    df['NDCI'] = (B5 - Red) / (B5 + Red)
                    df['TWOBDA'] = B5 / Red
                    df['THREEBDA'] = (1 / Red - 1 / B5) * B6
                    df['SABI'] = (B8A - Red) / (B1 + Green)
                df['BG_Ratio'] = Blue / Green
                df['RG_Ratio'] = Red / Green
            elif field_qual_para == "CDOM":
                model_names = ['GB_Ratio', 'RB_Ratio', 'RMinusG', 'GR_Ratio']
                if "Landsat 5" in field_satellite or "Landsat 7" in field_satellite:
                    Blue, Green, Red = df['B1'], df['B2'], df['B3']
                elif "Landsat 8" in field_satellite or "Landsat 9" in field_satellite:
                    Blue, Green, Red = df['B2'], df['B3'], df['B4']
                elif field_satellite == "Sentinel-2":
                    Blue, Green, Red = df['B2'], df['B3'], df['B4']
                df['GB_Ratio'] = Green / Blue
                df['RB_Ratio'] = Red / Blue
                df['RMinusG'] = Red - Green
                df['GR_Ratio'] = Green / Red
        except KeyError as e:
            QMessageBox.critical(self, "Missing Field", f"The input layer is missing a required band field: {e}")
            return
            
        # --- Update the shapefile in-place ---
        provider = input_layer.dataProvider()
        existing_field_names = [field.name() for field in provider.fields()]
        
        # Add new fields if they don't already exist
        fields_to_add = []
        for name in model_names:
            if name not in existing_field_names:
                fields_to_add.append(QgsField(name, 6, "double", 20, 10)) # QVariant.Double
        
        if fields_to_add:
            provider.addAttributes(fields_to_add)
            input_layer.updateFields()

        # Get the indices of the new fields
        field_indices = {name: input_layer.fields().indexOf(name) for name in model_names}

        # Start an edit session
        input_layer.startEditing()
        
        # Iterate through the DataFrame to update each feature
        for index, row in df.iterrows():
            feature_id = row['fid']
            attributes_to_update = {}
            for name in model_names:
                attributes_to_update[field_indices[name]] = row[name]
            input_layer.changeAttributeValues(feature_id, attributes_to_update)

        # Commit the changes to the layer
        if input_layer.commitChanges():
            QMessageBox.information(self, "Success", "The selected layer has been updated successfully.")
        else:
            QMessageBox.critical(self, "Error", "Could not commit changes to the layer.")
            input_layer.rollBack()

        self.accept()
   

# --- SPECTRAL CLUSTER DIALOG CLASS ---
class SpectralClusterDialog(QDialog):
    def __init__(self, parent=None):
        """Constructor."""
        super(SpectralClusterDialog, self).__init__(parent)
        self.setWindowTitle("Spectral Clustering of Sample Points")
        self.layout = QGridLayout()

        # Input Layer
        self.layer_label = QLabel("Input Shapefile:")
        self.layer_combo = QgsMapLayerComboBox()
        self.layer_combo.setFilters(QgsMapLayerProxyModel.VectorLayer)
        self.layout.addWidget(self.layer_label, 0, 0)
        self.layout.addWidget(self.layer_combo, 0, 1)

        # Satellite Type
        self.satellite_label = QLabel("Satellite Type:")
        self.satellite_combo = QComboBox()
        self.satellite_combo.addItems(["Landsat 5/7", "Landsat 8/9", "Sentinel-2"])
        self.layout.addWidget(self.satellite_label, 1, 0)
        self.layout.addWidget(self.satellite_combo, 1, 1)

        # Cluster Number
        self.cluster_label = QLabel("Number of Clusters:")
        self.cluster_spinbox = QSpinBox()
        self.cluster_spinbox.setRange(2, 100)
        self.cluster_spinbox.setValue(3)
        self.layout.addWidget(self.cluster_label, 2, 0)
        self.layout.addWidget(self.cluster_spinbox, 2, 1)
        
        # Minimum Members
        self.min_member_label = QLabel("Minimum Members per Cluster:")
        self.min_member_spinbox = QSpinBox()
        self.min_member_spinbox.setRange(1, 1000)
        self.min_member_spinbox.setValue(2)
        self.layout.addWidget(self.min_member_label, 3, 0)
        self.layout.addWidget(self.min_member_spinbox, 3, 1)

        # Max Iterations
        self.max_iter_label = QLabel("Max Iterations:")
        self.max_iter_spinbox = QSpinBox()
        self.max_iter_spinbox.setRange(1, 10000)
        self.max_iter_spinbox.setValue(1000)
        self.layout.addWidget(self.max_iter_label, 4, 0)
        self.layout.addWidget(self.max_iter_spinbox, 4, 1)

        # Buttons
        self.button_box = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel)
        self.button_box.accepted.connect(self.process)
        self.button_box.rejected.connect(self.reject)
        self.layout.addWidget(self.button_box, 5, 0, 1, 2)

        self.setLayout(self.layout)
    
    def process(self):
        """This method is called when the OK button is clicked."""
        layer = self.layer_combo.currentLayer()
        if not layer:
            QMessageBox.critical(None, "Error", "No layer selected.")
            return

        field_satellite = self.satellite_combo.currentText()
        cluster_number = self.cluster_spinbox.value()
        min_member = self.min_member_spinbox.value()
        max_iteration = self.max_iter_spinbox.value()

        # Define selected columns based on satellite
        if "Landsat 5" in field_satellite or "Landsat 7" in field_satellite:
            selected_columns = ['B1','B2','B3','B4']
        elif "Landsat 8" in field_satellite or "Landsat 9" in field_satellite:
            selected_columns = ['B2','B3','B4','B5']
        elif "Sentinel-2" in field_satellite:
            selected_columns = ['B2','B3','B4','B8']
        else:
            QMessageBox.critical(None, "Error", "Invalid satellite type.")
            return

        # Check if all required fields exist in the layer
        layer_fields = [field.name() for field in layer.fields()]
        if not all(col in layer_fields for col in selected_columns):
            QMessageBox.critical(None, "Error", f"The selected layer is missing one or more required fields for {field_satellite}: {', '.join(selected_columns)}")
            return

        # Read features into a pandas DataFrame, preserving feature IDs
        features = layer.getFeatures()
        data_list = [{'feature_id': f.id(), **dict(zip(layer_fields, f.attributes()))} for f in features]
        
        if not data_list:
            QMessageBox.warning(None, "Warning", "The selected layer has no features.")
            return
        
        df = pd.DataFrame(data_list)
        df.set_index('feature_id', inplace=True)

        # Prepare data for clustering, keeping track of original indices
        df_isodata = df[selected_columns].dropna()
        if df_isodata.empty:
            QMessageBox.warning(None, "Warning", "No valid data for clustering after removing null values.")
            return

        your_data = np.array(df_isodata)
        
        # Run ISODATA clustering
        isodata = ISODATA(your_data, cluster_number, min_member, max_iteration)
        centers, clusters, labels = isodata.fit()
        
        df_isodata['cluster'] = labels

        # Add results back to the shapefile
        provider = layer.dataProvider()
        cluster_field_name = 'cluster'
        
        # Add 'cluster' field if it doesn't exist
        if cluster_field_name not in layer_fields:
            provider.addAttributes([QgsField(cluster_field_name, 10)]) # QVariant.Int
            layer.updateFields()

        cluster_field_index = layer.fields().indexOf(cluster_field_name)
        
        layer.startEditing()
        
        # Map labels back to the original features using feature IDs
        for feature_id, row in df_isodata.iterrows():
            layer.changeAttributeValue(int(feature_id), cluster_field_index, int(row['cluster']))

        if layer.commitChanges():
            QMessageBox.information(None, "Success", f"Clustering complete. '{cluster_field_name}' field has been added/updated.")
        else:
            QMessageBox.critical(None, "Error", "Could not commit changes to the layer.")
        
        self.accept()


# --- WATER QUALITY INDICES CLASS ---
class WaterQualityIndicesDialog(QDialog):
    def __init__(self, parent=None):
        super(WaterQualityIndicesDialog, self).__init__(parent)
        self.setWindowTitle("Calculate Water Quality Spectral Indices from Image")
        self.layout = QVBoxLayout()
        
        # --- UI Setup ---
        controls_layout = QGridLayout()

        # Water Quality Parameter dropdown
        self.param_label = QLabel("1. Select Water Quality Parameter:")
        self.param_combo = QComboBox()
        self.param_combo.addItems(["Chlorophyl", "Turbidity", "CDOM"])
        controls_layout.addWidget(self.param_label, 0, 0, 1, 2)
        controls_layout.addWidget(self.param_combo, 1, 0, 1, 2)

        # Index Selection List
        self.index_label = QLabel("2. Select Indices to Calculate:")
        self.index_list_widget = QtWidgets.QListWidget()
        self.index_list_widget.setSelectionMode(QtWidgets.QAbstractItemView.ExtendedSelection) # Allow multi-select
        controls_layout.addWidget(self.index_label, 2, 0, 1, 2)
        controls_layout.addWidget(self.index_list_widget, 3, 0, 1, 2)

        # Connect the parameter combo to update the index list
        self.param_combo.currentIndexChanged.connect(self.update_index_list)
        
        self.layout.addLayout(controls_layout)

        # Input Bands and Paths
        paths_layout = QGridLayout()
        paths_layout.addWidget(QLabel("3. Select Input Bands and Output Directory:"), 0, 0, 1, 2)
        
        paths_layout.addWidget(QLabel("Blue Band:"), 1, 0)
        self.blue_path_widget = QgsFileWidget(); paths_layout.addWidget(self.blue_path_widget, 1, 1)

        paths_layout.addWidget(QLabel("Green Band:"), 2, 0)
        self.green_path_widget = QgsFileWidget(); paths_layout.addWidget(self.green_path_widget, 2, 1)
        
        paths_layout.addWidget(QLabel("Red Band:"), 3, 0)
        self.red_path_widget = QgsFileWidget(); paths_layout.addWidget(self.red_path_widget, 3, 1)
        
        paths_layout.addWidget(QLabel("NIR Band:"), 4, 0)
        self.nir_path_widget = QgsFileWidget(); paths_layout.addWidget(self.nir_path_widget, 4, 1)

        paths_layout.addWidget(QLabel("RedEdge Band (Optional for Sentinel-2):"), 5, 0)
        self.rededge_path_widget = QgsFileWidget(); paths_layout.addWidget(self.rededge_path_widget, 5, 1)
        
        # --- MODIFIED: Water mask is now a required input ---
        paths_layout.addWidget(QLabel("Water Mask (1=water):"), 6, 0)
        self.water_mask_widget = QgsFileWidget(); paths_layout.addWidget(self.water_mask_widget, 6, 1)

        paths_layout.addWidget(QLabel("Output Directory:"), 7, 0)
        self.output_dir_widget = QgsFileWidget(); self.output_dir_widget.setStorageMode(QgsFileWidget.GetDirectory); paths_layout.addWidget(self.output_dir_widget, 7, 1)
        
        self.layout.addLayout(paths_layout)

        # Process button
        self.process_button = QPushButton("Generate Selected Indices")
        self.process_button.clicked.connect(self.process)
        self.layout.addWidget(self.process_button)
        
        self.setLayout(self.layout)
        
        # Populate the index list for the initial parameter
        self.update_index_list()

    def update_index_list(self):
        """Populates the index list based on the selected parameter."""
        self.index_list_widget.clear()
        selected_parameter = self.param_combo.currentText()

        index_map = {
            "Chlorophyl": ["NDCI", "TWOBDA", "THREEBDA", "BlueGreenRatio", "RedGreenRatio", "SABI"],
            "Turbidity": ["RedBlueRatio", "NDTI", "GreenBlueRatio", "RedPlusNIR"],
            "CDOM": ["GreenBlueRatio", "RedBlueRatio", "RMinusG", "GreenRedRatio"]
        }
        
        available_indices = index_map.get(selected_parameter, [])
        self.index_list_widget.addItems(available_indices)
        for i in range(self.index_list_widget.count()):
            self.index_list_widget.item(i).setSelected(True)

    def validate_raster_inputs(self, layer_info_list):
        # (This helper function is unchanged)
        if not layer_info_list: return True, ""
        try:
            ref_name, ref_path = layer_info_list[0]
            with rasterio.open(ref_path) as ref_ds:
                ref_crs, ref_transform, ref_shape = ref_ds.crs, ref_ds.transform, ref_ds.shape
            for name, path in layer_info_list[1:]:
                with rasterio.open(path) as ds:
                    if ds.crs != ref_crs: return False, f"CRS Mismatch: Layer '{name}' ({ds.crs}) does not match reference '{ref_name}' ({ref_crs})."
                    if ds.transform != ref_transform: return False, f"Transform/Pixel Size Mismatch: Layer '{name}' does not align with reference '{ref_name}'."
                    if ds.shape != ref_shape: return False, f"Dimension Mismatch: Layer '{name}' ({ds.height}x{ds.width}) does not match reference '{ref_name}' ({ref_shape[0]}x{ref_shape[1]})."
            return True, ""
        except Exception as e:
            return False, f"Could not open or read one of the input files. Please check paths.\n\nDetails: {e}"

    def save_raster(self, data, profile, output_path):
        # (This helper function is unchanged)
        with rasterio.open(output_path, 'w', **profile) as dst:
            dst.write(data, 1)

    def add_layer_with_color_ramp(self, path, layer_name, color_start, color_end, min_val, max_val):
        # (This helper function is unchanged)
        layer = QgsRasterLayer(path, layer_name)
        if not layer.isValid(): return
        shader = QgsColorRampShader(); shader.setColorRampType(QgsColorRampShader.Interpolated)
        color_ramp_items = [QgsColorRampShader.ColorRampItem(min_val, QColor(color_start), str(min_val)), QgsColorRampShader.ColorRampItem(max_val, QColor(color_end), str(max_val))]
        shader.setColorRampItemList(color_ramp_items)
        raster_shader = QgsRasterShader(); raster_shader.setRasterShaderFunction(shader)
        renderer = QgsSingleBandPseudoColorRenderer(layer.dataProvider(), 1, raster_shader)
        layer.setRenderer(renderer); layer.triggerRepaint()
        QgsProject.instance().addMapLayer(layer)

    def process(self):
        # 1. Gather Inputs
        output_dir = self.output_dir_widget.filePath()
        selected_items = self.index_list_widget.selectedItems()
        selected_indices = [item.text() for item in selected_items]

        # --- MODIFIED: Validation logic now requires a water mask ---
        inputs_to_validate = []
        required_paths = {
            "Blue": self.blue_path_widget.filePath(),
            "Green": self.green_path_widget.filePath(),
            "Red": self.red_path_widget.filePath(),
            "NIR": self.nir_path_widget.filePath(),
            "WaterMask": self.water_mask_widget.filePath()
        }
        
        for name, path in required_paths.items():
            if not path:
                QMessageBox.warning(self, "Input Error", f"Please specify the required '{name}' file.")
                return
            inputs_to_validate.append((name, path))

        rededge_path = self.rededge_path_widget.filePath()
        if rededge_path:
            inputs_to_validate.append(("RedEdge", rededge_path))

        if not selected_indices:
            QMessageBox.warning(self, "Input Error", "Please select at least one index to calculate.")
            return

        # 2. Validate Raster Alignment
        is_valid, error_message = self.validate_raster_inputs(inputs_to_validate)
        if not is_valid:
            QMessageBox.critical(self, "Input Layer Error", error_message)
            return

        try:
            os.makedirs(output_dir, exist_ok=True)
            
            # 3. Load Raster Data
            with rasterio.open(required_paths["Blue"]) as blue_src:
                blue = blue_src.read(1, masked=True).astype(np.float32)
                profile = blue_src.profile.copy()
            
            green = rasterio.open(required_paths["Green"]).read(1, masked=True).astype(np.float32)
            red = rasterio.open(required_paths["Red"]).read(1, masked=True).astype(np.float32)
            nir = rasterio.open(required_paths["NIR"]).read(1, masked=True).astype(np.float32)
            water_mask = rasterio.open(required_paths["WaterMask"]).read(1) # Mask is required
            
            rededge = None
            if rededge_path:
                rededge = rasterio.open(rededge_path).read(1, masked=True).astype(np.float32)

            profile.update(driver='GTiff', dtype=rasterio.float32, count=1, nodata=np.nan)

            # 4. Calculate All Possible Indices
            with np.errstate(divide='ignore', invalid='ignore'):
                all_possible_indices = {
                    "NDCI": (nir - (rededge if rededge is not None else red)) / (nir + (rededge if rededge is not None else red)),
                    "TWOBDA": nir / red, "THREEBDA": (blue - red) / green if rededge is None else (1/red - 1/rededge) * nir,
                    "BlueGreenRatio": blue / green, "RedGreenRatio": red / green, "SABI": (nir - red) / (blue + green),
                    "RedBlueRatio": red / blue, "NDTI": (red - green) / (red + green),
                    "GreenBlueRatio": green / blue, "RedPlusNIR": red + nir, "RMinusG": red - green,
                    "GreenRedRatio": green / red
                }

            # (Color map info is unchanged)
            color_map_info = {
                "NDCI": ("blue", "red", -1, 1), "TWOBDA": ("green", "purple", 0, 10), "THREEBDA": ("yellow", "brown", -5, 5),
                "BlueGreenRatio": ("cyan", "magenta", 0, 5), "RedGreenRatio": ("orange", "blue", 0, 5), "SABI": ("red", "green", -1, 1),
                "RedBlueRatio": ("red", "blue", 0, 5), "NDTI": ("green", "brown", -1, 1), "GreenBlueRatio": ("green", "red", 0, 2),
                "RedPlusNIR": ("yellow", "orange", 0, 1), "RMinusG": ("blue", "yellow", -0.5, 0.5), "GreenRedRatio": ("orange", "purple", 0, 5)
            }

            # 5. Save Only the Selected Indices
            for name in selected_indices:
                if name in all_possible_indices:
                    data = all_possible_indices[name]
                    
                    # --- MODIFIED: Mask is now always applied ---
                    masked_data = np.ma.where(water_mask == 1, data, np.ma.masked)
                    
                    start_color, end_color, min_v, max_v = color_map_info[name]
                    index_path = os.path.join(output_dir, f"{name}.tif")
                    self.save_raster(masked_data.filled(fill_value=np.nan), profile, index_path)
                    self.add_layer_with_color_ramp(index_path, name, start_color, end_color, min_v, max_v)

            QMessageBox.information(self, "Success", "Selected water quality indices generated successfully.")
            self.accept()
        
        except Exception as e:
            import traceback
            QMessageBox.critical(self, "Error", f"Processing failed: {str(e)}\n{traceback.format_exc()}")


# --- POINT RASTER SAMPLER DIALOG CLASS ---
class PointRasterSamplerDialog(QDialog):
    def __init__(self, parent=None):
        """Constructor."""
        super(PointRasterSamplerDialog, self).__init__(parent)
        self.setWindowTitle("Matching Pixel Values with Water Sample Points")
        self.setMinimumSize(600, 500)

        # --- Create Layout and Widgets Programmatically ---
        main_layout = QGridLayout(self)

        # Point Layer Input
        point_label = QLabel("Input Point Layer:")
        self.point_layer_combo = QgsMapLayerComboBox()
        self.point_layer_combo.setFilters(QgsMapLayerProxyModel.PointLayer)
        main_layout.addWidget(point_label, 0, 0, 1, 2)
        main_layout.addWidget(self.point_layer_combo, 1, 0, 1, 2)

        # --- ADDED: Water Mask Layer Input ---
        mask_label = QLabel("Optional Water Mask Layer (1=water):")
        self.mask_layer_combo = QgsMapLayerComboBox()
        self.mask_layer_combo.setFilters(QgsMapLayerProxyModel.RasterLayer)
        self.mask_layer_combo.setAllowEmptyLayer(True) # Allow user to not select a mask
        main_layout.addWidget(mask_label, 2, 0, 1, 2)
        main_layout.addWidget(self.mask_layer_combo, 3, 0, 1, 2)

        # Raster Layer Selection
        raster_label = QLabel("Select Raster Layers to Sample:")
        self.raster_list = QtWidgets.QListWidget()
        self.raster_list.setSelectionMode(QtWidgets.QAbstractItemView.ExtendedSelection)
        main_layout.addWidget(raster_label, 4, 0)
        main_layout.addWidget(self.raster_list, 5, 0)

        # Field Name Table
        fields_label = QLabel("Specify Output Field Names:")
        self.fields_table = QtWidgets.QTableWidget()
        self.fields_table.setColumnCount(2)
        self.fields_table.setHorizontalHeaderLabels(["Raster Layer", "New Field Name"])
        self.fields_table.horizontalHeader().setSectionResizeMode(0, QtWidgets.QHeaderView.Stretch)
        self.fields_table.horizontalHeader().setSectionResizeMode(1, QtWidgets.QHeaderView.Stretch)
        main_layout.addWidget(fields_label, 4, 1)
        main_layout.addWidget(self.fields_table, 5, 1)
        
        # Sampling options
        options_layout = QHBoxLayout()
        method_label = QLabel("Sampling Method:")
        self.method_combo = QComboBox()
        self.method_combo.addItems(["Single Pixel (Nearest Neighbor)", "3x3 Window", "5x5 Window"])
        options_layout.addWidget(method_label)
        options_layout.addWidget(self.method_combo)
        
        stat_label = QLabel("Statistic:")
        self.stat_combo = QComboBox()
        self.stat_combo.addItems(["Mean", "Median"])
        options_layout.addWidget(stat_label)
        options_layout.addWidget(self.stat_combo)
        main_layout.addLayout(options_layout, 6, 0, 1, 2)
        
        self.method_combo.currentIndexChanged.connect(self.update_statistic_widget_state)
        self.update_statistic_widget_state()
        
        # Progress Bar and Buttons...
        self.progress_bar = QtWidgets.QProgressBar()
        main_layout.addWidget(self.progress_bar, 7, 0, 1, 2)
        self.button_box = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel)
        main_layout.addWidget(self.button_box, 8, 0, 1, 2)
        
        # Connect signals...
        self.button_box.accepted.connect(self.process)
        self.button_box.rejected.connect(self.reject)
        self.raster_list.itemSelectionChanged.connect(self.update_fields_table)
        QgsProject.instance().layersAdded.connect(self.populate_raster_list)
        QgsProject.instance().layersRemoved.connect(self.populate_raster_list)

        self.populate_raster_list()

    def update_statistic_widget_state(self):
        """Enables the statistic dropdown only when a window method is selected."""
        is_window_method = "Window" in self.method_combo.currentText()
        self.stat_combo.setEnabled(is_window_method)


    def process(self):
        """Core processing function with on-the-fly CRS transformation."""
        point_layer = self.point_layer_combo.currentLayer()
        water_mask_layer = self.mask_layer_combo.currentLayer()
        sampling_method = self.method_combo.currentText()
        statistic = self.stat_combo.currentText()

        # Input validation
        if not point_layer:
            QMessageBox.critical(self, "Error", "Please select a point vector layer.")
            return
        if self.fields_table.rowCount() == 0:
            QMessageBox.critical(self, "Error", "Please select at least one raster layer.")
            return

        field_names, raster_layers = [], []
        for row in range(self.fields_table.rowCount()):
            field_item = self.fields_table.item(row, 1)
            raster_item = self.fields_table.item(row, 0)
            if not field_item or not field_item.text():
                QMessageBox.critical(self, "Error", f"Provide field name for raster '{raster_item.text()}'.")
                return
            field_names.append(field_item.text())
            layer = QgsProject.instance().mapLayer(raster_item.data(Qt.UserRole))
            if not layer:
                QMessageBox.critical(self, "Error", f"Could not find raster layer '{raster_item.text()}'.")
                return
            raster_layers.append(layer)

        # Prepare layer for editing
        provider = point_layer.dataProvider()
        existing_fields = [field.name() for field in provider.fields()]
        fields_to_add = [QgsField(name, 6, "double", 20, 10) for name in field_names if name not in existing_fields]
        point_layer.startEditing()
        if fields_to_add:
            provider.addAttributes(fields_to_add)
            point_layer.updateFields()
        field_indices = [point_layer.fields().indexOf(name) for name in field_names]

        self.progress_bar.setMaximum(point_layer.featureCount())
        
        # --- ADDED: Cache for coordinate transforms to improve performance ---
        transform_cache = {}

        try:
            with rasterio.open(water_mask_layer.source()) if water_mask_layer else open(os.devnull) as mask_src:
                for i, point_feature in enumerate(point_layer.getFeatures()):
                    geom = point_feature.geometry()
                    if not geom: continue
                    point = geom.asPoint()
                    
                    is_water = True
                    if water_mask_layer:
                        try:
                            # Transform point to the mask's CRS
                            if point_layer.crs() != water_mask_layer.crs():
                                transform_key = (point_layer.crs().authid(), water_mask_layer.crs().authid())
                                if transform_key not in transform_cache:
                                    transform_cache[transform_key] = QgsCoordinateTransform(point_layer.crs(), water_mask_layer.crs(), QgsProject.instance())
                                point_for_mask = transform_cache[transform_key].transform(point)
                            else:
                                point_for_mask = point
                            
                            # Sample the mask value
                            mask_val = next(mask_src.sample([(point_for_mask.x(), point_for_mask.y())]), [0])[0]
                            if mask_val != 1:
                                is_water = False
                        except IndexError:
                             # Point is outside the mask's extent
                            is_water = False
                    
                    for j, raster_layer in enumerate(raster_layers):
                        result_val = None
                        
                        if is_water:
                            # --- ADDED: On-the-fly CRS transformation ---
                            source_crs = point_layer.crs()
                            dest_crs = raster_layer.crs()
                            
                            if source_crs != dest_crs:
                                transform_key = (source_crs.authid(), dest_crs.authid())
                                if transform_key not in transform_cache:
                                    transform_cache[transform_key] = QgsCoordinateTransform(source_crs, dest_crs, QgsProject.instance())
                                xform = transform_cache[transform_key]
                                point_transformed = xform.transform(point)
                            else:
                                point_transformed = point
                            # --- END of new block ---

                            if "Single Pixel" in sampling_method:
                                val, res = raster_layer.dataProvider().sample(point_transformed, 1)
                                if res: result_val = float(val)
                            else: # Windowed sampling logic
                                window_dim = 3 if "3x3" in sampling_method else 5
                                offset = window_dim // 2
                                try:
                                    with rasterio.open(raster_layer.source()) as src:
                                        row, col = src.index(point_transformed.x(), point_transformed.y())
                                        window = rasterio.windows.Window(col - offset, row - offset, window_dim, window_dim)
                                        window_data = src.read(1, window=window, boundless=True, fill_value=src.nodata)
                                        nodata_val = src.nodata
                                        if nodata_val is not None:
                                            valid_pixels = window_data[window_data != float(nodata_val)]
                                        else:
                                            valid_pixels = window_data.flatten()
                                        if valid_pixels.size > 0:
                                            if statistic == "Mean": result_val = float(np.mean(valid_pixels))
                                            elif statistic == "Median": result_val = float(np.median(valid_pixels))
                                except (IndexError, rasterio.errors.RasterioIOError):
                                    # Point is outside raster extent, result_val remains None
                                    pass
                        
                        point_layer.changeAttributeValue(point_feature.id(), field_indices[j], result_val)
                    self.progress_bar.setValue(i + 1)
        except Exception as e:
             QMessageBox.critical(self, "Processing Error", f"A critical error occurred: {e}")
             point_layer.rollBack()
             return

        if not point_layer.commitChanges():
            QMessageBox.critical(self, "Error", "Could not commit changes to the layer.")
        else:
            QMessageBox.information(self, "Success", "Processing complete.")
        
        self.accept()



    # (The other methods like populate_raster_list, update_fields_table, etc., remain the same)
    def populate_raster_list(self):
        """Populates the list widget with raster layers currently in the QGIS project."""
        current_selection = {self.raster_list.item(i).data(Qt.UserRole) for i in range(self.raster_list.count()) if self.raster_list.item(i).isSelected()}
        self.raster_list.clear()
        layers = QgsProject.instance().mapLayers().values()
        for layer in layers:
            if isinstance(layer, QgsRasterLayer):
                item = QListWidgetItem(layer.name())
                item.setData(Qt.UserRole, layer.id())
                self.raster_list.addItem(item)
                if layer.id() in current_selection:
                    item.setSelected(True)

    def add_to_table(self, layer_name, layer_id):
        """Adds a new row to the field names table."""
        row_position = self.fields_table.rowCount()
        self.fields_table.insertRow(row_position)

        raster_item = QTableWidgetItem(layer_name)
        raster_item.setFlags(raster_item.flags() & ~Qt.ItemIsEditable)
        raster_item.setData(Qt.UserRole, layer_id)
        self.fields_table.setItem(row_position, 0, raster_item)

        default_field_name = ''.join(e for e in layer_name if e.isalnum())[:10]
        field_name_item = QTableWidgetItem(default_field_name)
        self.fields_table.setItem(row_position, 1, field_name_item)

    def update_fields_table(self):
        """Adds or removes rows from the fields table when a raster selection changes."""
        self.fields_table.setRowCount(0)
        selected_ids = {item.data(Qt.UserRole) for item in self.raster_list.selectedItems()}
        for i in range(self.raster_list.count()):
            item = self.raster_list.item(i)
            layer_id = item.data(Qt.UserRole)
            if layer_id in selected_ids:
                self.add_to_table(item.text(), layer_id)

    def closeEvent(self, event):
        """Disconnect signals when the dialog is closed."""
        try:
            QgsProject.instance().layersAdded.disconnect(self.populate_raster_list)
            QgsProject.instance().layersRemoved.disconnect(self.populate_raster_list)
        except TypeError:
            pass # Signal was not connected
        super(PointRasterSamplerDialog, self).closeEvent(event)


# --- ACOLITE DIALOG CLASS ---
class AcoliteDialog(QDialog):
    def __init__(self, parent=None):
        """Constructor."""
        super().__init__(parent)
        self.setWindowTitle("Atmospheric Correction with ACOLITE")
        self.setMinimumSize(600, 450) # Increased min size slightly
        self.worker = None
        self.temp_files = []
        self.setup_ui()

    def setup_ui(self):
        """Set up the dialog UI."""
        self.layout = QGridLayout(self)

        # ACOLITE launch.py path
        self.acolite_label = QLabel("ACOLITE launch.py Path:")
        self.acolite_path = QgsFileWidget()
        self.acolite_path.setFilter("Python Files (*.py)")
        self.layout.addWidget(self.acolite_label, 0, 0)
        self.layout.addWidget(self.acolite_path, 0, 1)

        # Settings file (now optional)
        self.settings_label = QLabel("Settings File (Optional):")
        self.settings_path = QgsFileWidget()
        self.settings_path.setFilter("Text Files (*.txt *.ini)")
        # --- MODIFICATION: Connect fileChanged signal ---
        self.settings_path.fileChanged.connect(self.toggle_ac_combo)
        self.layout.addWidget(self.settings_label, 1, 0)
        self.layout.addWidget(self.settings_path, 1, 1)

        # --- MODIFICATION: Add Algorithm ComboBox ---
        self.ac_label = QLabel("Default Algorithm (if no settings file):")
        self.ac_combo = QComboBox()
        self.ac_combo.addItems(["DSF", "RAdCor", "EXP"])
        self.ac_combo.setCurrentText("RAdCor") # Keep RAdCor as default
        self.layout.addWidget(self.ac_label, 2, 0)
        self.layout.addWidget(self.ac_combo, 2, 1)
        # --- END MODIFICATION ---

        # Input file or directory (row shifted)
        self.input_label = QLabel("Input Directory:")
        self.input_path = QgsFileWidget()
        self.input_path.setStorageMode(QgsFileWidget.GetDirectory)
        self.layout.addWidget(self.input_label, 3, 0)
        self.layout.addWidget(self.input_path, 3, 1)

        # Output directory (row shifted)
        self.output_label = QLabel("Output Directory:")
        self.output_path = QgsFileWidget()
        self.output_path.setStorageMode(QgsFileWidget.GetDirectory)
        self.layout.addWidget(self.output_label, 4, 0)
        self.layout.addWidget(self.output_path, 4, 1)

        # Log window (row shifted)
        self.log_label = QLabel("ACOLITE Output:")
        self.log_window = QPlainTextEdit()
        self.log_window.setReadOnly(True)
        self.layout.addWidget(self.log_label, 5, 0)
        self.layout.addWidget(self.log_window, 6, 0, 1, 3)

        # Buttons (row shifted)
        self.button_box = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel)
        self.button_box.accepted.connect(self.run_acolite_processing)
        self.button_box.rejected.connect(self.reject)
        self.layout.addWidget(self.button_box, 7, 0, 1, 3)

        self.setLayout(self.layout)
        
        # --- MODIFICATION: Set initial state of the combo box ---
        self.toggle_ac_combo(self.settings_path.filePath())

    # --- MODIFICATION: Add new method to toggle combo box ---
    def toggle_ac_combo(self, file_path_str=""):
        """Disables the algorithm dropdown if a settings file is provided."""
        is_enabled = not bool(file_path_str)
        self.ac_label.setEnabled(is_enabled)
        self.ac_combo.setEnabled(is_enabled)
    # --- END MODIFICATION ---

    def run_acolite_processing(self):
        """Run the ACOLITE processing in a separate thread."""
        acolite_script = self.acolite_path.filePath()
        settings_file = self.settings_path.filePath()
        input_dir = self.input_path.filePath()
        output_dir = self.output_path.filePath()
        
        # --- MODIFICATION: Get selected algorithm ---
        selected_algorithm = self.ac_combo.currentText().lower()
        # --- END MODIFICATION ---
        
        # Validate inputs
        if not acolite_script or not os.path.exists(acolite_script):
            QMessageBox.critical(self, "Input Error", "Please specify a valid ACOLITE launch.py file.")
            return
        if not input_dir:
            QMessageBox.critical(self, "Input Error", "Please specify a valid input directory.")
            return
        if not output_dir:
            QMessageBox.critical(self, "Input Error", "Please specify an output directory.")
            return

        # If no settings file is provided, create a temporary one with default content
        if not settings_file:
            try:
                # --- MODIFICATION: Use f-string to insert selected algorithm ---
                default_settings_content = (
                    "polygon=None\n"
                    "atmospheric_correction=True\n"
                    f"atmospheric_correction_method={selected_algorithm}\n"
                    "tact_run=False\n"
                    "l2w_parameters=rhos_*\n"
                    "l2r_export_geotiff=True\n"
                    "l1r_export_geotiff=False\n"
                    "l2w_export_geotiff=False\n"
                    "rgb_rhos=True\n"
                    "map_l2w=False\n"
                )
                # --- END MODIFICATION ---
                
                # Create a temporary file
                fd, temp_settings_path = tempfile.mkstemp(suffix=".txt", prefix="acolite_settings_")
                with os.fdopen(fd, 'w') as temp_file:
                    temp_file.write(default_settings_content)
                
                settings_file = temp_settings_path
                self.temp_files.append(settings_file) # Add to list for cleanup
                self.log_window.appendPlainText(f"No settings file provided. Using temporary settings with method: {selected_algorithm}")

            except Exception as e:
                QMessageBox.critical(self, "Error", f"Could not create temporary settings file: {e}")
                return

        self.log_window.clear()
        self.log_window.appendPlainText("Starting ACOLITE processing...")
        self.button_box.button(QDialogButtonBox.Ok).setEnabled(False)
        self.button_box.button(QDialogButtonBox.Cancel).setText("Cancel")

        # Build command
        command = ['python', acolite_script, '--cli', '--inputfile', input_dir, '--output', output_dir, '--settings', settings_file]

        self.log_window.appendPlainText(f"Running command: {' '.join(command)}")

        # Start worker thread
        self.worker = AcoliteWorker(command, output_dir)
        self.worker.log_message.connect(self.update_log)
        self.worker.finished.connect(self.on_processing_finished)
        self.worker.error.connect(self.on_processing_error)
        self.worker.start()

    def update_log(self, message):
        """Update the log window with messages from the worker."""
        self.log_window.appendPlainText(message)
        self.log_window.ensureCursorVisible()

    def on_processing_finished(self, return_code, output_dir):
        """Handle completion of ACOLITE processing."""
        self.button_box.button(QDialogButtonBox.Ok).setEnabled(True)
        self.button_box.button(QDialogButtonBox.Cancel).setText("Cancel")

        if return_code == 0:
            self.log_window.appendPlainText("\nACOLITE processing finished successfully.")
            # Load output rasters into QGIS
            for file in os.listdir(output_dir):
                if file.endswith(".tif"):
                    output_path = os.path.join(output_dir, file)
                    layer = QgsRasterLayer(output_path, os.path.basename(output_path))
                    if layer.isValid():
                        QgsProject.instance().addMapLayer(layer)
            QMessageBox.information(self, "Success", f"ACOLITE processing completed. Output saved to:\n{output_dir}")
            self.accept()
        elif return_code == -2:  # User cancellation
            self.log_window.appendPlainText("\nACOLITE processing was cancelled by the user.")
        else:
            self.log_window.appendPlainText(f"\nACOLITE processing failed with exit code {return_code}.")
            QMessageBox.critical(self, "Error", f"ACOLITE processing failed. Check the log for details.")

        self.cleanup()

    def on_processing_error(self, error_message):
        """Handle errors from the worker thread."""
        self.log_window.appendPlainText(f"\nError: {error_message}")
        QMessageBox.critical(self, "Error", f"ACOLITE processing failed: {error_message}")
        self.button_box.button(QDialogButtonBox.Ok).setEnabled(True)
        self.button_box.button(QDialogButtonBox.Cancel).setText("Cancel")
        self.cleanup()

    def reject(self):
        """Handle dialog cancellation."""
        if self.worker and self.worker.isRunning():
            self.worker.stop()
            self.worker.wait()
        self.cleanup()
        super().reject()

    def closeEvent(self, event):
        """Handle dialog closure."""
        if self.worker and self.worker.isRunning():
            self.worker.stop()
            self.worker.wait()
        self.cleanup()
        super().closeEvent(event)

    def cleanup(self):
        """Clean up resources."""
        if self.worker:
            if self.worker.isRunning():
                self.worker.stop()
                self.worker.wait()
            self.worker = None
        # Clean up temporary files
        for temp_file in self.temp_files:
            try:
                if os.path.exists(temp_file):
                    os.remove(temp_file)
            except OSError as e:
                print(f"Error removing temp file {temp_file}: {e}")
        self.temp_files = []


class AcoliteWorker(QThread):
    """
    Worker thread for running ACOLITE processing to avoid freezing the GUI.
    """
    log_message = pyqtSignal(str)
    finished = pyqtSignal(int, str)  # Return code and output directory
    error = pyqtSignal(str)

    def __init__(self, command, output_dir):
        super().__init__()
        self.command = command
        self.output_dir = output_dir
        self.process = None
        self._is_running = True

    def run(self):
        """Runs the ACOLITE subprocess in a cleaned environment."""
        try:
            # Create a clean environment to avoid conflicts with QGIS's GDAL/PROJ variables
            env = os.environ.copy()
            for var in ['PROJ_LIB', 'GDAL_DATA', 'GDAL_DRIVER_PATH']:
                if var in env:
                    del env[var]

            self.process = subprocess.Popen(
                self.command,
                stdout=subprocess.PIPE,
                stderr=subprocess.STDOUT,
                text=True,
                env=env,
                creationflags=subprocess.CREATE_NO_WINDOW if os.name == 'nt' else 0
            )

            for line in iter(self.process.stdout.readline, ''):
                if not self._is_running:
                    break
                self.log_message.emit(line.strip())

            self.process.stdout.close()
            return_code = self.process.wait()
            if self._is_running:
                self.finished.emit(return_code, self.output_dir)
            else:
                self.finished.emit(-2, self.output_dir)  # Custom code for cancellation

        except Exception as e:
            self.error.emit(f"An error occurred in the worker thread: {str(e)}")
            self.finished.emit(-1, self.output_dir)

    def stop(self):
        """Stops the running process."""
        self._is_running = False
        if self.process and self.process.poll() is None:
            self.process.terminate()
            try:
                self.process.wait(timeout=3)  # Wait up to 3 seconds for graceful termination
            except subprocess.TimeoutExpired:
                self.process.kill()  # Force kill if it doesn't terminate
            self.log_message.emit("Processing was cancelled by the user.")


# ---TSI CALCULATOR DIALOG CLASS
class TSICalculatorDialog(QDialog):
    def __init__(self, iface, parent=None):
        """Constructor."""
        super(TSICalculatorDialog, self).__init__(parent)
        self.iface = iface
        self.setWindowTitle("Trophic State Index (TSI) Calculator")
        self.setup_ui()

    def setup_ui(self):
        """Build the user interface with modern PyQGIS widgets."""
        layout = QGridLayout(self)

        # Chlorophyll-a layer selection
        layout.addWidget(QLabel("Select Chlorophyll-a Raster Layer (µg/L):"), 0, 0)
        self.chla_combo = QgsMapLayerComboBox(self)
        self.chla_combo.setFilters(QgsMapLayerProxyModel.RasterLayer)
        layout.addWidget(self.chla_combo, 0, 1)

        # Secchi Depth layer selection
        self.use_sd_checkbox = QCheckBox("Use Secchi Depth Layer")
        self.use_sd_checkbox.stateChanged.connect(self.toggle_sd_input)
        layout.addWidget(self.use_sd_checkbox, 1, 0, 1, 2)

        self.sd_label = QLabel("Select Secchi Depth Raster Layer (m):")
        self.sd_combo = QgsMapLayerComboBox(self)
        self.sd_combo.setFilters(QgsMapLayerProxyModel.RasterLayer)
        layout.addWidget(self.sd_label, 2, 0)
        layout.addWidget(self.sd_combo, 2, 1)

        # --- MODIFIED: Water Mask is now a required input ---
        self.mask_label = QLabel("Water Mask Layer (1=water):")
        self.mask_combo = QgsMapLayerComboBox(self)
        self.mask_combo.setFilters(QgsMapLayerProxyModel.RasterLayer)
        layout.addWidget(self.mask_label, 3, 0)
        layout.addWidget(self.mask_combo, 3, 1)

        # Output file selection
        layout.addWidget(QLabel("Output TSI (Chl-a) Raster:"), 4, 0)
        self.output_chla_path = QgsFileWidget(self)
        self.output_chla_path.setStorageMode(QgsFileWidget.SaveFile)
        self.output_chla_path.setFilter("GeoTIFF (*.tif)")
        layout.addWidget(self.output_chla_path, 4, 1)
        
        self.output_sd_label = QLabel("Output TSI (Secchi Depth) Raster:")
        self.output_sd_path = QgsFileWidget(self)
        self.output_sd_path.setStorageMode(QgsFileWidget.SaveFile)
        self.output_sd_path.setFilter("GeoTIFF (*.tif)")
        layout.addWidget(self.output_sd_label, 5, 0)
        layout.addWidget(self.output_sd_path, 5, 1)

        # Run and Cancel buttons
        button_box = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel)
        button_box.accepted.connect(self.run_processing)
        button_box.rejected.connect(self.reject)
        layout.addWidget(button_box, 6, 0, 1, 2)

        # Set initial state
        self.toggle_sd_input()

    def toggle_sd_input(self):
        enabled = self.use_sd_checkbox.isChecked()
        self.sd_label.setEnabled(enabled)
        self.sd_combo.setEnabled(enabled)
        self.output_sd_label.setEnabled(enabled)
        self.output_sd_path.setEnabled(enabled)

    def run_calculation(self, input_path, output_path, expression_func, mask_path):
        """Reads rasters, performs a calculation, and saves the output."""
        with rasterio.open(input_path) as src:
            profile = src.profile
            input_data = src.read(1, masked=True).astype('float32')
            input_data.mask = input_data.mask | (input_data <= 0)

        with rasterio.open(mask_path) as mask_src:
            if mask_src.shape != input_data.shape or mask_src.transform != src.transform:
                raise ValueError("The Water Mask layer does not align with the input raster (CRS, Extent, and Dimensions must match).")
            mask_data = mask_src.read(1)

        # Perform the calculation and apply the mask
        output_data = expression_func(input_data)
        output_data[mask_data != 1] = np.ma.masked

        profile.update(dtype=rasterio.float32, count=1, nodata=-9999.0)
        
        with rasterio.open(output_path, 'w', **profile) as dst:
            dst.write(output_data.filled(profile['nodata']), 1)

    def calculate_trophic_stats(self, tsi_chla_path):
        """Loads a TSI raster and computes statistics on its pixel values."""
        try:
            with rasterio.open(tsi_chla_path) as src:
                data = src.read(1, masked=True).astype('float32')
            
            valid_data = data[~data.mask]
            if valid_data.size == 0:
                return {"oligotrophic": 0, "mesotrophic": 0, "eutrophic": 0}

            total_pixels = float(valid_data.size)
            oligotrophic = np.sum(valid_data < 40) / total_pixels * 100
            mesotrophic = np.sum((valid_data >= 40) & (valid_data < 50)) / total_pixels * 100
            eutrophic = np.sum(valid_data >= 50) / total_pixels * 100

            return {"oligotrophic": oligotrophic, "mesotrophic": mesotrophic, "eutrophic": eutrophic}
        except Exception as e:
            print(f"Could not calculate stats: {e}")
            return {"oligotrophic": 0, "mesotrophic": 0, "eutrophic": 0}

    def run_processing(self):
        # 1. Gather inputs
        chla_layer = self.chla_combo.currentLayer()
        mask_layer = self.mask_combo.currentLayer()
        use_sd = self.use_sd_checkbox.isChecked()
        sd_layer = self.sd_combo.currentLayer() if use_sd else None
        output_chla = self.output_chla_path.filePath()
        output_sd = self.output_sd_path.filePath() if use_sd else None

        # 2. Validate inputs
        if not all([chla_layer, mask_layer, output_chla]):
            QMessageBox.critical(self, "Input Error", "Please select a valid Chl-a layer, Water Mask, and output path.")
            return
        if use_sd and (not sd_layer or not output_sd):
            QMessageBox.critical(self, "Input Error", "Please select a valid Secchi Depth layer and output path.")
            return

        try:
            mask_path = mask_layer.source()

            # Process TSI for Chlorophyll-a
            chla_func = lambda data: 9.81 * np.ma.log(data) + 30.6
            self.run_calculation(chla_layer.source(), output_chla, chla_func, mask_path)
            QgsProject.instance().addMapLayer(QgsRasterLayer(output_chla, "TSI_Chlorophyll"))

            # Create and load classified Chl-a layer
            classified_chla_path = os.path.splitext(output_chla)[0] + '_classified.tif'
            self.create_classified_tsi_raster(output_chla, classified_chla_path)
            classified_chla_layer = QgsRasterLayer(classified_chla_path, "TSI_Chlorophyll_Classified")
            if classified_chla_layer.isValid():
                self.style_classified_layer(classified_chla_layer)
                QgsProject.instance().addMapLayer(classified_chla_layer)

            # Process TSI for Secchi Depth (if applicable)
            if use_sd:
                sd_func = lambda data: 60 - 14.41 * np.ma.log(data)
                self.run_calculation(sd_layer.source(), output_sd, sd_func, mask_path)
                QgsProject.instance().addMapLayer(QgsRasterLayer(output_sd, "TSI_Secchi_Depth"))
                
                classified_sd_path = os.path.splitext(output_sd)[0] + '_classified.tif'
                self.create_classified_tsi_raster(output_sd, classified_sd_path)
                classified_sd_layer = QgsRasterLayer(classified_sd_path, "TSI_Secchi_Depth_Classified")
                if classified_sd_layer.isValid():
                    self.style_classified_layer(classified_sd_layer)
                    QgsProject.instance().addMapLayer(classified_sd_layer)

            QMessageBox.information(self, "Success", "TSI calculation and classification completed successfully.")
            self.accept()
            
        except Exception as e:
            import traceback
            QMessageBox.critical(self, "Processing Error", f"An error occurred:\n{e}\n\n{traceback.format_exc()}")
    
    # --- ADDED: Moved styling to a reusable method ---
    def style_classified_layer(self, layer):
        """Applies a unique value renderer to a classified TSI raster."""
        color_map = []
        color_map.append(QgsPalettedRasterRenderer.Class(1, QColor("#3b528b"), "Oligotrophic"))
        color_map.append(QgsPalettedRasterRenderer.Class(2, QColor("#21918c"), "Mesotrophic"))
        color_map.append(QgsPalettedRasterRenderer.Class(3, QColor("#fde725"), "Eutrophic"))
        
        renderer = QgsPalettedRasterRenderer(layer.dataProvider(), 1, color_map)
        layer.setRenderer(renderer)
        layer.triggerRepaint()

    def create_classified_tsi_raster(self, input_tsi_path, output_classified_path):
        """Creates a classified raster from a continuous TSI raster."""
        # (This method is unchanged from the previous version)
        with rasterio.open(input_tsi_path) as src:
            profile = src.profile
            tsi_data = src.read(1, masked=True)

        NODATA_VAL, OLIGO_VAL, MESO_VAL, EUTRO_VAL = 255, 1, 2, 3
        classified_array = np.full(tsi_data.shape, NODATA_VAL, dtype=np.uint8)
        
        valid_mask = ~tsi_data.mask
        classified_array[valid_mask & (tsi_data < 40)] = OLIGO_VAL
        classified_array[valid_mask & (tsi_data >= 40) & (tsi_data < 50)] = MESO_VAL
        classified_array[valid_mask & (tsi_data >= 50)] = EUTRO_VAL

        profile.update(dtype=rasterio.uint8, count=1, nodata=NODATA_VAL)
        with rasterio.open(output_classified_path, 'w', **profile) as dst:
            dst.write(classified_array, 1)


# --- WATER MASK DIALOG CLASS ---
class WaterMaskDialog(QDialog):
    def __init__(self, parent=None):
        """Constructor."""
        super(WaterMaskDialog, self).__init__(parent)
        self.setWindowTitle("Water Mask Creator")
        self.setMinimumSize(700, 800)

        # Initialize all instance attributes at the beginning
        self.raster_layer = None
        self.map_tool = None
        self.rubber_band = None
        self.raster_path = None
        self.threshold = None
        self.temp_ndwi_file = None
        self.temp_mask_file = None
        self.classification_mode = None # To track if using NIR or NDWI logic

        # Main Layout and UI Setup
        main_layout = QHBoxLayout(self)
        controls_layout = QVBoxLayout()

        # Step 1: Input Source Selection
        input_groupbox = QtWidgets.QGroupBox("Step 1: Choose Input Source")
        input_groupbox_layout = QVBoxLayout(input_groupbox)

        # Radio buttons with descriptive text
        self.use_existing_radio = QRadioButton("Load NIR Band for Threshold")
        self.calc_ndwi_radio = QRadioButton("Load Green & NIR to Calculate NDWI for Threshold")
        self.create_mask_radio = QRadioButton("Create Mask from Polygon Shapefile (saves directly)")
        self.use_existing_radio.setChecked(True)
        
        radio_layout = QHBoxLayout()
        radio_layout.addWidget(self.use_existing_radio)
        radio_layout.addWidget(self.calc_ndwi_radio)
        radio_layout.addWidget(self.create_mask_radio)
        input_groupbox_layout.addLayout(radio_layout)
        
        # Stacked widget to hold the different input panels
        self.input_stack = QtWidgets.QStackedWidget()
        input_groupbox_layout.addWidget(self.input_stack)

        # Panel 1: Existing Raster Input (NIR)
        existing_raster_widget = QtWidgets.QWidget()
        existing_raster_layout = QVBoxLayout(existing_raster_widget)
        self.existing_file_widget = QgsFileWidget()
        self.existing_file_widget.setFilter("*.tif *.tiff *.jp2")
        self.load_existing_button = QPushButton("Load Raster for Analysis")
        existing_raster_layout.addWidget(self.existing_file_widget)
        existing_raster_layout.addWidget(self.load_existing_button)
        self.input_stack.addWidget(existing_raster_widget)

        # Panel 2: NDWI Calculation Input
        ndwi_inputs_widget = QtWidgets.QWidget()
        ndwi_layout = QGridLayout(ndwi_inputs_widget)
        ndwi_layout.addWidget(QLabel("Green Band:"), 0, 0); self.green_band_widget = QgsFileWidget(); self.green_band_widget.setFilter("*.tif *.tiff *.jp2"); ndwi_layout.addWidget(self.green_band_widget, 0, 1)
        ndwi_layout.addWidget(QLabel("NIR Band:"), 1, 0); self.nir_band_widget = QgsFileWidget(); self.nir_band_widget.setFilter("*.tif *.tiff *.jp2"); ndwi_layout.addWidget(self.nir_band_widget, 1, 1)
        self.calculate_ndwi_button = QPushButton("Calculate and Load NDWI for Analysis")
        ndwi_layout.addWidget(self.calculate_ndwi_button, 2, 0, 1, 2)
        self.input_stack.addWidget(ndwi_inputs_widget)
        
        # Panel 3: Create Mask from Polygon
        create_mask_widget = QtWidgets.QWidget()
        create_mask_layout = QGridLayout(create_mask_widget)
        create_mask_layout.addWidget(QLabel("Polygon Layer:"), 0, 0)
        self.polygon_layer_combo = QgsMapLayerComboBox(); self.polygon_layer_combo.setFilters(QgsMapLayerProxyModel.PolygonLayer); create_mask_layout.addWidget(self.polygon_layer_combo, 0, 1)
        create_mask_layout.addWidget(QLabel("Reference Raster Layer:"), 1, 0)
        self.reference_raster_combo = QgsMapLayerComboBox(); self.reference_raster_combo.setFilters(QgsMapLayerProxyModel.RasterLayer); create_mask_layout.addWidget(self.reference_raster_combo, 1, 1)

        self.use_buffer_checkbox = QCheckBox("Apply Negative Buffer to Shrink Mask")
        self.use_buffer_checkbox.stateChanged.connect(self.toggle_buffer_widgets)
        create_mask_layout.addWidget(self.use_buffer_checkbox, 2, 0, 1, 2)

        self.buffer_distance_label = QLabel("Shrink Distance (in Polygon Layer's Unit):")
        self.buffer_distance_spinbox = QDoubleSpinBox()
        self.buffer_distance_spinbox.setRange(0.0, 10000.0) # Set a practical range
        self.buffer_distance_spinbox.setValue(10.0) # A sensible default
        self.buffer_distance_spinbox.setDecimals(2)
        create_mask_layout.addWidget(self.buffer_distance_label, 3, 0)
        create_mask_layout.addWidget(self.buffer_distance_spinbox, 3, 1)

        create_mask_layout.addWidget(QLabel("Output Mask Path (.tif):"), 4, 0)
        self.mask_output_path_widget = QgsFileWidget(storageMode=QgsFileWidget.SaveFile)
        self.mask_output_path_widget.setFilter("GeoTIFF (*.tif)")
        create_mask_layout.addWidget(self.mask_output_path_widget, 4, 1)
        
        self.generate_from_poly_button = QPushButton("Create and Save Mask")
        create_mask_layout.addWidget(self.generate_from_poly_button, 5, 0, 1, 2)
        self.input_stack.addWidget(create_mask_widget)

        controls_layout.addWidget(input_groupbox)
        
        # Connect signals
        self.use_existing_radio.toggled.connect(self.update_input_mode)
        self.calc_ndwi_radio.toggled.connect(self.update_input_mode)
        self.create_mask_radio.toggled.connect(self.update_input_mode)
        self.load_existing_button.clicked.connect(self.load_existing_raster)
        self.calculate_ndwi_button.clicked.connect(self.calculate_and_load_ndwi)
        self.generate_from_poly_button.clicked.connect(self.generate_mask_from_polygon)

        # Step 2: Analysis Section
        self.analysis_groupbox = QtWidgets.QGroupBox("Step 2: Calculate Threshold")
        analysis_layout = QVBoxLayout(self.analysis_groupbox)

        self.shape_label = QLabel("Drawing Shape Type:")
        self.shape_combo = QComboBox()
        self.shape_combo.addItems(["Rectangle", "Polygon"])
        analysis_layout.addWidget(self.shape_label)
        analysis_layout.addWidget(self.shape_combo)
        
        self.draw_button = QPushButton("Draw Analysis Area on Map")
        self.draw_button.clicked.connect(self.draw_shape)
        analysis_layout.addWidget(self.draw_button)
        
        self.calculate_button = QPushButton("Calculate Threshold from Area")
        self.calculate_button.clicked.connect(self.calculate_distribution)
        analysis_layout.addWidget(self.calculate_button)
        
        calculated_layout = QHBoxLayout()
        self.calculated_radio = QRadioButton("Use Calculated Threshold:")
        self.calculated_radio.setChecked(True)
        self.result_label = QLabel("Not calculated")
        calculated_layout.addWidget(self.calculated_radio)
        calculated_layout.addWidget(self.result_label)
        calculated_layout.addStretch()
        analysis_layout.addLayout(calculated_layout)

        custom_layout = QHBoxLayout()
        self.custom_radio = QRadioButton("Use Custom Threshold:")
        self.custom_threshold_input = QDoubleSpinBox()
        self.custom_threshold_input.setRange(-9999.0, 9999.0)
        self.custom_threshold_input.setDecimals(4)
        self.custom_threshold_input.setEnabled(False)
        custom_layout.addWidget(self.custom_radio)
        custom_layout.addWidget(self.custom_threshold_input)
        analysis_layout.addLayout(custom_layout)
        self.custom_radio.toggled.connect(self.custom_threshold_input.setEnabled)
        controls_layout.addWidget(self.analysis_groupbox)

        # Step 3: Output Section
        self.output_groupbox = QtWidgets.QGroupBox("Step 3: Generate Mask")
        output_layout = QVBoxLayout(self.output_groupbox)
        self.refine_checkbox = QCheckBox("Refine Mask with QA/SCL Band")
        self.refine_checkbox.setChecked(False)
        output_layout.addWidget(self.refine_checkbox)
        self.qa_scl_label = QLabel("QA (Landsat) or SCL (Sentinel-2) Band:")
        self.qa_scl_path_widget = QgsFileWidget()
        self.qa_scl_path_widget.setFilter("*.tif *.tiff *.jp2")
        output_layout.addWidget(self.qa_scl_label)
        output_layout.addWidget(self.qa_scl_path_widget)
        self.output_label = QLabel("Final Water Mask Output Path (.tif):")
        self.output_path_widget = QgsFileWidget(storageMode=QgsFileWidget.SaveFile)
        self.output_path_widget.setFilter("GeoTIFF (*.tif)")
        output_layout.addWidget(self.output_label)
        output_layout.addWidget(self.output_path_widget)
        self.generate_button = QPushButton("Generate Water Mask")
        self.generate_button.clicked.connect(self.generate_water_mask)
        output_layout.addWidget(self.generate_button)
        controls_layout.addWidget(self.output_groupbox)

        controls_layout.addStretch()

        # Right Side: Map and Plot Canvases
        right_layout = QVBoxLayout()
        self.canvas = QgsMapCanvas()
        self.canvas.setCanvasColor(Qt.white)
        right_layout.addWidget(self.canvas)
        self.figure = plt.figure()
        self.plot_canvas = FigureCanvas(self.figure)
        right_layout.addWidget(self.plot_canvas)
        
        main_layout.addLayout(controls_layout, 1)
        main_layout.addLayout(right_layout, 3)
        
        # Final State Initialization
        self.toggle_refinement_widgets(False)
        self.refine_checkbox.toggled.connect(self.toggle_refinement_widgets)
        
        self.update_input_mode()
        self.toggle_buffer_widgets()


    def update_input_mode(self):
        """Switches the visible input panel and enables/disables analysis steps."""
        is_raster_loaded = self.raster_layer is not None
        
        if self.create_mask_radio.isChecked():
            self.input_stack.setCurrentIndex(2)
            self.analysis_groupbox.setEnabled(False)
            self.output_groupbox.setEnabled(False)
        elif self.use_existing_radio.isChecked():
            self.input_stack.setCurrentIndex(0)
            self.analysis_groupbox.setEnabled(is_raster_loaded)
            self.output_groupbox.setEnabled(is_raster_loaded)
        elif self.calc_ndwi_radio.isChecked():
            self.input_stack.setCurrentIndex(1)
            self.analysis_groupbox.setEnabled(is_raster_loaded)
            self.output_groupbox.setEnabled(is_raster_loaded)

    def toggle_buffer_widgets(self):
        """Enables or disables the buffer input widgets."""
        is_enabled = self.use_buffer_checkbox.isChecked()
        self.buffer_distance_label.setEnabled(is_enabled)
        self.buffer_distance_spinbox.setEnabled(is_enabled)

    def generate_mask_from_polygon(self):
            """
            Rasterizes a polygon layer to create a binary water mask.

            Pixels inside the polygon(s) are assigned a value of 1. All other pixels
            are assigned a value of 0. The reference raster is used only to
            define the output grid (dimensions, transform, CRS).
            """
            polygon_layer = self.polygon_layer_combo.currentLayer()
            reference_raster = self.reference_raster_combo.currentLayer()
            output_path = self.mask_output_path_widget.filePath()
            
            use_buffer = self.use_buffer_checkbox.isChecked()
            buffer_distance = self.buffer_distance_spinbox.value()

            if not polygon_layer or not reference_raster or not output_path:
                QMessageBox.warning(self, "Input Error", "Please specify a polygon layer, a reference raster, and an output path.")
                return
            
            try:
                # Open the reference raster only to get its profile (grid definition)
                with rasterio.open(reference_raster.source()) as src:
                    ref_profile = src.profile

                geometries = []
                source_crs = polygon_layer.crs()
                dest_crs = reference_raster.crs()
                transform = None
                if source_crs != dest_crs:
                    transform = QgsCoordinateTransform(source_crs, dest_crs, QgsProject.instance())
                
                if polygon_layer.featureCount() == 0:
                    QMessageBox.warning(self, "Input Error", "The selected polygon layer has no features.")
                    return

                # Prepare geometries, applying buffer if requested
                for feature in polygon_layer.getFeatures():
                    geom = feature.geometry()
                    if transform:
                        geom.transform(transform)

                    if use_buffer and buffer_distance > 0:
                        buffered_geom = geom.buffer(-buffer_distance, 5) 
                        if buffered_geom and not buffered_geom.isEmpty():
                            geometries.append(json.loads(buffered_geom.asJson()))
                    else:
                        geometries.append(json.loads(geom.asJson()))

                if not geometries:
                    QMessageBox.warning(self, "Processing Warning", "No valid geometries found to rasterize. This can occur if the negative buffer is too large.")
                    return
                
                # Rasterize the polygons. Pixels inside are 1, all others are 0.
                rasterized_array = rasterio.features.rasterize(
                    shapes=geometries,
                    out_shape=(ref_profile['height'], ref_profile['width']),
                    transform=ref_profile['transform'],
                    fill=0,              # Pixels outside the polygon get 0
                    default_value=1,     # Pixels inside the polygon get 1
                    all_touched=True,
                    dtype=np.uint8
                )
                
                # Update the output profile for an 8-bit mask with no nodata value
                out_profile = ref_profile.copy()
                out_profile.update(dtype=rasterio.uint8, count=1, nodata=None)

                with rasterio.open(output_path, 'w', **out_profile) as dst:
                    dst.write(rasterized_array, 1)
                
                # Add the newly created mask to the QGIS project with styling for both classes
                layer_name = os.path.basename(output_path)
                new_layer = QgsRasterLayer(output_path, layer_name)
                if new_layer.isValid():
                    color_map = [
                        QgsPalettedRasterRenderer.Class(0, QColor("black"), "0 (Not Water)"),
                        QgsPalettedRasterRenderer.Class(1, QColor("white"), "1 (Water)")
                    ]
                    renderer = QgsPalettedRasterRenderer(new_layer.dataProvider(), 1, color_map)
                    new_layer.setRenderer(renderer)
                    QgsProject.instance().addMapLayer(new_layer)
                    new_layer.triggerRepaint()
                    QMessageBox.information(self, "Success", f"Mask created and saved to:\n{output_path}\nand added to the project.")
                else:
                    QMessageBox.critical(self, "Error", "Failed to save or load the created mask.")

            except Exception as e:
                QMessageBox.critical(self, "Mask Generation Error", f"An error occurred: {e}")

    def calculate_and_load_ndwi(self):
        """Calculates NDWI and sets the classification mode."""
        self.classification_mode = 'ndwi'
        
        green_path = self.green_band_widget.filePath()
        nir_path = self.nir_band_widget.filePath()

        if not green_path or not nir_path:
            QMessageBox.warning(self, "Input Error", "Please specify both Green and NIR band rasters.")
            return

        try:
            with rasterio.open(green_path) as green_src, rasterio.open(nir_path) as nir_src:
                if green_src.shape != nir_src.shape:
                    QMessageBox.critical(self, "Error", "The dimensions of the Green and NIR bands do not match.")
                    return
                
                profile = green_src.profile
                green = green_src.read(1, masked=True).astype(np.float32)
                nir = nir_src.read(1, masked=True).astype(np.float32)

            with np.errstate(divide='ignore', invalid='ignore'):
                denominator = green + nir
                ndwi = (green - nir) / denominator
            
            if self.temp_ndwi_file:
                try: os.remove(self.temp_ndwi_file)
                except OSError: pass
            
            fd, self.temp_ndwi_file = tempfile.mkstemp(suffix=".tif", prefix="ndwi_")
            os.close(fd)

            profile.update(driver='GTiff', dtype=rasterio.float32, count=1, nodata=-9999.0)

            with rasterio.open(self.temp_ndwi_file, 'w', **profile) as dst:
                dst.write(ndwi.filled(-9999.0).astype(rasterio.float32), 1)

            self.load_raster(self.temp_ndwi_file)

        except Exception as e:
            QMessageBox.critical(self, "NDWI Calculation Error", f"An error occurred: {e}")

    def load_existing_raster(self):
        """Loads a user-selected raster and sets the classification mode."""
        self.classification_mode = 'nir'
        path = self.existing_file_widget.filePath()
        if not path:
            QMessageBox.warning(self, "Input Error", "Please select a file to load.")
            return
        self.load_raster(path)
    
    def load_raster(self, file_path):
        """Generic method to load a raster, display it, and enable the workflow."""
        if file_path:
            self.raster_path = file_path
            self.raster_layer = QgsRasterLayer(file_path, "Loaded Input Raster")
            if not self.raster_layer.isValid():
                QMessageBox.critical(self, "Error", f"Failed to load the raster layer:\n{file_path}")
                self.raster_path = None
                return
            
            QgsProject.instance().addMapLayer(self.raster_layer, False)
            self.canvas.setLayers([self.raster_layer])
            self.canvas.setExtent(self.raster_layer.extent())
            self.canvas.refresh()
            self.threshold = None
            self.result_label.setText("Not calculated")
            
            self.analysis_groupbox.setEnabled(True)
            self.output_groupbox.setEnabled(True)
            self.generate_button.setEnabled(False)

    def calculate_distribution(self):
        """Calculates a threshold by fitting a 2-component Gaussian Mixture Model."""
        if not self.raster_layer or not self.rubber_band or self.rubber_band.asGeometry().isEmpty():
            QMessageBox.warning(self, "Warning", "Please load a raster and draw a shape first.")
            return

        polygon = self.rubber_band.asGeometry()
        
        try:
            with rasterio.open(self.raster_path) as src:
                source_nodata_value = src.nodata
                polygon_geojson = json.loads(polygon.asJson())
                out_image, out_transform = rasterio.mask.mask(src, [polygon_geojson], crop=True, all_touched=True)
                
                if source_nodata_value is not None:
                    pixel_values = out_image[out_image != source_nodata_value]
                else:
                    pixel_values = out_image.flatten()
                
                if pixel_values.size < 20:
                    QMessageBox.warning(self, "Warning", "Not enough valid pixels found for GMM.")
                    return

        except Exception as e:
            QMessageBox.critical(self, "Error", f"Failed to extract pixel values: {e}")
            return
        
        try:
            data_to_fit = pixel_values.reshape(-1, 1)
            gmm = GaussianMixture(n_components=2, random_state=0).fit(data_to_fit)
            means = gmm.means_.flatten()
            stds = np.sqrt(gmm.covariances_).flatten()
            weights = gmm.weights_.flatten()

            mu1, mu2 = means[0], means[1]
            std1, std2 = stds[0], stds[1]
            w1, w2 = weights[0], weights[1]
            a = 1/(2*std1**2) - 1/(2*std2**2)
            b = mu2/(std2**2) - mu1/(std1**2)
            c = mu1**2 /(2*std1**2) - mu2**2 / (2*std2**2) - np.log(std2/std1) - np.log(w1/w2)
            
            solutions = np.roots([a,b,c])

            min_mean, max_mean = min(mu1, mu2), max(mu1, mu2)
            self.threshold = None
            for sol in solutions:
                if min_mean < sol < max_mean:
                    self.threshold = sol
                    break
            if self.threshold is None:
                self.threshold = np.mean(means)
                QMessageBox.warning(self, "GMM Warning", "Could not find clear intersection. Using mean as a fallback.")
        except Exception as e:
            QMessageBox.critical(self, "GMM Error", f"Could not calculate threshold with GMM: {e}")
            return
            
        self.result_label.setText(f"{self.threshold:.4f}")
        self.generate_button.setEnabled(True)

        self.figure.clear()
        ax = self.figure.add_subplot(111)
        ax.hist(pixel_values, bins=50, density=True, alpha=0.6, label='Pixel Histogram')
        ax.axvline(self.threshold, color='r', linestyle='--', label=f"GMM Threshold: {self.threshold:.4f}")
        
        low_mean_idx = np.argmin(means)
        high_mean_idx = 1 - low_mean_idx
        labels = [''] * 2
        if self.classification_mode == 'nir':
            labels[low_mean_idx], labels[high_mean_idx] = "Water", "Land"
        else:
            labels[low_mean_idx], labels[high_mean_idx] = "Land", "Water"
        
        x_plot = np.linspace(pixel_values.min(), pixel_values.max(), 500)
        for i in range(2):
            pdf = weights[i] * (1 / (stds[i] * np.sqrt(2 * np.pi))) * np.exp(-0.5 * ((x_plot - means[i]) / stds[i]) ** 2)
            ax.plot(x_plot, pdf, label=labels[i])

        ax.set_title("Pixel Distribution and GMM Fit")
        ax.set_xlabel("Pixel Value"); ax.set_ylabel("Density"); ax.legend(); self.plot_canvas.draw()

    def generate_water_mask(self):
        """Generates a water mask by classifying and optionally refining."""
        final_threshold = self.custom_threshold_input.value() if self.custom_radio.isChecked() else self.threshold
        if final_threshold is None or not self.raster_path: QMessageBox.critical(self, "Error", "A threshold must be available."); return
        output_path = self.output_path_widget.filePath()
        if not output_path: QMessageBox.critical(self, "Output Error", "Please specify a final output path."); return
        if not self.classification_mode: QMessageBox.critical(self, "Error", "Could not determine classification mode. Please reload data."); return

        try:
            with rasterio.open(self.raster_path) as src:
                data = src.read(1, masked=True)
                profile = src.profile # Keep profile for reference
            out_nodata = 255
            classified_data = np.full(data.shape, out_nodata, dtype=np.uint8)
            valid_data_mask = ~data.mask
            
            if self.classification_mode == 'nir':
                classified_data[valid_data_mask & (data < final_threshold)] = 1
                classified_data[valid_data_mask & (data >= final_threshold)] = 0
            else:
                classified_data[valid_data_mask & (data < final_threshold)] = 0
                classified_data[valid_data_mask & (data >= final_threshold)] = 1
            
            # --- MODIFIED: Refinement block with resampling ---
            if self.refine_checkbox.isChecked():
                qa_scl_path = self.qa_scl_path_widget.filePath()
                if not qa_scl_path: 
                    QMessageBox.warning(self, "Input Error", "Please select a QA/SCL band for refinement.")
                    return
                
                with rasterio.open(qa_scl_path) as qa_src:
                    # Check if resampling is needed
                    if qa_src.shape != classified_data.shape or qa_src.transform != profile['transform']:
                        QMessageBox.information(self, "Resampling", "SCL band does not match input raster. Resampling with Nearest Neighbor...")
                        # Resample to match the shape and transform of the data raster
                        qa_scl_band = qa_src.read(
                            out_shape=classified_data.shape,
                            resampling=Resampling.nearest
                        )[0]
                    else:
                        # If they match, just read the data
                        qa_scl_band = qa_src.read(1)

                if "pixel_qa" in os.path.basename(qa_scl_path).lower() or "qa_pixel" in os.path.basename(qa_scl_path).lower():
                    bad_pixels_mask = (np.bitwise_and(qa_scl_band, (1 << 3)) > 0) | (np.bitwise_and(qa_scl_band, (1 << 4)) > 0) | (np.bitwise_and(qa_scl_band, (1 << 5)) > 0)
                else:
                    bad_classes = [1, 3, 8, 9, 10, 11] # SCL classes for non-water/clouds/shadows
                    bad_pixels_mask = np.isin(qa_scl_band, bad_classes)
                
                # Set bad pixels to non-water (0)
                classified_data[bad_pixels_mask] = 0

            profile.update(dtype=rasterio.uint8, count=1, compress='lzw', nodata=out_nodata)
            with rasterio.open(output_path, 'w', **profile) as dst:
                dst.write(classified_data, 1)

            new_layer = QgsRasterLayer(output_path, os.path.basename(output_path))

            if new_layer.isValid():
                color_map = [
                    QgsPalettedRasterRenderer.Class(0, QColor("black"), "0 (Not Water)"),
                    QgsPalettedRasterRenderer.Class(1, QColor("white"), "1 (Water)")
                ]
                renderer = QgsPalettedRasterRenderer(new_layer.dataProvider(), 1, color_map)
                new_layer.setRenderer(renderer)
                QgsProject.instance().addMapLayer(new_layer)
                new_layer.triggerRepaint()
                QMessageBox.information(self, "Success", f"Final water mask saved to:\n{output_path}\nand added to the project.")
            else:
                QMessageBox.critical(self, "Error", "Failed to save or load the final water mask.")


        except Exception as e:
            import traceback
            QMessageBox.critical(self, "Processing Error", f"An error occurred: {e}\n{traceback.format_exc()}")

    def toggle_refinement_widgets(self, checked):
        self.qa_scl_label.setEnabled(checked)
        self.qa_scl_path_widget.setEnabled(checked)

    def draw_shape(self):
        if not self.raster_layer:
            QMessageBox.warning(self, "Warning", "Please load a raster first.")
            return
        if self.rubber_band:
            self.rubber_band.reset(QgsWkbTypes.PolygonGeometry)
        else:
            self.rubber_band = QgsRubberBand(self.canvas, QgsWkbTypes.PolygonGeometry)
            self.rubber_band.setColor(QColor(255, 0, 0, 128))
            self.rubber_band.setWidth(2)
        self.map_tool = RectangleMapTool(self.canvas) if self.shape_combo.currentText() == "Rectangle" else PolygonMapTool(self.canvas)
        self.map_tool.deactivated.connect(self.process_geometry)
        self.canvas.setMapTool(self.map_tool)

    def process_geometry(self):
        if self.map_tool and self.map_tool.geometry:
            self.rubber_band.setToGeometry(self.map_tool.geometry, None)
            self.rubber_band.show()
        self.canvas.unsetMapTool(self.map_tool)
        self.map_tool = None

    def closeEvent(self, event):
        """Clean up temporary files when the dialog is closed."""
        if self.temp_ndwi_file:
            try: os.remove(self.temp_ndwi_file)
            except OSError as e: print(f"Error removing temp file: {e}")
        event.accept()


# --- REGRESSION ANALYSIS DIALOG CLASS ---
class RegressionAnalysisDialog(QDialog):
    def __init__(self, parent=None):
        """Constructor."""
        super(RegressionAnalysisDialog, self).__init__(parent)
        self.setWindowTitle("Linear Regression Analysis")
        
        main_layout = QVBoxLayout(self)
        controls_layout = QGridLayout()

        # Input Layer
        self.layer_label = QLabel("Input Shapefile:")
        self.layer_combo = QgsMapLayerComboBox()
        self.layer_combo.setFilters(QgsMapLayerProxyModel.VectorLayer)
        self.layer_combo.layerChanged.connect(self.update_combos)
        controls_layout.addWidget(self.layer_label, 0, 0)
        controls_layout.addWidget(self.layer_combo, 0, 1, 1, 2)

        # Water Quality Parameter
        self.param_label = QLabel("Water Quality Parameter:")
        self.param_combo = QComboBox()
        self.param_combo.addItems(["Turbidity", "Chlorophyl", "CDOM"])
        self.param_combo.currentIndexChanged.connect(self.update_combos)
        controls_layout.addWidget(self.param_label, 1, 0)
        controls_layout.addWidget(self.param_combo, 1, 1, 1, 2)

        # Spectral Index
        self.index_label = QLabel("Spectral Index:")
        self.index_combo = QComboBox()
        controls_layout.addWidget(self.index_label, 2, 0)
        controls_layout.addWidget(self.index_combo, 2, 1, 1, 2)

        # Process Buttons Layout
        buttons_layout = QHBoxLayout()
        self.process_button = QPushButton("Run Analysis")
        self.process_button.clicked.connect(self.process_regression)
        buttons_layout.addWidget(self.process_button)

        self.apply_model_button = QPushButton("Apply Model to Image")
        self.apply_model_button.clicked.connect(self.apply_model_to_image)
        self.apply_model_button.setVisible(False)
        buttons_layout.addWidget(self.apply_model_button)
        controls_layout.addLayout(buttons_layout, 3, 1)
        
        main_layout.addLayout(controls_layout)

        # Results Layout
        results_layout = QHBoxLayout()
        self.results_text = QPlainTextEdit()
        self.results_text.setReadOnly(True)
        results_layout.addWidget(self.results_text, 1)

        self.figure = plt.figure()
        self.plot_canvas = FigureCanvas(self.figure)
        results_layout.addWidget(self.plot_canvas, 2)
        
        main_layout.addLayout(results_layout)
        
        self.model = None
        self.model_x_field = None
        
        self.update_combos()

    def update_combos(self):
        """Populate the spectral index combo box based on the selected layer and parameter."""
        self.apply_model_button.setVisible(False) 
        
        self.index_combo.clear()
        layer = self.layer_combo.currentLayer()
        if not layer:
            return

        parameter = self.param_combo.currentText()
        model_map = {
            "Turbidity": ['RB_Ratio', 'NDTI', 'GB_Ratio', 'RedPlusNIR'],
            "Chlorophyl": ['NDCI', 'TWOBDA', 'THREEBDA', 'BG_Ratio', 'RG_Ratio', 'SABI'],
            "CDOM": ['GB_Ratio', 'RB_Ratio', 'RMinusG', 'GR_Ratio']
        }
        possible_indices = model_map.get(parameter, [])
        layer_fields = [field.name() for field in layer.fields()]
        
        available_indices = [field for field in possible_indices if field in layer_fields]
        self.index_combo.addItems(available_indices)

    def process_regression(self):
        """Performs linear regression analysis."""
        layer = self.layer_combo.currentLayer()
        if not layer: QMessageBox.critical(self, "Input Error", "Please select an input shapefile."); return
        
        parameter = self.param_combo.currentText()
        x_field = self.index_combo.currentText()
        y_field = parameter

        if not x_field: QMessageBox.critical(self, "Input Error", "No spectral index selected or available."); return
        
        df = pd.DataFrame([f.attributes() for f in layer.getFeatures()], columns=[field.name() for field in layer.fields()])
        self.results_text.clear()
        self.results_text.appendPlainText(f"--- Linear Regression Results for {parameter} ---")

        if y_field not in df.columns:
            self.results_text.appendPlainText(f"\nERROR: Dependent variable field '{y_field}' not found.")
            return
            
        try:
            df[x_field] = pd.to_numeric(df[x_field], errors='coerce')
            df[y_field] = pd.to_numeric(df[y_field], errors='coerce')
            data = df[[y_field, x_field]].dropna()

            if data.empty:
                self.results_text.appendPlainText(f"\nModel: {x_field} vs {y_field}\n  - No valid data pairs after cleaning.")
                return

            X = np.array(data[x_field]).reshape(-1, 1)
            y = np.array(data[y_field])

            model = LinearRegression()
            model.fit(X, y)
            y_pred = model.predict(X)
            
            self.model = model
            self.model_x_field = x_field

            r_squared = model.score(X, y)
            mae = mean_absolute_error(y, y_pred)
            mse = mean_squared_error(y, y_pred)
            rmse = np.sqrt(mse)
            
            self.results_text.appendPlainText(f"\nModel: {self.model_x_field} vs {y_field}")
            self.results_text.appendPlainText(f"\n--- Performance Metrics ---")
            self.results_text.appendPlainText(f"  R-squared: {r_squared:.4f}")
            self.results_text.appendPlainText(f"  Mean Absolute Error (MAE): {mae:.4f}")
            self.results_text.appendPlainText(f"  Root Mean Squared Error (RMSE): {rmse:.4f}")
            self.results_text.appendPlainText(f"\n--- Equation ---")
            self.results_text.appendPlainText(f"  {y_field} = {self.model.coef_[0]:.4f} * {self.model_x_field} + {self.model.intercept_:.4f}")

            self.figure.clear()
            ax = self.figure.add_subplot(111)
            ax.scatter(y, y_pred, alpha=0.5, label='Predicted vs. Actual')
            lims = [np.min([y.min(), y_pred.min()]), np.max([y.max(), y_pred.max()])]
            ax.plot(lims, lims, 'r--', alpha=0.75, zorder=0, label="1:1 Line")
            ax.set_aspect('equal', adjustable='box')
            ax.set_xlabel("Actual Values"); ax.set_ylabel("Predicted Values")
            ax.set_title(f"Linear Regression: {y_field} vs {self.model_x_field}")
            ax.legend(); ax.grid(True); self.plot_canvas.draw()
            self.apply_model_button.setVisible(True)

        except Exception as e:
            self.results_text.appendPlainText(f"\nCould not process model for '{x_field}': {e}")
            self.apply_model_button.setVisible(False)

    def apply_model_to_image(self):
        """Shows a dialog and calls the shared helper function to apply the model."""
        if self.model is None:
            QMessageBox.critical(self, "Error", "A valid regression model must be run first.")
            return

        dialog = ApplyModelDialog(self.model_x_field, self)
        if not dialog.exec_():
            return

        apply_model_to_raster(
            parent=self, model=self.model, model_x_field=self.model_x_field,
            image_path=dialog.imagePath(), band_num=dialog.bandNumber(),
            output_path=dialog.outputPath(), use_mask=dialog.useMask(),
            mask_path=dialog.maskPath()
        )


# --- RANDOM FOREST REGRESSION DIALOG CLASS ---
class RandomForestRegressionDialog(QDialog):
    def __init__(self, parent=None):
        """Constructor."""
        super(RandomForestRegressionDialog, self).__init__(parent)
        self.setWindowTitle("Random Forest Regression Analysis")
        
        main_layout = QVBoxLayout(self)
        controls_layout = QGridLayout()

        self.layer_label = QLabel("Input Shapefile:")
        self.layer_combo = QgsMapLayerComboBox()
        self.layer_combo.setFilters(QgsMapLayerProxyModel.VectorLayer)
        self.layer_combo.layerChanged.connect(self.update_combos)
        controls_layout.addWidget(self.layer_label, 0, 0)
        controls_layout.addWidget(self.layer_combo, 0, 1, 1, 2)

        self.param_label = QLabel("Water Quality Parameter:")
        self.param_combo = QComboBox()
        self.param_combo.addItems(["Turbidity", "Chlorophyl", "CDOM"])
        self.param_combo.currentIndexChanged.connect(self.update_combos)
        controls_layout.addWidget(self.param_label, 1, 0)
        controls_layout.addWidget(self.param_combo, 1, 1, 1, 2)
        
        self.index_label = QLabel("Spectral Index:")
        self.index_combo = QComboBox()
        controls_layout.addWidget(self.index_label, 2, 0)
        controls_layout.addWidget(self.index_combo, 2, 1, 1, 2)

        buttons_layout = QHBoxLayout()
        self.process_button = QPushButton("Run Analysis")
        self.process_button.clicked.connect(self.process_regression)
        buttons_layout.addWidget(self.process_button)

        self.apply_model_button = QPushButton("Apply Model to Image")
        self.apply_model_button.clicked.connect(self.apply_model_to_image)
        self.apply_model_button.setVisible(False)
        buttons_layout.addWidget(self.apply_model_button)
        controls_layout.addLayout(buttons_layout, 3, 1)

        main_layout.addLayout(controls_layout)

        results_layout = QHBoxLayout()
        self.results_text = QPlainTextEdit()
        self.results_text.setReadOnly(True)
        results_layout.addWidget(self.results_text, 1)

        self.figure = plt.figure()
        self.plot_canvas = FigureCanvas(self.figure)
        results_layout.addWidget(self.plot_canvas, 2)
        
        main_layout.addLayout(results_layout)
        
        self.model = None
        self.model_x_field = None
        
        self.update_combos()

    def update_combos(self):
        """Populate the spectral index combo box based on the selected layer and parameter."""
        self.apply_model_button.setVisible(False)
        self.index_combo.clear()
        layer = self.layer_combo.currentLayer()
        if not layer: return
        parameter = self.param_combo.currentText()
        model_map = {"Turbidity": ['RB_Ratio', 'NDTI', 'GB_Ratio', 'RedPlusNIR'], "Chlorophyl": ['NDCI', 'TWOBDA', 'THREEBDA', 'BG_Ratio', 'RG_Ratio', 'SABI'], "CDOM": ['GB_Ratio', 'RB_Ratio', 'RMinusG', 'GR_Ratio']}
        possible_indices = model_map[parameter]
        layer_fields = [field.name() for field in layer.fields()]
        available_indices = [field for field in possible_indices if field in layer_fields]
        self.index_combo.addItems(available_indices)
        
    def process_regression(self):
        """Performs Random Forest regression analysis."""
        layer = self.layer_combo.currentLayer()
        if not layer: QMessageBox.critical(self, "Input Error", "Please select an input shapefile."); return
        
        parameter = self.param_combo.currentText()
        x_field = self.index_combo.currentText()
        y_field = parameter

        if not x_field: QMessageBox.critical(self, "Input Error", "No spectral index selected or available."); return
        
        df = pd.DataFrame([f.attributes() for f in layer.getFeatures()], columns=[field.name() for field in layer.fields()])
        self.results_text.clear()
        self.results_text.appendPlainText(f"--- Random Forest Regression Results for {parameter} ---")

        if y_field not in df.columns: self.results_text.appendPlainText(f"\nERROR: Field '{y_field}' not found."); return
            
        try:
            data = df[[y_field, x_field]].dropna()
            if data.empty: self.results_text.appendPlainText(f"\nModel: {x_field} vs {y_field}\n  - No valid data pairs."); return

            X = np.array(data[x_field]).reshape(-1, 1)
            y = np.array(data[y_field])

            model = RandomForestRegressor(n_estimators=100, random_state=42)
            model.fit(X, y)
            
            self.model = model
            self.model_x_field = x_field

            y_pred = self.model.predict(X)
            
            r_squared = self.model.score(X, y)
            mae = mean_absolute_error(y, y_pred)
            mse = mean_squared_error(y, y_pred)
            rmse = np.sqrt(mse)
            
    # --- START: New code to save plot data ---
            try:
                # Ask user where to save the data
                save_path, _ = QtWidgets.QFileDialog.getSaveFileName(
                    self, 
                    "Save Plot Data", 
                    "", 
                    "CSV Files (*.csv)"
                )

                if save_path:
                    # Create a new DataFrame with actual, predicted, and index values
                    plot_data = pd.DataFrame({
                        'Actual_Y': y,
                        'Predicted_Y': y_pred,
                        'Input_X': X.flatten()
                    })
                    plot_data.to_csv(save_path, index=False)
                    self.results_text.appendPlainText(f"\nPlot data saved to:\n{save_path}")

            except Exception as e:
                self.results_text.appendPlainText(f"\nCould not save plot data: {e}")
            # --- END: New code to save plot data ---

            self.results_text.appendPlainText(f"\nModel: {self.model_x_field} vs {y_field}")
            self.results_text.appendPlainText(f"\n--- Model Specifications ---")
            self.results_text.appendPlainText(f"  n_estimators: 100")
            
            self.results_text.appendPlainText(f"\n--- Performance Metrics ---")
            self.results_text.appendPlainText(f"  R-squared: {r_squared:.4f}")
            self.results_text.appendPlainText(f"  Mean Absolute Error (MAE): {mae:.4f}")
            self.results_text.appendPlainText(f"  Root Mean Squared Error (RMSE): {rmse:.4f}")

            self.figure.clear()
            ax = self.figure.add_subplot(111)
            ax.scatter(y, y_pred, alpha=0.5, label='Predicted vs. Actual')
            ax.plot([y.min(), y.max()], [y.min(), y.max()], '--r', linewidth=2, label="1:1 Line")
            ax.set_xlabel("Actual Values"); ax.set_ylabel("Predicted Values")
            ax.set_title(f"Random Forest: Actual vs. Predicted for {self.model_x_field}")
            ax.grid(True); ax.legend(); self.plot_canvas.draw()
            
            self.apply_model_button.setVisible(True)

        except Exception as e:
            self.results_text.appendPlainText(f"\nCould not process model for '{x_field}': {e}")
            self.apply_model_button.setVisible(False)

    def apply_model_to_image(self):
        """Shows a dialog and calls the shared helper function to apply the model."""
        if self.model is None:
            QMessageBox.critical(self, "Error", "A valid Random Forest model must be run first.")
            return

        dialog = ApplyModelDialog(self.model_x_field, self)
        if not dialog.exec_():
            return
            
        apply_model_to_raster(
            parent=self, model=self.model, model_x_field=self.model_x_field,
            image_path=dialog.imagePath(), band_num=dialog.bandNumber(),
            output_path=dialog.outputPath(), use_mask=dialog.useMask(),
            mask_path=dialog.maskPath()
        )


# --- SVM REGRESSION DIALOG CLASS ---
class SvmRegressionDialog(QDialog):
    def __init__(self, parent=None):
        """Constructor."""
        super(SvmRegressionDialog, self).__init__(parent)
        self.setWindowTitle("SVM Regression Analysis")
        
        main_layout = QVBoxLayout(self)
        controls_layout = QGridLayout()
        
        self.layer_label = QLabel("Input Shapefile:")
        self.layer_combo = QgsMapLayerComboBox()
        self.layer_combo.setFilters(QgsMapLayerProxyModel.VectorLayer)
        self.layer_combo.layerChanged.connect(self.update_combos)
        controls_layout.addWidget(self.layer_label, 0, 0)
        controls_layout.addWidget(self.layer_combo, 0, 1, 1, 2)

        self.param_label = QLabel("Water Quality Parameter:")
        self.param_combo = QComboBox()
        self.param_combo.addItems(["Turbidity", "Chlorophyl", "CDOM"])
        self.param_combo.currentIndexChanged.connect(self.update_combos)
        controls_layout.addWidget(self.param_label, 1, 0)
        controls_layout.addWidget(self.param_combo, 1, 1, 1, 2)
        
        self.index_label = QLabel("Spectral Index:")
        self.index_combo = QComboBox()
        controls_layout.addWidget(self.index_label, 2, 0)
        controls_layout.addWidget(self.index_combo, 2, 1, 1, 2)

        self.kernel_label = QLabel("Kernel:")
        self.kernel_combo = QComboBox()
        self.kernel_combo.addItems(['rbf', 'linear', 'poly'])
        controls_layout.addWidget(self.kernel_label, 3, 0)
        controls_layout.addWidget(self.kernel_combo, 3, 1)

        self.c_label = QLabel("C:")
        self.c_spinbox = QDoubleSpinBox()
        self.c_spinbox.setRange(0.01, 1000.0); self.c_spinbox.setValue(1.0)
        controls_layout.addWidget(self.c_label, 4, 0)
        controls_layout.addWidget(self.c_spinbox, 4, 1)

        self.gamma_label = QLabel("Gamma:")
        self.gamma_combo = QComboBox()
        self.gamma_combo.addItems(['scale', 'auto'])
        controls_layout.addWidget(self.gamma_label, 5, 0)
        controls_layout.addWidget(self.gamma_combo, 5, 1)

        buttons_layout = QHBoxLayout()
        self.process_button = QPushButton("Run Analysis")
        self.process_button.clicked.connect(self.process_regression)
        buttons_layout.addWidget(self.process_button)

        self.apply_model_button = QPushButton("Apply Model to Image")
        self.apply_model_button.clicked.connect(self.apply_model_to_image)
        self.apply_model_button.setVisible(False)
        buttons_layout.addWidget(self.apply_model_button)
        controls_layout.addLayout(buttons_layout, 6, 1)
        
        main_layout.addLayout(controls_layout)

        results_layout = QHBoxLayout()
        self.results_text = QPlainTextEdit()
        self.results_text.setReadOnly(True)
        results_layout.addWidget(self.results_text, 1)
        self.figure = plt.figure()
        self.plot_canvas = FigureCanvas(self.figure)
        results_layout.addWidget(self.plot_canvas, 2)
        main_layout.addLayout(results_layout)
        
        self.model = None
        self.model_x_field = None
        self.update_combos()

    def update_combos(self):
        """Populate the spectral index combo box based on the selected layer and parameter."""
        self.apply_model_button.setVisible(False)
        self.index_combo.clear()
        layer = self.layer_combo.currentLayer()
        if not layer: return
        parameter = self.param_combo.currentText()
        model_map = {"Turbidity": ['RB_Ratio', 'NDTI', 'GB_Ratio', 'RedPlusNIR'],"Chlorophyl": ['NDCI', 'TWOBDA', 'THREEBDA', 'BG_Ratio', 'RG_Ratio', 'SABI'],"CDOM": ['GB_Ratio', 'RB_Ratio', 'RMinusG', 'GR_Ratio']}
        possible_indices = model_map[parameter]
        layer_fields = [field.name() for field in layer.fields()]
        available_indices = [field for field in possible_indices if field in layer_fields]
        self.index_combo.addItems(available_indices)

    def process_regression(self):
        """Performs SVM regression analysis."""
        layer = self.layer_combo.currentLayer()
        if not layer: QMessageBox.critical(self, "Input Error", "Please select an input shapefile."); return
        
        parameter = self.param_combo.currentText()
        x_field = self.index_combo.currentText()
        y_field = parameter

        if not x_field: QMessageBox.critical(self, "Input Error", "No spectral index selected or available."); return
            
        df = pd.DataFrame([f.attributes() for f in layer.getFeatures()], columns=[field.name() for field in layer.fields()])
        self.results_text.clear()
        self.results_text.appendPlainText(f"--- SVM Regression Results for {parameter} ---")

        if y_field not in df.columns: self.results_text.appendPlainText(f"\nERROR: Field '{y_field}' not found."); return
            
        try:
            data = df[[y_field, x_field]].dropna()
            if data.empty: self.results_text.appendPlainText(f"\nModel: {x_field} vs {y_field}\n  - No valid data pairs."); return

            X = np.array(data[x_field]).reshape(-1, 1)
            y = np.array(data[y_field])
            
            kernel = self.kernel_combo.currentText()
            c_value = self.c_spinbox.value()
            gamma_value = self.gamma_combo.currentText()

            model = SVR(kernel=kernel, C=c_value, gamma=gamma_value)
            model.fit(X, y)
            
            self.model = model
            self.model_x_field = x_field
            y_pred = self.model.predict(X)
            
            r_squared = self.model.score(X, y)
            mae = mean_absolute_error(y, y_pred)
            mse = mean_squared_error(y, y_pred)
            rmse = np.sqrt(mse)

            self.results_text.appendPlainText(f"\nModel: {self.model_x_field} vs {y_field}")
            self.results_text.appendPlainText(f"\n--- Model Specifications ---"); self.results_text.appendPlainText(f"  Kernel: {kernel}")
            self.results_text.appendPlainText(f"  C: {c_value}"); self.results_text.appendPlainText(f"  Gamma: {gamma_value}")
            self.results_text.appendPlainText(f"\n--- Performance Metrics ---"); self.results_text.appendPlainText(f"  R-squared: {r_squared:.4f}")
            self.results_text.appendPlainText(f"  Mean Absolute Error (MAE): {mae:.4f}"); self.results_text.appendPlainText(f"  Root Mean Squared Error (RMSE): {rmse:.4f}")
            self.results_text.appendPlainText(f"\n--- Support Vectors ---"); self.results_text.appendPlainText(f"  Number of support vectors: {model.n_support_[0]}")
            
            self.figure.clear()
            ax = self.figure.add_subplot(111)
            ax.scatter(y, y_pred, alpha=0.5, label='Predicted vs. Actual')
            ax.plot([y.min(), y.max()], [y.min(), y.max()], '--r', linewidth=2, label='1:1 Line')
            ax.set_xlabel("Actual Values"); ax.set_ylabel("Predicted Values"); ax.set_title(f"SVM Regression: Actual vs. Predicted")
            ax.grid(True); ax.legend(); self.plot_canvas.draw()
            
            self.apply_model_button.setVisible(True)

        except Exception as e:
            self.results_text.appendPlainText(f"\nCould not process model for '{x_field}': {e}")
            self.apply_model_button.setVisible(False)

    def apply_model_to_image(self):
        """Shows a dialog and calls the shared helper function to apply the model."""
        if self.model is None:
            QMessageBox.critical(self, "Error", "A valid SVM model must be run first.")
            return

        dialog = ApplyModelDialog(self.model_x_field, self)
        if not dialog.exec_():
            return

        apply_model_to_raster(
            parent=self, model=self.model, model_x_field=self.model_x_field,
            image_path=dialog.imagePath(), band_num=dialog.bandNumber(),
            output_path=dialog.outputPath(), use_mask=dialog.useMask(),
            mask_path=dialog.maskPath()
        )
