# -*- coding: utf-8 -*-
"""
YOLOX Utilities

This module contains utility functions for YOLOX model loading,
preprocessing, inference, and postprocessing.

NOTE: This is a simplified implementation. For production use,
consider integrating the full YOLOX repository.
"""

from typing import List, Dict, Tuple
import numpy as np
import torch
import cv2
import os
import urllib.request


# COCO class names
COCO_CLASSES = [
    "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat",
    "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat",
    "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack",
    "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball",
    "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket",
    "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
    "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair",
    "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote",
    "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book",
    "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"
]

# Model download URLs
MODEL_URLS = {
    'yolox-nano': 'https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_nano.pth',
    'yolox-tiny': 'https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_tiny.pth',
    'yolox-s': 'https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_s.pth',
    'yolox-m': 'https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_m.pth',
    'yolox-l': 'https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_l.pth',
    'yolox-x': 'https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_x.pth'
}


def load_yolox_model(model_name: str = 'yolox-s', device: str = 'cuda',
                     weights_path: str = None):
    """
    Load YOLOX model with pretrained weights.

    NOTE: This is a placeholder implementation. In production, you would:
    1. Import from the YOLOX package
    2. Initialize the model architecture
    3. Load weights
    4. Set to eval mode

    Args:
        model_name (str): Model name ('yolox-nano', 'yolox-tiny', 'yolox-s',
                         'yolox-m', 'yolox-l', 'yolox-x')
        device (str): Device to load model on ('cuda' or 'cpu')
        weights_path (str): Path to weights file (if None, downloads pretrained)

    Returns:
        torch.nn.Module: Loaded YOLOX model in eval mode
    """
    if device == 'cuda' and not torch.cuda.is_available():
        print("CUDA not available, falling back to CPU")
        device = 'cpu'

    if weights_path is None:
        weights_path = download_pretrained_weights(model_name, './resources/models')

    print(f"Model {model_name} loaded on {device}")
    print(f"Weights: {weights_path}")
    print("NOTE: This is a placeholder. Integrate full YOLOX for production.")

    return None


def download_pretrained_weights(model_name: str, save_dir: str) -> str:
    """
    Download pretrained YOLOX weights from official repository.

    Args:
        model_name (str): Model name (e.g., 'yolox-s')
        save_dir (str): Directory to save the weights

    Returns:
        str: Path to the downloaded weights file
    """
    if model_name not in MODEL_URLS:
        raise ValueError(f"Unknown model: {model_name}. Available: {list(MODEL_URLS.keys())}")

    os.makedirs(save_dir, exist_ok=True)

    weights_filename = f"{model_name}.pth"
    weights_path = os.path.join(save_dir, weights_filename)

    if not os.path.exists(weights_path):
        print(f"Downloading {model_name} weights...")
        url = MODEL_URLS[model_name]

        try:
            urllib.request.urlretrieve(url, weights_path)
            print(f"Downloaded to {weights_path}")
        except Exception as e:
            print(f"Failed to download weights: {e}")
            raise

    return weights_path


def preprocess_image(image: np.ndarray, input_size: Tuple[int, int] = (640, 640)) -> torch.Tensor:
    """
    Preprocess image for YOLOX inference.

    Args:
        image (np.ndarray): Input image in BGR format (H, W, C)
        input_size (Tuple[int, int]): Target input size (width, height)

    Returns:
        torch.Tensor: Preprocessed image tensor (1, C, H, W)
    """
    h, w = image.shape[:2]
    scale = min(input_size[0] / w, input_size[1] / h)
    new_w, new_h = int(w * scale), int(h * scale)

    resized = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LINEAR)

    padded = np.full((input_size[1], input_size[0], 3), 114, dtype=np.uint8)
    padded[:new_h, :new_w] = resized

    rgb = cv2.cvtColor(padded, cv2.COLOR_BGR2RGB)

    img_tensor = torch.from_numpy(rgb).permute(2, 0, 1).float()
    img_tensor = img_tensor.unsqueeze(0)

    return img_tensor


def postprocess_detections(outputs: torch.Tensor, img_size: Tuple[int, int],
                           input_size: Tuple[int, int], conf_threshold: float = 0.5,
                           nms_threshold: float = 0.45) -> List[Dict]:
    """
    Postprocess YOLOX model outputs.

    NOTE: This is a simplified placeholder. In production, use YOLOX's postprocessing.

    Args:
        outputs (torch.Tensor): Raw model outputs
        img_size (Tuple[int, int]): Original image size (width, height)
        input_size (Tuple[int, int]): Model input size (width, height)
        conf_threshold (float): Confidence threshold for filtering
        nms_threshold (float): NMS IoU threshold

    Returns:
        List[Dict]: List of detections
    """
    return []


def visualize_detections(image: np.ndarray, detections: List[Dict],
                         show_labels: bool = True, show_scores: bool = True) -> np.ndarray:
    """
    Visualize detection results on an image.

    Args:
        image (np.ndarray): Input image (H, W, C)
        detections (List[Dict]): List of detection dictionaries
        show_labels (bool): Whether to show class labels
        show_scores (bool): Whether to show confidence scores

    Returns:
        np.ndarray: Image with visualized detections
    """
    output = image.copy()

    for detection in detections:
        bbox = detection['bbox']
        class_name = detection.get('class_name', 'unknown')
        confidence = detection.get('confidence', 0.0)

        x, y, w, h = [int(v) for v in bbox]

        color = (0, 255, 0)
        thickness = 2
        cv2.rectangle(output, (x, y), (x + w, y + h), color, thickness)

        if show_labels or show_scores:
            label_parts = []
            if show_labels:
                label_parts.append(class_name)
            if show_scores:
                label_parts.append(f"{confidence:.2f}")

            label_text = " ".join(label_parts)

            font = cv2.FONT_HERSHEY_SIMPLEX
            font_scale = 0.5
            font_thickness = 1

            (text_w, text_h), baseline = cv2.getTextSize(
                label_text, font, font_scale, font_thickness
            )

            cv2.rectangle(
                output,
                (x, y - text_h - baseline - 5),
                (x + text_w, y),
                color,
                -1
            )

            cv2.putText(
                output,
                label_text,
                (x, y - baseline - 2),
                font,
                font_scale,
                (0, 0, 0),
                font_thickness
            )

    return output


def get_coco_class_names() -> List[str]:
    """
    Get COCO dataset class names (80 classes).

    Returns:
        List[str]: List of class names
    """
    return COCO_CLASSES.copy()


def nms(bboxes: List[List[float]], scores: List[float], threshold: float = 0.45) -> List[int]:
    """
    Apply Non-Maximum Suppression.

    Args:
        bboxes (List[List[float]]): List of bounding boxes [[x, y, w, h], ...]
        scores (List[float]): List of confidence scores
        threshold (float): IoU threshold

    Returns:
        List[int]: Indices of boxes to keep
    """
    if len(bboxes) == 0:
        return []

    boxes = np.array(bboxes)
    scores_arr = np.array(scores)

    x1 = boxes[:, 0]
    y1 = boxes[:, 1]
    x2 = boxes[:, 0] + boxes[:, 2]
    y2 = boxes[:, 1] + boxes[:, 3]

    areas = (x2 - x1) * (y2 - y1)

    order = scores_arr.argsort()[::-1]

    keep = []

    while order.size > 0:
        i = order[0]
        keep.append(i)

        xx1 = np.maximum(x1[i], x1[order[1:]])
        yy1 = np.maximum(y1[i], y1[order[1:]])
        xx2 = np.minimum(x2[i], x2[order[1:]])
        yy2 = np.minimum(y2[i], y2[order[1:]])

        w = np.maximum(0.0, xx2 - xx1)
        h = np.maximum(0.0, yy2 - yy1)

        intersection = w * h
        iou = intersection / (areas[i] + areas[order[1:]] - intersection)

        inds = np.where(iou <= threshold)[0]
        order = order[inds + 1]

    return keep


def compute_iou(box1: List[float], box2: List[float]) -> float:
    """
    Compute Intersection over Union (IoU) between two boxes.

    Args:
        box1 (List[float]): First box [x, y, width, height]
        box2 (List[float]): Second box [x, y, width, height]

    Returns:
        float: IoU value (0-1)
    """
    x1, y1, w1, h1 = box1
    x2, y2, w2, h2 = box2

    x_left = max(x1, x2)
    y_top = max(y1, y2)
    x_right = min(x1 + w1, x2 + w2)
    y_bottom = min(y1 + h1, y2 + h2)

    if x_right < x_left or y_bottom < y_top:
        return 0.0

    intersection_area = (x_right - x_left) * (y_bottom - y_top)

    box1_area = w1 * h1
    box2_area = w2 * h2
    union_area = box1_area + box2_area - intersection_area

    if union_area == 0:
        return 0.0

    return intersection_area / union_area
