# -*- coding: utf-8 -*-
"""
Classifier utilities for Supervised Classifier plugin.
Contains functions for getting classifiers and classification methods.
"""
import numpy as np
from scipy.spatial.distance import cdist


# Default algorithm parameters
DEFAULT_PARAMS = {
    "Random Forest": {
        "n_estimators": 100,
        "max_depth": None,
        "min_samples_split": 2,
        "min_samples_leaf": 1
    },
    "SVM (Support Vector Machine)": {
        "kernel": "rbf",
        "C": 1.0,
        "gamma": "scale",
        "max_iter": 1000
    },
    "KNN (K-Nearest Neighbors)": {
        "n_neighbors": 5,
        "weights": "distance",
        "algorithm": "auto"
    },
    "Gradient Boosting": {
        "n_estimators": 100,
        "learning_rate": 0.1,
        "max_depth": 3
    },
    "Maximum Likelihood": {},
    "LDA (Linear Discriminant Analysis)": {
        "solver": "svd"
    },
    "QDA (Quadratic Discriminant Analysis)": {}
}


def get_classifier(method, num_iterations=100, custom_params=None):
    """
    Get a classifier instance based on the method name.
    
    Args:
        method: Classification method name
        num_iterations: Number of iterations/estimators
        custom_params: Optional dict of custom parameters
    
    Returns:
        Classifier instance or None for Minimum Distance
    """
    from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
    from sklearn.svm import SVC
    from sklearn.neighbors import KNeighborsClassifier
    from sklearn.naive_bayes import GaussianNB
    from sklearn.discriminant_analysis import LinearDiscriminantAnalysis, QuadraticDiscriminantAnalysis
    
    params = DEFAULT_PARAMS.get(method, {}).copy()
    if custom_params:
        params.update(custom_params)
    
    if method == "Random Forest":
        return RandomForestClassifier(
            n_estimators=num_iterations,
            max_depth=params.get('max_depth'),
            min_samples_split=params.get('min_samples_split', 2),
            min_samples_leaf=params.get('min_samples_leaf', 1),
            n_jobs=-1,
            random_state=42
        )
    
    elif method == "SVM (Support Vector Machine)":
        return SVC(
            kernel=params.get('kernel', 'rbf'),
            C=params.get('C', 1.0),
            gamma=params.get('gamma', 'scale'),
            max_iter=num_iterations,
            random_state=42
        )
    
    elif method == "KNN (K-Nearest Neighbors)":
        try:
            return KNeighborsClassifier(
                n_neighbors=params.get('n_neighbors', 5),
                weights=params.get('weights', 'distance'),
                algorithm=params.get('algorithm', 'auto'),
                n_jobs=-1
            )
        except TypeError:
            return KNeighborsClassifier(
                n_neighbors=params.get('n_neighbors', 5),
                weights=params.get('weights', 'distance')
            )
    
    elif method == "Gradient Boosting":
        return GradientBoostingClassifier(
            n_estimators=num_iterations,
            learning_rate=params.get('learning_rate', 0.1),
            max_depth=params.get('max_depth', 3),
            random_state=42
        )
    
    elif method == "Maximum Likelihood":
        return GaussianNB()
    
    elif method == "LDA (Linear Discriminant Analysis)":
        return LinearDiscriminantAnalysis(
            solver=params.get('solver', 'svd')
        )
    
    elif method == "QDA (Quadratic Discriminant Analysis)":
        return QuadraticDiscriminantAnalysis()
    
    elif method == "Minimum Distance":
        return None
    
    else:
        raise Exception(f"Invalid method selected: {method}")


def classify_minimum_distance(train_data, X_test):
    """
    Classify using minimum distance to class means.
    
    Args:
        train_data: Tuple of (X_train, y_train)
        X_test: Test data to classify
    
    Returns:
        Array of predicted class indices
    """
    X_train, y_train = train_data
    class_means = np.array([X_train[y_train == k].mean(axis=0) for k in np.unique(y_train)])
    return np.argmin(cdist(X_test, class_means), axis=1)


# Method abbreviation for filenames
METHOD_ABBREVIATIONS = {
    "Minimum Distance": "MinDist", 
    "Random Forest": "RF", 
    "SVM (Support Vector Machine)": "SVM",
    "KNN (K-Nearest Neighbors)": "KNN", 
    "Gradient Boosting": "GB",
    "Maximum Likelihood": "MaxLik", 
    "LDA (Linear Discriminant Analysis)": "LDA", 
    "QDA (Quadratic Discriminant Analysis)": "QDA"
}


# All available classification methods
CLASSIFICATION_METHODS = [
    "Minimum Distance",
    "Random Forest",
    "Gradient Boosting",
    "SVM (Support Vector Machine)",
    "KNN (K-Nearest Neighbors)",
    "Maximum Likelihood",
    "LDA (Linear Discriminant Analysis)",
    "QDA (Quadratic Discriminant Analysis)"
]


# Methods that don't use iterations parameter
NON_ITERATIVE_METHODS = [
    "Minimum Distance",
    "Maximum Likelihood",
    "LDA (Linear Discriminant Analysis)",
    "QDA (Quadratic Discriminant Analysis)",
    "KNN (K-Nearest Neighbors)"
]


def get_method_abbreviation(method):
    """Get abbreviated method name for filenames."""
    return METHOD_ABBREVIATIONS.get(method, method[:3].upper())


def get_algorithm_parameters(method):
    """Get adjustable parameters for a classification method."""
    return DEFAULT_PARAMS.get(method, {})


def is_iterative_method(method):
    """Check if method uses iterations parameter."""
    return method not in NON_ITERATIVE_METHODS


def get_method_description(method):
    """Get description for classification method."""
    descriptions = {
        "Minimum Distance": "Assigns pixels to nearest class centroid (Euclidean distance)",
        "Random Forest": "Ensemble of decision trees with bagging",
        "Gradient Boosting": "Sequential ensemble with gradient descent optimization",
        "SVM (Support Vector Machine)": "Finds optimal hyperplane for class separation",
        "KNN (K-Nearest Neighbors)": "Classifies based on k nearest training samples",
        "Maximum Likelihood": "Assumes Gaussian distribution, uses Bayes theorem",
        "LDA (Linear Discriminant Analysis)": "Linear boundaries, assumes shared covariance",
        "QDA (Quadratic Discriminant Analysis)": "Quadratic boundaries, class-specific covariance"
    }
    return descriptions.get(method, "")
