# -*- coding: utf-8 -*-
"""
Detection Data Model

This module contains the DetectionModel class for managing
YOLOX detection results.
"""

from typing import List, Dict, Tuple, Optional
import os
import sys

sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from utils.yolox_utils import get_coco_class_names


class DetectionModel:
    """
    Detection Data Model

    Manages YOLOX detection results including storage, filtering, and retrieval.
    """

    def __init__(self):
        """
        Initialize the detection data model.
        """
        self.detections: List[Dict] = []
        self.class_names = get_coco_class_names()
        self._next_id = 1

    def add_detection(self, frame_idx: int, bbox: List[float],
                      class_id: int, confidence: float,
                      gps_coords: Tuple[float, float],
                      timestamp: Optional[float] = None):
        """
        Add a single detection result.

        Args:
            frame_idx (int): Frame index where object was detected
            bbox (List[float]): Bounding box [x, y, width, height]
            class_id (int): COCO class ID
            confidence (float): Detection confidence score (0-1)
            gps_coords (Tuple[float, float]): (longitude, latitude) in WGS84
            timestamp (float): Video timestamp in seconds (optional)
        """
        # Get class name from COCO class list
        class_name = self.class_names[class_id] if 0 <= class_id < len(self.class_names) else 'unknown'

        detection = {
            'id': self._next_id,
            'frame_idx': frame_idx,
            'timestamp': timestamp if timestamp is not None else 0.0,
            'bbox': bbox,
            'class_id': class_id,
            'class_name': class_name,
            'confidence': confidence,
            'gps_coords': gps_coords  # (longitude, latitude)
        }

        self.detections.append(detection)
        self._next_id += 1

    def add_detections_batch(self, detections: List[Dict]):
        """
        Add multiple detections at once.

        Args:
            detections (List[Dict]): List of detection dictionaries with keys:
                - frame_idx, bbox, class_id, confidence, gps_coords, timestamp (optional)
        """
        for det in detections:
            self.add_detection(
                frame_idx=det['frame_idx'],
                bbox=det['bbox'],
                class_id=det['class_id'],
                confidence=det['confidence'],
                gps_coords=det['gps_coords'],
                timestamp=det.get('timestamp', None)
            )

    def get_detections(self) -> List[Dict]:
        """
        Get all detection results.

        Returns:
            List[Dict]: List of detection dictionaries with keys:
                - id (int)
                - frame_idx (int)
                - timestamp (float)
                - bbox (List[float])
                - class_id (int)
                - class_name (str)
                - confidence (float)
                - gps_coords (Tuple[float, float])
        """
        return self.detections.copy()

    def get_detections_by_frame(self, frame_idx: int) -> List[Dict]:
        """
        Get detections for a specific frame.

        Args:
            frame_idx (int): Frame index

        Returns:
            List[Dict]: List of detections in the frame
        """
        return [det for det in self.detections if det['frame_idx'] == frame_idx]

    def filter_by_class(self, class_ids: List[int]) -> List[Dict]:
        """
        Filter detections by class IDs.

        Args:
            class_ids (List[int]): List of COCO class IDs to filter

        Returns:
            List[Dict]: Filtered detection list
        """
        return [det for det in self.detections if det['class_id'] in class_ids]

    def filter_by_confidence(self, min_confidence: float) -> List[Dict]:
        """
        Filter detections by minimum confidence score.

        Args:
            min_confidence (float): Minimum confidence threshold (0-1)

        Returns:
            List[Dict]: Filtered detection list
        """
        return [det for det in self.detections if det['confidence'] >= min_confidence]

    def filter_by_class_names(self, class_names: List[str]) -> List[Dict]:
        """
        Filter detections by class names.

        Args:
            class_names (List[str]): List of class names to filter

        Returns:
            List[Dict]: Filtered detection list
        """
        return [det for det in self.detections if det['class_name'] in class_names]

    def get_detection_count(self) -> int:
        """
        Get the total number of detections.

        Returns:
            int: Total detection count
        """
        return len(self.detections)

    def get_class_statistics(self) -> Dict[int, int]:
        """
        Get detection count statistics by class.

        Returns:
            Dict[int, int]: Dictionary mapping class_id to count
        """
        stats = {}
        for det in self.detections:
            class_id = det['class_id']
            stats[class_id] = stats.get(class_id, 0) + 1
        return stats

    def get_class_name_statistics(self) -> Dict[str, int]:
        """
        Get detection count statistics by class name.

        Returns:
            Dict[str, int]: Dictionary mapping class_name to count
        """
        stats = {}
        for det in self.detections:
            class_name = det['class_name']
            stats[class_name] = stats.get(class_name, 0) + 1
        return stats

    def get_frame_indices(self) -> List[int]:
        """
        Get list of unique frame indices that have detections.

        Returns:
            List[int]: Sorted list of frame indices
        """
        frame_indices = set(det['frame_idx'] for det in self.detections)
        return sorted(list(frame_indices))

    def clear(self):
        """
        Clear all detection results.
        """
        self.detections = []
        self._next_id = 1
