# -*- coding: utf-8 -*-
"""
Classifier utilities for Supervised Classifier plugin.
Contains functions for getting classifiers and classification methods.
"""
import numpy as np
from scipy.spatial.distance import cdist


def get_classifier(method, num_iterations):
    """
    Get a classifier instance based on the method name.
    
    Args:
        method: Classification method name
        num_iterations: Number of iterations/estimators
    
    Returns:
        Classifier instance or None for Minimum Distance
    """
    from sklearn.ensemble import RandomForestClassifier
    from sklearn.svm import SVC
    from sklearn.neighbors import KNeighborsClassifier
    from sklearn.naive_bayes import GaussianNB
    from sklearn.discriminant_analysis import LinearDiscriminantAnalysis, QuadraticDiscriminantAnalysis
    
    if method == "Random Forest":
        return RandomForestClassifier(n_estimators=num_iterations, n_jobs=-1)
    elif method == "SVM (Support Vector Machine)":
        return SVC(max_iter=num_iterations)
    elif method == "KNN (K-Nearest Neighbors)":
        try:
            return KNeighborsClassifier(n_neighbors=5, weights='distance', n_jobs=-1)
        except TypeError:
            return KNeighborsClassifier(n_neighbors=5, weights='distance')
    elif method == "Maximum Likelihood":
        return GaussianNB()
    elif method == "LDA (Linear Discriminant Analysis)":
        return LinearDiscriminantAnalysis()
    elif method == "QDA (Quadratic Discriminant Analysis)":
        return QuadraticDiscriminantAnalysis()
    elif method == "Minimum Distance":
        return None
    else:
        raise Exception("Invalid method selected")


def classify_minimum_distance(train_data, X_test):
    """
    Classify using minimum distance to class means.
    
    Args:
        train_data: Tuple of (X_train, y_train)
        X_test: Test data to classify
    
    Returns:
        Array of predicted class indices
    """
    X_train, y_train = train_data
    class_means = np.array([X_train[y_train == k].mean(axis=0) for k in np.unique(y_train)])
    return np.argmin(cdist(X_test, class_means), axis=1)


# Method abbreviation for filenames
METHOD_ABBREVIATIONS = {
    "Minimum Distance": "MinDist", 
    "Random Forest": "RF", 
    "SVM (Support Vector Machine)": "SVM",
    "KNN (K-Nearest Neighbors)": "KNN", 
    "Maximum Likelihood": "MaxLik", 
    "LDA (Linear Discriminant Analysis)": "LDA", 
    "QDA (Quadratic Discriminant Analysis)": "QDA"
}


def get_method_abbreviation(method):
    """Get abbreviated method name for filenames."""
    return METHOD_ABBREVIATIONS.get(method, method[:3].upper())
