from __future__ import annotations

"""3D Drillhole Exporter.

This module provides exporters for 3D drillhole data (traces and intervals).
"""

from typing import Any

from qgis.core import (
    QgsFeature,
    QgsField,
    QgsFields,
    QgsGeometry,
    QgsPoint,
    QgsWkbTypes,
)
from qgis.PyQt.QtCore import QMetaType

from sec_interp.core import utils as scu
from sec_interp.logger_config import get_logger

from .base_exporter import BaseExporter

logger = get_logger(__name__)


class DrillholeTrace3DExporter(BaseExporter):
    """Exports 3D drillhole traces to a Shapefile."""

    def get_supported_extensions(self) -> list[str]:
        """Get supported extensions."""
        return [".shp"]

    def export(self, output_path: Any, data: dict[str, Any]) -> bool:
        """Export 3D drillhole traces to a Shapefile.

        Args:
            output_path: Path to the output Shapefile.
            data: Dictionary containing 'drillhole_data' and 'crs'.
                  Can include 'use_projected' (bool).

        Returns:
            bool: True if export successful, False otherwise.

        """
        drillhole_data = data.get("drillhole_data")
        crs = data.get("crs")
        use_projected = data.get("use_projected", False)
        if not drillhole_data or not crs:
            return False

        try:
            fields = self._prepare_fields()
            writer = scu.create_shapefile_writer(
                str(output_path), crs, fields, QgsWkbTypes.LineStringZ
            )

            for hole_data in drillhole_data:
                self._process_hole_trace(writer, fields, hole_data, use_projected)

            del writer
        except Exception as e:
            logger.exception(f"Error exporting 3D traces to {output_path}: {e}")
            return False
        return True

    def _process_hole_trace(
        self, writer: Any, fields: QgsFields, hole_data: tuple, use_projected: bool
    ) -> None:
        """Process and write a single hole trace feature."""
        # Standard format: (hid, spatial_points, segments)
        # Legacy/Test format: (hid, trace2d, trace3d, traces3d_proj, segments)
        hole_id = hole_data[0]

        if len(hole_data) == 3:
            # New format with SpatialMeta objects
            spatial_points = hole_data[1]
            if use_projected:
                points = [
                    QgsPoint(p.x_proj or 0.0, p.y_proj or 0.0, p.z)
                    for p in spatial_points
                    if p.x_proj is not None
                ]
            else:
                points = [
                    QgsPoint(p.x_3d or 0.0, p.y_3d or 0.0, p.z)
                    for p in spatial_points
                    if p.x_3d is not None
                ]
        elif len(hole_data) == 5:
            # Legacy/Integration Test format
            _, _, traces_3d, traces_3d_proj, _ = hole_data
            points_source = traces_3d_proj if use_projected else traces_3d
            points = [QgsPoint(x, y, z) for x, y, z in points_source]
        else:
            logger.warning(
                f"Unexpected hole data format (length {len(hole_data)}) for hole {hole_id}"
            )
            return

        if not points or len(points) < 2:
            return

        geom = QgsGeometry.fromPolyline(points)
        if geom and not geom.isNull():
            feat = QgsFeature(fields)
            feat.setGeometry(geom)
            feat.setAttribute("hole_id", str(hole_id))
            writer.addFeature(feat)

    def _prepare_fields(self) -> QgsFields:
        """Create standard fields for drillhole trace."""
        fields = QgsFields()
        fields.append(QgsField("hole_id", QMetaType.Type.QString))
        return fields


class DrillholeInterval3DExporter(BaseExporter):
    """Exports 3D drillhole intervals to a Shapefile."""

    def get_supported_extensions(self) -> list[str]:
        """Get supported extensions."""
        return [".shp"]

    def export(self, output_path: Any, data: dict[str, Any]) -> bool:
        """Export 3D drillhole intervals to a Shapefile.

        Args:
            output_path: Path to the output Shapefile.
            data: Dictionary containing 'drillhole_data' and 'crs'.
                  Can include 'use_projected' (bool).

        Returns:
            bool: True if export successful, False otherwise.

        """
        drillhole_data = data.get("drillhole_data")
        crs = data.get("crs")
        use_projected = data.get("use_projected", False)
        if not drillhole_data or not crs:
            return False

        try:
            fields = self._prepare_fields()
            writer = scu.create_shapefile_writer(
                str(output_path), crs, fields, QgsWkbTypes.LineStringZ
            )

            for hole_data in drillhole_data:
                self._process_hole_intervals(writer, fields, hole_data, use_projected)

            del writer
        except Exception as e:
            logger.exception(f"Error exporting 3D intervals to {output_path}: {e}")
            return False
        return True

    def _process_hole_intervals(
        self, writer: Any, fields: QgsFields, hole_data: tuple, use_projected: bool
    ) -> None:
        """Process and write intervals for a single hole."""
        # segments are always the last element in both 3 and 5 element formats
        hole_id = hole_data[0]
        segments = hole_data[-1]

        if not segments or not isinstance(segments, list):
            return

        for segment in segments:
            points_source = segment.points_3d_projected if use_projected else segment.points_3d
            if not points_source or len(points_source) < 2:
                continue

            points = [QgsPoint(x, y, z) for x, y, z in points_source]
            geom = QgsGeometry.fromPolyline(points)

            if geom and not geom.isNull():
                feat = QgsFeature(fields)
                feat.setGeometry(geom)
                feat.setAttribute("hole_id", str(hole_id))
                attrs = segment.attributes
                feat.setAttribute("from_depth", attrs.get("from", 0.0))
                feat.setAttribute("to_depth", attrs.get("to", 0.0))
                feat.setAttribute("unit", segment.unit_name)
                writer.addFeature(feat)

    def _prepare_fields(self) -> QgsFields:
        """Create fields for drillhole intervals."""
        fields = QgsFields()
        fields.append(QgsField("hole_id", QMetaType.Type.QString))
        fields.append(QgsField("from_depth", QMetaType.Type.Double))
        fields.append(QgsField("to_depth", QMetaType.Type.Double))
        fields.append(QgsField("unit", QMetaType.Type.QString))
        return fields
