# -*- coding: utf-8 -*-
"""
/***************************************************************************
 SupervisedClassification
                                 A QGIS plugin
 A plugin to classify selected raster file with reference
                              -------------------
        begin                : 2024-06-15
        git sha              : $Format:%H$
        email                : mastools.help@gmail.com
 ***************************************************************************/
"""
from qgis.PyQt.QtCore import QSettings, QTranslator, QCoreApplication, Qt
from qgis.PyQt.QtGui import QIcon
from qgis.PyQt.QtWidgets import (QAction, QMessageBox, QToolBar, QDialog, QVBoxLayout, 
                                 QHBoxLayout, QTableWidget, QTableWidgetItem, QPushButton,
                                 QFileDialog, QHeaderView, QLabel)
from qgis.core import QgsRasterLayer, QgsProject
import os
import sys
from .resources_rc import *
import time
import json

# Flag to track if dependencies are available
DEPENDENCIES_AVAILABLE = False

# Try to import required packages
try:
    from .classification_dialog import ClassificationDialog
    from osgeo import ogr, gdal
    from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
    from sklearn.svm import SVC
    from sklearn.neighbors import KNeighborsClassifier
    from sklearn.preprocessing import LabelEncoder
    from sklearn.naive_bayes import GaussianNB
    from sklearn.discriminant_analysis import LinearDiscriminantAnalysis, QuadraticDiscriminantAnalysis
    from scipy.spatial.distance import cdist
    import numpy as np
    import joblib
    import pandas as pd
    DEPENDENCIES_AVAILABLE = True
except ImportError:
    # Dependencies not available - will be handled in __init__
    pass


class LabelMappingDialog(QDialog):
    """Dialog to display and export label mappings"""
    def __init__(self, label_mappings, parent=None):
        super().__init__(parent)
        self.setWindowTitle("Label Mappings")
        self.resize(500, 400)
        self.label_mappings = label_mappings
        
        layout = QVBoxLayout()
        
        # Title label
        title_label = QLabel("Output Raster Values and Label Names:")
        title_label.setStyleSheet("font-weight: bold; font-size: 12pt;")
        layout.addWidget(title_label)
        
        # Table to display mappings
        self.mapping_table = QTableWidget()
        self.mapping_table.setColumnCount(2)
        self.mapping_table.setHorizontalHeaderLabels(["Raster Value", "Label Name"])
        self.mapping_table.horizontalHeader().setStretchLastSection(True)
        self.mapping_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeToContents)
        self.mapping_table.setRowCount(len(label_mappings))
        
        # Populate table
        sorted_mappings = sorted(label_mappings.items(), key=lambda x: x[1])
        for row, (label_name, raster_value) in enumerate(sorted_mappings):
            # Raster value
            value_item = QTableWidgetItem(str(raster_value))
            value_item.setFlags(value_item.flags() & ~Qt.ItemIsEditable)
            self.mapping_table.setItem(row, 0, value_item)
            
            # Label name
            name_item = QTableWidgetItem(str(label_name))
            name_item.setFlags(name_item.flags() & ~Qt.ItemIsEditable)
            self.mapping_table.setItem(row, 1, name_item)
        
        layout.addWidget(self.mapping_table)
        
        # Export buttons
        button_layout = QHBoxLayout()
        
        export_csv_button = QPushButton("Export as CSV")
        export_csv_button.clicked.connect(self.export_as_csv)
        button_layout.addWidget(export_csv_button)
        
        export_json_button = QPushButton("Export as JSON")
        export_json_button.clicked.connect(self.export_as_json)
        button_layout.addWidget(export_json_button)
        
        close_button = QPushButton("Close")
        close_button.clicked.connect(self.accept)
        button_layout.addWidget(close_button)
        
        layout.addLayout(button_layout)
        
        self.setLayout(layout)
    
    def export_as_csv(self):
        """Export label mappings as CSV"""
        file_path, _ = QFileDialog.getSaveFileName(
            self, "Save Label Mappings as CSV", "", "CSV Files (*.csv)"
        )
        
        if file_path:
            try:
                df = pd.DataFrame(list(self.label_mappings.items()), 
                                 columns=['Label_Name', 'Raster_Value'])
                df = df.sort_values('Raster_Value')
                df.to_csv(file_path, index=False)
                QMessageBox.information(self, "Export Successful", 
                                       f"Label mappings exported to:\n{file_path}")
            except Exception as e:
                QMessageBox.critical(self, "Export Failed", 
                                    f"Failed to export CSV: {str(e)}")
    
    def export_as_json(self):
        """Export label mappings as JSON"""
        file_path, _ = QFileDialog.getSaveFileName(
            self, "Save Label Mappings as JSON", "", "JSON Files (*.json)"
        )
        
        if file_path:
            try:
                with open(file_path, 'w') as f:
                    json.dump(self.label_mappings, f, indent=4)
                QMessageBox.information(self, "Export Successful", 
                                       f"Label mappings exported to:\n{file_path}")
            except Exception as e:
                QMessageBox.critical(self, "Export Failed", 
                                    f"Failed to export JSON: {str(e)}")


class AccuracyAssessmentDialog(QDialog):
    """Dialog to display accuracy assessment results"""
    def __init__(self, metrics, class_stats=None, output_folder=None, method="Classification", parent=None):
        super().__init__(parent)
        self.setWindowTitle("Accuracy Assessment")
        self.resize(600, 500)
        self.metrics = metrics
        self.class_stats = class_stats
        self.output_folder = output_folder
        self.method = method
        
        layout = QVBoxLayout()
        
        # Title
        title_label = QLabel("Classification Accuracy Assessment")
        title_label.setStyleSheet("font-weight: bold; font-size: 14pt;")
        layout.addWidget(title_label)
        
        # Overall Metrics
        metrics_layout = QHBoxLayout()
        
        oa_label = QLabel(f"Overall Accuracy: {metrics['overall_accuracy']*100:.2f}%")
        oa_label.setStyleSheet("font-weight: bold; font-size: 12pt; color: #2e7d32;")
        metrics_layout.addWidget(oa_label)
        
        kappa_label = QLabel(f"Kappa: {metrics['kappa']:.4f}")
        kappa_label.setStyleSheet("font-weight: bold; font-size: 12pt; color: #1565c0;")
        metrics_layout.addWidget(kappa_label)
        
        layout.addLayout(metrics_layout)
        
        # Additional metrics
        add_metrics = QLabel(
            f"Precision: {metrics['precision']:.4f}  |  "
            f"Recall: {metrics['recall']:.4f}  |  "
            f"F1-Score: {metrics['f1_score']:.4f}"
        )
        layout.addWidget(add_metrics)
        
        # Confusion Matrix
        layout.addWidget(QLabel("Confusion Matrix:"))
        cm = metrics['confusion_matrix']
        cm_table = QTableWidget()
        n_classes = len(cm)
        cm_table.setRowCount(n_classes)
        cm_table.setColumnCount(n_classes)
        
        class_names = metrics.get('class_names', [str(i) for i in range(n_classes)])
        cm_table.setHorizontalHeaderLabels(class_names)
        cm_table.setVerticalHeaderLabels(class_names)
        
        for i in range(n_classes):
            for j in range(n_classes):
                item = QTableWidgetItem(str(cm[i][j]))
                item.setFlags(item.flags() & ~Qt.ItemIsEditable)
                cm_table.setItem(i, j, item)
        
        cm_table.horizontalHeader().setSectionResizeMode(QHeaderView.Stretch)
        layout.addWidget(cm_table)
        
        # Class Statistics
        if class_stats:
            layout.addWidget(QLabel(f"Total Pixels: {class_stats['total_pixels']:,}"))
        
        # Feature Importance (for RF, GB, LDA)
        if 'feature_importance' in metrics and metrics['feature_importance'].get('supported'):
            importance = metrics['feature_importance']
            layout.addWidget(QLabel("Top Feature Importances:"))
            # Show top 5 features
            top_features = list(importance['importances'].items())[:5]
            importance_text = "  |  ".join([f"{k}: {v:.3f}" for k, v in top_features])
            fi_label = QLabel(importance_text)
            fi_label.setStyleSheet("color: #7b1fa2;")
            layout.addWidget(fi_label)
        
        # Buttons
        button_layout = QHBoxLayout()
        
        export_btn = QPushButton("Export Report")
        export_btn.clicked.connect(self.export_report)
        button_layout.addWidget(export_btn)
        
        close_btn = QPushButton("Close")
        close_btn.clicked.connect(self.accept)
        button_layout.addWidget(close_btn)
        
        layout.addLayout(button_layout)
        self.setLayout(layout)
    
    def export_report(self):
        """Export accuracy report to file"""
        from .analysis_utils import export_accuracy_report
        
        file_path, _ = QFileDialog.getSaveFileName(
            self, "Save Accuracy Report", "", "Text Files (*.txt)"
        )
        
        if file_path:
            try:
                export_accuracy_report(self.metrics, self.class_stats, file_path, self.method)
                QMessageBox.information(self, "Export Successful", 
                                       f"Accuracy report exported to:\n{file_path}")
            except Exception as e:
                QMessageBox.critical(self, "Export Failed", 
                                    f"Failed to export report: {str(e)}")


class CRSAnalysisDialog(QDialog):
    """Dialog to display CRS analysis and mismatch warnings"""
    def __init__(self, crs_analysis, parent=None):
        super().__init__(parent)
        self.setWindowTitle("CRS Analysis - Coordinate Reference System Check")
        self.resize(800, 500)
        self.crs_analysis = crs_analysis
        
        layout = QVBoxLayout()
        
        # Warning header
        if crs_analysis['has_mismatch']:
            header = QLabel("⚠️ CRS Mismatch Detected!")
            header.setStyleSheet("font-weight: bold; font-size: 14pt; color: #d32f2f;")
        else:
            header = QLabel("✅ All CRS Match")
            header.setStyleSheet("font-weight: bold; font-size: 14pt; color: #2e7d32;")
        layout.addWidget(header)
        
        # Analysis Table
        layout.addWidget(QLabel("Input Files CRS Analysis:"))
        table = QTableWidget()
        table.setColumnCount(4)
        table.setHorizontalHeaderLabels(["File Name", "Type", "CRS (EPSG)", "Status"])
        table.horizontalHeader().setStretchLastSection(True)
        table.horizontalHeader().setSectionResizeMode(0, QHeaderView.Stretch)
        
        all_items = crs_analysis['rasters'] + crs_analysis['vectors']
        table.setRowCount(len(all_items))
        
        for row, item in enumerate(all_items):
            # File name
            name_item = QTableWidgetItem(item['name'])
            name_item.setFlags(name_item.flags() & ~Qt.ItemIsEditable)
            table.setItem(row, 0, name_item)
            
            # Type
            type_item = QTableWidgetItem(item['type'])
            type_item.setFlags(type_item.flags() & ~Qt.ItemIsEditable)
            table.setItem(row, 1, type_item)
            
            # CRS
            crs_item = QTableWidgetItem(item['epsg'])
            crs_item.setFlags(crs_item.flags() & ~Qt.ItemIsEditable)
            table.setItem(row, 2, crs_item)
            
            # Status
            status = "✅ Match" if item['matches_target'] else "❌ Mismatch"
            status_item = QTableWidgetItem(status)
            status_item.setFlags(status_item.flags() & ~Qt.ItemIsEditable)
            if not item['matches_target']:
                status_item.setBackground(Qt.red)
                status_item.setForeground(Qt.white)
            table.setItem(row, 3, status_item)
        
        layout.addWidget(table)
        
        # Target CRS info
        target_label = QLabel(f"Target CRS (from first raster): EPSG:{crs_analysis['target_epsg']}")
        target_label.setStyleSheet("font-weight: bold;")
        layout.addWidget(target_label)
        
        # Suggestions
        if crs_analysis['has_mismatch']:
            layout.addWidget(QLabel(""))
            suggestions_label = QLabel("📋 Suggestions to Fix CRS Mismatches:")
            suggestions_label.setStyleSheet("font-weight: bold; font-size: 11pt;")
            layout.addWidget(suggestions_label)
            
            suggestions = [
                "1. Reproject mismatched files to match the target CRS using:",
                "   • QGIS: Vector → Data Management Tools → Reproject Layer",
                "   • QGIS: Raster → Projections → Warp (Reproject)",
                "",
                "2. Or use 'On-the-fly' reprojection (less accurate):",
                "   • Enable Project CRS in QGIS Project Properties",
                "",
                "3. Verify all inputs use the same CRS before classification",
                "   • Mismatched CRS can cause incorrect sample extraction"
            ]
            
            for suggestion in suggestions:
                layout.addWidget(QLabel(suggestion))
        
        # Buttons
        button_layout = QHBoxLayout()
        
        continue_btn = QPushButton("Continue Anyway" if crs_analysis['has_mismatch'] else "Continue")
        continue_btn.clicked.connect(self.accept)
        button_layout.addWidget(continue_btn)
        
        if crs_analysis['has_mismatch']:
            cancel_btn = QPushButton("Cancel")
            cancel_btn.clicked.connect(self.reject)
            button_layout.addWidget(cancel_btn)
        
        layout.addLayout(button_layout)
        self.setLayout(layout)


class SupervisedClassification:
    """QGIS Plugin Implementation."""
    
    # Required packages: {pip_package_name: import_module_name}
    REQUIRED_PACKAGES = {
        'scikit-learn': 'sklearn',
        'scipy': 'scipy',
        'numpy': 'numpy',
        'pandas': 'pandas',
        'joblib': 'joblib'
    }
    
    def __init__(self, iface):
        self.iface = iface
        self.plugin_dir = os.path.dirname(__file__)
        self.actions = []
        self.menu = self.tr(u'&MAS Raster Processing')
        self.toolbar = None
        self.first_start = None
        self.dependencies_ok = DEPENDENCIES_AVAILABLE
        # Dependencies are checked when user clicks the plugin, not at startup

    def _check_and_install_dependencies(self):
        """Check and install missing dependencies using the robust installer"""
        try:
            from .dependency_installer import DependencyInstaller
            
            installer = DependencyInstaller(self.iface, self.REQUIRED_PACKAGES)
            installer.PLUGIN_NAME = "Supervised Classifier"
            
            if installer.check_and_install(silent_if_ok=True):
                # Dependencies installed successfully - try to reload them
                if self._reload_dependencies():
                    self.dependencies_ok = True
                    QMessageBox.information(
                        self.iface.mainWindow(),
                        "Ready to Use",
                        "Dependencies installed successfully!\n\n"
                        "The plugin is now ready to use."
                    )
                else:
                    # Reload failed - need restart
                    self.dependencies_ok = False
            else:
                self.dependencies_ok = False
        except Exception as e:
            from qgis.core import Qgis, QgsMessageLog
            QgsMessageLog.logMessage(
                f"Failed to check dependencies: {str(e)}",
                "Supervised Classifier",
                Qgis.Warning
            )
            self.dependencies_ok = False
    
    def _reload_dependencies(self):
        """
        Try to reload/import all required dependencies after installation.
        Returns True if all imports succeed, False otherwise.
        """
        global DEPENDENCIES_AVAILABLE
        global ClassificationDialog, ogr, gdal
        global RandomForestClassifier, SVC, KNeighborsClassifier, LabelEncoder
        global GaussianNB, LinearDiscriminantAnalysis, QuadraticDiscriminantAnalysis
        global cdist, np, joblib, pd
        
        try:
            # Import all required modules
            from .classification_dialog import ClassificationDialog as CD
            from osgeo import ogr as _ogr, gdal as _gdal
            from sklearn.ensemble import RandomForestClassifier as RFC
            from sklearn.svm import SVC as _SVC
            from sklearn.neighbors import KNeighborsClassifier as KNC
            from sklearn.preprocessing import LabelEncoder as LE
            from sklearn.naive_bayes import GaussianNB as _GNB
            from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as _LDA
            from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis as _QDA
            from scipy.spatial.distance import cdist as _cdist
            import numpy as _np
            import joblib as _joblib
            import pandas as _pd
            
            # Update global references
            ClassificationDialog = CD
            ogr = _ogr
            gdal = _gdal
            RandomForestClassifier = RFC
            SVC = _SVC
            KNeighborsClassifier = KNC
            LabelEncoder = LE
            GaussianNB = _GNB
            LinearDiscriminantAnalysis = _LDA
            QuadraticDiscriminantAnalysis = _QDA
            cdist = _cdist
            np = _np
            joblib = _joblib
            pd = _pd
            
            DEPENDENCIES_AVAILABLE = True
            
            from qgis.core import Qgis, QgsMessageLog
            QgsMessageLog.logMessage(
                "Successfully reloaded all dependencies",
                "Supervised Classifier",
                Qgis.Success
            )
            return True
            
        except ImportError as e:
            from qgis.core import Qgis, QgsMessageLog
            QgsMessageLog.logMessage(
                f"Failed to reload dependencies: {str(e)}",
                "Supervised Classifier",
                Qgis.Warning
            )
            return False

    def tr(self, message):
        return QCoreApplication.translate('SupervisedClassification', message)

    def add_action(self, icon_path, text, callback, enabled_flag=True, add_to_menu=True, add_to_toolbar=True, status_tip=None, whats_this=None, parent=None):
        icon = QIcon(icon_path)
        action = QAction(icon, text, parent)
        action.triggered.connect(callback)
        action.setEnabled(enabled_flag)

        if add_to_toolbar:
            self.toolbar.addAction(action)
        if add_to_menu:
            self.iface.addPluginToRasterMenu(self.menu, action)
        self.actions.append(action)
        return action

    def initGui(self):
        icon_path = ':/supervised.png'
        self.toolbar = self.iface.mainWindow().findChild(QToolBar, 'MASRasterProcessingToolbar')
        if self.toolbar is None:
            self.toolbar = self.iface.addToolBar(u'MAS Raster Processing')
            self.toolbar.setObjectName('MASRasterProcessingToolbar')

        self.action_SpvClassification = QAction(QIcon(icon_path), u"&Supervised Classifier", self.iface.mainWindow())
        self.action_SpvClassification.triggered.connect(self.run)
        self.iface.addPluginToRasterMenu(self.menu, self.action_SpvClassification)
        self.toolbar.addAction(self.action_SpvClassification)
        self.actions.append(self.action_SpvClassification)

    def unload(self):
        """Properly clean up plugin UI to avoid duplicate icons on reload"""
        for action in self.actions:
            # Remove from menu
            self.iface.removePluginRasterMenu(self.menu, action)
            # Remove from toolbar
            if self.toolbar:
                self.toolbar.removeAction(action)
        
        # Clear actions list
        self.actions = []
        
        # Only delete toolbar if we created it and it's empty
        if self.toolbar and self.toolbar.actions() == []:
            self.iface.mainWindow().removeToolBar(self.toolbar)
            self.toolbar = None

    def run(self):
        # Check if dependencies are available
        if not self.dependencies_ok:
            # Try to reload dependencies first (they might have been installed)
            if self._reload_dependencies():
                self.dependencies_ok = True
            else:
                # Still not available - prompt for installation
                self._check_and_install_dependencies()
                if not self.dependencies_ok:
                    return
        
        self.show_classification_dialog()

    def show_classification_dialog(self):
        self.dialog = ClassificationDialog()
        self.dialog.classify_signal.connect(self.classify)
        self.dialog.on_method_change()
        self.dialog.exec_()

    def update_progress(self, step, total_steps):
        progress = int((step / total_steps) * 100)
        self.dialog.progress_bar.setValue(progress)
        QCoreApplication.processEvents()

    def classify(self, selected_rasters, selected_references, output_folder, method, open_in_qgis, save_model, num_iterations, use_pretrained_model, show_accuracy=True, export_training_csv=False, training_csv_path="", collect_from_all_rasters=False):
        """
        Handle batch classification for multiple rasters and references
        """
        if not selected_rasters:
            QMessageBox.critical(None, "Missing Input", "Please select at least one raster image.")
            return
        
        if not selected_references:
            QMessageBox.critical(None, "Missing Input", "Please select at least one reference file.")
            return
        
        # Warning for multiple pre-trained models
        if use_pretrained_model and len(selected_references) > 1:
            msg_box = QMessageBox()
            msg_box.setIcon(QMessageBox.Warning)
            msg_box.setWindowTitle("Multiple Models Selected")
            msg_box.setText(f"You have selected {len(selected_references)} pre-trained models.")
            msg_box.setInformativeText(
                "Each raster will be classified using ALL selected models.\n\n"
                f"Total outputs: {len(selected_rasters)} rasters × {len(selected_references)} models = "
                f"{len(selected_rasters) * len(selected_references)} files.\n\n"
                "Do you want to continue?"
            )
            msg_box.setStandardButtons(QMessageBox.Yes | QMessageBox.No)
            msg_box.setDefaultButton(QMessageBox.No)
            
            if msg_box.exec_() == QMessageBox.No:
                return
        
        # CRS Analysis Check (skip for pre-trained models)
        if not use_pretrained_model and not training_csv_path:
            crs_analysis = self.analyze_all_crs(selected_rasters, selected_references)
            if crs_analysis['has_mismatch']:
                dialog = CRSAnalysisDialog(crs_analysis)
                if dialog.exec_() != QDialog.Accepted:
                    self.dialog.show_processing_info("Classification cancelled due to CRS mismatch.")
                    return
        
        start_time = time.time()
        processed_count = 0
        failed_count = 0
        last_label_mappings = None
        last_accuracy_metrics = None
        last_class_stats = None
        outputs_to_load = []  # Collect outputs to load after processing completes
        
        try:
            if use_pretrained_model:
                # ========== PRE-TRAINED MODEL MODE ==========
                # Process: Model 1 → All Rasters → Model 2 → All Rasters
                total_items = len(selected_references) * len(selected_rasters)
                current_item = 0
                
                for model_idx, ref_info in enumerate(selected_references):
                    model_path = ref_info['path']
                    model_name = os.path.basename(model_path)
                    
                    print(f"\n{'='*70}")
                    print(f"LOADING MODEL {model_idx + 1}/{len(selected_references)}: {model_name}")
                    print(f"{'='*70}")
                    
                    # Load model once
                    classifier = self.load_model(model_path)
                    if not classifier:
                        print(f"❌ Failed to load model: {model_name}")
                        failed_count += len(selected_rasters)
                        current_item += len(selected_rasters)
                        continue
                    
                    # Get class names once
                    class_names = self.get_class_names_from_model(model_path)
                    if not class_names:
                        print(f"❌ Failed to get class names for: {model_name}")
                        failed_count += len(selected_rasters)
                        current_item += len(selected_rasters)
                        continue
                    
                    print(f"✓ Model loaded successfully")
                    print(f"✓ Classes: {class_names}")
                    
                    # Classify all rasters with this model
                    for raster_idx, raster_info in enumerate(selected_rasters):
                        current_item += 1
                        image_path = raster_info['path']
                        output_name = raster_info['output_name']
                        selected_bands = raster_info['selected_bands']
                        
                        # Modify output name to include model name
                        base_name, ext = os.path.splitext(output_name)
                        model_suffix = os.path.splitext(model_name)[0].replace('trained_model_', '')
                        modified_output_name = f"{base_name}_{model_suffix}{ext}"
                        
                        self.dialog.show_processing_info(
                            f"Model {model_idx + 1}/{len(selected_references)} - "
                            f"Raster {raster_idx + 1}/{len(selected_rasters)}: {raster_info['name']}"
                        )
                        QCoreApplication.processEvents()
                        
                        try:
                            # Determine output folder
                            if output_folder:
                                current_output_folder = output_folder
                            else:
                                current_output_folder = os.path.dirname(image_path)
                            
                            # Prepare test data
                            X_test = self.prepare_test_data(image_path, selected_bands)
                            if X_test.size == 0:
                                raise Exception("No test data")
                            
                            # Predict
                            predictions = classifier.predict(X_test)
                            
                            # Convert to label names
                            label_predictions = np.array([class_names[int(p)] for p in predictions])
                            
                            # Create temp encoder
                            class TempEncoder:
                                def __init__(self, classes):
                                    self.classes_ = classes
                            
                            encoder = TempEncoder(class_names)
                            
                            # Save
                            classified_image_path, label_mappings = self.save_classified_image(
                                label_predictions, current_output_folder, image_path, 
                                encoder, "Pretrained", modified_output_name
                            )
                            
                            last_label_mappings = label_mappings
                            
                            # Collect for loading later (after all processing)
                            if open_in_qgis:
                                outputs_to_load.append((classified_image_path, label_mappings))
                            
                            processed_count += 1
                            print(f"✓ Saved: {os.path.basename(classified_image_path)}")
                        
                        except Exception as e:
                            failed_count += 1
                            error_msg = f"Failed {raster_info['name']}: {str(e)}"
                            print(f"❌ {error_msg}")
                            import traceback
                            traceback.print_exc()
                            self.dialog.show_processing_info(error_msg)
                            QCoreApplication.processEvents()
                        
                        # Update progress
                        progress = int((current_item / total_items) * 100)
                        self.dialog.progress_bar.setValue(progress)
                        QCoreApplication.processEvents()
            
            else:
                # ========== TRAINING MODE ==========
                # Process: Raster 1 → All References → Raster 2 → All References
                total_items = len(selected_rasters)
                
                for idx, raster_info in enumerate(selected_rasters):
                    image_path = raster_info['path']
                    output_name = raster_info['output_name']
                    selected_bands = raster_info['selected_bands']
                    
                    self.dialog.show_processing_info(f"Processing {idx + 1}/{total_items}: {raster_info['name']}")
                    QCoreApplication.processEvents()
                    
                    try:
                        # Determine output folder
                        if output_folder:
                            current_output_folder = output_folder
                        else:
                            current_output_folder = os.path.dirname(image_path)
                        
                        # Check if collecting from multiple rasters
                        if collect_from_all_rasters and len(selected_rasters) > 1 and idx == 0:
                            # Only collect once on first raster, then reuse for all
                            self.dialog.show_processing_info(f"Collecting samples from {len(selected_rasters)} rasters...")
                            QCoreApplication.processEvents()
                            X_train, y_train, label_encoder = self.collect_training_from_multiple_rasters(
                                selected_rasters, selected_references, selected_bands
                            )
                            # Store for subsequent rasters
                            self._multi_raster_training = (X_train, y_train, label_encoder)
                        elif collect_from_all_rasters and hasattr(self, '_multi_raster_training'):
                            # Reuse collected training data for subsequent rasters
                            X_train, y_train, label_encoder = self._multi_raster_training
                        elif training_csv_path and os.path.exists(training_csv_path):
                            self.dialog.show_processing_info(f"Loading training data from CSV...")
                            QCoreApplication.processEvents()
                            from .analysis_utils import load_training_from_csv
                            X_train, y_train, label_encoder = load_training_from_csv(training_csv_path)
                        elif len(selected_references) == 1:
                            # Single reference - use simple approach
                            ref_info = selected_references[0]
                            shapefile_path = ref_info['path']
                            label_fields = ref_info['selected_fields']
                            
                            if not label_fields:
                                raise Exception(f"No label fields selected for {ref_info['name']}")
                            
                            X_train, y_train, label_encoder = self.prepare_training_data(
                                shapefile_path, image_path, label_fields, selected_bands
                            )
                        else:
                            # Multiple references - use merged approach with class normalization
                            self.dialog.show_processing_info(f"Merging {len(selected_references)} reference files...")
                            QCoreApplication.processEvents()
                            
                            X_train, y_train, label_encoder, unified_classes = self.normalize_and_merge_training_data(
                                selected_references, image_path, selected_bands
                            )
                            print(f"Unified classes from {len(selected_references)} files: {unified_classes}")
                        
                        if X_train.size == 0 or y_train.size == 0:
                            raise Exception("No valid training data found.")
                        
                        # Check training balance and warn if imbalanced
                        try:
                            from .analysis_utils import check_training_balance
                            balance_info = check_training_balance(y_train, list(label_encoder.classes_))
                            if balance_info['warnings']:
                                for warning in balance_info['warnings']:
                                    print(f"⚠️ {warning}")
                        except Exception as e:
                            print(f"Could not check balance: {e}")
                        
                        classifier = self.get_classifier(method, num_iterations)
                        X_test = self.prepare_test_data(image_path, selected_bands)
                        
                        if X_test.size == 0:
                            raise Exception("No valid test data found.")
                        
                        self.dialog.show_processing_info(f"Training classifier with {X_train.shape[0]} samples...")
                        QCoreApplication.processEvents()
                        
                        if method == "Minimum Distance":
                            predictions = self.classify_minimum_distance((X_train, y_train), X_test)
                        else:
                            # Train classifier once (n_estimators/max_iter is set in constructor)
                            classifier.fit(X_train, y_train)
                            self.update_progress(100, 100)  # Training complete
                            predictions = classifier.predict(X_test)
                        
                        predictions = label_encoder.inverse_transform(predictions)
                        
                        classified_image_path, label_mappings = self.save_classified_image(
                            predictions, current_output_folder, image_path, label_encoder, method, output_name
                        )
                        
                        last_label_mappings = label_mappings
                        
                        # Export training samples if requested
                        if export_training_csv:
                            try:
                                from .analysis_utils import export_training_samples_csv
                                csv_path = os.path.join(current_output_folder, f"training_samples_{os.path.splitext(os.path.basename(image_path))[0]}.csv")
                                export_training_samples_csv(X_train, y_train, label_encoder, csv_path)
                                print(f"Training samples exported to: {csv_path}")
                            except Exception as e:
                                print(f"Warning: Could not export training samples: {e}")
                        
                        # Calculate accuracy metrics if requested
                        if show_accuracy:
                            try:
                                from .analysis_utils import calculate_accuracy_metrics, calculate_class_statistics
                                if method == "Minimum Distance":
                                    y_pred = self.classify_minimum_distance((X_train, y_train), X_train)
                                else:
                                    y_pred = classifier.predict(X_train)
                                
                                last_accuracy_metrics = calculate_accuracy_metrics(
                                    y_train, y_pred, list(label_encoder.classes_)
                                )
                                last_class_stats = calculate_class_statistics(
                                    predictions, label_encoder
                                )
                                
                                # Get feature importance for RF/GB
                                try:
                                    from .analysis_utils import get_feature_importance
                                    band_names = [f"Band {b}" for b in (selected_bands if selected_bands else range(1, X_train.shape[1]+1))]
                                    importance_info = get_feature_importance(classifier, band_names)
                                    if importance_info.get('supported'):
                                        last_accuracy_metrics['feature_importance'] = importance_info
                                        print(f"Top features: {list(importance_info['importances'].items())[:3]}")
                                except Exception as e:
                                    print(f"Could not get feature importance: {e}")
                                
                                print(f"Accuracy: {last_accuracy_metrics['overall_accuracy']*100:.2f}%, Kappa: {last_accuracy_metrics['kappa']:.4f}")
                            except Exception as e:
                                print(f"Warning: Could not calculate accuracy: {e}")
                        
                        if save_model:
                            self.save_model_info(
                                classifier, current_output_folder, method, X_train, y_train, label_encoder
                            )
                        
                        # Collect for loading later (after all processing)
                        if open_in_qgis:
                            outputs_to_load.append((classified_image_path, label_mappings))
                        
                        processed_count += 1
                    
                    except Exception as e:
                        failed_count += 1
                        error_msg = f"Failed to process {raster_info['name']}: {str(e)}"
                        print(f"❌ {error_msg}")
                        import traceback
                        traceback.print_exc()
                        self.dialog.show_processing_info(error_msg)
                        QCoreApplication.processEvents()
                    
                    # Update progress
                    progress = int((idx + 1) / total_items * 100)
                    self.dialog.progress_bar.setValue(progress)
                    QCoreApplication.processEvents()
            
            end_time = time.time()
            elapsed_time = end_time - start_time
            
            if use_pretrained_model:
                total_expected = len(selected_rasters) * len(selected_references)
                info = f"Classification process completed in {elapsed_time:.2f} seconds.\n"
                info += f"Successfully processed: {processed_count}/{total_expected} classifications.\n"
            else:
                info = f"Classification process completed in {elapsed_time:.2f} seconds.\n"
                info += f"Successfully processed: {processed_count}/{len(selected_rasters)} raster(s).\n"
            
            if failed_count > 0:
                info += f"Failed: {failed_count} item(s)."
            
            self.dialog.show_processing_info(info)
            
            # ========== LOAD ALL OUTPUTS WITH SYMBOLOGY ==========
            if outputs_to_load:
                self.dialog.show_processing_info("Loading outputs into QGIS...")
                QCoreApplication.processEvents()
                
                for output_path, mappings in outputs_to_load:
                    try:
                        self.open_output_in_qgis(output_path, mappings)
                    except Exception as e:
                        print(f"Warning: Could not load {output_path}: {e}")
                
                # Force UI refresh after all layers loaded
                QCoreApplication.processEvents()
            
            # Show completion info again
            self.dialog.show_processing_info(info)
            
            # Show appropriate completion message
            if processed_count > 0:
                if use_pretrained_model:
                    QMessageBox.information(
                        None, "Classification Complete", 
                        f"Processed {processed_count} classifications in {elapsed_time:.2f} seconds.\n\n"
                        "Please check the model info files for label mapping details."
                    )
                elif show_accuracy and last_accuracy_metrics:
                    # Show accuracy assessment dialog
                    accuracy_dialog = AccuracyAssessmentDialog(
                        last_accuracy_metrics, 
                        last_class_stats,
                        output_folder if output_folder else os.path.dirname(selected_rasters[0]['path']),
                        method
                    )
                    accuracy_dialog.exec_()
                elif last_label_mappings:
                    msg_box = QMessageBox()
                    msg_box.setIcon(QMessageBox.Information)
                    msg_box.setWindowTitle("Classification Complete")
                    msg_box.setText(f"Processed {processed_count}/{len(selected_rasters)} raster(s) in {elapsed_time:.2f} seconds.")
                    msg_box.setInformativeText("Would you like to view the label mappings?")
                    msg_box.setStandardButtons(QMessageBox.Yes | QMessageBox.No)
                    msg_box.setDefaultButton(QMessageBox.Yes)
                    
                    if msg_box.exec_() == QMessageBox.Yes:
                        mapping_dialog = LabelMappingDialog(last_label_mappings)
                        mapping_dialog.exec_()
                else:
                    QMessageBox.information(None, "Classification Complete", 
                                        f"Processed {processed_count}/{len(selected_rasters)} raster(s) in {elapsed_time:.2f} seconds.")
        
        except Exception as e:
            QMessageBox.critical(None, "Classification Failed", f"Critical error: {str(e)}")
            import traceback
            traceback.print_exc()

    def get_class_names_from_model(self, model_path):
        """Extract class names from model info file"""
        info_path = model_path.replace('.pkl', '_info.txt')
        if not os.path.exists(info_path):
            info_path = model_path.replace('trained_model_', 'model_info_').replace('.pkl', '.txt')
        
        if not os.path.exists(info_path):
            print(f"Model info not found: {info_path}")
            return None
        
        try:
            with open(info_path, 'r') as f:
                for line in f:
                    if line.startswith("Labels:"):
                        import ast
                        return ast.literal_eval(line.split("Labels:")[1].strip())
        except Exception as e:
            print(f"Error reading model info: {e}")
        
        return None

    def load_model(self, model_path):
        try:
            return joblib.load(model_path)
        except Exception as e:
            print(f"Model loading error: {e}")
            return None

    def validate_crs_match(self, raster_path, vector_path):
        """Check if raster and vector have matching CRS"""
        try:
            raster_ds = gdal.Open(raster_path)
            vector_ds = ogr.Open(vector_path)
            
            if not raster_ds or not vector_ds:
                return True, "Could not open files for CRS check"
            
            raster_srs = raster_ds.GetProjection()
            layer = vector_ds.GetLayer()
            vector_srs = layer.GetSpatialRef()
            
            if not raster_srs or not vector_srs:
                return True, "CRS information not available"
            
            # Compare using EPSG codes if available
            from osgeo import osr
            raster_sr = osr.SpatialReference()
            raster_sr.ImportFromWkt(raster_srs)
            
            if raster_sr.IsSame(vector_srs):
                return True, "CRS match"
            else:
                raster_epsg = raster_sr.GetAuthorityCode(None) or "Unknown"
                vector_epsg = vector_srs.GetAuthorityCode(None) or "Unknown"
                return False, f"CRS mismatch: Raster EPSG:{raster_epsg}, Vector EPSG:{vector_epsg}"
        except Exception as e:
            return True, f"CRS check failed: {e}"

    def analyze_all_crs(self, selected_rasters, selected_references):
        """Analyze CRS of all selected rasters and reference files"""
        from osgeo import osr
        
        analysis = {
            'rasters': [],
            'vectors': [],
            'target_epsg': None,
            'has_mismatch': False
        }
        
        target_epsg = None
        
        # Analyze rasters
        for raster_info in selected_rasters:
            try:
                ds = gdal.Open(raster_info['path'])
                if ds:
                    srs_wkt = ds.GetProjection()
                    sr = osr.SpatialReference()
                    sr.ImportFromWkt(srs_wkt)
                    epsg = sr.GetAuthorityCode(None) or "Unknown"
                    
                    if target_epsg is None:
                        target_epsg = epsg
                        analysis['target_epsg'] = epsg
                    
                    # Compare EPSG codes directly (more reliable than IsSame)
                    matches = (epsg == target_epsg) or (epsg == "Unknown") or (target_epsg == "Unknown")
                    if not matches:
                        analysis['has_mismatch'] = True
                    
                    analysis['rasters'].append({
                        'name': raster_info['name'],
                        'path': raster_info['path'],
                        'type': 'Raster',
                        'epsg': f"EPSG:{epsg}",
                        'matches_target': matches
                    })
                    ds = None
            except Exception as e:
                analysis['rasters'].append({
                    'name': raster_info['name'],
                    'path': raster_info['path'],
                    'type': 'Raster',
                    'epsg': f"Error: {e}",
                    'matches_target': True
                })
        
        # Analyze vectors
        for ref_info in selected_references:
            try:
                ds = ogr.Open(ref_info['path'])
                if ds:
                    layer = ds.GetLayer()
                    sr = layer.GetSpatialRef()
                    epsg = sr.GetAuthorityCode(None) if sr else "Unknown"
                    
                    # Compare EPSG codes directly
                    matches = (epsg == target_epsg) or (epsg == "Unknown") or (target_epsg is None)
                    if not matches:
                        analysis['has_mismatch'] = True
                    
                    analysis['vectors'].append({
                        'name': ref_info['name'],
                        'path': ref_info['path'],
                        'type': 'Vector',
                        'epsg': f"EPSG:{epsg}" if epsg else "Unknown",
                        'matches_target': matches
                    })
                    ds = None
            except Exception as e:
                analysis['vectors'].append({
                    'name': ref_info['name'],
                    'path': ref_info['path'],
                    'type': 'Vector',
                    'epsg': f"Error: {e}",
                    'matches_target': True
                })
        
        return analysis

    def collect_training_from_multiple_rasters(self, selected_rasters, selected_references, selected_bands=None):
        """
        Collect training samples from multiple rasters.
        For each reference feature, extract samples from all rasters that intersect it.
        
        Args:
            selected_rasters: List of raster info dicts
            selected_references: List of reference info dicts
            selected_bands: Optional band selection
            
        Returns:
            Tuple of (X_train, y_train, label_encoder) with merged samples
        """
        from sklearn.preprocessing import LabelEncoder
        
        all_features = []
        all_labels = []
        
        for raster_info in selected_rasters:
            image_path = raster_info['path']
            bands = raster_info.get('selected_bands', selected_bands)
            
            self.dialog.show_processing_info(f"Extracting samples from: {raster_info['name']}")
            QCoreApplication.processEvents()
            
            for ref_info in selected_references:
                shapefile_path = ref_info['path']
                label_fields = ref_info.get('selected_fields', [])
                
                if not label_fields:
                    continue
                
                try:
                    X, y, _ = self.prepare_training_data(
                        shapefile_path, image_path, label_fields, bands
                    )
                    
                    if X.size > 0:
                        all_features.append(X)
                        all_labels.extend(y)
                        print(f"Collected {len(y)} samples from {raster_info['name']}")
                except Exception as e:
                    print(f"Could not extract from {raster_info['name']}: {e}")
        
        if not all_features:
            raise Exception("No training samples could be extracted from any raster")
        
        # Merge all samples
        X_train = np.vstack(all_features)
        y_labels = np.array(all_labels)
        
        # Create unified label encoder
        label_encoder = LabelEncoder()
        y_train = label_encoder.fit_transform(y_labels)
        
        print(f"Total merged samples: {len(y_train)} from {len(selected_rasters)} rasters")
        print(f"Classes: {list(label_encoder.classes_)}")
        
        return X_train, y_train, label_encoder

    def get_classifier(self, method, num_iterations):
        if method == "Random Forest":
            return RandomForestClassifier(n_estimators=num_iterations, n_jobs=-1, random_state=42)
        elif method == "Gradient Boosting":
            return GradientBoostingClassifier(n_estimators=num_iterations, random_state=42)
        elif method == "SVM (Support Vector Machine)":
            return SVC(max_iter=num_iterations, random_state=42)
        elif method == "KNN (K-Nearest Neighbors)":
            # n_neighbors=5, distance-weighted
            # n_jobs available in sklearn 0.24+, use try-except for compatibility
            try:
                return KNeighborsClassifier(n_neighbors=5, weights='distance', n_jobs=-1)
            except TypeError:
                return KNeighborsClassifier(n_neighbors=5, weights='distance')
        elif method == "Maximum Likelihood":
            # GaussianNB approximates MLC (assumes feature independence)
            return GaussianNB()
        elif method == "LDA (Linear Discriminant Analysis)":
            # Linear Discriminant Analysis (shared covariance)
            return LinearDiscriminantAnalysis()
        elif method == "QDA (Quadratic Discriminant Analysis)":
            # Quadratic Discriminant Analysis (full covariance per class)
            return QuadraticDiscriminantAnalysis()
        elif method == "Minimum Distance":
            return None
        else:
            raise Exception("Invalid method selected")

    def classify_minimum_distance(self, train_data, X_test):
        X_train, y_train = train_data
        class_means = np.array([X_train[y_train == k].mean(axis=0) for k in np.unique(y_train)])
        return np.argmin(cdist(X_test, class_means), axis=1)

    def normalize_and_merge_training_data(self, selected_references, image_path, selected_bands=None):
        """
        Merge training data from multiple reference files with class normalization.
        
        Class normalization rules:
        - Multiple columns selected (binary 0/1): column names become class labels
        - Single column: column values become class labels
        - Same class name/value across files: treated as same class
        """
        image_ds = gdal.Open(image_path)
        if not image_ds:
            raise Exception(f"Failed to open image: {image_path}")

        geotransform = image_ds.GetGeoTransform()
        x_origin, pixel_width, _, y_origin, _, pixel_height = geotransform
        x_min, x_max = 0, image_ds.RasterXSize
        y_min, y_max = 0, image_ds.RasterYSize

        # Get bands
        if selected_bands:
            bands = [image_ds.GetRasterBand(i) for i in selected_bands]
        else:
            bands = [image_ds.GetRasterBand(i+1) for i in range(image_ds.RasterCount)]

        if not bands:
            raise Exception("Failed to get raster bands")

        # Read all band data at once for faster access
        band_arrays = [band.ReadAsArray() for band in bands]
        nodata_values = [band.GetNoDataValue() for band in bands]

        all_X_train = []
        all_y_train = []
        all_classes = set()

        # Process each reference file
        for ref_info in selected_references:
            shapefile_path = ref_info['path']
            label_fields = ref_info['selected_fields']
            
            if not label_fields:
                print(f"Skipping {ref_info['name']}: No label fields selected")
                continue

            shapefile_ds = ogr.Open(shapefile_path)
            if not shapefile_ds:
                print(f"Skipping {ref_info['name']}: Failed to open")
                continue

            layer = shapefile_ds.GetLayer()
            if not layer:
                continue

            # Determine if binary mode (multiple columns) or value mode (single column)
            if isinstance(label_fields, str):
                label_fields = [label_fields]
            
            use_column_names_as_labels = len(label_fields) > 1  # Binary mode

            for feature in layer:
                geom = feature.GetGeometryRef()
                if not geom:
                    continue
                
                # Get label based on mode
                if use_column_names_as_labels:
                    # Multiple columns - binary mode: column name = class label
                    label = None
                    for field_name in label_fields:
                        val = feature.GetField(field_name)
                        if val == 1 or val == "1" or val == True:
                            label = field_name
                            break
                    if not label:
                        continue
                else:
                    # Single column - value mode: column value = class label
                    label = feature.GetField(label_fields[0])
                    if label is None:
                        continue
                    # Convert to string for consistent comparison
                    label = str(label)
                
                all_classes.add(label)
                
                # Get geometry type and extract sample points
                geom_type = geom.GetGeometryType()
                sample_points = []
                
                # Point geometry
                if geom_type in [ogr.wkbPoint, ogr.wkbPoint25D, ogr.wkbMultiPoint, ogr.wkbMultiPoint25D]:
                    if geom_type in [ogr.wkbMultiPoint, ogr.wkbMultiPoint25D]:
                        for i in range(geom.GetGeometryCount()):
                            pt = geom.GetGeometryRef(i)
                            sample_points.append((pt.GetX(), pt.GetY()))
                    else:
                        sample_points.append((geom.GetX(), geom.GetY()))
                
                # Line or Polygon geometry
                else:
                    env = geom.GetEnvelope()
                    px_start = max(0, int((env[0] - x_origin) / pixel_width))
                    px_end = min(x_max, int((env[1] - x_origin) / pixel_width) + 1)
                    py_start = max(0, int((env[3] - y_origin) / pixel_height))
                    py_end = min(y_max, int((env[2] - y_origin) / pixel_height) + 1)
                    
                    for px in range(px_start, px_end):
                        for py in range(py_start, py_end):
                            pixel_x = x_origin + (px + 0.5) * pixel_width
                            pixel_y = y_origin + (py + 0.5) * pixel_height
                            
                            point = ogr.Geometry(ogr.wkbPoint)
                            point.AddPoint(pixel_x, pixel_y)
                            
                            if geom_type in [ogr.wkbLineString, ogr.wkbLineString25D, 
                                            ogr.wkbMultiLineString, ogr.wkbMultiLineString25D]:
                                if geom.Distance(point) < abs(pixel_width) * 0.5:
                                    sample_points.append((px, py, True))
                            else:
                                if geom.Contains(point) or geom.Intersects(point):
                                    sample_points.append((px, py, True))
                
                # Extract pixel values
                for sample in sample_points:
                    if len(sample) == 3:
                        pixel_x, pixel_y, _ = sample
                    else:
                        x, y = sample
                        pixel_x = int((x - x_origin) / pixel_width)
                        pixel_y = int((y - y_origin) / pixel_height)
                    
                    if not (x_min <= pixel_x < x_max and y_min <= pixel_y < y_max):
                        continue
                    
                    pixel_values = []
                    valid = True
                    for i, arr in enumerate(band_arrays):
                        val = arr[int(pixel_y), int(pixel_x)]
                        if nodata_values[i] is not None and val == nodata_values[i]:
                            valid = False
                            break
                        pixel_values.append(val)
                    
                    if valid and len(pixel_values) == len(bands):
                        all_X_train.append(pixel_values)
                        all_y_train.append(label)

            print(f"Processed {ref_info['name']}: {layer.GetFeatureCount()} features")

        # Convert to numpy arrays
        X_train = np.array(all_X_train)
        y_train = np.array(all_y_train)

        if X_train.size == 0 or y_train.size == 0:
            return X_train, y_train, LabelEncoder(), list(all_classes)

        # Encode labels
        label_encoder = LabelEncoder()
        y_train = label_encoder.fit_transform(y_train)
        
        print(f"Merged training data: {X_train.shape}")
        print(f"Unified classes ({len(all_classes)}): {list(label_encoder.classes_)}")

        return X_train, y_train, label_encoder, list(label_encoder.classes_)

    def prepare_training_data(self, shapefile_path, image_path, label_fields, selected_bands=None):
        """Prepare training data with support for multiple label columns and geometry types (Point, Line, Polygon)"""
        shapefile_ds = ogr.Open(shapefile_path)
        if not shapefile_ds:
            raise Exception(f"Failed to open shapefile: {shapefile_path}")

        image_ds = gdal.Open(image_path)
        if not image_ds:
            raise Exception(f"Failed to open image: {image_path}")

        layer = shapefile_ds.GetLayer()
        if not layer:
            raise Exception("Failed to get layer from shapefile")

        geotransform = image_ds.GetGeoTransform()
        x_origin, pixel_width, _, y_origin, _, pixel_height = geotransform
        x_min, x_max = 0, image_ds.RasterXSize
        y_min, y_max = 0, image_ds.RasterYSize

        # Get bands
        if selected_bands:
            bands = [image_ds.GetRasterBand(i) for i in selected_bands]
        else:
            bands = [image_ds.GetRasterBand(i+1) for i in range(image_ds.RasterCount)]

        if not bands:
            raise Exception("Failed to get raster bands")

        # Read all band data at once for faster access
        band_arrays = [band.ReadAsArray() for band in bands]
        nodata_values = [band.GetNoDataValue() for band in bands]

        X_train, y_train = [], []
        label_encoder = LabelEncoder()
        
        # Handle multiple label fields
        if isinstance(label_fields, str):
            label_fields = [label_fields]
        
        multiple_labels = len(label_fields) > 1

        for feature in layer:
            geom = feature.GetGeometryRef()
            if not geom:
                continue
            
            # Get label
            if multiple_labels:
                label = None
                for field_name in label_fields:
                    if feature.GetField(field_name) == 1:
                        label = field_name
                        break
                if not label:
                    continue
            else:
                label = feature.GetField(label_fields[0])
                if label is None:
                    continue
            
            # Get geometry type and extract sample points
            geom_type = geom.GetGeometryType()
            sample_points = []
            
            # Point geometry - use the point directly
            if geom_type in [ogr.wkbPoint, ogr.wkbPoint25D, ogr.wkbMultiPoint, ogr.wkbMultiPoint25D]:
                if geom_type in [ogr.wkbMultiPoint, ogr.wkbMultiPoint25D]:
                    for i in range(geom.GetGeometryCount()):
                        pt = geom.GetGeometryRef(i)
                        sample_points.append((pt.GetX(), pt.GetY()))
                else:
                    sample_points.append((geom.GetX(), geom.GetY()))
            
            # Line or Polygon geometry - sample pixels along/within the geometry
            else:
                # Get envelope (bounding box) of the geometry
                env = geom.GetEnvelope()  # (minX, maxX, minY, maxY)
                
                # Calculate pixel range to check
                px_start = max(0, int((env[0] - x_origin) / pixel_width))
                px_end = min(x_max, int((env[1] - x_origin) / pixel_width) + 1)
                py_start = max(0, int((env[3] - y_origin) / pixel_height))  # Note: pixel_height is negative
                py_end = min(y_max, int((env[2] - y_origin) / pixel_height) + 1)
                
                # Check each pixel in the bounding box
                for px in range(px_start, px_end):
                    for py in range(py_start, py_end):
                        # Get center of pixel
                        pixel_x = x_origin + (px + 0.5) * pixel_width
                        pixel_y = y_origin + (py + 0.5) * pixel_height
                        
                        # Create point and check if it's within/on the geometry
                        point = ogr.Geometry(ogr.wkbPoint)
                        point.AddPoint(pixel_x, pixel_y)
                        
                        # For lines, use distance check (within half pixel)
                        if geom_type in [ogr.wkbLineString, ogr.wkbLineString25D, 
                                        ogr.wkbMultiLineString, ogr.wkbMultiLineString25D]:
                            if geom.Distance(point) < abs(pixel_width) * 0.5:
                                sample_points.append((px, py, True))  # Already pixel coords
                        else:
                            # For polygons, check if point is within
                            if geom.Contains(point) or geom.Intersects(point):
                                sample_points.append((px, py, True))  # Already pixel coords
            
            # Extract pixel values for each sample point
            for sample in sample_points:
                if len(sample) == 3:  # Already pixel coordinates
                    pixel_x, pixel_y, _ = sample
                else:  # World coordinates, need to convert
                    x, y = sample
                    pixel_x = int((x - x_origin) / pixel_width)
                    pixel_y = int((y - y_origin) / pixel_height)
                
                if not (x_min <= pixel_x < x_max and y_min <= pixel_y < y_max):
                    continue
                
                pixel_values = []
                valid = True
                for i, arr in enumerate(band_arrays):
                    val = arr[int(pixel_y), int(pixel_x)]
                    if nodata_values[i] is not None and val == nodata_values[i]:
                        valid = False
                        break
                    pixel_values.append(val)
                
                if valid and len(pixel_values) == len(bands):
                    X_train.append(pixel_values)
                    y_train.append(label)

        X_train = np.array(X_train)
        y_train = np.array(y_train)

        if X_train.size == 0 or y_train.size == 0:
            return X_train, y_train, label_encoder

        y_train = label_encoder.fit_transform(y_train)
        
        print(f"Training data: {X_train.shape}, Labels: {list(label_encoder.classes_)}")

        return X_train, y_train, label_encoder

    def prepare_test_data(self, image_path, selected_bands=None):
        """Prepare test data from raster image - optimized for performance"""
        image_ds = gdal.Open(image_path)
        if not image_ds:
            raise Exception(f"Failed to open image: {image_path}")

        if selected_bands:
            band_indices = selected_bands
        else:
            band_indices = list(range(1, image_ds.RasterCount + 1))

        if not band_indices:
            raise Exception("Failed to get raster bands")

        # OPTIMIZED: Read entire bands at once instead of pixel-by-pixel
        # This is O(bands) GDAL calls instead of O(rows × cols × bands)
        band_arrays = []
        for band_idx in band_indices:
            band = image_ds.GetRasterBand(band_idx)
            if band is None:
                raise Exception(f"Failed to read band {band_idx}")
            band_data = band.ReadAsArray()
            if band_data is None:
                raise Exception(f"Failed to read data from band {band_idx}")
            band_arrays.append(band_data.flatten())

        # Stack bands as columns: each row is a pixel, each column is a band
        X_test = np.column_stack(band_arrays)

        return X_test

    def save_classified_image(self, predictions, output_folder, image_path, label_encoder, method, output_name=None):
        image_ds = gdal.Open(image_path)
        if not image_ds:
            raise Exception(f"Failed to open image: {image_path}")

        driver = gdal.GetDriverByName("GTiff")
        
        method_abbr = {
            "Minimum Distance": "MinDist", "Random Forest": "RF", 
            "Gradient Boosting": "GB",
            "SVM (Support Vector Machine)": "SVM",
            "KNN (K-Nearest Neighbors)": "KNN", "Maximum Likelihood": "MaxLik", 
            "LDA (Linear Discriminant Analysis)": "LDA", 
            "QDA (Quadratic Discriminant Analysis)": "QDA"
        }.get(method, method[:3].upper())
        
        if output_name:
            output_path = os.path.join(output_folder, output_name)
        else:
            base_name = os.path.splitext(os.path.basename(image_path))[0]
            output_path = os.path.join(output_folder, f"{base_name}_Classified_{method_abbr}.tif")
        
        output_ds = driver.Create(output_path, image_ds.RasterXSize, image_ds.RasterYSize, 1, gdal.GDT_Int32)
        output_ds.SetGeoTransform(image_ds.GetGeoTransform())
        output_ds.SetProjection(image_ds.GetProjection())
        output_band = output_ds.GetRasterBand(1)
        
        # Create label to int mapping
        if isinstance(label_encoder, dict):
            label_to_int = label_encoder
        else:
            label_to_int = {label: i + 1 for i, label in enumerate(label_encoder.classes_)}
        
        # Convert predictions to raster values
        classified_image = np.vectorize(label_to_int.get)(predictions).reshape(
            image_ds.RasterYSize, image_ds.RasterXSize
        )
        output_band.WriteArray(classified_image)
        output_band.FlushCache()
        output_ds = None
        
        # Create label mappings
        if isinstance(label_encoder, dict):
            label_mappings = label_encoder
        else:
            label_mappings = {label: i + 1 for i, label in enumerate(label_encoder.classes_)}

        return output_path, label_mappings

    def open_output_in_qgis(self, classified_image_path, label_mappings=None):
        """Load classified image into QGIS with descriptive layer name"""
        # Create descriptive layer name from output filename
        base_name = os.path.splitext(os.path.basename(classified_image_path))[0]
        layer_name = base_name.replace("_", " ")
        
        layer = QgsRasterLayer(classified_image_path, layer_name)
        if not layer.isValid():
            raise Exception("Failed to load classified image in QGIS.")
        
        QgsProject.instance().addMapLayer(layer)

    def show_processing_info(self, info):
        self.dialog.processing_info_label.setText(info)

    def save_model_info(self, classifier, output_folder, method, X_train, y_train, label_encoder):
        # Use short names for model files
        suffix_map = {
            "Minimum Distance": "MinDist", "Random Forest": "RF", 
            "Gradient Boosting": "GB",
            "SVM (Support Vector Machine)": "SVM",
            "KNN (K-Nearest Neighbors)": "KNN", "Maximum Likelihood": "MaxLik",
            "LDA (Linear Discriminant Analysis)": "LDA",
            "QDA (Quadratic Discriminant Analysis)": "QDA"
        }
        suffix = suffix_map.get(method, method.replace(" ", "_").lower())
        model_path = os.path.join(output_folder, f"trained_model_{suffix}.pkl")
        joblib.dump(classifier, model_path)

        from sklearn.metrics import accuracy_score, classification_report
        
        if method == "Minimum Distance":
            y_pred = self.classify_minimum_distance((X_train, y_train), X_train)
        else:
            y_pred = classifier.predict(X_train)
        
        accuracy = accuracy_score(y_train, y_pred)
        report = classification_report(y_train, y_pred, target_names=label_encoder.classes_)

        model_info_path = os.path.join(output_folder, f"model_info_{suffix}.txt")
        class_distribution = pd.Series(y_train).value_counts().to_string()
        label_mappings = {label: i + 1 for i, label in enumerate(label_encoder.classes_)}

        with open(model_info_path, 'w') as file:
            file.write(f"Model Path: {model_path}\n")
            file.write(f"Accuracy: {accuracy}\n")
            file.write(f"Classification Report:\n{report}\n")
            file.write(f"Labels: {list(label_encoder.classes_)}\n")
            file.write(f"Class distribution in training data:\n{class_distribution}\n")
            file.write(f"Label mappings (label to numerical value): {label_mappings}\n")

        print(f"Model saved: {model_path}")
        return accuracy, class_distribution

    def open_output_in_qgis(self, classified_image_path, method="Classified"):
        """Load classified image into QGIS with descriptive layer name"""
        # Create descriptive layer name from output filename
        base_name = os.path.splitext(os.path.basename(classified_image_path))[0]
        layer_name = base_name.replace("_", " ")
        
        layer = QgsRasterLayer(classified_image_path, layer_name)
        if not layer.isValid():
            raise Exception("Failed to load classified image in QGIS.")
        QgsProject.instance().addMapLayer(layer)

    def show_processing_info(self, info):
        self.dialog.processing_info_label.setText(info)
