# -*- coding: utf-8 -*-
"""
Detection Controller

This module contains the DetectionController class for handling
YOLOX object detection operations.
"""

from typing import List, Dict, Tuple, Optional
import numpy as np
import os
import sys
from PyQt5.QtCore import QObject, pyqtSignal

sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from models.detection_model import DetectionModel
from models.video_model import VideoModel
from utils import yolox_utils, coordinate_utils


class DetectionController(QObject):
    """
    Detection Controller

    Handles YOLOX model loading and object detection execution.
    """

    # Signals
    detection_progress = pyqtSignal(int, str)  # (progress_percent, message)
    detection_completed = pyqtSignal(list)  # (detections_list)
    frame_processed = pyqtSignal(int, int)  # (current_frame, total_frames)
    error_occurred = pyqtSignal(str)  # (error_message)

    def __init__(self, detection_model: DetectionModel):
        """
        Initialize the detection controller.

        Args:
            detection_model (DetectionModel): Detection data model instance
        """
        super().__init__()
        self.detection_model = detection_model
        self.yolox_model = None
        self.confidence_threshold = 0.5
        self.nms_threshold = 0.45
        self.device = 'cuda'
        self.model_loaded = False

    def load_yolox_model(self, model_path: str = None, model_name: str = 'yolox-s',
                         device: str = 'cuda'):
        """
        Load YOLOX model.

        Args:
            model_path (str): Path to model weights file (optional)
            model_name (str): Model name ('yolox-s', 'yolox-m', etc.)
            device (str): Device to use ('cuda' or 'cpu')

        Emits:
            detection_progress: Signal with loading progress
            error_occurred: Signal if model loading fails
        """
        try:
            self.detection_progress.emit(0, f"Loading YOLOX model: {model_name}...")

            self.device = device

            # Use yolox_utils to load the model
            self.yolox_model = yolox_utils.load_yolox_model(
                model_name=model_name,
                device=device,
                weights_path=model_path
            )

            # Note: The current yolox_utils.load_yolox_model is a placeholder
            # In production, this will return an actual YOLOX model
            # For now, we mark it as loaded to allow testing of other components
            self.model_loaded = True

            self.detection_progress.emit(100, f"Model {model_name} loaded on {device}")

        except Exception as e:
            error_msg = f"Error loading YOLOX model: {str(e)}"
            self.error_occurred.emit(error_msg)
            self.model_loaded = False

    def process_video(self, video_model: VideoModel, frame_interval: int = 30,
                      confidence_threshold: float = 0.5, gps_track: Optional[List[Dict]] = None):
        """
        Process video and detect objects in frames.

        Args:
            video_model (VideoModel): Video model containing the video data
            frame_interval (int): Process every Nth frame
            confidence_threshold (float): Minimum confidence for detections
            gps_track (List[Dict]): GPS track data (optional, will extract if not provided)

        Emits:
            detection_progress: Signal with progress updates
            frame_processed: Signal after each frame is processed
            detection_completed: Signal when all frames are processed
        """
        if not self.model_loaded:
            self.error_occurred.emit("YOLOX model not loaded")
            return

        if not video_model.is_loaded():
            self.error_occurred.emit("No video loaded")
            return

        try:
            self.confidence_threshold = confidence_threshold

            # Get GPS track if not provided
            if gps_track is None:
                gps_track = video_model.extract_gps_metadata()

            # Get video info
            total_frames = video_model.get_total_frames()
            fps = video_model.get_fps()

            self.detection_progress.emit(0, "Starting object detection...")

            # Clear previous detections
            self.detection_model.clear()

            frame_count = 0
            detection_count = 0

            # Process frames at specified interval
            for frame_idx in range(0, total_frames, frame_interval):
                # Get frame
                frame = video_model.get_frame_at_index(frame_idx)

                if frame is None:
                    continue

                # Detect objects in frame
                detections = self.detect_objects(frame)

                # Calculate timestamp
                timestamp = frame_idx / fps if fps > 0 else 0.0

                # Interpolate GPS position for this frame
                gps_coords = self._interpolate_gps_for_frame(
                    frame_idx, gps_track, fps
                )

                # Add detections to model
                for det in detections:
                    if det['confidence'] >= self.confidence_threshold:
                        self.detection_model.add_detection(
                            frame_idx=frame_idx,
                            bbox=det['bbox'],
                            class_id=det['class_id'],
                            confidence=det['confidence'],
                            gps_coords=gps_coords if gps_coords else (0.0, 0.0),
                            timestamp=timestamp
                        )
                        detection_count += 1

                frame_count += 1

                # Emit progress signals
                progress = int((frame_idx / total_frames) * 100)
                self.detection_progress.emit(
                    progress,
                    f"Processed {frame_count} frames, found {detection_count} objects"
                )
                self.frame_processed.emit(frame_idx, total_frames)

            # Complete
            self.detection_progress.emit(
                100,
                f"Detection complete: {detection_count} objects in {frame_count} frames"
            )

            detections_list = self.detection_model.get_detections()
            self.detection_completed.emit(detections_list)

        except Exception as e:
            error_msg = f"Error during video processing: {str(e)}"
            self.error_occurred.emit(error_msg)

    def detect_objects(self, frame: np.ndarray) -> List[Dict]:
        """
        Detect objects in a single frame.

        Args:
            frame (np.ndarray): Input frame image

        Returns:
            List[Dict]: List of detections with keys:
                - bbox (List[float]): [x, y, width, height]
                - class_id (int): COCO class ID
                - class_name (str): Class name
                - confidence (float): Detection confidence
        """
        if not self.model_loaded:
            return []

        try:
            # Preprocess image
            img_tensor = yolox_utils.preprocess_image(frame)

            # TODO: Run YOLOX inference
            # outputs = self.yolox_model(img_tensor)
            # detections = yolox_utils.postprocess_detections(
            #     outputs, frame.shape[:2], (640, 640),
            #     self.confidence_threshold, self.nms_threshold
            # )

            # Placeholder: Return empty list until full YOLOX integration
            # In production, this will return actual detections
            detections = []

            return detections

        except Exception as e:
            print(f"Error in object detection: {e}")
            return []

    def _interpolate_gps_for_frame(self, frame_idx: int, gps_track: List[Dict],
                                   fps: float) -> Optional[Tuple[float, float]]:
        """
        Interpolate GPS coordinates for a given frame.

        Args:
            frame_idx (int): Frame index
            gps_track (List[Dict]): GPS track data
            fps (float): Frames per second

        Returns:
            Tuple[float, float]: (longitude, latitude) or None
        """
        if not gps_track or len(gps_track) == 0:
            return None

        # Calculate timestamp for this frame
        timestamp = frame_idx / fps if fps > 0 else 0.0

        # Use coordinate_utils to interpolate GPS position
        gps_coords = coordinate_utils.interpolate_gps_position(gps_track, timestamp)

        return gps_coords

    def set_confidence_threshold(self, threshold: float):
        """
        Set the confidence threshold for detections.

        Args:
            threshold (float): Confidence threshold (0-1)
        """
        self.confidence_threshold = max(0.0, min(1.0, threshold))

    def set_nms_threshold(self, threshold: float):
        """
        Set the NMS (Non-Maximum Suppression) threshold.

        Args:
            threshold (float): NMS threshold (0-1)
        """
        self.nms_threshold = max(0.0, min(1.0, threshold))

    def get_class_names(self) -> List[str]:
        """
        Get list of COCO class names.

        Returns:
            List[str]: List of 80 COCO class names
        """
        return yolox_utils.get_coco_class_names()

    def is_model_loaded(self) -> bool:
        """
        Check if YOLOX model is loaded.

        Returns:
            bool: True if model is loaded
        """
        return self.model_loaded

    def get_detection_statistics(self) -> Dict:
        """
        Get statistics about current detections.

        Returns:
            Dict: Statistics including counts by class
        """
        return {
            'total_detections': self.detection_model.get_detection_count(),
            'class_statistics': self.detection_model.get_class_name_statistics(),
            'frame_indices': self.detection_model.get_frame_indices()
        }
