# -*- coding: utf-8 -*-
"""
Video Utilities

This module contains utility functions for video processing,
GPS metadata extraction, and frame extraction.
"""

from typing import List, Dict, Tuple, Generator
import numpy as np
import cv2
import subprocess
import json
import re
import os
from PyQt5.QtGui import QImage


def extract_gps_from_video(video_path: str) -> List[Dict]:
    """
    Extract GPS metadata from a video file using exiftool or ffmpeg.

    Args:
        video_path (str): Path to the video file

    Returns:
        List[Dict]: List of GPS data points with keys:
            - timestamp (float): Video timestamp in seconds
            - latitude (float): GPS latitude
            - longitude (float): GPS longitude
            - altitude (float): GPS altitude (optional)
    """
    gps_data = []

    try:
        # Try using exiftool first (more reliable)
        result = subprocess.run(
            ['exiftool', '-ee', '-G3', '-json', video_path],
            capture_output=True,
            text=True,
            timeout=60
        )

        if result.returncode == 0:
            data = json.loads(result.stdout)
            if data and len(data) > 0:
                video_data = data[0]

                # Extract GPS coordinates (different formats for different cameras)
                lat = None
                lon = None
                alt = None

                # DJI format
                if 'Camera:GPSLatitude' in video_data:
                    lat = _parse_gps_coordinate(video_data['Camera:GPSLatitude'])
                    lon = _parse_gps_coordinate(video_data['Camera:GPSLongitude'])
                    alt = video_data.get('Camera:GPSAltitude', None)

                # GoPro format
                elif 'Track1:GPSLatitude' in video_data:
                    lat = _parse_gps_coordinate(video_data['Track1:GPSLatitude'])
                    lon = _parse_gps_coordinate(video_data['Track1:GPSLongitude'])
                    alt = video_data.get('Track1:GPSAltitude', None)

                # Generic GPS format
                elif 'GPS:GPSLatitude' in video_data:
                    lat = _parse_gps_coordinate(video_data['GPS:GPSLatitude'])
                    lon = _parse_gps_coordinate(video_data['GPS:GPSLongitude'])
                    alt = video_data.get('GPS:GPSAltitude', None)

                if lat is not None and lon is not None:
                    gps_data.append({
                        'timestamp': 0.0,
                        'latitude': lat,
                        'longitude': lon,
                        'altitude': alt if alt else 0.0
                    })

    except (subprocess.TimeoutExpired, FileNotFoundError, json.JSONDecodeError) as e:
        print(f"GPS extraction failed: {e}")

    return gps_data


def _parse_gps_coordinate(coord_str) -> float:
    """
    Parse GPS coordinate from various formats to decimal degrees.

    Args:
        coord_str: GPS coordinate string or number

    Returns:
        float: Decimal degrees
    """
    if isinstance(coord_str, (int, float)):
        return float(coord_str)

    if isinstance(coord_str, str):
        # Try parsing DMS format: "37 deg 34' 0.00\" N"
        dms_match = re.match(r"(\d+)\s*deg\s*(\d+)'\s*([\d.]+)\"\s*([NSEW])", coord_str)
        if dms_match:
            deg, min, sec, dir = dms_match.groups()
            decimal = float(deg) + float(min)/60 + float(sec)/3600
            if dir in ['S', 'W']:
                decimal = -decimal
            return decimal

        # Try parsing simple decimal
        try:
            return float(coord_str)
        except ValueError:
            pass

    return 0.0


def parse_gps_string(gps_str: str) -> Tuple[float, float]:
    """
    Parse GPS coordinate string to decimal degrees.

    Handles formats like:
    - "37°34'00\"N, 126°58'40\"E"
    - "37.5665, 126.9780"

    Args:
        gps_str (str): GPS coordinate string

    Returns:
        Tuple[float, float]: (latitude, longitude) in decimal degrees
    """
    # Try splitting by comma
    parts = gps_str.split(',')
    if len(parts) == 2:
        try:
            # Try simple decimal format
            lat = float(parts[0].strip())
            lon = float(parts[1].strip())
            return (lat, lon)
        except ValueError:
            pass

    # Try DMS format
    dms_pattern = r"(\d+)°(\d+)'([\d.]+)\"([NSEW])"
    matches = re.findall(dms_pattern, gps_str)

    if len(matches) >= 2:
        coords = []
        for match in matches[:2]:
            deg, min, sec, dir = match
            decimal = float(deg) + float(min)/60 + float(sec)/3600
            if dir in ['S', 'W']:
                decimal = -decimal
            coords.append(decimal)
        return tuple(coords)

    return (0.0, 0.0)


def extract_frames(video_path: str, interval: int = 1) -> Generator[Tuple[int, np.ndarray], None, None]:
    """
    Extract frames from a video file as a generator (memory efficient).

    Args:
        video_path (str): Path to the video file
        interval (int): Extract every Nth frame (default: 1 = all frames)

    Yields:
        Tuple[int, np.ndarray]: (frame_index, frame_image)
            - frame_index (int): Index of the frame
            - frame_image (np.ndarray): Frame as numpy array (H, W, C) in BGR format
    """
    cap = cv2.VideoCapture(video_path)

    if not cap.isOpened():
        raise ValueError(f"Cannot open video file: {video_path}")

    frame_idx = 0

    try:
        while True:
            ret, frame = cap.read()

            if not ret:
                break

            if frame_idx % interval == 0:
                yield (frame_idx, frame)

            frame_idx += 1

    finally:
        cap.release()


def get_video_info(video_path: str) -> Dict:
    """
    Get video file information.

    Args:
        video_path (str): Path to the video file

    Returns:
        Dict: Video information with keys:
            - fps (float): Frames per second
            - total_frames (int): Total number of frames
            - duration (float): Duration in seconds
            - width (int): Video width in pixels
            - height (int): Video height in pixels
            - codec (str): Video codec
            - has_gps (bool): Whether video contains GPS metadata
    """
    cap = cv2.VideoCapture(video_path)

    if not cap.isOpened():
        raise ValueError(f"Cannot open video file: {video_path}")

    fps = cap.get(cv2.CAP_PROP_FPS)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fourcc = int(cap.get(cv2.CAP_PROP_FOURCC))

    # Convert fourcc to codec string
    codec = "".join([chr((fourcc >> 8 * i) & 0xFF) for i in range(4)])

    duration = total_frames / fps if fps > 0 else 0

    cap.release()

    # Check for GPS metadata
    has_gps = False
    gps_data = extract_gps_from_video(video_path)
    if gps_data and len(gps_data) > 0:
        has_gps = True

    return {
        'fps': fps,
        'total_frames': total_frames,
        'duration': duration,
        'width': width,
        'height': height,
        'codec': codec,
        'has_gps': has_gps
    }


def convert_frame_to_qimage(frame: np.ndarray) -> QImage:
    """
    Convert numpy frame to QImage for display in Qt widgets.

    Args:
        frame (np.ndarray): Frame as numpy array (H, W, C) in BGR format

    Returns:
        QImage: Qt image object
    """
    # Convert BGR to RGB
    rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

    height, width, channels = rgb_frame.shape
    bytes_per_line = channels * width

    qimage = QImage(rgb_frame.data, width, height, bytes_per_line, QImage.Format_RGB888)

    return qimage.copy()


def resize_frame(frame: np.ndarray, target_size: Tuple[int, int]) -> np.ndarray:
    """
    Resize a frame while maintaining aspect ratio.

    Args:
        frame (np.ndarray): Input frame
        target_size (Tuple[int, int]): Target (width, height)

    Returns:
        np.ndarray: Resized frame
    """
    h, w = frame.shape[:2]
    target_w, target_h = target_size

    # Calculate aspect ratio
    aspect = w / h
    target_aspect = target_w / target_h

    if aspect > target_aspect:
        # Width is the limiting factor
        new_w = target_w
        new_h = int(target_w / aspect)
    else:
        # Height is the limiting factor
        new_h = target_h
        new_w = int(target_h * aspect)

    resized = cv2.resize(frame, (new_w, new_h), interpolation=cv2.INTER_LINEAR)

    return resized


def draw_bboxes_on_frame(frame: np.ndarray, bboxes: List[List[float]],
                         labels: List[str] = None, scores: List[float] = None) -> np.ndarray:
    """
    Draw bounding boxes on a frame.

    Args:
        frame (np.ndarray): Input frame
        bboxes (List[List[float]]): List of bounding boxes [[x, y, w, h], ...]
        labels (List[str]): List of class labels (optional)
        scores (List[float]): List of confidence scores (optional)

    Returns:
        np.ndarray: Frame with drawn bounding boxes
    """
    output = frame.copy()

    for i, bbox in enumerate(bboxes):
        x, y, w, h = [int(v) for v in bbox]

        # Draw rectangle
        color = (0, 255, 0)  # Green
        thickness = 2
        cv2.rectangle(output, (x, y), (x + w, y + h), color, thickness)

        # Draw label and score
        if labels is not None and i < len(labels):
            label_text = labels[i]

            if scores is not None and i < len(scores):
                label_text += f" {scores[i]:.2f}"

            # Draw background for text
            font = cv2.FONT_HERSHEY_SIMPLEX
            font_scale = 0.5
            font_thickness = 1

            (text_w, text_h), baseline = cv2.getTextSize(
                label_text, font, font_scale, font_thickness
            )

            cv2.rectangle(
                output,
                (x, y - text_h - baseline - 5),
                (x + text_w, y),
                color,
                -1
            )

            # Draw text
            cv2.putText(
                output,
                label_text,
                (x, y - baseline - 2),
                font,
                font_scale,
                (0, 0, 0),
                font_thickness
            )

    return output
