# -*- coding: utf-8 -*-
"""
Analysis utilities for Supervised Classifier plugin.
Contains functions for accuracy assessment, statistics, and reporting.
"""
import numpy as np
import os


def calculate_accuracy_metrics(y_true, y_pred, class_names=None):
    """
    Calculate comprehensive accuracy metrics.
    
    Args:
        y_true: True labels (encoded)
        y_pred: Predicted labels (encoded)
        class_names: List of class names for reporting
    
    Returns:
        Dictionary with accuracy metrics
    """
    from sklearn.metrics import (
        accuracy_score, 
        confusion_matrix, 
        classification_report,
        cohen_kappa_score,
        precision_score,
        recall_score,
        f1_score
    )
    
    metrics = {
        'overall_accuracy': accuracy_score(y_true, y_pred),
        'kappa': cohen_kappa_score(y_true, y_pred),
        'confusion_matrix': confusion_matrix(y_true, y_pred),
        'precision': precision_score(y_true, y_pred, average='weighted', zero_division=0),
        'recall': recall_score(y_true, y_pred, average='weighted', zero_division=0),
        'f1_score': f1_score(y_true, y_pred, average='weighted', zero_division=0),
    }
    
    # Per-class metrics
    if class_names is not None:
        metrics['classification_report'] = classification_report(
            y_true, y_pred, target_names=class_names, zero_division=0
        )
        metrics['class_names'] = class_names
    else:
        metrics['classification_report'] = classification_report(y_true, y_pred, zero_division=0)
    
    return metrics


def calculate_class_statistics(predictions, label_encoder, pixel_size=None):
    """
    Calculate class statistics from classification results.
    
    Args:
        predictions: Array of predicted class labels
        label_encoder: LabelEncoder with class names
        pixel_size: Tuple (width, height) in map units for area calculation
    
    Returns:
        Dictionary with class statistics
    """
    unique, counts = np.unique(predictions, return_counts=True)
    
    stats = {}
    total_pixels = len(predictions)
    
    for label, count in zip(unique, counts):
        class_name = str(label)
        percentage = (count / total_pixels) * 100
        
        class_stats = {
            'pixel_count': int(count),
            'percentage': round(percentage, 2)
        }
        
        # Calculate area if pixel size provided
        if pixel_size is not None:
            pixel_area = abs(pixel_size[0] * pixel_size[1])
            area_m2 = count * pixel_area
            class_stats['area_m2'] = round(area_m2, 2)
            class_stats['area_ha'] = round(area_m2 / 10000, 4)
            class_stats['area_km2'] = round(area_m2 / 1000000, 6)
        
        stats[class_name] = class_stats
    
    return {
        'total_pixels': total_pixels,
        'num_classes': len(unique),
        'classes': stats
    }


def check_training_balance(y_train, class_names=None, threshold_ratio=10):
    """
    Check if training samples are balanced across classes.
    
    Args:
        y_train: Training labels (encoded)
        class_names: Optional list of class names
        threshold_ratio: Warn if max/min ratio exceeds this
    
    Returns:
        Dictionary with balance info and warnings
    """
    unique, counts = np.unique(y_train, return_counts=True)
    
    max_count = counts.max()
    min_count = counts.min()
    ratio = max_count / min_count if min_count > 0 else float('inf')
    
    class_counts = {}
    for i, (label, count) in enumerate(zip(unique, counts)):
        name = class_names[i] if class_names and i < len(class_names) else str(label)
        class_counts[name] = int(count)
    
    warnings = []
    if ratio > threshold_ratio:
        warnings.append(f"Class imbalance detected: ratio {ratio:.1f}:1")
        warnings.append(f"Max samples: {max_count}, Min samples: {min_count}")
        warnings.append("Consider balancing samples or using class weights")
    
    if min_count < 10:
        warnings.append(f"Very few samples ({min_count}) for some classes")
    
    return {
        'class_counts': class_counts,
        'max_samples': int(max_count),
        'min_samples': int(min_count),
        'imbalance_ratio': round(ratio, 2),
        'is_balanced': ratio <= threshold_ratio,
        'warnings': warnings
    }


def cross_validate(X_train, y_train, classifier, n_folds=5):
    """
    Perform k-fold cross-validation.
    
    Args:
        X_train: Training features
        y_train: Training labels
        classifier: Classifier instance
        n_folds: Number of folds
    
    Returns:
        Dictionary with CV results
    """
    from sklearn.model_selection import cross_val_score, cross_val_predict
    
    scores = cross_val_score(classifier, X_train, y_train, cv=n_folds, scoring='accuracy')
    
    return {
        'fold_scores': scores.tolist(),
        'mean_accuracy': round(scores.mean(), 4),
        'std_accuracy': round(scores.std(), 4),
        'n_folds': n_folds
    }


def get_feature_importance(classifier, band_names=None):
    """
    Get feature importance from classifier (if supported).
    
    Args:
        classifier: Trained classifier
        band_names: Optional list of band names
    
    Returns:
        Dictionary with feature importances or None if not supported
    """
    # Random Forest, Gradient Boosting
    if hasattr(classifier, 'feature_importances_'):
        importances = classifier.feature_importances_
        n_features = len(importances)
        
        if band_names is None:
            band_names = [f"Band {i+1}" for i in range(n_features)]
        
        # Sort by importance
        indices = np.argsort(importances)[::-1]
        
        return {
            'importances': {band_names[i]: round(importances[i], 4) for i in indices},
            'ranked_features': [band_names[i] for i in indices],
            'supported': True
        }
    
    # LDA
    if hasattr(classifier, 'coef_'):
        coef = np.abs(classifier.coef_).mean(axis=0) if classifier.coef_.ndim > 1 else np.abs(classifier.coef_)
        n_features = len(coef)
        
        if band_names is None:
            band_names = [f"Band {i+1}" for i in range(n_features)]
        
        indices = np.argsort(coef)[::-1]
        
        return {
            'importances': {band_names[i]: round(coef[i], 4) for i in indices},
            'ranked_features': [band_names[i] for i in indices],
            'supported': True
        }
    
    return {'supported': False}


def export_accuracy_report(metrics, class_stats, output_path, method_name="Classification"):
    """
    Export accuracy assessment report as text file.
    
    Args:
        metrics: Accuracy metrics dictionary
        class_stats: Class statistics dictionary
        output_path: Path to save report
        method_name: Classification method name
    """
    with open(output_path, 'w') as f:
        f.write("=" * 60 + "\n")
        f.write(f"ACCURACY ASSESSMENT REPORT - {method_name}\n")
        f.write("=" * 60 + "\n\n")
        
        # Overall metrics
        f.write("OVERALL METRICS\n")
        f.write("-" * 40 + "\n")
        f.write(f"Overall Accuracy: {metrics['overall_accuracy']*100:.2f}%\n")
        f.write(f"Kappa Coefficient: {metrics['kappa']:.4f}\n")
        f.write(f"Weighted Precision: {metrics['precision']:.4f}\n")
        f.write(f"Weighted Recall: {metrics['recall']:.4f}\n")
        f.write(f"Weighted F1-Score: {metrics['f1_score']:.4f}\n\n")
        
        # Confusion Matrix
        f.write("CONFUSION MATRIX\n")
        f.write("-" * 40 + "\n")
        cm = metrics['confusion_matrix']
        if 'class_names' in metrics:
            f.write("Classes: " + ", ".join(metrics['class_names']) + "\n")
        f.write(np.array2string(cm) + "\n\n")
        
        # Classification Report
        f.write("DETAILED CLASSIFICATION REPORT\n")
        f.write("-" * 40 + "\n")
        f.write(metrics['classification_report'] + "\n")
        
        # Class Statistics
        if class_stats:
            f.write("CLASS STATISTICS\n")
            f.write("-" * 40 + "\n")
            f.write(f"Total Pixels: {class_stats['total_pixels']:,}\n")
            f.write(f"Number of Classes: {class_stats['num_classes']}\n\n")
            
            for class_name, stats in class_stats['classes'].items():
                f.write(f"  {class_name}:\n")
                f.write(f"    Pixels: {stats['pixel_count']:,} ({stats['percentage']:.2f}%)\n")
                if 'area_ha' in stats:
                    f.write(f"    Area: {stats['area_ha']:.2f} ha ({stats['area_km2']:.4f} km²)\n")
        
        f.write("\n" + "=" * 60 + "\n")
    
    return output_path


def export_training_samples_csv(X_train, y_train, label_encoder, output_path, band_names=None):
    """
    Export training samples to CSV for external analysis.
    
    Args:
        X_train: Training features array
        y_train: Training labels (encoded)
        label_encoder: LabelEncoder with class names
        output_path: Path to save CSV
        band_names: Optional list of band names
    """
    import pandas as pd
    
    n_bands = X_train.shape[1]
    if band_names is None:
        band_names = [f"Band_{i+1}" for i in range(n_bands)]
    
    # Create DataFrame
    df = pd.DataFrame(X_train, columns=band_names)
    
    # Add class labels
    df['Class_Encoded'] = y_train
    df['Class_Name'] = label_encoder.inverse_transform(y_train)
    
    # Save to CSV
    df.to_csv(output_path, index=False)
    
    return output_path


def load_training_from_csv(csv_path):
    """
    Load training samples from CSV file (exported by export_training_samples_csv).
    
    Args:
        csv_path: Path to CSV file with training samples
    
    Returns:
        Tuple of (X_train, y_train, label_encoder)
    """
    import pandas as pd
    from sklearn.preprocessing import LabelEncoder
    
    df = pd.read_csv(csv_path)
    
    # Get band columns (all except Class_Encoded and Class_Name)
    band_columns = [col for col in df.columns if col not in ['Class_Encoded', 'Class_Name']]
    
    X_train = df[band_columns].values
    
    # Use class names if available, otherwise use encoded values
    if 'Class_Name' in df.columns:
        class_labels = df['Class_Name'].values
    else:
        class_labels = df['Class_Encoded'].values
    
    # Create label encoder from class names
    label_encoder = LabelEncoder()
    y_train = label_encoder.fit_transform(class_labels)
    
    print(f"Loaded {len(X_train)} training samples with {len(band_columns)} bands")
    print(f"Classes: {list(label_encoder.classes_)}")
    
    return X_train, y_train, label_encoder
