"""Core Utilities Module
"""

import math
import re

from qgis.core import (
    QgsCoordinateReferenceSystem,
    QgsDistanceArea,
    QgsFeature,
    QgsField,
    QgsFields,
    QgsGeometry,
    QgsPointXY,
    QgsProject,
    QgsVectorFileWriter,
    QgsWkbTypes,
)
from qgis.PyQt.QtCore import QVariant


def calculate_line_azimuth(line_geom):
    """Calculate the azimuth of a line.

    Args:
        line_geom (QgsGeometry): The line geometry to calculate azimuth for.

    Returns:
        float: Azimuth in degrees (0-360). Returns 0 for points or invalid lines.
    """
    if line_geom.wkbType() == QgsWkbTypes.Point:
        return 0  # Points have no azimuth
    elif line_geom.wkbType() == QgsWkbTypes.LineString:
        line = line_geom.asPolyline()
        if len(line) < 2:
            return 0
        # Calculate azimuth of first segment (from first to second point)
        p1 = line[0]
        p2 = line[1]
        azimuth = math.degrees(math.atan2(p2.x() - p1.x(), p2.y() - p1.y()))
        # Convert to compass bearing (0-360)
        if azimuth < 0:
            azimuth += 360
        return azimuth
    else:
        # For other geometry types, return a default value
        return 0


def calculate_step_size(geom, raster_lyr):
    """Calculate step size based on slope and raster resolution.

    Ensures that sampling occurs at approximately one pixel intervals,
    accounting for the slope of the line relative to the raster grid.

    Args:
        geom (QgsGeometry): The geometry to sample along.
        raster_lyr (QgsRasterLayer): The raster layer being sampled.

    Returns:
        float: Calculated step size in map units.
    """
    # Get raster resolution
    res = raster_lyr.rasterUnitsPerPixelX()

    # Calculate step size based on slope to ensure 1 pixel sampling
    dist_step = res
    try:
        if geom.isMultipart():
            parts = geom.asMultiPolyline()
            line_pts = parts[0] if parts else []
        else:
            line_pts = geom.asPolyline()

        if line_pts and len(line_pts) >= 2:
            p1 = line_pts[0]
            p2 = line_pts[-1]
            dx = abs(p2.x() - p1.x())
            dy = abs(p2.y() - p1.y())
            if max(dx, dy) > 0:
                dist_step = geom.length() * res / max(dx, dy)
    except (AttributeError, IndexError, ZeroDivisionError):
        # Fallback to simple resolution if geometry parsing fails
        pass
    return dist_step


def get_line_start_point(geometry):
    """Helper to get the start point of a line geometry.

    Args:
        geometry (QgsGeometry): The geometry to get start point from.

    Returns:
        QgsPointXY: The starting point of the line.
    """
    if geometry.isMultipart():
        return geometry.asMultiPolyline()[0][0]
    else:
        return geometry.asPolyline()[0]


def create_distance_area(crs):
    """Helper to create and configure QgsDistanceArea.

    Args:
        crs (QgsCoordinateReferenceSystem): The CRS to use for calculations.

    Returns:
        QgsDistanceArea: Configured distance area object.
    """
    da = QgsDistanceArea()
    da.setSourceCrs(crs, QgsProject.instance().transformContext())
    da.setEllipsoid(crs.ellipsoidAcronym())
    return da


def sample_elevation_along_line(
    geometry, raster_layer, band_number, distance_area, reference_point=None
):
    """Helper to sample elevation values along a line geometry.

    Args:
        geometry: QgsGeometry (LineString) to sample along.
        raster_layer: QgsRasterLayer to sample from.
        band_number: Raster band number.
        distance_area: QgsDistanceArea for distance calculations.
        reference_point: Optional QgsPointXY to measure distance from.
                         If None, distance is measured from the start of the geometry.

    Returns:
        List of QgsPointXY(distance, elevation).
    """
    dist_step = calculate_step_size(geometry, raster_layer)
    length = geometry.length()
    current_dist = 0.0
    points = []

    start_pt = reference_point if reference_point else geometry.interpolate(0).asPoint()

    while current_dist <= length:
        pt = geometry.interpolate(current_dist).asPoint()

        # Calculate distance for X axis
        if reference_point:
            dist_from_start = distance_area.measureLine(reference_point, pt)
        else:
            dist_from_start = distance_area.measureLine(start_pt, pt)

        val, ok = raster_layer.dataProvider().sample(pt, band_number)
        elev = val if ok else 0.0
        points.append(QgsPointXY(dist_from_start, elev))
        current_dist += dist_step

    return points


def create_shapefile_writer(
    output_path, crs, fields, geometry_type=QgsWkbTypes.LineString
):
    """Helper to create a QgsVectorFileWriter.

    Args:
        output_path (Path or str): Path where shapefile will be created.
        crs (QgsCoordinateReferenceSystem): CRS for the shapefile.
        fields (QgsFields): Fields definition for the shapefile.
        geometry_type (QgsWkbTypes.GeometryType): Geometry type (default: LineString).

    Returns:
        QgsVectorFileWriter: Initialized writer object.

    Raises:
        Exception: If writer creation fails.
    """
    writer = QgsVectorFileWriter(
        str(output_path), "UTF-8", fields, geometry_type, crs, "ESRI Shapefile"
    )

    if writer.hasError() != QgsVectorFileWriter.NoError:
        raise OSError(
            f"Error creating shapefile {output_path}: {writer.errorMessage()}"
        )

    return writer


def prepare_profile_context(line_lyr):
    """Prepare common context for profile operations.

    Args:
        line_lyr (QgsVectorLayer): The cross-section line layer.

    Returns:
        tuple: (line_geom, line_start, distance_area)
            - line_geom (QgsGeometry): The geometry of the section line.
            - line_start (QgsPointXY): The starting point of the line.
            - distance_area (QgsDistanceArea): Configured distance area object.

    Raises:
        ValueError: If layer has no features or invalid geometry.
    """
    line_feat = next(line_lyr.getFeatures(), None)
    if not line_feat:
        raise ValueError("Line layer has no features")

    line_geom = line_feat.geometry()
    if not line_geom or line_geom.isNull():
        raise ValueError("Line geometry is not valid")

    line_start = get_line_start_point(line_geom)
    da = create_distance_area(line_lyr.crs())

    return line_geom, line_start, da


def calculate_apparent_dip(true_strike, true_dip, line_azimuth):
    """Convert true dip to apparent dip in section plane.

    The apparent dip is the inclination of a plane measured in a direction
    not perpendicular to the strike. In a vertical cross-section, the
    apparent dip depends on the angle between the strike of the plane
    and the azimuth of the cross-section line.

    Formula:
        tan(apparent_dip) = tan(true_dip) * sin(alpha)

        Where alpha is the angle between the strike of the plane and the
        direction of the cross-section (section azimuth).
        alpha = strike - section_azimuth

    Args:
        true_strike (float): Strike of the geological plane (0-360 degrees).
        true_dip (float): True dip of the geological plane (0-90 degrees).
        line_azimuth (float): Azimuth of the cross-section line (0-360 degrees).

    Returns:
        float: Apparent dip in degrees. Positive values indicate dip,
               negative values might occur depending on quadrant but are
               typically normalized.
    """
    alpha = math.radians(true_strike)
    beta = math.radians(true_dip)
    theta = math.radians(line_azimuth)
    app_dip = math.degrees(math.atan(math.tan(beta) * math.sin(alpha - theta)))
    return app_dip


# Preview rendering utilities


def calculate_bounds(topo_data, geol_data=None):
    """Calculate min/max bounds for all data.

    Args:
        topo_data (list): List of (distance, elevation) tuples for topography.
        geol_data (list, optional): List of (distance, elevation, name) tuples for geology.

    Returns:
        dict: Dictionary containing 'min_d', 'max_d', 'min_e', 'max_e' with 5% padding.
    """
    dists = [p[0] for p in topo_data]
    elevs = [p[1] for p in topo_data]

    if geol_data:
        for _, points in geol_data:
            dists.extend([p[0] for p in points])
            elevs.extend([p[1] for p in points])

    min_d, max_d = min(dists), max(dists)
    min_e, max_e = min(elevs), max(elevs)

    # Avoid division by zero
    if max_d == min_d:
        max_d = min_d + 100
    if max_e == min_e:
        max_e = min_e + 10

    # Add 5% padding
    d_range = max_d - min_d
    e_range = max_e - min_e

    return {
        "min_d": min_d - d_range * 0.05,
        "max_d": max_d + d_range * 0.05,
        "min_e": min_e - e_range * 0.05,
        "max_e": max_e + e_range * 0.05,
    }


def create_coordinate_transform(bounds, view_w, view_h, margin, vert_exag=1.0):
    """Create coordinate transformation function.

    Args:
        bounds: Dictionary with min_d, max_d, min_e, max_e
        view_w: View width in pixels
        view_h: View height in pixels
        margin: Margin in pixels
        vert_exag: Vertical exaggeration factor (default 1.0 = no exaggeration, i.e., 1:1 scale)

    Returns:
        Function that transforms (distance, elevation) to (x, y) screen coordinates
    """
    data_w = bounds["max_d"] - bounds["min_d"]
    data_h = bounds["max_e"] - bounds["min_e"]

    # Calculate potential scales for each axis
    potential_scale_x = (view_w - 2 * margin) / data_w
    potential_scale_y = (view_h - 2 * margin) / data_h

    # Use the smaller scale as the base to ensure everything fits
    # This gives us a 1:1 aspect ratio when vert_exag = 1.0
    base_scale = min(potential_scale_x, potential_scale_y)

    # Apply base scale to both axes
    scale_x = base_scale
    scale_y = base_scale * vert_exag  # Apply vertical exaggeration

    def transform(dist, elev):
        x = margin + (dist - bounds["min_d"]) * scale_x
        y = view_h - margin - (elev - bounds["min_e"]) * scale_y
        return x, y

    return transform


def calculate_interval(data_range):
    """Calculate nice interval for axis labels.

    Args:
        data_range (float): The total range of data values.

    Returns:
        float: A 'nice' interval (e.g., 1, 2, 5, 10, etc.) for grid lines.
    """
    magnitude = 10 ** math.floor(math.log10(data_range))
    normalized = data_range / magnitude

    if normalized < 2:
        return magnitude * 0.5
    elif normalized < 5:
        return magnitude
    else:
        return magnitude * 2


def interpolate_elevation(topo_data, distance):
    """Interpolate elevation at given distance.

    Args:
        topo_data (list): List of (distance, elevation) tuples.
        distance (float): Distance at which to interpolate elevation.

    Returns:
        float: Interpolated elevation value.
    """
    if not topo_data:
        return 0

    # Find nearest points
    for i in range(len(topo_data) - 1):
        dist1, elev1 = topo_data[i]
        dist2, elev2 = topo_data[i + 1]
        # Interpolate elevation
        if dist1 <= distance <= dist2:
            ratio = (distance - dist1) / (dist2 - dist1)
            return elev1 + (elev2 - elev1) * ratio

    # Return last elevation if distance is beyond last point
    return topo_data[-1][1] if topo_data else 0


"""
Utilities for parsing geological structural measurements.

Supports:
- Numeric strike/dip (e.g. strike=345, dip=22)
- Field notation (e.g. "N 15° W", "22° SW")
"""


# ------------------------------------
#  STRIKE PARSER
# ------------------------------------
def parse_strike(value):
    """Accepts:
        - Numeric azimuth (string or int)
        - Quadrant notation ("N 30° E", "S 15° W")

    Returns:
        strike in azimuth degrees (0–360)
    """
    if value is None:
        return None

    # If already numeric, return directly
    try:
        return float(value)
    except (ValueError, TypeError):
        pass

    # Normalize value
    text = (
        str(value)
        .replace("°", "")
        .replace("º", "")
        .replace("ø", "")  # Support for alternative degree symbol
        .strip()
        .upper()
    )

    # Regex for quadrant notation: N/S + angle + E/W
    # Supports integers and decimals for the angle
    match = re.match(r"([NS])\s*(\d+\.?\d*)\s*([EW])", text)
    if not match:
        return None  # invalid notation

    d1, ang, d2 = match.groups()
    ang = float(ang)

    # Quadrant rules
    strike = 0  # Initialize to prevent NameError
    if d1 == "N" and d2 == "E":
        strike = ang
    elif d1 == "N" and d2 == "W":
        strike = 360 - ang
    elif d1 == "S" and d2 == "E":
        strike = 180 - ang
    elif d1 == "S" and d2 == "W":
        strike = 180 + ang

    return strike % 360


# ------------------------------------
#  DIP PARSER
# ------------------------------------
def parse_dip(value):
    """Accepts:
        - Numeric dip: "22", "45.5", "30.0"
        - Field notation: "22° SW", "45 NE", "10 S"
    Returns:
        (dip_angle, dip_direction_azimuth)
    """
    if value is None:
        return None, None

    text = (
        str(value)
        .replace("°", "")
        .replace("º", "")
        .replace("ø", "")  # Support for alternative degree symbol
        .strip()
        .upper()
    )

    # Case 1: numeric only (integer or decimal)
    numeric_only = re.match(r"^(\d+\.?\d*)$", text)
    if numeric_only:
        return float(text), None

    # Case 2: full dip + direction
    match = re.match(r"(\d+\.?\d*)\s*([NSEW]{1,2})", text)
    if not match:
        return None, None

    dip, cardinal = match.groups()
    dip = float(dip)

    dip_dir = cardinal_to_azimuth(cardinal)

    return dip, dip_dir


# ------------------------------------
#  Helper for converting cardinal directions to azimuth
# ------------------------------------
def cardinal_to_azimuth(text):
    """Converts:
        N, NE, E, SE, S, SW, W, NW
    Returns:
        0–360 azimuth
    """
    table = {
        "N": 0,
        "NE": 45,
        "E": 90,
        "SE": 135,
        "S": 180,
        "SW": 225,
        "W": 270,
        "NW": 315,
    }

    return table.get(text)


def calculate_drillhole_trajectory(
    collar_point, collar_z, survey_data, section_azimuth, densify_step=1.0
):
    """Calculate 3D trajectory of a drillhole using survey data.

    Uses tangential method for trajectory calculation with densification
    to generate intermediate points for continuous interval projection.

    Args:
        collar_point: QgsPointXY of collar location (X, Y)
        collar_z: Elevation of collar (Z)
        survey_data: List of tuples (depth, azimuth, inclination) sorted by depth
        section_azimuth: Azimuth of the section line in degrees
        densify_step: Distance in meters between interpolated points (default 1.0m)

    Returns:
        List of tuples (depth, x, y, z, dist_along_section, offset_from_section)
    """
    if not survey_data:
        return []

    trajectory = []

    # Start at collar
    x, y, z = collar_point.x(), collar_point.y(), collar_z
    prev_depth = 0.0

    # Add collar point
    trajectory.append((0.0, x, y, z, 0.0, 0.0))

    for depth, azimuth, inclination in survey_data:
        if depth <= prev_depth:
            continue

        # Calculate interval
        interval = depth - prev_depth

        # Convert angles to radians
        azim_rad = math.radians(azimuth)
        incl_rad = math.radians(inclination)

        # Calculate total displacement for this segment using tangential method
        # Note: Inclination convention: -90° = vertical down, 0° = horizontal
        # We need to convert to standard convention where 0° = vertical down
        # Standard: 0° down, 90° horizontal
        # Our data: -90° down, 0° horizontal
        # Conversion: standard_incl = 90 + inclination

        standard_incl_rad = math.radians(90 + inclination)

        # Vertical component (negative because Z decreases downward)
        total_dz = -interval * math.cos(standard_incl_rad)

        # Horizontal components (East, North)
        total_dx = interval * math.sin(standard_incl_rad) * math.sin(azim_rad)
        total_dy = interval * math.sin(standard_incl_rad) * math.cos(azim_rad)

        # Densify: generate intermediate points along this segment
        num_steps = max(1, int(interval / densify_step))

        for i in range(1, num_steps + 1):
            # Calculate fraction of segment
            fraction = i / num_steps

            # Interpolate depth
            interp_depth = prev_depth + interval * fraction

            # Interpolate position (linear interpolation along segment)
            interp_x = x + total_dx * fraction
            interp_y = y + total_dy * fraction
            interp_z = z + total_dz * fraction

            # Add interpolated point
            trajectory.append((interp_depth, interp_x, interp_y, interp_z, 0.0, 0.0))

        # Update position to end of segment
        x += total_dx
        y += total_dy
        z += total_dz

        prev_depth = depth

    return trajectory


def project_trajectory_to_section(trajectory, line_geom, line_start, distance_area):
    """Project drillhole trajectory points onto section line.

    Args:
        trajectory: List of (depth, x, y, z, _, _) from calculate_drillhole_trajectory
        line_geom: QgsGeometry of section line
        line_start: QgsPointXY of section line start
        distance_area: QgsDistanceArea for measurements

    Returns:
        List of tuples (depth, x, y, z, dist_along, offset)
    """
    projected = []

    for depth, x, y, z, _, _ in trajectory:
        point = QgsPointXY(x, y)
        point_geom = QgsGeometry.fromPointXY(point)

        # Find nearest point on line
        nearest_point = line_geom.nearestPoint(point_geom)
        nearest_pt_xy = nearest_point.asPoint()

        # Calculate distance along section
        dist_along = distance_area.measureLine(line_start, nearest_pt_xy)

        # Calculate offset from section
        offset = distance_area.measureLine(point, nearest_pt_xy)

        projected.append((depth, x, y, z, dist_along, offset))

    return projected


def interpolate_intervals_on_trajectory(trajectory, intervals, buffer_width):
    """Interpolate interval attributes along drillhole trajectory.

    Args:
        trajectory: List of (depth, x, y, z, dist_along, offset) tuples
        intervals: List of (from_depth, to_depth, attribute) tuples
        buffer_width: Maximum offset to include

    Returns:
        List of (attribute, list of (dist_along, elevation)) tuples
    """
    geol_segments = []

    for from_depth, to_depth, attribute in intervals:
        # Find trajectory points within this interval
        interval_points = []

        for depth, x, y, z, dist_along, offset in trajectory:
            # Check if point is within interval and buffer
            if from_depth <= depth <= to_depth and offset <= buffer_width:
                interval_points.append((dist_along, z))

        # Add segment if we have points
        if interval_points:
            geol_segments.append((attribute, interval_points))

    return geol_segments


class ShapefileExporter:
    """Handles exporting profile data to shapefiles."""

    def __init__(self, output_folder, base_name, crs=None):
        """Initialize exporter.

        Args:
            output_folder (str): Directory to save shapefiles.
            base_name (str): Base name for files (e.g. "Section_A").
            crs (QgsCoordinateReferenceSystem): CRS to assign to shapefiles.
        """
        self.output_folder = output_folder
        self.base_name = base_name
        # Use provided CRS or default to empty (no CRS)
        self.crs = crs if crs else QgsCoordinateReferenceSystem()

    def export_profile_data(self, topo_data, drillhole_data, collars_data):
        """Export all profile data to shapefiles.

        Args:
            topo_data: List of (dist, elev) tuples.
            drillhole_data: List of (hole_id, trajectory_points, geology_segments).
                            trajectory_points: List of (dist, elev).
                            geology_segments: List of (unit, points).
            collars_data: List of (hole_id, dist, elev).

        Returns:
            list: List of created file paths.
        """
        created_files = []

        # 1. Export Topography
        if topo_data:
            path = f"{self.output_folder}/{self.base_name}_topo.shp"
            if self._export_topography(topo_data, path):
                created_files.append(path)

        # 2. Export Drillhole Traces
        if drillhole_data:
            path = f"{self.output_folder}/{self.base_name}_traces.shp"
            if self._export_traces(drillhole_data, path):
                created_files.append(path)

        # 3. Export Geology Intervals
        if drillhole_data:
            path = f"{self.output_folder}/{self.base_name}_geology.shp"
            if self._export_geology(drillhole_data, path):
                created_files.append(path)

        # 4. Export Collars
        if collars_data:
            path = f"{self.output_folder}/{self.base_name}_collars.shp"
            if self._export_collars(collars_data, path):
                created_files.append(path)

        return created_files

    def _export_topography(self, topo_data, output_path):
        """Export topographic profile."""
        fields = QgsFields()
        fields.append(QgsField("id", QVariant.Int))

        writer = QgsVectorFileWriter(
            output_path,
            "UTF-8",
            fields,
            QgsWkbTypes.LineString,
            self.crs,
            "ESRI Shapefile",
        )

        if writer.hasError() != QgsVectorFileWriter.NoError:
            return False

        # Create single line feature
        points = [QgsPointXY(d, e) for d, e in topo_data]
        feat = QgsFeature()
        feat.setGeometry(QgsGeometry.fromPolylineXY(points))
        feat.setAttributes([1])
        writer.addFeature(feat)

        del writer
        return True

    def _export_traces(self, drillhole_data, output_path):
        """Export drillhole traces."""
        fields = QgsFields()
        fields.append(QgsField("HoleID", QVariant.String))

        writer = QgsVectorFileWriter(
            output_path,
            "UTF-8",
            fields,
            QgsWkbTypes.LineString,
            self.crs,
            "ESRI Shapefile",
        )

        if writer.hasError() != QgsVectorFileWriter.NoError:
            return False

        for hole_id, points, _ in drillhole_data:
            if len(points) < 2:
                continue

            line_points = [QgsPointXY(d, e) for d, e in points]
            feat = QgsFeature()
            feat.setGeometry(QgsGeometry.fromPolylineXY(line_points))
            feat.setAttributes([hole_id])
            writer.addFeature(feat)

        del writer
        return True

    def _export_geology(self, drillhole_data, output_path):
        """Export geological intervals."""
        fields = QgsFields()
        fields.append(QgsField("HoleID", QVariant.String))
        fields.append(QgsField("Unit", QVariant.String))

        writer = QgsVectorFileWriter(
            output_path,
            "UTF-8",
            fields,
            QgsWkbTypes.LineString,
            self.crs,
            "ESRI Shapefile",
        )

        if writer.hasError() != QgsVectorFileWriter.NoError:
            return False

        for hole_id, _, segments in drillhole_data:
            for unit, points in segments:
                if len(points) < 2:
                    continue

                line_points = [QgsPointXY(d, e) for d, e in points]
                feat = QgsFeature()
                feat.setGeometry(QgsGeometry.fromPolylineXY(line_points))
                feat.setAttributes([hole_id, str(unit)])
                writer.addFeature(feat)

        del writer
        return True

    def _export_collars(self, collars_data, output_path):
        """Export collar locations."""
        fields = QgsFields()
        fields.append(QgsField("HoleID", QVariant.String))
        fields.append(QgsField("Dist", QVariant.Double))
        fields.append(QgsField("Elev", QVariant.Double))

        writer = QgsVectorFileWriter(
            output_path, "UTF-8", fields, QgsWkbTypes.Point, self.crs, "ESRI Shapefile"
        )

        if writer.hasError() != QgsVectorFileWriter.NoError:
            return False

        for hole_id, dist, elev in collars_data:
            feat = QgsFeature()
            feat.setGeometry(QgsGeometry.fromPointXY(QgsPointXY(dist, elev)))
            feat.setAttributes([hole_id, float(dist), float(elev)])
            writer.addFeature(feat)

        del writer
        return True
