import os
import datetime
#from commonTools import printProgressBar, read_config
import sys
from shapely.geometry import Polygon,MultiPolygon,Point,LinearRing,LineString,MultiLineString,Polygon,MultiPoint
import numpy as np
from scipy.spatial import cKDTree
from geopandas import GeoDataFrame
import geopandas as gpd
from collections import defaultdict
import shutil
from datetime import datetime
sys.path.append(os.path.abspath('../telemac-mascaret/scripts/python3'))
from pyteltools.slf import Serafin
from pyteltools.slf.variable.variables_2d import basic_2D_vars_IDs
from data_manip.extraction.telemac_file import TelemacFile
#from data_manip.extraction.telemac_file import write_selafin

import numpy as np

from shapely.strtree import STRtree

try:
    # TELEMAC Python modules
    from pretel.extract_contour import extract_contour, write_gis_file
    from pretel.manip_telfile import alter
    from run_telfile import alter_parser, contour_parser

    HAS_TELEMAC = True
except ImportError:
    HAS_TELEMAC = False



def lire_slf(fichier):
    slf = TelemacFile(fichier)
    coords = list(zip(slf.meshx, slf.meshy))
    ikle = slf.ikle2  # shape (n_elements, 3) — base 0
    return coords, ikle

def ecrire_t3s(fichier_t3s, coords, ikle, source_files):
    #now = datetime.datetime.now().strftime("%a, %b %d, %Y %I:%M %p")
    with open(fichier_t3s, "w") as f:
        f.write("#########################################################################\n")
        f.write(":FileType t3s  ASCII  EnSim 1.0\n")
        f.write("# Canadian Hydraulics Centre/National Research Council (c) 1998-2012\n")
        f.write("# DataType                 2D T3 Scalar Mesh\n#\n")
        f.write(":Application              BlueKenue\n")
        f.write(":Version                  3.3.4\n")
        f.write(":WrittenBy                fusion_script\n")
        f.write(f":CreationDate             now\n#\n")
        f.write("#------------------------------------------------------------------------\n")
        for src in source_files:
            f.write(f":SourceFile   {src}\n")
        f.write("#\n#\n")
        f.write(":AttributeName 1 BOTTOM\n")
        f.write(":AttributeUnits 1 M\n#\n")
        f.write(f":NodeCount {len(coords)}\n")
        f.write(f":ElementCount {len(ikle)}\n")
        f.write(":ElementType  T3\n#\n")
        f.write(":EndHeader\n#\n")
        f.write(f"{len(coords)}\n")


        for i, (x, y) in enumerate(coords, start=1):
            f.write(f"{x:.6f} {y:.6f} 0.0\n")

        for i, (a, b, c) in enumerate(ikle, start=1):
            f.write(f"{a+1} {b+1} {c+1}\n")

class TriangleIndexe:
    def __init__(self, polygon, index, points,origine):
        self.polygon = polygon
        self.index = index
        self.points = points
        self.origine = origine

def resample_line(line, step=1.0):
    """Réinterpole une LineString à un pas régulier (step)"""
    if line.length == 0:
        return line
    num_points = max(int(np.ceil(line.length / step)), 2)
    distances = np.linspace(0, line.length, num_points)
    points = [line.interpolate(d) for d in distances]
    coords = [(pt.x, pt.y) for pt in points]
    return LineString(coords)
def filtrer_superpositions(coords, ikle,origine_triangle, seuil_aire=1, distance_seuil=1e-3):
    # Étape 1 : créer les triangles valides
    triangles = []
    polygons = []

    for idx, (a, b, c) in enumerate(ikle):
        poly = Polygon([coords[a], coords[b], coords[c]])

        if poly.is_valid and poly.area > seuil_aire:
            triangles.append(TriangleIndexe(poly, idx, {a, b, c},origine_triangle[idx]))
            polygons.append(poly)

    tree = STRtree(polygons)
    supprimes = set()

    # Étape 3 : parcours et détection des conflits
    for t in triangles:
        if t.index in supprimes:
            continue
        voisins = tree.query(t.polygon)

        for id_voisin_poly in voisins:
            if id_voisin_poly == t.index :
                continue

            voisin_obj = triangles[id_voisin_poly]
            if not voisin_obj or voisin_obj.index in supprimes:
                continue

            # Supprimer uniquement l'élément d’indice plus grand
            if voisin_obj.index <= t.index:
                continue
            intersection = t.polygon.intersection(voisin_obj.polygon)

            if type(intersection) == Polygon:
                is_superpose = (
                        (intersection.area > seuil_aire)
                        or (t.polygon.distance(voisin_obj.polygon) < distance_seuil)
                )
                if is_superpose and t.origine != voisin_obj.origine:
                    communs = t.points.intersection(voisin_obj.points)

                    if len(communs) < 2:
                        supprimes.add(voisin_obj.index)
                        # On NE break PAS ici, on continue à tester les autres voisins# ← Stop dès qu’un conflit justifie la suppression

    # Étape 4 : nettoyage des éléments supprimés
    ikle_filtre = [tri for i, tri in enumerate(ikle) if i not in supprimes]

    # Étape 5 : suppression des points orphelins + renumérotation
    used_pts = sorted(set(i for tri in ikle_filtre for i in tri))
    map_pts = {old: new for new, old in enumerate(used_pts)}
    coords_nettoyes = [coords[i] for i in used_pts]
    ikle_final = [(map_pts[a], map_pts[b], map_pts[c]) for (a, b, c) in ikle_filtre]

    return coords_nettoyes, ikle_final

def separate_bank_to_i2s(bank_file,work_directory,shp =False ):
    if not type(bank_file) == GeoDataFrame:
        gdf_banks = gpd.read_file(bank_file)  # GeoDataFrame avec un ou plusieurs polygones
    else:
        gdf_banks = bank_file

    for i, row in gdf_banks.iterrows():
        bank_line = row['geometry']
        name_bank = row['Name']
        gdf_temp = GeoDataFrame()
        gdf_temp.loc[0, 'geometry'] = bank_line
        if shp:
            gdf_temp.to_file(os.path.join(work_directory, name_bank +'.shp'))

        with open(os.path.join(work_directory, name_bank +'.i2s'), 'w') as f:
            f.write("#########################################################################\n")
            f.write(":FileType i2s  ASCII  EnSim 1.0\n")
            f.write(":Application BlueKenue\n")
            f.write(":Version 3.3.4\n")
            f.write(":WrittenBy density_map_from_poly\n")
            f.write(":AttributeUnits 1 m\n")
            f.write(":EndHeader\n")


            f.write(f"{len(bank_line.coords)}\n")
            for coord in bank_line.coords:
                f.write(f"{coord[0]:.6f} {coord[1]:.6f}\n")





def reconnect_maillage(coords, ikle, tolerance=1e-6):
    """
    Nettoie un maillage : supprime les points orphelins, dédoublonne les sommets proches,
    corrige les indices de connectivité et supprime les triangles dupliqués.

    Args:
        coords (list of tuple): liste (x, y)
        ikle (list of tuple): liste (i1, i2, i3) indices des triangles
        tolerance (float): distance de fusion pour les points proches

    Returns:
        coords_new (list of tuple)
        ikle_new (list of tuple)
    """
    coords_np = np.array(coords)

    # Étape 1 : dédoublonnage des points proches
    kdtree = cKDTree(coords_np)
    n = len(coords_np)
    point_map = np.arange(n)
    visited = np.zeros(n, dtype=bool)

    for i in range(n):
        if visited[i]:
            continue
        voisins = kdtree.query_ball_point(coords_np[i], r=tolerance)
        for j in voisins:
            if j != i:
                point_map[j] = i
                visited[j] = True

    # Étape 2 : appliquer la map aux triangles
    ikle_remap = []
    for tri in ikle:
        a, b, c = [point_map[i] for i in tri]
        if len(set((a, b, c))) == 3:
            ikle_remap.append(tuple(sorted((a, b, c))))

    # Étape 3 : supprimer les triangles dupliqués
    ikle_unique = list(set(ikle_remap))

    # Étape 4 : suppression des points non utilisés
    used_points = sorted(set(i for tri in ikle_unique for i in tri))
    remap_index = {old: new for new, old in enumerate(used_points)}
    coords_new = [coords[i] for i in used_points]
    ikle_new = [(remap_index[a], remap_index[b], remap_index[c]) for (a, b, c) in ikle_unique]

    return coords_new, ikle_new
def detect_crest_from_area_acc(gdf_polygone, gdf_crest):


        try:
            joined = gpd.sjoin(gdf_crest, gdf_polygone, how='inner', predicate='intersects')
        except TypeError:
            # fallback pour anciennes versions de geopandas
            joined = gpd.sjoin(gdf_crest, gdf_polygone, how='inner', op='intersects')
        # calculer l'intersection exacte (garde index des polygones si besoin)
        joined['geometry'] = joined.apply(
            lambda r: r.geometry.intersection(gdf_polygone.loc[r['index_right'], 'geometry']),axis=1)
        gdf_intersection = gpd.GeoDataFrame(joined[['geometry']], geometry='geometry', crs=gdf_polygone.crs)
        gdf_intersection = gdf_intersection[~gdf_intersection.geometry.is_empty]
        gdf_intersection = gdf_intersection[
            gdf_intersection.geometry.type.isin(["LineString", "MultiLineString"])]
        gdf_intersection = gdf_intersection.reset_index(drop=True)
        gdf_intersection = gdf_intersection.drop_duplicates(subset='geometry')
        return gdf_intersection

from shapely.ops import unary_union
def density_map_from_poly(polyfile,step, dilate_distance, boundary_line = None, gdf_bank = None,save_path =None):

    if not type(polyfile) == GeoDataFrame:
        gdf_polys = gpd.read_file(polyfile)  # GeoDataFrame avec un ou plusieurs polygones
    else:
        gdf_polys =polyfile

    polygons = [geom for geom in gdf_polys.geometry if geom.geom_type == 'Polygon' or geom.geom_type == 'LineString']

    if boundary_line:
        if not type(boundary_line ) == GeoDataFrame:
            gdf_boundary = gpd.read_file(boundary_line)  # GeoDataFrame avec un ou plusieurs polygones
        else:
            gdf_boundary =boundary_line

        boundary_buffered = GeoDataFrame()
        for i, row in gdf_boundary.iterrows():
            if isinstance(row['geometry'], Polygon) and row['geometry'].exterior.is_ring:
                boundary_buffered.loc[i,'geometry'] = row['geometry'].boundary.buffer(5*dilate_distance)

    # Créer un MultiPolygon
    #multi = MultiPolygon(polygons)


    buffers = []
    for poly in polygons:
        # Transformer l'exterior en LineString
        if poly.geom_type == 'Polygon':
            line = LineString(poly.exterior.coords)
        elif poly.geom_type == 'LineString':
            line = poly


        if line.buffer(dilate_distance).geom_type == "Polygon" :
            line_ext = line.buffer(dilate_distance).exterior
            buffers.append(cut_linestring(line_ext,dilate_distance))

        elif line.buffer(dilate_distance).geom_type == "MultiPolygon":
            polys = list(line.buffer(dilate_distance).geoms)
            for pol in polys:
                line_ext = pol.exterior
                buffers.append(cut_linestring(line_ext,dilate_distance))



    #outlines_raw =buffers
    def resample_line(line, step):
        length = line.length
        n_points = max(2, int(length / step))
        new_line = LineString([line.interpolate(float(n) / n_points, normalized=True) for n in range(n_points + 1)])
        return new_line

    outlines = [resample_line(outline, step) for outline in buffers]


    #new outline remove if to close to given lines
    if boundary_line:
        new_outlines =[]
        for outline in outlines:
            new_coords =[]
            for pt in outline.coords:
                for i, row in boundary_buffered.iterrows():
                    if not row['geometry'].contains(Point(pt)) and gdf_boundary.loc[0,'geometry'].contains(Point(pt)):
                        new_coords.append(pt)

            if len(new_coords) > 1:
                if new_coords[0] != new_coords[-1]:  # refermer si besoin
                    new_coords.append(new_coords[0])
                new_outlines.append(LineString(new_coords))

        outlines = new_outlines


    #distance au berges
    if gdf_bank is not None:
        bank_buffered = GeoDataFrame()
        for i, row in gdf_bank.iterrows():
            if isinstance(row['geometry'], LineString) :
                bank_buffered.loc[i, 'geometry'] = row['geometry'].buffer(2 * dilate_distance)

        bank_list = list(bank_buffered.geometry.exterior)

    if save_path:
        gdf_buffer_to_export = gpd.GeoDataFrame(geometry=outlines)
        gdf_buffer_to_export.to_file(os.path.join(save_path, 'dike_density.shp'))
        with open(os.path.join(save_path,'density_contour_' + str(step)+ 'm_all.i2s'), 'w') as f:
            f.write("#########################################################################\n")
            f.write(":FileType i2s  ASCII  EnSim 1.0\n")
            f.write(":Application BlueKenue\n")
            f.write(":Version 3.3.4\n")
            f.write(":WrittenBy density_map_from_poly\n")
            f.write(":AttributeUnits 1 m\n")
            f.write(":EndHeader\n")

            for line in outlines:
                #f.write("Polyline 1\n")  # attribut 1 (arbitraire)
                f.write(f"{len(line.coords)} {step}\n")

                for coord in line.coords:
                    f.write(f"{coord[0]:.6f} {coord[1]:.6f}\n")
        if gdf_bank is not None:
            gdf_bank_to_export = gpd.GeoDataFrame( geometry=bank_list)
            gdf_ls = rings_to_linestrings(gdf_bank_to_export)

            gdf_ls.to_file(os.path.join(save_path,'bank_density.shp'))

            with open(os.path.join(save_path,'density_bank_' + str(step)+ 'm_all.i2s'), 'w') as f:
                f.write("#########################################################################\n")
                f.write(":FileType i2s  ASCII  EnSim 1.0\n")
                f.write(":Application BlueKenue\n")
                f.write(":Version 3.3.4\n")
                f.write(":WrittenBy density_map_from_poly\n")
                f.write(":AttributeUnits 1 m\n")
                f.write(":EndHeader\n")

                for line in bank_list:
                    #f.write("Polyline 1\n")  # attribut 1 (arbitraire)
                    print(line)

                    f.write(f"{len(line.coords)} {step}\n")

                    for coord in line.coords:
                        f.write(f"{coord[0]:.6f} {coord[1]:.6f}\n")

        with open(os.path.join(save_path,'density_openline_' + str(step)+ 'm_all.i2s'), 'w') as f:
            f.write("#########################################################################\n")
            f.write(":FileType i2s  ASCII  EnSim 1.0\n")
            f.write(":Application BlueKenue\n")
            f.write(":Version 3.3.4\n")
            f.write(":WrittenBy density_map_from_poly\n")
            f.write(":AttributeUnits 1 m\n")
            f.write(":EndHeader\n")

            for line in outlines:
                #f.write("Polyline 1\n")  # attribut 1 (arbitraire)
                f.write(f"{len(line.coords)-1} {step}\n")
                for coord in line.coords[:-1]:
                    f.write(f"{coord[0]:.6f} {coord[1]:.6f}\n")

def rings_to_linestrings(gdf):
    """Convertit tous les LinearRing en LineString dans un GeoDataFrame"""
    new_geoms = []
    for geom in gdf.geometry:
        if isinstance(geom, LinearRing):
            # On enlève le dernier point qui duplique le premier
            new_geoms.append(LineString(geom.coords[:-1]))
        else:
            new_geoms.append(geom)
    gdf = gdf.copy()
    gdf["geometry"] = new_geoms
    return gdf

def cut_linestring(ring, step, tol=1e-6):

    coords = list(ring.coords[:-1])

    if len(coords) < 2:
        return None

    i, j = 0, len(coords) - 1
    best_line = LineString(coords[i:j+1])
    best_dist = Point(coords[i]).distance(Point(coords[j]))

    while i < j:
        p1, p2 = Point(coords[i]), Point(coords[j])
        d = p1.distance(p2)

        # mettre à jour la meilleure ligne
        if abs(d - step) < abs(best_dist - step):
            best_line = LineString(coords[i:j+1])
            best_dist = d

        # arrêter si suffisamment proche
        if abs(d - step) <= tol:
            break

        # réduire la ligne
        if d > step:
            if Point(coords[i+1]).distance(p2) > Point(p1).distance(Point(coords[j-1])):
                i += 1
            else:
                j -= 1
        else:
            break

    return best_line

def write_i2s(gdf,work_directory,size_mesh,name):

    with open(os.path.join(work_directory, name +'_' + str(size_mesh) + 'm.i2s'), 'w') as f:
        f.write("#########################################################################\n")
        f.write(":FileType i2s  ASCII  EnSim 1.0\n")
        f.write(":Application BlueKenue\n")
        f.write(":Version 3.3.4\n")
        f.write(":WrittenBy density_map_from_poly\n")
        f.write(":AttributeUnits 1 m\n")
        f.write(":EndHeader\n")

        for i, row in gdf.iterrows():
            line_raw = row['geometry']
            line = resample_line(line_raw, size_mesh)
            f.write(f"{len(line.coords)} {size_mesh}\n")
            if line.coords[0] == line.coords[-1]:
                line = LineString(line.coords[:-1])
            for coord in line.coords:
                f.write(f"{coord[0]:.6f} {coord[1]:.6f}\n")

def write_selafin(slf_file, meshx, meshy, ikle, ipobo, values, times, varnames=['BOTTOM'], varunits=['M']):
    slf = TelemacFile(slf_file)

    # Header général
    slf.title = 'SLF créé depuis T3S'
    slf.varnames = varnames
    slf.varunits = varunits
    slf.nbv1 = len(varnames)

    slf.meshx = np.array(meshx)
    slf.meshy = np.array(meshy)
    slf.npoin2 = len(meshx)
    slf.nelem2 = len(ikle)
    slf.ndp2 = 3
    slf.ikle2 = np.array(ikle, dtype=np.int32)
    slf.ipobo = np.array(ipobo, dtype=np.int32)

    # Écriture du header
    slf.append_header()

    # Écriture des pas de temps
    for t, valset in zip(times, values):
        slf.append_core_time_step(t, valset)

    slf.close()
    print(f"Fichier SLF écrit : {slf_file}")


from shapely.geometry import Polygon, LineString
from shapely.ops import split


def cut_closed_lines(closed_lines, polygons):
    """
    closed_lines : list of LineString (closed, so actually rings)
    polygons : list of Polygon
    return : list of LineString (open) after removing the parts inside polygons
    """
    result = []

    # union of all mask polygons
    mask = polygons[0]
    for p in polygons[1:]:
        mask = mask.union(p)

    for line in closed_lines:
        # make sure the line is actually closed
        if not line.is_ring:
            continue

            # remove the part inside the mask
        diff = line.difference(mask)

        # the difference may be:
        # - a single LineString
        # - multiple LineStrings
        # - empty
        if diff.is_empty:
            continue
        elif diff.geom_type == "LineString":
            result.append(diff)
        elif diff.geom_type == "MultiLineString":
            result.extend(diff.geoms)

    return result



def adjust_point_to_buffer(pt, gdf_buffer, size_mesh):
    """Si le point est trop proche d'une ligne du buffer, le déplacer dessus"""
    if gdf_buffer is None or gdf_buffer.empty:
        return (pt.x, pt.y)

    min_dist = np.inf
    nearest_geom = None
    for buf_line in gdf_buffer.geometry:
        d = pt.distance(buf_line)
        if d < min_dist:
            min_dist = d
            nearest_geom = buf_line

    if min_dist < size_mesh and nearest_geom is not None:
        nearest_pt = nearest_geom.interpolate(nearest_geom.project(pt))
        return (nearest_pt.x, nearest_pt.y)
    else:
        return (pt.x, pt.y)

def enforce_min_distance(
    gdf,
    size_mesh,
    gdf_lignes_buffer=None,
    use_buffer_distance=True,
    iterations=1,
    step=1.0
):
    """
    Décale les points des lignes pour maintenir une distance minimale :
      - avec la ligne appairée (0↔1, 2↔3, ...)
      - avec les lignes du gdf_lignes_buffer si fourni

    Toutes les lignes sont réinterpolées au pas size_mesh pour détecter les segments trop proches.
    """
    # 🔹 explode MultiLineString
    if gdf.geometry.type.isin(["MultiLineString"]).any():
        gdf = gdf.explode(index_parts=False).reset_index(drop=True)

    # 🔹 réinterpolation des lignes principales
    coords_list = [list(resample_line(line, size_mesh).coords) for line in gdf.geometry]

    # 🔹 réinterpolation du buffer si demandé
    if gdf_lignes_buffer is not None and use_buffer_distance:
        #buffer_geoms = [resample_line(line, 1.0) for line in gdf_lignes_buffer.geometry]
        buffer_geoms = [line for line in gdf_lignes_buffer.geometry]
        gdf_lignes_buffer = gpd.GeoDataFrame(geometry=buffer_geoms, crs=gdf_lignes_buffer.crs)

    # 🔹 itérations pour stabiliser
    for _ in range(iterations):
        # traiter par paires successives : 0↔1, 2↔3, ...
        for i in range(0, len(coords_list), 2):
            print(i)
            if i+1 >= len(coords_list):
                break  # pas de pair si nombre impair de lignes

            line1 = coords_list[i]
            line2 = coords_list[i+1]

            # ajuster line1
            new_line1 = []
            for p in line1:
                pt = Point(p[:2])
                # ajustement vers buffer si nécessaire
                if use_buffer_distance:
                    p_new = adjust_point_to_buffer(pt, gdf_lignes_buffer, size_mesh)
                    pt = Point(p_new)
                # ajustement par rapport à la ligne appairée
                d_pair = pt.distance(LineString([q[:2] for q in line2]))
                if d_pair < size_mesh:
                    nearest = LineString([q[:2] for q in line2]).interpolate(
                        LineString([q[:2] for q in line2]).project(pt)
                    )
                    v = np.array(pt.coords[0][:2]) - np.array(nearest.coords[0])
                    if np.linalg.norm(v) > 0:
                        v = v / np.linalg.norm(v)
                        shift = (size_mesh - d_pair) * v * step
                        new_coords = tuple(np.array(pt.coords[0][:2]) + shift)
                    else:
                        new_coords = pt.coords[0][:2]
                else:
                    new_coords = pt.coords[0][:2]
                new_line1.append(new_coords)

            # ajuster line2
            new_line2 = []
            for p in line2:
                pt = Point(p[:2])
                if use_buffer_distance:
                    p_new = adjust_point_to_buffer(pt, gdf_lignes_buffer, size_mesh)
                    pt = Point(p_new)
                d_pair = pt.distance(LineString([q[:2] for q in line1]))
                if d_pair < size_mesh:
                    nearest = LineString([q[:2] for q in line1]).interpolate(
                        LineString([q[:2] for q in line1]).project(pt)
                    )
                    v = np.array(pt.coords[0][:2]) - np.array(nearest.coords[0])
                    if np.linalg.norm(v) > 0:
                        v = v / np.linalg.norm(v)
                        shift = (size_mesh - d_pair) * v * step
                        new_coords = tuple(np.array(pt.coords[0][:2]) + shift)
                    else:
                        new_coords = pt.coords[0][:2]
                else:
                    new_coords = pt.coords[0][:2]
                new_line2.append(new_coords)

            coords_list[i] = new_line1
            coords_list[i+1] = new_line2

    # reconstruire les LineString
    new_geoms = [LineString(c) for c in coords_list if len(c) > 1]
    new_gdf = gdf.copy()
    new_gdf.geometry = new_geoms
    return new_gdf


from shapely.ops import unary_union

def extract_lines_outside_buffers(gdf_poly, gdf_line1, gdf_line2, buffer_dist=0.1):
    """
    Pour chaque polygone dans gdf_poly, extrait les portions de son contour extérieur
    qui ne sont pas dans le buffer des lignes données.
    Retourne des LineStrings ouvertes (pas de point dupliqué de fermeture).
    """
    # Fusionner les buffers une seule fois
    buffer_union = unary_union([
        gdf_line1.buffer(buffer_dist).unary_union,
        gdf_line2.buffer(buffer_dist).unary_union
    ])

    new_lines = []

    for poly in gdf_poly.geometry:

        if poly is None:
            continue

        # Créer une ligne ouverte à partir du contour extérieur
        line = LineString(poly.exterior.coords[:-1])

        # Supprimer les parties qui tombent dans le buffer
        outside = line.difference(buffer_union)

        # outside peut être LineString ou MultiLineString
        if outside.is_empty:
            continue
        elif isinstance(outside, LineString):
            new_lines.append(outside)
        elif isinstance(outside, MultiLineString):
            new_lines.extend(list(outside.geoms))

    return gpd.GeoDataFrame(geometry=new_lines, crs=gdf_poly.crs)