# -*- coding: utf-8 -*-
import os
import numpy as np
import random
from qgis.PyQt.QtCore import QSettings, QTranslator, qVersion, QCoreApplication, Qt
from qgis.PyQt.QtGui import QIcon, QColor
from qgis.PyQt.QtWidgets import QAction, QFileDialog, QMessageBox, QToolBar
from qgis.core import (QgsProject, QgsRasterLayer, QgsPalettedRasterRenderer,
                       QgsColorRampShader, QgsRasterShader, QgsSingleBandPseudoColorRenderer)
from osgeo import gdal, osr
from .classify_dialog import UnsupervisedClassifierDialog, ClusterStatisticsDialog
from . import resources_rc

# Suppress all warnings
import warnings
warnings.filterwarnings('ignore')

# ========== DEPENDENCY DETECTION ==========
# sklearn is OPTIONAL - plugin works without it using numpy implementations
sklearn_available = False
try:
    from sklearn.cluster import KMeans as SklearnKMeans
    sklearn_available = True
except ImportError:
    pass

# scipy is OPTIONAL - used for optimized distance calculations
scipy_available = False
try:
    from scipy.spatial.distance import cdist as scipy_cdist
    scipy_available = True
except ImportError:
    pass


# ========== PURE NUMPY IMPLEMENTATIONS (OPTIMIZED FOR LARGE RASTERS) ==========

def compute_distances_to_centroids(data, centroids, chunk_size=50000):
    """
    Compute distances from all data points to all centroids.
    Uses chunked processing to avoid memory issues with large rasters.
    Returns: (n_samples,) array of indices of nearest centroid
    """
    n_samples = data.shape[0]
    n_clusters = centroids.shape[0]
    
    labels = np.empty(n_samples, dtype=np.int32)
    
    # Process in chunks to avoid memory explosion
    for start in range(0, n_samples, chunk_size):
        end = min(start + chunk_size, n_samples)
        chunk = data[start:end]
        
        # Compute squared distances efficiently: ||a-b||^2 = ||a||^2 + ||b||^2 - 2*a.b
        # This avoids creating a 3D array
        chunk_sq = np.sum(chunk ** 2, axis=1, keepdims=True)  # (chunk_size, 1)
        cent_sq = np.sum(centroids ** 2, axis=1)  # (n_clusters,)
        cross = chunk @ centroids.T  # (chunk_size, n_clusters)
        
        distances_sq = chunk_sq + cent_sq - 2 * cross  # broadcasting
        distances_sq = np.maximum(distances_sq, 0)  # numerical stability
        
        labels[start:end] = np.argmin(distances_sq, axis=1)
    
    return labels


def compute_min_distances_sq(data, centroids, chunk_size=50000):
    """
    Compute minimum squared distance from each point to any centroid.
    Used for K-means++ initialization.
    """
    n_samples = data.shape[0]
    min_dist_sq = np.empty(n_samples)
    
    for start in range(0, n_samples, chunk_size):
        end = min(start + chunk_size, n_samples)
        chunk = data[start:end]
        
        chunk_sq = np.sum(chunk ** 2, axis=1, keepdims=True)
        cent_sq = np.sum(centroids ** 2, axis=1)
        cross = chunk @ centroids.T
        
        distances_sq = chunk_sq + cent_sq - 2 * cross
        distances_sq = np.maximum(distances_sq, 0)
        
        min_dist_sq[start:end] = np.min(distances_sq, axis=1)
    
    return min_dist_sq


def small_cdist(A, B):
    """
    Compute pairwise distances for small arrays (e.g., centroid-to-centroid).
    Only use for small arrays like cluster centroids, not full data!
    """
    A_sq = np.sum(A ** 2, axis=1, keepdims=True)
    B_sq = np.sum(B ** 2, axis=1)
    cross = A @ B.T
    distances_sq = A_sq + B_sq - 2 * cross
    distances_sq = np.maximum(distances_sq, 0)
    return np.sqrt(distances_sq)


def numpy_kmeans(data, n_clusters, max_iter=300, n_init=3, tol=1e-4):
    """
    Pure numpy K-means implementation optimized for large rasters.
    Uses chunk-based distance computation to handle satellite imagery.
    """
    n_samples, n_features = data.shape
    best_labels = None
    best_inertia = np.inf
    
    # Reduce iterations for very large datasets
    if n_samples > 500000:
        n_init = 1
        max_iter = min(max_iter, 100)
    elif n_samples > 100000:
        n_init = 2
        max_iter = min(max_iter, 150)
    
    for init_run in range(n_init):
        # K-means++ initialization (simplified for speed)
        centroids = np.empty((n_clusters, n_features))
        
        # First centroid: random sample
        idx = np.random.randint(n_samples)
        centroids[0] = data[idx]
        
        # Remaining centroids: weighted by distance
        for k in range(1, n_clusters):
            min_dist_sq = compute_min_distances_sq(data, centroids[:k])
            
            # Avoid zero division
            total = min_dist_sq.sum()
            if total == 0:
                idx = np.random.randint(n_samples)
            else:
                probabilities = min_dist_sq / total
                idx = np.random.choice(n_samples, p=probabilities)
            
            centroids[k] = data[idx]
        
        # K-means iterations
        labels = np.zeros(n_samples, dtype=np.int32)
        
        for iteration in range(max_iter):
            # Assign to nearest centroid (chunked)
            new_labels = compute_distances_to_centroids(data, centroids)
            
            # Check convergence
            if np.array_equal(labels, new_labels):
                break
            labels = new_labels
            
            # Update centroids (vectorized)
            new_centroids = np.empty_like(centroids)
            for k in range(n_clusters):
                mask = labels == k
                if np.any(mask):
                    new_centroids[k] = data[mask].mean(axis=0)
                else:
                    # Empty cluster: reinitialize randomly
                    new_centroids[k] = data[np.random.randint(n_samples)]
            
            # Check centroid convergence
            centroid_shift = np.linalg.norm(new_centroids - centroids)
            if centroid_shift < tol:
                break
            centroids = new_centroids
        
        # Calculate inertia (sum of squared distances to assigned centroid)
        min_dist_sq = compute_min_distances_sq(data, centroids)
        inertia = np.sum(min_dist_sq)
        
        if inertia < best_inertia:
            best_inertia = inertia
            best_labels = labels.copy()
    
    return best_labels


def fuzzy_cmeans(data, n_clusters, max_iter=100, fuzziness=2.0, tol=1e-4, chunk_size=50000):
    """
    Fuzzy C-Means clustering (pure numpy implementation).
    Provides soft clustering where each point has a membership degree to each cluster.
    Works efficiently on large rasters using chunked processing.
    
    Parameters:
        fuzziness: >1, higher values = fuzzier clusters (default 2.0)
    """
    n_samples, n_features = data.shape
    
    # Initialize membership matrix randomly
    np.random.seed(42)
    membership = np.random.rand(n_samples, n_clusters)
    membership = membership / membership.sum(axis=1, keepdims=True)  # Normalize rows
    
    # Reduce iterations for large datasets
    if n_samples > 500000:
        max_iter = min(max_iter, 50)
    elif n_samples > 100000:
        max_iter = min(max_iter, 75)
    
    m = fuzziness
    
    for iteration in range(max_iter):
        # Calculate cluster centroids (weighted by membership)
        membership_m = membership ** m
        centroids = np.zeros((n_clusters, n_features))
        
        for k in range(n_clusters):
            weights = membership_m[:, k:k+1]
            centroids[k] = (data * weights).sum(axis=0) / weights.sum()
        
        # Update membership matrix using chunked distance calculation
        new_membership = np.zeros_like(membership)
        
        for start in range(0, n_samples, chunk_size):
            end = min(start + chunk_size, n_samples)
            chunk = data[start:end]
            
            # Compute distances to all centroids
            chunk_sq = np.sum(chunk ** 2, axis=1, keepdims=True)
            cent_sq = np.sum(centroids ** 2, axis=1)
            cross = chunk @ centroids.T
            distances_sq = chunk_sq + cent_sq - 2 * cross
            distances_sq = np.maximum(distances_sq, 1e-10)  # Avoid division by zero
            distances = np.sqrt(distances_sq)
            
            # Update membership for this chunk
            exponent = 2.0 / (m - 1)
            for k in range(n_clusters):
                ratio = distances[:, k:k+1] / distances
                new_membership[start:end, k] = 1.0 / np.sum(ratio ** exponent, axis=1)
        
        # Check convergence
        diff = np.abs(new_membership - membership).max()
        membership = new_membership
        
        if diff < tol:
            break
    
    # Return hard labels (argmax of membership)
    labels = np.argmax(membership, axis=1).astype(np.int32)
    return labels


def minibatch_kmeans(data, n_clusters, max_iter=100, batch_size=10000):
    """
    Mini-Batch K-Means (pure numpy implementation).
    Faster than standard K-means for large datasets by using random batches.
    """
    n_samples, n_features = data.shape
    
    # Initialize centroids randomly
    np.random.seed(42)
    indices = np.random.choice(n_samples, n_clusters, replace=False)
    centroids = data[indices].copy()
    
    # Track centroid update counts for averaging
    counts = np.zeros(n_clusters)
    
    for iteration in range(max_iter):
        # Sample a mini-batch
        batch_indices = np.random.choice(n_samples, min(batch_size, n_samples), replace=False)
        batch = data[batch_indices]
        
        # Assign batch to nearest centroids
        batch_sq = np.sum(batch ** 2, axis=1, keepdims=True)
        cent_sq = np.sum(centroids ** 2, axis=1)
        cross = batch @ centroids.T
        distances_sq = batch_sq + cent_sq - 2 * cross
        distances_sq = np.maximum(distances_sq, 0)
        labels = np.argmin(distances_sq, axis=1)
        
        # Update centroids with streaming average
        for k in range(n_clusters):
            mask = labels == k
            if np.any(mask):
                counts[k] += mask.sum()
                eta = 1.0 / counts[k]  # Learning rate decreases over time
                centroid_update = batch[mask].mean(axis=0)
                centroids[k] = (1 - eta) * centroids[k] + eta * centroid_update
    
    # Final assignment using chunked approach
    labels = compute_distances_to_centroids(data, centroids)
    return labels


def isodata_clustering_fast(data, num_clusters, max_iter=100, max_merge=0.5,
                            min_split_std=0.5, max_std=1.0, min_samples=10):
    """
    Optimized ISODATA clustering algorithm.
    Uses chunked distance calculations for large rasters.
    """
    n_samples = data.shape[0]
    
    # Use sklearn KMeans for initial clustering if available, otherwise numpy
    if sklearn_available:
        model = SklearnKMeans(n_clusters=num_clusters, n_init=3, max_iter=50, random_state=42)
        labels = model.fit_predict(data)
        centroids = model.cluster_centers_.copy()
    else:
        labels = numpy_kmeans(data, num_clusters, max_iter=50, n_init=3)
        # Calculate initial centroids
        centroids = np.array([data[labels == k].mean(axis=0) if np.any(labels == k) 
                              else data[np.random.randint(n_samples)] 
                              for k in range(num_clusters)])
    
    # Limit ISODATA refinement iterations for speed
    isodata_iterations = min(max_iter // 20, 5)
    
    for iteration in range(isodata_iterations):
        unique_labels = np.unique(labels)
        n_clusters_current = len(unique_labels)
        
        # Vectorized cluster statistics calculation
        cluster_means = []
        cluster_stds = []
        cluster_sizes = []
        valid_labels = []
        
        for label in unique_labels:
            mask = labels == label
            cluster_size = np.sum(mask)
            
            if cluster_size >= min_samples:
                cluster_data = data[mask]
                cluster_means.append(np.mean(cluster_data, axis=0))
                cluster_stds.append(np.std(cluster_data, axis=0))
                cluster_sizes.append(cluster_size)
                valid_labels.append(label)
        
        if not cluster_means:
            break
        
        cluster_means = np.array(cluster_means)
        cluster_stds = np.array(cluster_stds)
        cluster_sizes = np.array(cluster_sizes)
        n_valid = len(valid_labels)
        
        # ===== MERGE PHASE (Vectorized) =====
        if n_valid > 2:
            # Calculate all pairwise distances at once (small array - centroids only)
            pairwise_distances = small_cdist(cluster_means, cluster_means)
            np.fill_diagonal(pairwise_distances, np.inf)  # Don't merge with self
            
            merged = set()
            new_centroids = []
            
            for i in range(n_valid):
                if i in merged:
                    continue
                
                # Find closest cluster
                j = np.argmin(pairwise_distances[i])
                
                if j not in merged and pairwise_distances[i, j] < max_merge:
                    # Merge clusters i and j (weighted average)
                    w1, w2 = cluster_sizes[i], cluster_sizes[j]
                    new_mean = (cluster_means[i] * w1 + cluster_means[j] * w2) / (w1 + w2)
                    new_centroids.append(new_mean)
                    merged.add(i)
                    merged.add(j)
                else:
                    new_centroids.append(cluster_means[i])
        else:
            new_centroids = list(cluster_means)
        
        # ===== SPLIT PHASE =====
        final_centroids = []
        for i, mean in enumerate(new_centroids):
            if i < len(cluster_stds):
                max_std_val = np.max(cluster_stds[i])
                if max_std_val > max_std and i < len(cluster_sizes) and cluster_sizes[i] > min_samples * 2:
                    # Split: add two slightly offset centroids
                    offset = cluster_stds[i] * 0.3
                    final_centroids.append(mean + offset)
                    final_centroids.append(mean - offset)
                else:
                    final_centroids.append(mean)
            else:
                final_centroids.append(mean)
        
        if not final_centroids:
            break
        
        # Reassign labels using chunked approach
        centroids = np.array(final_centroids)
        labels = compute_distances_to_centroids(data, centroids)
        
        # Stop if we have enough clusters
        if len(final_centroids) >= num_clusters:
            break
    
    return labels


class UnsupervisedClassifier:
    def __init__(self, iface):
        self.iface = iface
        self.plugin_dir = os.path.dirname(__file__)
        self.actions = []
        self.menu = self.tr(u'&MAS Raster Processing')
        self.toolbar = None
        self.first_start = None

    def tr(self, message):
        return QCoreApplication.translate('UnsupervisedClassifier', message)

    def initGui(self):
        icon_path = ':/cluster.png'
        self.toolbar = self.iface.mainWindow().findChild(QToolBar, 'MASRasterProcessingToolbar')
        if self.toolbar is None:
            self.toolbar = self.iface.addToolBar(u'MAS Raster Processing')
            self.toolbar.setObjectName('MASRasterProcessingToolbar')

        self.action_UnspvClassification = QAction(QIcon(icon_path), u"&Unsupervised Classifier", self.iface.mainWindow())
        self.action_UnspvClassification.setObjectName('UnsupervisedClassifierAction')
        self.action_UnspvClassification.triggered.connect(self.run)
        
        # Check if action already exists in toolbar (avoid duplicates on reload)
        existing_actions = [a.objectName() for a in self.toolbar.actions()]
        if 'UnsupervisedClassifierAction' not in existing_actions:
            self.toolbar.addAction(self.action_UnspvClassification)
        
        self.iface.addPluginToRasterMenu(self.menu, self.action_UnspvClassification)
        self.actions.append(self.action_UnspvClassification)

    def unload(self):
        for action in self.actions:
            self.iface.removePluginMenu(self.tr(u'&MAS Raster Processing'), action)
            # Remove from toolbar properly
            if self.toolbar:
                self.toolbar.removeAction(action)
        self.actions = []
        # Don't delete toolbar - it may be shared with other plugins

    def apply_classification_symbology(self, layer, num_classes):
        """Apply random but visually distinct colors to classified raster"""
        try:
            # Generate visually distinct colors using HSV color space
            classes = []
            for i in range(num_classes):
                # Distribute hues evenly around the color wheel
                hue = int((i * 360 / num_classes) % 360)
                # Use high saturation and value for vivid colors
                color = QColor.fromHsv(hue, 200, 220)
                classes.append(
                    QgsPalettedRasterRenderer.Class(
                        i, color, f"Class {i + 1}"
                    )
                )
            
            renderer = QgsPalettedRasterRenderer(layer.dataProvider(), 1, classes)
            layer.setRenderer(renderer)
            layer.triggerRepaint()
        except Exception as e:
            # Silently fail - symbology is nice to have, not critical
            print(f"Could not apply symbology: {e}")

    def run(self):
        if not hasattr(self, 'dlg'):
            self.dlg = UnsupervisedClassifierDialog(iface=self.iface, parent=self.iface.mainWindow())
            self.dlg.runButton.clicked.connect(self.run_clustering)
        self.dlg.show()
        result = self.dlg.exec_()

    def run_clustering(self):
        # Get selected rasters
        selected_rasters = self.dlg.get_selected_rasters()
        
        if not selected_rasters:
            QMessageBox.warning(self.dlg, "Warning", "No rasters selected. Please add and select rasters to process.")
            return
        
        self.dlg.runButton.setEnabled(False)
        self.dlg.runButton.setText("Processing...")
        
        clustering_method = self.dlg.algorithmComboBox.currentText()
        num_clusters = self.dlg.numClustersSpinBox.value()
        max_iter = self.dlg.maxIterSpinBox.value()
        max_merge = self.dlg.maxMergeDoubleSpinBox.value()
        min_split_std = self.dlg.minSplitStdDoubleSpinBox.value()
        max_std = self.dlg.maxStdDoubleSpinBox.value()
        min_samples = self.dlg.minSamplesSpinBox.value()
        open_in_qgis = self.dlg.openInQgisCheckBox.isChecked()
        export_stats = self.dlg.exportStatsCheckBox.isChecked()
        
        total_files = len(selected_rasters)
        self.dlg.update_progress(0, total_files, "Starting batch processing...")
        
        success_count = 0
        failed_files = []
        
        for idx, raster_info in enumerate(selected_rasters, start=1):
            input_file = raster_info['input']
            output_file = raster_info['output']
            selected_bands = raster_info.get('bands', [])
            file_name = os.path.basename(input_file)
            
            self.dlg.update_progress(idx - 1, total_files, f"Processing ({idx}/{total_files}): {file_name}")
            
            try:
                if not os.path.exists(input_file):
                    failed_files.append(f"{file_name}: File not found")
                    continue
                
                if not selected_bands:
                    failed_files.append(f"{file_name}: No bands selected")
                    continue
                
                output_dir = os.path.dirname(output_file)
                if output_dir and not os.path.exists(output_dir):
                    os.makedirs(output_dir)
                
                success, error_msg = self.process_single_raster(
                    input_file, output_file, clustering_method, num_clusters,
                    selected_bands, max_iter, max_merge, min_split_std,
                    max_std, min_samples, open_in_qgis, export_stats
                )
                
                if success:
                    success_count += 1
                    self.dlg.update_progress(idx, total_files, f"Completed ({idx}/{total_files}): {file_name}")
                else:
                    failed_files.append(f"{file_name}: {error_msg}")
                    
            except Exception as e:
                failed_files.append(f"{file_name}: {str(e)}")
        
        self.dlg.hide_progress()
        self.dlg.runButton.setEnabled(True)
        self.dlg.runButton.setText("Run Classification")
        
        message = f"Successfully processed {success_count} out of {total_files} raster(s)."
        if failed_files:
            message += f"\n\nFailed files:\n" + "\n".join(failed_files[:10])
            if len(failed_files) > 10:
                message += f"\n... and {len(failed_files) - 10} more"
        
        if success_count > 0:
            QMessageBox.information(self.dlg, "Classification Complete", message)
        else:
            QMessageBox.critical(self.dlg, "Classification Failed", message)

    def process_single_raster(self, input_file, output_file, clustering_method, num_clusters,
                             selected_bands, max_iter, max_merge, min_split_std,
                             max_std, min_samples, open_in_qgis, export_stats=False):
        try:
            sat_dataset = gdal.Open(input_file)
            if sat_dataset is None:
                return False, "Could not open file"
            
            actual_band_count = sat_dataset.RasterCount
            valid_bands = [b for b in selected_bands if b <= actual_band_count]
            
            if not valid_bands:
                return False, f"No valid bands (file has {actual_band_count} bands)"
            
            bands_data = [sat_dataset.GetRasterBand(i).ReadAsArray().astype(float) for i in valid_bands]
            
            nrows, ncols = bands_data[0].shape
            reshaped_data = np.stack(bands_data, axis=-1).reshape(-1, len(valid_bands))
            original_data = clean_data(reshaped_data.copy())  # Keep original for stats
            normalized_data = normalize_data(reshaped_data)
            
            try:
                if clustering_method == 'K-means (Pure Numpy)':
                    # Pure numpy implementation - always available
                    labels = numpy_kmeans(normalized_data, num_clusters, max_iter=300, n_init=10)
                
                elif clustering_method == 'Mini-Batch K-means':
                    # Fast numpy implementation for large datasets
                    labels = minibatch_kmeans(normalized_data, num_clusters, max_iter=100, batch_size=10000)
                
                elif clustering_method == 'Fuzzy C-Means':
                    # Soft clustering - pure numpy
                    labels = fuzzy_cmeans(normalized_data, num_clusters, max_iter=100, fuzziness=2.0)
                
                elif clustering_method == 'ISODATA (Optimized)':
                    # Uses numpy with optional sklearn acceleration
                    labels = isodata_clustering_fast(normalized_data, num_clusters, max_iter, 
                                               max_merge, min_split_std, max_std, min_samples)
                
                elif clustering_method == 'K-means (sklearn)':
                    # Uses sklearn if available
                    if sklearn_available:
                        model = SklearnKMeans(n_clusters=num_clusters, n_init=10, max_iter=300, random_state=42)
                        labels = model.fit_predict(normalized_data)
                    else:
                        return False, "sklearn not installed for this method"
                        
                else:
                    return False, f"Unknown clustering method: {clustering_method}"
                    
            except Exception as cluster_error:
                return False, f"Clustering error: {str(cluster_error)}"
            
            # Count actual clusters and compute statistics
            unique_clusters = len(np.unique(labels))
            stats_data = self.compute_cluster_statistics(labels, original_data, valid_bands)
            
            clustered_image = labels.reshape(nrows, ncols).astype(np.uint8)
            
            driver = gdal.GetDriverByName('GTiff')
            out_dataset = driver.Create(output_file, ncols, nrows, 1, gdal.GDT_Byte)
            out_dataset.SetGeoTransform(sat_dataset.GetGeoTransform())
            out_dataset.SetProjection(sat_dataset.GetProjection())
            out_band = out_dataset.GetRasterBand(1)
            out_band.WriteArray(clustered_image)
            out_band.FlushCache()
            out_dataset = None
            sat_dataset = None

            if open_in_qgis:
                layer_name = os.path.splitext(os.path.basename(output_file))[0]
                layer = self.iface.addRasterLayer(output_file, layer_name)
                if layer:
                    self.apply_classification_symbology(layer, unique_clusters)
            
            # Show statistics dialog
            self.show_statistics_dialog(stats_data, output_file)

            return True, f"Success ({unique_clusters} clusters)"

        except Exception as e:
            return False, str(e)
    
    def compute_cluster_statistics(self, labels, data, band_names):
        """Compute cluster statistics and return as list of dicts"""
        stats_data = []
        unique_labels = np.unique(labels)
        total_pixels = len(labels)
        
        for label in unique_labels:
            mask = labels == label
            count = int(np.sum(mask))
            percentage = (count / total_pixels) * 100
            cluster_data = data[mask]
            
            row = {
                "Cluster": label + 1,
                "Pixels": count,
                "Percentage": round(percentage, 2)
            }
            
            # Add band statistics
            for b_idx, b_name in enumerate(band_names):
                row[f"Band{b_name}_Mean"] = round(float(np.mean(cluster_data[:, b_idx])), 4)
                row[f"Band{b_name}_Std"] = round(float(np.std(cluster_data[:, b_idx])), 4)
            
            stats_data.append(row)
        
        return stats_data
    
    def show_statistics_dialog(self, stats_data, output_file):
        """Show the cluster statistics dialog"""
        try:
            dialog = ClusterStatisticsDialog(stats_data, output_file, self.dlg)
            dialog.exec_()
        except Exception as e:
            print(f"Could not show statistics dialog: {e}")
    
    def export_cluster_statistics(self, output_file, labels, data, band_names):
        """Export cluster statistics to CSV file"""
        try:
            stats_file = output_file.replace('.tif', '_stats.csv').replace('.TIF', '_stats.csv')
            unique_labels = np.unique(labels)
            
            with open(stats_file, 'w') as f:
                # Header
                band_headers = [f"Band{i}_Mean,Band{i}_Std" for i in band_names]
                f.write("Cluster,PixelCount,Percentage," + ",".join(band_headers) + "\n")
                
                total_pixels = len(labels)
                for label in unique_labels:
                    mask = labels == label
                    count = np.sum(mask)
                    percentage = (count / total_pixels) * 100
                    cluster_data = data[mask]
                    
                    # Band statistics
                    band_stats = []
                    for b in range(cluster_data.shape[1]):
                        band_stats.append(f"{np.mean(cluster_data[:, b]):.4f}")
                        band_stats.append(f"{np.std(cluster_data[:, b]):.4f}")
                    
                    f.write(f"{label + 1},{count},{percentage:.2f}," + ",".join(band_stats) + "\n")
        except Exception as e:
            print(f"Could not export statistics: {e}")


def clean_data(data):
    """Clean data by replacing NaN and infinite values"""
    return np.nan_to_num(data, nan=0.0, posinf=0.0, neginf=0.0)


def normalize_data(data):
    """Normalize data using standardization (z-score)"""
    std = np.std(data, axis=0)
    std[std == 0] = 1  # Avoid division by zero
    mean = np.mean(data, axis=0)
    normalized = (data - mean) / std
    return normalized
