
import math
from qgis.core import (
    Qgis,
    QgsPoint,
    QgsGeometry,
    QgsWkbTypes,
    QgsFeatureRequest
)
from .geology import TopographicProfile, LithologicalSegment, StructuralPoint, BeddingPoint, GeologicalProfile
import matplotlib.pyplot as plt



def get_dem_resolution(raster_layer):
    extent = raster_layer.extent()
    width = raster_layer.width()
    height = raster_layer.height()

    res_x = extent.width() / width
    res_y = extent.height() / height

    return min(res_x, res_y)


def line_to_multipoints(feature_line, step, invert_line=False):
    # Get feature geometry
    geometry = feature_line.geometry()

    # Check if the feature has multiple lines
    if geometry.isMultipart():
        lines = geometry.asMultiPolyline()
        if len(lines) != 1:
            raise ValueError('The selected object has multiple lines.')
        geometry = QgsGeometry.fromPolylineXY(lines[0])

    # Check if the line has multiple vertices
    polyline = geometry.asPolyline()
    if len(polyline) != 2:
        raise ValueError('Invalid object, use a straight line (2 vertices).')

    # Invert line if enabled
    if invert_line:
        polyline.reverse()
        geometry = QgsGeometry.fromPolylineXY(polyline)

    # Calculate line length
    line_length = geometry.length()

    # Get interpolation step
    n_intervals = max(1, int(math.ceil(line_length / step)))
    interval = line_length / n_intervals

    # Generate points along the line at regular intervals
    points = []

    for i in range(n_intervals + 1):
        distance = i * interval
        if distance > line_length:
            distance = line_length
        point = geometry.interpolate(distance).asPoint()
        points.append(point)

    last_point = geometry.interpolate(line_length).asPoint()

    if points[-1] != last_point:
        points.append(last_point)

    return points


def calculate_distance(p1, p2):

    dx = p2.x() - p1.x()
    dy = p2.y() - p1.y()

    return math.sqrt(dx**2 + dy**2)


def calculate_azimuth(p1, p2):

    azimuth_rad = math.atan2(p2.x() - p1.x(), p2.y() - p1.y())
    azimuth = math.degrees(azimuth_rad)

    if azimuth < 0:
        azimuth += 360

    return azimuth


def calculate_alpha(az1, az2):

    angle = abs(az1 - az2)

    return min(angle, abs(180 - angle))


def get_intersection(plane, azimuth, dip, topography):
    # Punto base del plano
    x0, y0, z0 = plane.x(), plane.y(), plane.z()

    # Calculo del vector normal al plano
    dip_direction = (azimuth + 90) % 360  # regla de la mano derecha

    dip_direction_rad = math.radians(dip_direction)
    dip_rad = math.radians(dip)

    a = math.sin(dip_rad) * math.sin(dip_direction_rad)
    b = math.sin(dip_rad) * math.cos(dip_direction_rad)
    c = math.cos(dip_rad)

    # === RECORRER SEGMENTOS ===
    point = None
    for i in range(len(topography) - 1):
        p1 = topography[i]
        p2 = topography[i + 1]

        # Evaluar el plano en los extremos
        f1 = a * (p1.x() - x0) + b * (p1.y() - y0) + c * (p1.z() - z0)
        f2 = a * (p2.x() - x0) + b * (p2.y() - y0) + c * (p2.z() - z0)

        if f1 * f2 <= 0:  # hay interseccion o uno toca el plano

            # Calculo del vector director
            dx = p2.x() - p1.x()
            dy = p2.y() - p1.y()
            dz = p2.z() - p1.z()

            denom = (a * dx) + (b * dy) + (c * dz)
            if denom == 0:
                continue  # segmento paralelo al plano

            t = -(a * (p1.x() - x0) + b * (p1.y() - y0) + c * (p1.z() - z0)) / denom

            if 0 <= t <= 1:
                # Punto de interseccion
                xi = p1.x() + (t * dx)
                yi = p1.y() + (t * dy)
                zi = p1.z() + (t * dz)

                point = QgsPoint(xi, yi, zi)
                break

    return point


def calculate_apparent_dip(dip, alpha):

    dip_rad = math.radians(dip)
    alpha_rad = math.radians(alpha)
    apparent_dip_rad = math.atan(math.tan(dip_rad) * math.sin(alpha_rad))

    return math.degrees(apparent_dip_rad)


def get_dip_inclination(section_azimuth, dip_direction):

    angle_diff = dip_direction - section_azimuth

    if angle_diff > 180:
        angle_diff -= 360
    elif angle_diff <= - 180:
        angle_diff += 360

    if abs(angle_diff) == 90 or abs(angle_diff) == 270:
        return 'normal'
    elif abs(angle_diff) < 90:
        return 'right'
    else:
        return 'left'


def create_topographic_profile(line, dem, invert_line=False):

    # get steps from dem resolution
    step = get_dem_resolution(raster_layer=dem)

    # line to multipoints
    points = line_to_multipoints(feature_line=line, step=step, invert_line=invert_line)

    # Extract topography data
    start_point = points[0]
    dem_provider = dem.dataProvider()

    distances = []
    elevations = []

    for point in points:
        distance = calculate_distance(p1=start_point, p2=point)
        elevation = dem_provider.sample(point, 1)[0]
        distances.append(distance)
        elevations.append(elevation)

    return TopographicProfile(distances, elevations)


def create_geological_profile(line, dem, invert_line, lithology, lithology_names, lithology_colors,
                              structural, structural_names, structural_color,
                              bedding, bedding_azimuth, bedding_dip, bedding_buffer):

    # get steps from dem resolution
    step = get_dem_resolution(raster_layer=dem)

    # line to multipoints
    section_points = line_to_multipoints(feature_line=line, step=step, invert_line=invert_line)

    # Points to line geometry
    line_geometry = QgsGeometry.fromPolylineXY(section_points)  # LineString

    # Variables
    start_point = section_points[0]
    dem_provider = dem.dataProvider()

    # Extract topography and lithology data
    visited_coordinates = set()
    points_3d = {}
    lithological_segments = []

    for litho in lithology.getFeatures():

        lithology_geometry = litho.geometry()
        litho_intersects = line_geometry.intersection(lithology_geometry)

        if litho_intersects.isEmpty():
            continue

        if litho_intersects.isMultipart():
            segments = litho_intersects.asMultiPolyline()
        else:
            segments = [litho_intersects.asPolyline()]

        name = litho[lithology_names]
        color = lithology_colors[name]

        for segment in segments:

            distances = []
            elevations = []
            for point in segment:
                distance = round(calculate_distance(p1=start_point, p2=point), 4)
                elevation = dem_provider.sample(point, 1)[0]

                distances.append(distance)
                elevations.append(elevation)

                coordinate = (point.x(), point.y())
                if coordinate not in visited_coordinates:
                    points_3d[distance] = (QgsPoint(point.x(), point.y(), elevation))
                    visited_coordinates.add(coordinate)

            lithological_segment = LithologicalSegment(distances=distances,
                                                       elevations=elevations,
                                                       name=name,
                                                       color=color)
            lithological_segments.append(lithological_segment)

    topography = [points_3d[dist] for dist in sorted(points_3d.keys())]

    # Extract structural data
    structural_points = None

    if structural:

        structural_points = []

        for structure in structural.getFeatures():

            structure_geometry = structure.geometry()
            structure_intersects = line_geometry.intersection(structure_geometry)

            if structure_intersects.isEmpty():
                continue

            name = structure[structural_names]

            if structure_intersects.type() == QgsWkbTypes.PointGeometry:
                if structure_intersects.isMultipart():
                    points = structure_intersects.asMultiPoint()
                else:
                    points = [structure_intersects.asPoint()]
            else:
                continue

            for point in points:
                distance = calculate_distance(p1=start_point, p2=point)
                elevation = dem_provider.sample(point, 1)[0]

                structural_point = StructuralPoint(distance=distance,
                                                   elevation=elevation,
                                                   name=name,
                                                   color=structural_color)

                structural_points.append(structural_point)

    # Extract bedding data
    bedding_points = None

    if bedding:

        # Line section buffer
        buffer_geometry = line_geometry.buffer(bedding_buffer, 2, Qgis.EndCapStyle.Flat, Qgis.JoinStyle.Miter, 1.0)
        buffer_rectangle = buffer_geometry.boundingBox()
        buffer_request = QgsFeatureRequest().setFilterRect(buffer_rectangle)

        # Line section azimuth
        section_azimuth = calculate_azimuth(p1=start_point, p2=section_points[-1])

        bedding_points = []

        for bed in bedding.getFeatures(buffer_request):

            bed_geometry = bed.geometry()

            if not buffer_geometry.contains(bed_geometry):
                continue

            if bed_geometry.isMultipart():
                raise ValueError('The bedding layer has multiple points in a feature.')

            bed_point = bed_geometry.asPoint()
            bed_elevation = dem_provider.sample(bed_point, 1)[0]
            bed_plane = QgsPoint(bed_point.x(), bed_point.y(), bed_elevation)

            bed_azimuth = bed[bedding_azimuth]
            bed_dip = bed[bedding_dip]

            bed_intersection = get_intersection(plane=bed_plane, azimuth=bed_azimuth, dip=bed_dip,
                                                topography=topography)

            distance = calculate_distance(p1=start_point, p2=bed_intersection)

            alpha = calculate_alpha(az1=section_azimuth, az2=bed_azimuth)
            dip_direction = (bed_azimuth + 90) % 360

            apparent_dip = calculate_apparent_dip(dip=bed_dip, alpha=alpha)
            dip_inclination = get_dip_inclination(section_azimuth=section_azimuth, dip_direction=dip_direction)
            print(dip_inclination)

            bed_point = BeddingPoint(distance=distance,
                                     elevation=bed_intersection.z(),
                                     dip_inclination=dip_inclination,
                                     apparent_dip=apparent_dip)

            bedding_points.append(bed_point)


    geological_profile = GeologicalProfile(lithologies=lithological_segments,
                                           structures=structural_points,
                                           bedding=bedding_points)

    return geological_profile


def get_figure_size(size, orientation='landscape'):

    mm_per_inch = 25.4

    paper_sizes = {
        'A0': {
            'portrait': (841 / mm_per_inch, 1189 / mm_per_inch),
            'landscape': (1189 / mm_per_inch, 841 / mm_per_inch),
        },
        'A1': {
            'portrait': (594 / mm_per_inch, 841 / mm_per_inch),
            'landscape': (841 / mm_per_inch, 594 / mm_per_inch),
        },
        'A2': {
            'portrait': (420 / mm_per_inch, 594 / mm_per_inch),
            'landscape': (594 / mm_per_inch, 420 / mm_per_inch),
        },
        'A3': {
            'portrait': (297 / mm_per_inch, 420 / mm_per_inch),
            'landscape': (420 / mm_per_inch, 297 / mm_per_inch),
        },
        'A4': {
            'portrait': (210 / mm_per_inch, 297 / mm_per_inch),
            'landscape': (297 / mm_per_inch, 210 / mm_per_inch),
        },
        'A5': {
            'portrait': (148 / mm_per_inch, 210 / mm_per_inch),
            'landscape': (210 / mm_per_inch, 148 / mm_per_inch),
        },
    }

    return paper_sizes[size][orientation]


def save_geological_profile(path, lithology, depth, x_limits, y_limits, structures=None, bedding=None,
                            bedding_length=None, title=None, y_label=None, start_label=None, end_label=None,
                            size=(8, 5)):

    fig, ax = plt.subplots(figsize=size, dpi=300)
    ax.set_ylabel(y_label, fontsize=9)
    ax.tick_params(axis='both', which='both', labelsize=8, labelright=True, right=True)
    ax.spines.top.set_visible(False)
    ax.spines.left.set_visible(False)
    ax.spines.right.set_visible(False)
    ax.spines.bottom.set_linewidth(1)
    ax.minorticks_on()
    ax.set_axisbelow(True)

    ax.set_title(title, fontsize=11, pad=25)

    ax.annotate(start_label, xy=(0, 0), xytext=(0, 1.05), xycoords='axes fraction', textcoords='axes fraction',
                fontsize=9, ha='center',
                arrowprops=dict(arrowstyle="<-, head_width=0.35", lw=1, shrinkA=0, shrinkB=0))
    ax.annotate(end_label, xy=(1, 0), xytext=(1, 1.05), xycoords='axes fraction', textcoords='axes fraction',
                fontsize=9, ha='center',
                arrowprops=dict(arrowstyle="<-, head_width=0.35", lw=1, shrinkA=0, shrinkB=0))

    for l in lithology:
        ax.plot(l.distances, l.elevations, label=l.name, color=l.color)

    ax.set_xlim(x_limits[0], x_limits[1])
    ax.set_ylim(y_limits[0] - depth, y_limits[1] + (0.1 * x_limits[1]))

    if structures:
        s_x = [s.distance for s in structures]
        s_y = [s.elevation for s in structures]
        s_colors = [s.color for s in structures]
        s_names = [s.name for s in structures]

        ax.scatter(s_x, s_y, color=s_colors, marker='+', s=80, label=s_names)

        for i, name in enumerate(s_names):
            ax.annotate(name, xy=(s_x[i], s_y[i] + (0.015 * x_limits[1])),
                             xytext=(s_x[i], y_limits[1] + (0.05 * x_limits[1])), fontsize=9, ha='center',
                             arrowprops=dict(arrowstyle="->, head_width=0.25", lw=1, shrinkA=0, shrinkB=0))

    if bedding:

        length = bedding_length if bedding_length else 50

        for bed in bedding:
            beta = math.radians(bed.apparent_dip)
            dx = length * math.cos(beta)
            dz = - length * math.sin(beta)

            if bed.dip_inclination == 'left':
                dx = - dx

            ax.plot([bed.distance, bed.distance + dx], [bed.elevation, bed.elevation + dz], color='black')

    ax.set_aspect('equal', adjustable='box')

    fig.savefig(path, bbox_inches='tight')
    plt.close(fig)
