# -*- coding: utf-8 -*-
"""
Batch Processing Algorithm

This module contains the Processing algorithm for batch processing
multiple videos.
"""

import os
import sys
import glob
from qgis.core import (
    QgsProcessingAlgorithm,
    QgsProcessingParameterFile,
    QgsProcessingParameterFolderDestination,
    QgsProcessingParameterNumber,
    QgsProcessingParameterEnum,
    QgsProcessingParameterBoolean,
    QgsProcessingException,
    QgsCoordinateReferenceSystem
)

# Add parent directories to path for imports
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))

from models.video_model import VideoModel
from models.detection_model import DetectionModel
from controllers.detection_controller import DetectionController
from controllers.export_controller import ExportController


class BatchProcessingAlgorithm(QgsProcessingAlgorithm):
    """
    Processing algorithm for batch video processing with YOLOX.
    """

    # Parameter names
    INPUT_FOLDER = 'INPUT_FOLDER'
    OUTPUT_FOLDER = 'OUTPUT_FOLDER'
    FRAME_INTERVAL = 'FRAME_INTERVAL'
    CONFIDENCE_THRESHOLD = 'CONFIDENCE_THRESHOLD'
    MODEL_NAME = 'MODEL_NAME'
    DEVICE = 'DEVICE'
    MERGE_RESULTS = 'MERGE_RESULTS'
    ADD_TO_MAP = 'ADD_TO_MAP'

    def __init__(self):
        """
        Initialize the algorithm.
        """
        super().__init__()

    def initAlgorithm(self, config=None):
        """
        Define algorithm parameters.

        Args:
            config: Algorithm configuration
        """
        # Input folder containing videos
        self.addParameter(
            QgsProcessingParameterFile(
                self.INPUT_FOLDER,
                'Input folder with video files',
                behavior=QgsProcessingParameterFile.Folder
            )
        )

        # Output folder
        self.addParameter(
            QgsProcessingParameterFolderDestination(
                self.OUTPUT_FOLDER,
                'Output folder for GeoPackages'
            )
        )

        # Frame interval
        self.addParameter(
            QgsProcessingParameterNumber(
                self.FRAME_INTERVAL,
                'Frame interval (process every Nth frame)',
                type=QgsProcessingParameterNumber.Integer,
                defaultValue=30,
                minValue=1
            )
        )

        # Confidence threshold
        self.addParameter(
            QgsProcessingParameterNumber(
                self.CONFIDENCE_THRESHOLD,
                'Confidence threshold',
                type=QgsProcessingParameterNumber.Double,
                defaultValue=0.5,
                minValue=0.0,
                maxValue=1.0
            )
        )

        # Model selection
        model_options = ['yolox-nano', 'yolox-tiny', 'yolox-s', 'yolox-m', 'yolox-l', 'yolox-x']
        self.addParameter(
            QgsProcessingParameterEnum(
                self.MODEL_NAME,
                'YOLOX model',
                options=model_options,
                defaultValue=2  # yolox-s
            )
        )

        # Device selection
        device_options = ['cuda', 'cpu']
        self.addParameter(
            QgsProcessingParameterEnum(
                self.DEVICE,
                'Processing device',
                options=device_options,
                defaultValue=0  # cuda
            )
        )

        # Merge results
        self.addParameter(
            QgsProcessingParameterBoolean(
                self.MERGE_RESULTS,
                'Merge all results into single GeoPackage',
                defaultValue=False
            )
        )

        # Add to map
        self.addParameter(
            QgsProcessingParameterBoolean(
                self.ADD_TO_MAP,
                'Add result layers to map',
                defaultValue=True
            )
        )

    def processAlgorithm(self, parameters, context, feedback):
        """
        Execute the batch processing algorithm.

        Args:
            parameters: Algorithm parameters
            context: Processing context
            feedback: Feedback object for progress reporting

        Returns:
            dict: Output parameters
        """
        # Get parameters
        input_folder = self.parameterAsFile(parameters, self.INPUT_FOLDER, context)
        output_folder = self.parameterAsString(parameters, self.OUTPUT_FOLDER, context)
        frame_interval = self.parameterAsInt(parameters, self.FRAME_INTERVAL, context)
        confidence_threshold = self.parameterAsDouble(parameters, self.CONFIDENCE_THRESHOLD, context)

        model_options = ['yolox-nano', 'yolox-tiny', 'yolox-s', 'yolox-m', 'yolox-l', 'yolox-x']
        model_idx = self.parameterAsEnum(parameters, self.MODEL_NAME, context)
        model_name = model_options[model_idx]

        device_options = ['cuda', 'cpu']
        device_idx = self.parameterAsEnum(parameters, self.DEVICE, context)
        device = device_options[device_idx]

        merge_results = self.parameterAsBoolean(parameters, self.MERGE_RESULTS, context)
        add_to_map = self.parameterAsBoolean(parameters, self.ADD_TO_MAP, context)

        # Create output folder if it doesn't exist
        os.makedirs(output_folder, exist_ok=True)

        # Find all video files
        video_extensions = ['*.mp4', '*.avi', '*.mov', '*.mkv', '*.MP4', '*.AVI', '*.MOV', '*.MKV']
        video_files = []
        for ext in video_extensions:
            video_files.extend(glob.glob(os.path.join(input_folder, ext)))

        if len(video_files) == 0:
            raise QgsProcessingException(f'No video files found in: {input_folder}')

        feedback.pushInfo(f'Found {len(video_files)} video files')
        feedback.pushInfo(f'Model: {model_name}, Device: {device}')

        try:
            # Load YOLOX model once for all videos
            detection_model = DetectionModel()
            detection_controller = DetectionController(detection_model)

            feedback.pushInfo(f'Loading YOLOX model: {model_name}...')
            detection_controller.load_yolox_model(model_name=model_name, device=device)

            all_detections = []
            processed_count = 0

            # Process each video
            for idx, video_path in enumerate(video_files):
                if feedback.isCanceled():
                    break

                video_name = os.path.basename(video_path)
                feedback.pushInfo(f'\nProcessing video {idx+1}/{len(video_files)}: {video_name}')

                try:
                    # Create video model
                    video_model = VideoModel()

                    # Load video
                    if not video_model.load_video(video_path):
                        feedback.pushWarning(f'Failed to load: {video_name}')
                        continue

                    # Get GPS track
                    gps_track = video_model.extract_gps_metadata()
                    if not gps_track or len(gps_track) == 0:
                        feedback.pushWarning(f'No GPS data in: {video_name}')

                    # Process video
                    detection_model.clear()  # Clear for each video
                    detection_controller.process_video(
                        video_model=video_model,
                        frame_interval=frame_interval,
                        confidence_threshold=confidence_threshold,
                        gps_track=gps_track
                    )

                    # Get detections
                    detections = detection_model.get_detections()
                    feedback.pushInfo(f'Found {len(detections)} detections in {video_name}')

                    if merge_results:
                        all_detections.extend(detections)
                    else:
                        # Export individual GeoPackage
                        if len(detections) > 0:
                            output_name = os.path.splitext(video_name)[0] + '.gpkg'
                            output_path = os.path.join(output_folder, output_name)

                            export_controller = ExportController(context.project().instance())
                            crs = QgsCoordinateReferenceSystem("EPSG:4326")

                            export_controller.export_to_geopackage(
                                detections=detections,
                                output_path=output_path,
                                crs=crs,
                                layer_name=os.path.splitext(video_name)[0]
                            )

                            if add_to_map:
                                export_controller.add_layer_to_map(
                                    output_path,
                                    os.path.splitext(video_name)[0]
                                )

                    # Cleanup
                    video_model.close()
                    processed_count += 1

                    # Update progress
                    progress = int(((idx + 1) / len(video_files)) * 100)
                    feedback.setProgress(progress)

                except Exception as e:
                    feedback.pushWarning(f'Error processing {video_name}: {str(e)}')
                    continue

            # Merge results if requested
            if merge_results and len(all_detections) > 0:
                feedback.pushInfo(f'\nMerging {len(all_detections)} total detections...')

                merged_output = os.path.join(output_folder, 'merged_detections.gpkg')
                export_controller = ExportController(context.project().instance())
                crs = QgsCoordinateReferenceSystem("EPSG:4326")

                export_controller.export_to_geopackage(
                    detections=all_detections,
                    output_path=merged_output,
                    crs=crs,
                    layer_name='merged_detections'
                )

                if add_to_map:
                    export_controller.add_layer_to_map(merged_output, 'merged_detections')

            feedback.setProgress(100)
            feedback.pushInfo(f'\nBatch processing complete!')
            feedback.pushInfo(f'Processed {processed_count}/{len(video_files)} videos')

            return {self.OUTPUT_FOLDER: output_folder}

        except Exception as e:
            raise QgsProcessingException(f'Error during batch processing: {str(e)}')

    def name(self) -> str:
        """
        Returns the algorithm name.

        Returns:
            str: Algorithm name (used internally)
        """
        return 'batch_video_detection'

    def displayName(self) -> str:
        """
        Returns the translated algorithm name for display.

        Returns:
            str: Display name
        """
        return 'Batch Process Videos'

    def group(self) -> str:
        """
        Returns the algorithm group name.

        Returns:
            str: Group name
        """
        return 'Object Detection'

    def groupId(self) -> str:
        """
        Returns the algorithm group ID.

        Returns:
            str: Group ID
        """
        return 'object_detection'

    def shortHelpString(self) -> str:
        """
        Returns a short help string for the algorithm.

        Returns:
            str: Help text
        """
        return """
        Process multiple GPS-synchronized videos at once using YOLOX object detection.

        This algorithm processes all video files in a folder, detects objects using YOLOX,
        and exports the results as GeoPackage files.

        Parameters:
        - Input folder: Folder containing video files
        - Output folder: Where to save GeoPackage results
        - Frame interval: Process every Nth frame (higher = faster)
        - Confidence threshold: Minimum detection confidence (0-1)
        - Model: YOLOX model variant (nano, tiny, s, m, l, x)
        - Device: CUDA (GPU) or CPU processing
        - Merge results: Combine all detections into one GeoPackage
        - Add to map: Automatically add result layers to QGIS

        Output:
        - One GeoPackage per video (or merged if selected)
        - Each GeoPackage contains detection points with:
          * GPS coordinates
          * Object class and confidence
          * Frame index and timestamp
          * Bounding box coordinates
        """

    def createInstance(self):
        """
        Create a new instance of the algorithm.

        Returns:
            BatchProcessingAlgorithm: New instance
        """
        return BatchProcessingAlgorithm()
