# -*- coding: utf-8 -*-
"""
Training utilities for Supervised Classifier plugin.
Contains functions for preparing training data from vector files and raster images.
"""
import numpy as np
from osgeo import ogr, gdal
from sklearn.preprocessing import LabelEncoder


def prepare_training_data(shapefile_path, image_path, label_fields, selected_bands=None):
    """
    Prepare training data with support for multiple label columns and geometry types (Point, Line, Polygon).
    
    Args:
        shapefile_path: Path to vector file containing training samples
        image_path: Path to raster image
        label_fields: Field name(s) containing class labels
        selected_bands: List of band indices to use (1-indexed), or None for all bands
    
    Returns:
        X_train: numpy array of pixel values
        y_train: numpy array of encoded labels
        label_encoder: LabelEncoder fitted to the labels
    """
    shapefile_ds = ogr.Open(shapefile_path)
    if not shapefile_ds:
        raise Exception(f"Failed to open shapefile: {shapefile_path}")

    image_ds = gdal.Open(image_path)
    if not image_ds:
        raise Exception(f"Failed to open image: {image_path}")

    layer = shapefile_ds.GetLayer()
    if not layer:
        raise Exception("Failed to get layer from shapefile")

    geotransform = image_ds.GetGeoTransform()
    x_origin, pixel_width, _, y_origin, _, pixel_height = geotransform
    x_min, x_max = 0, image_ds.RasterXSize
    y_min, y_max = 0, image_ds.RasterYSize

    # Get bands
    if selected_bands:
        bands = [image_ds.GetRasterBand(i) for i in selected_bands]
    else:
        bands = [image_ds.GetRasterBand(i+1) for i in range(image_ds.RasterCount)]

    if not bands:
        raise Exception("Failed to get raster bands")

    # Read all band data at once for faster access
    band_arrays = [band.ReadAsArray() for band in bands]
    nodata_values = [band.GetNoDataValue() for band in bands]

    X_train, y_train = [], []
    label_encoder = LabelEncoder()
    
    # Handle multiple label fields
    if isinstance(label_fields, str):
        label_fields = [label_fields]
    
    multiple_labels = len(label_fields) > 1

    for feature in layer:
        geom = feature.GetGeometryRef()
        if not geom:
            continue
        
        # Get label
        if multiple_labels:
            label = None
            for field_name in label_fields:
                if feature.GetField(field_name) == 1:
                    label = field_name
                    break
            if not label:
                continue
        else:
            label = feature.GetField(label_fields[0])
            if label is None:
                continue
        
        # Get geometry type and extract sample points
        sample_points = _extract_sample_points(geom, x_origin, pixel_width, y_origin, pixel_height, x_max, y_max)
        
        # Extract pixel values for each sample point
        for sample in sample_points:
            pixel_x, pixel_y = _get_pixel_coords(sample, x_origin, pixel_width, y_origin, pixel_height)
            
            if not (x_min <= pixel_x < x_max and y_min <= pixel_y < y_max):
                continue
            
            pixel_values = _extract_pixel_values(band_arrays, nodata_values, int(pixel_y), int(pixel_x))
            
            if pixel_values is not None and len(pixel_values) == len(bands):
                X_train.append(pixel_values)
                y_train.append(label)

    X_train = np.array(X_train)
    y_train = np.array(y_train)

    if X_train.size == 0 or y_train.size == 0:
        return X_train, y_train, label_encoder

    y_train = label_encoder.fit_transform(y_train)
    
    print(f"Training data: {X_train.shape}, Labels: {list(label_encoder.classes_)}")

    return X_train, y_train, label_encoder


def normalize_and_merge_training_data(selected_references, image_path, selected_bands=None):
    """
    Merge training data from multiple reference files with class normalization.
    
    Class normalization rules:
    - Multiple columns selected (binary 0/1): column names become class labels
    - Single column: column values become class labels
    - Same class name/value across files: treated as same class
    
    Args:
        selected_references: List of reference file info dicts with 'path' and 'selected_fields'
        image_path: Path to raster image
        selected_bands: List of band indices to use (1-indexed), or None for all bands
    
    Returns:
        X_train: numpy array of merged pixel values
        y_train: numpy array of encoded labels
        label_encoder: LabelEncoder fitted to unified labels
        unified_classes: List of all unique class names
    """
    image_ds = gdal.Open(image_path)
    if not image_ds:
        raise Exception(f"Failed to open image: {image_path}")

    geotransform = image_ds.GetGeoTransform()
    x_origin, pixel_width, _, y_origin, _, pixel_height = geotransform
    x_min, x_max = 0, image_ds.RasterXSize
    y_min, y_max = 0, image_ds.RasterYSize

    # Get bands
    if selected_bands:
        bands = [image_ds.GetRasterBand(i) for i in selected_bands]
    else:
        bands = [image_ds.GetRasterBand(i+1) for i in range(image_ds.RasterCount)]

    if not bands:
        raise Exception("Failed to get raster bands")

    # Read all band data at once
    band_arrays = [band.ReadAsArray() for band in bands]
    nodata_values = [band.GetNoDataValue() for band in bands]

    all_X_train = []
    all_y_train = []
    all_classes = set()

    # Process each reference file
    for ref_info in selected_references:
        shapefile_path = ref_info['path']
        label_fields = ref_info['selected_fields']
        
        if not label_fields:
            print(f"Skipping {ref_info['name']}: No label fields selected")
            continue

        shapefile_ds = ogr.Open(shapefile_path)
        if not shapefile_ds:
            print(f"Skipping {ref_info['name']}: Failed to open")
            continue

        layer = shapefile_ds.GetLayer()
        if not layer:
            continue

        # Determine binary mode vs value mode
        if isinstance(label_fields, str):
            label_fields = [label_fields]
        
        use_column_names_as_labels = len(label_fields) > 1

        for feature in layer:
            geom = feature.GetGeometryRef()
            if not geom:
                continue
            
            # Get label
            if use_column_names_as_labels:
                label = None
                for field_name in label_fields:
                    val = feature.GetField(field_name)
                    if val == 1 or val == "1" or val == True:
                        label = field_name
                        break
                if not label:
                    continue
            else:
                label = feature.GetField(label_fields[0])
                if label is None:
                    continue
                label = str(label)
            
            all_classes.add(label)
            
            # Extract sample points
            sample_points = _extract_sample_points(geom, x_origin, pixel_width, y_origin, pixel_height, x_max, y_max)
            
            # Extract pixel values
            for sample in sample_points:
                pixel_x, pixel_y = _get_pixel_coords(sample, x_origin, pixel_width, y_origin, pixel_height)
                
                if not (x_min <= pixel_x < x_max and y_min <= pixel_y < y_max):
                    continue
                
                pixel_values = _extract_pixel_values(band_arrays, nodata_values, int(pixel_y), int(pixel_x))
                
                if pixel_values is not None and len(pixel_values) == len(bands):
                    all_X_train.append(pixel_values)
                    all_y_train.append(label)

        print(f"Processed {ref_info['name']}: {layer.GetFeatureCount()} features")

    X_train = np.array(all_X_train)
    y_train = np.array(all_y_train)

    if X_train.size == 0 or y_train.size == 0:
        return X_train, y_train, LabelEncoder(), list(all_classes)

    label_encoder = LabelEncoder()
    y_train = label_encoder.fit_transform(y_train)
    
    print(f"Merged training data: {X_train.shape}")
    print(f"Unified classes ({len(all_classes)}): {list(label_encoder.classes_)}")

    return X_train, y_train, label_encoder, list(label_encoder.classes_)


def prepare_test_data(image_path, selected_bands=None):
    """
    Prepare test data from raster image - optimized for performance.
    
    Args:
        image_path: Path to raster image
        selected_bands: List of band indices (1-indexed), or None for all bands
    
    Returns:
        X_test: numpy array of all pixel values (flattened)
    """
    image_ds = gdal.Open(image_path)
    if not image_ds:
        raise Exception(f"Failed to open image: {image_path}")

    # Get bands
    if selected_bands:
        bands = [image_ds.GetRasterBand(i) for i in selected_bands]
    else:
        bands = [image_ds.GetRasterBand(i+1) for i in range(image_ds.RasterCount)]

    if not bands:
        raise Exception("Failed to get raster bands")

    # Read all bands at once and reshape
    band_arrays = [band.ReadAsArray().flatten() for band in bands]
    X_test = np.column_stack(band_arrays)
    
    print(f"Test data shape: {X_test.shape}")
    return X_test


def _extract_sample_points(geom, x_origin, pixel_width, y_origin, pixel_height, x_max, y_max):
    """Extract sample points from geometry (supports Point, Line, Polygon)."""
    geom_type = geom.GetGeometryType()
    sample_points = []
    
    # Point geometry
    if geom_type in [ogr.wkbPoint, ogr.wkbPoint25D, ogr.wkbMultiPoint, ogr.wkbMultiPoint25D]:
        if geom_type in [ogr.wkbMultiPoint, ogr.wkbMultiPoint25D]:
            for i in range(geom.GetGeometryCount()):
                pt = geom.GetGeometryRef(i)
                sample_points.append((pt.GetX(), pt.GetY()))
        else:
            sample_points.append((geom.GetX(), geom.GetY()))
    
    # Line or Polygon geometry
    else:
        env = geom.GetEnvelope()
        px_start = max(0, int((env[0] - x_origin) / pixel_width))
        px_end = min(x_max, int((env[1] - x_origin) / pixel_width) + 1)
        py_start = max(0, int((env[3] - y_origin) / pixel_height))
        py_end = min(y_max, int((env[2] - y_origin) / pixel_height) + 1)
        
        for px in range(px_start, px_end):
            for py in range(py_start, py_end):
                pixel_x = x_origin + (px + 0.5) * pixel_width
                pixel_y = y_origin + (py + 0.5) * pixel_height
                
                point = ogr.Geometry(ogr.wkbPoint)
                point.AddPoint(pixel_x, pixel_y)
                
                if geom_type in [ogr.wkbLineString, ogr.wkbLineString25D, 
                                ogr.wkbMultiLineString, ogr.wkbMultiLineString25D]:
                    if geom.Distance(point) < abs(pixel_width) * 0.5:
                        sample_points.append((px, py, True))
                else:
                    if geom.Contains(point) or geom.Intersects(point):
                        sample_points.append((px, py, True))
    
    return sample_points


def _get_pixel_coords(sample, x_origin, pixel_width, y_origin, pixel_height):
    """Convert sample to pixel coordinates."""
    if len(sample) == 3:
        return sample[0], sample[1]
    else:
        x, y = sample
        return int((x - x_origin) / pixel_width), int((y - y_origin) / pixel_height)


def _extract_pixel_values(band_arrays, nodata_values, row, col):
    """Extract pixel values from all bands, returns None if nodata."""
    pixel_values = []
    for i, arr in enumerate(band_arrays):
        val = arr[row, col]
        if nodata_values[i] is not None and val == nodata_values[i]:
            return None
        pixel_values.append(val)
    return pixel_values
