######################################################################
#             / ____|  ____|  __ \| |  | |  ____|  ____|             #
#            | |    | |__  | |__) | |__| | |__  | |__                #
#            | |    |  __| |  ___/|  __  |  __| |  __|               #
#            | |____| |____| |    | |  | | |____| |____              #
#             \_____|______|_|    |_|  |_|______|______|             #
######################################################################
#
# CEPHEE
# Copyright (C) 2024 Toulouse INP
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License for more details :
# <http://www.gnu.org/licenses/>.
#
######################################################################

# global
from geopandas import read_file, GeoDataFrame
from scipy.spatial import cKDTree
import rasterio.mask
from os import path
from shapely.ops import substring
from shapely.geometry import LineString, mapping
from rasterio import features
import geopandas as gpd
# third-party
from tatooinemesher.section import *
from tatooinemesher.constraint_line import ConstraintLine
from tatooinemesher.mesh_constructor import MeshConstructor
from shapely.strtree import STRtree
# local
from core.Section import Section
from core.Hydraulics import *
from core.Tools import polygonize_water_mask


def import_XS_lines(BV, param):
    """ Import XS lines in a shp format. Associate them with the nearest reach

    :param BV: Watershed
    :type BV: ModelCatchment
    :param param: Parameters requires for computation
    :type param: Parameters
    """

    if not param['I']['XS_filepath']:
        BV.display('Error during XS import, XS_filepath not specified', 2)
        return
    else:
        XS_shp = read_file(param['I']['XS_filepath'])
    n_sections = 0

    try:
        multi_linestring = Convert_sections(XS_shp)
    except (ValueError, TypeError) as e:
        raise Exception('Erreur lors de la conversion des sections: {}'.format(e))


    if BV.DEM_stack:
        if len(BV.DEM_stack['file_list']) ==1:
            print('only one DEM considered mask can be used')
            DEMi = rasterio.open(BV.DEM_stack['file_list'][0], 'r', crs=BV.crs)
            nodata = DEMi.nodata
    else:
        DEMi = None
        nodata = np.nan

    for reach in BV.reach:
        list_of_points = []
        list_of_lines = []
        reach.Xinterp = []
        reach.section_list = []
        line_reach = reach.geodata['geometry']
        line_reach_xy = LineString([(x, y) for x, y in zip(line_reach.xy[0], line_reach.xy[1])])

        for indsection, line in enumerate(multi_linestring.geoms):

            if type(line) == LineString:
                line_xy = LineString([(x, y) for x, y in zip(line.xy[0], line.xy[1])])
                point_intersection = line_xy.intersection(line_reach_xy)
                if type(point_intersection) == Point:
                    list_of_points.append(point_intersection)
                    list_of_lines.append(line)
            else:
                raise Exception('Cross section are not LineString')

        coords = []
        for pt in list_of_points:
            coords.append((pt.x, pt.y))

        # Associer coords avec list_of_lines
        zipped_lists = list(zip(coords, list_of_lines))
        # Trier les listes associées selon le même ordre que sorted_coords
        sorted_zipped = sorted(zipped_lists, key=lambda pair: line_reach_xy.project(Point(pair[0])))
        # Récupérer les listes triées
        sorted_coords_final = [pair[0] for pair in sorted_zipped]
        sorted_coords_xy = [(coord[0], coord[1]) for coord in sorted_coords_final]
        sorted_list_of_lines = [pair[1] for pair in sorted_zipped]
        line_reach_tot = LineString(sorted_coords_xy)
        angle_section = [np.arctan2((line.coords[0][1] - line.coords[1][-1]),
                                    (line.coords[0][0] - line.coords[0][-1]))
                         for line in sorted_list_of_lines]

        angle_line = []
        # angle de la ligne centrale entre 2 points de la ligne
        for i in range(len(line_reach_tot.coords)):
            if i == 0:
                angle_line.append(np.arctan2(line_reach_tot.xy[1][i] - line_reach_tot.xy[1][i + 1],
                                             line_reach_tot.xy[0][i] - line_reach_tot.xy[0][i + 1])+np.pi)
            else:
                angle_line.append(np.arctan2(line_reach_tot.xy[1][i - 1] - line_reach_tot.xy[1][i],
                                             line_reach_tot.xy[0][i - 1] - line_reach_tot.xy[0][i])+np.pi)

        diff_angle = [np.abs(angle_line[i] - angle_section[i]) - np.pi / 2 for i in range(len(angle_section))]
        X_reach_tot = [line_reach_xy.project(Point(pt)) for pt in sorted_coords_final]
        Xinterp = []  # liste des abscisses des points sélectionné à la fin
        point_xy = []


        if len(sorted_coords_xy) > 1:
            for indpoint, point in enumerate(sorted_coords_xy):
                list_diff = [diff_angle[i] for i, x in enumerate(X_reach_tot) if x == X_reach_tot[indpoint]]

                if diff_angle[indpoint] == np.min(list_diff):
                    line = sorted_list_of_lines[indpoint]
                    point_xy.append(Point(point[0], point[1]))
                    Xinterp.append(X_reach_tot[indpoint])

                    if not line.has_z:
                        line_int = LineString(
                            [line.interpolate(float(n) / param['XS']['number_of_points'], normalized=True) for n in
                             range(param['XS']['number_of_points'] + 1)])
                        S1 = Section(line_int, indpoint)

                    else:
                        S1 = Section(line, indpoint)
                        zbed = [coord[2] for coord in line.coords]
                        S1.Zbed = np.min(zbed)

                    S1.normal = [np.sin(angle_line[indpoint]), np.cos(angle_line[indpoint])]
                    S1.dist_proj_axe = Xinterp[-1]
                    S1.name = str('river_' + str(reach.geodata['River']) + '_Reach_' + str(
                        reach.geodata['Reach']) + '_section_' + str(indpoint))
                    S1.centre = point_xy[-1]  # point du centre défini par la ligne centrale .shp
                    S1.distanceBord()

                    if line.has_z:
                        S1.originXSLocation(param['XS'][
                                                'distSearchMin'])  # l'origine de la distance latérale est au point le plus bas
                    S1.coord.values['W'] = [10 for ii in range(len(S1.geom.coords))]
                    reach.add_section(S1)

        reach.line_int = LineString([(point.x, point.y) for point in point_xy])
        reach.Xinterp = Xinterp
        n_sections += len(reach.section_list)

    if param['I']['riverbanks_filepath']:
        import_banks_from_layer(BV, param, param['I']['riverbanks_filepath'])
    # Reduce XS width to banks if necessary
    if param['XS']['width_from_banks'] and param['I']['riverbanks_filepath'] != "no layer":
        reduceXStobank(BV, param['XS']['distSearchMin'])

    BV.display('Number of cross-sections imported: ' + str(n_sections))


def create_XS_lines(BV, param, method='channel'):
    """ Create all sections for all reaches

    :param BV: Watershed
    :type BV: ModelCatchment
    :param L: half width of the section
    :type L : float
    :param method: method for XS bounds computation
    :type method: str, optional
    """
    nx = param['XS']['number_of_points']
    L = param['XS']['width']

    # Cleaning previous sections
    for reach1 in BV.reach:
        reach1.section_list = []

    # Creating new sections
    n_sections = 0
    for reach1 in BV.reach:
        reach1.id_first_section = n_sections
        Xs = reach1.line_int.xy[0]
        Ys = reach1.line_int.xy[1]
        line2 = LineString([(x, y) for (x, y, z) in reach1.geodata['geometry'].coords])
        # recherche du point original le plus proche
        c0 = np.transpose(np.array([line2.xy[0], line2.xy[1]]))
        c1 = np.transpose(np.array([Xs, Ys]))
        t0 = cKDTree(c0)
        distance, neighbours = t0.query(c1)
        lines=[]
        angles = []


        for i in range(len(Xs)):
            angle, neighbours_down, neighbours_up = angle_from_line(reach1.line_int, line2, neighbours, i)
            Xstart = Xs[i] + L * np.cos(angle + np.pi / 2)
            Ystart = Ys[i] + L * np.sin(angle + np.pi / 2)
            Xend = Xs[i] + L * np.cos(angle - np.pi / 2)
            Yend = Ys[i] + L * np.sin(angle - np.pi / 2)
            line = LineString([(Xstart, Ystart), (Xend, Yend)])
            lines.append(line)
            angles.append(angle)


        ind_to_remove =[]
        for i in range(len(Xs)):
            if not lines[i].is_empty:
                line = lines[i]
                line_int = LineString([line.interpolate(float(n) / nx, normalized=True) for n in range(nx + 1)])
                S1 = Section(line_int, i)
                S1.dist_proj_axe = reach1.Xinterp[i]  # postion longitudinale de la section
                S1.name = str(
                    'river_' + str(reach1.geodata['River']) + '_Reach_' + str(reach1.geodata['Reach']) + '_section_' + str(
                        i))
                S1.centre = Point(Xs[i], Ys[i])  # point du centre défini par la ligne centrale .shp
                S1.distanceBord()
                S1.normal = [np.sin(angles[i]), np.cos(angles[i])]
                S1.Zbed = 0
                reach1.add_section(S1)
            else:
                ind_to_remove.append(i)

        reach1.Xinterp = [s.dist_proj_axe for s in reach1.section_list]
        n_sections += len(reach1.section_list)

    if param['I']['riverbanks_filepath']:
        import_banks_from_layer(BV, param, param['I']['riverbanks_filepath'])
    # Reduce XS width to banks if necessary
    if param['XS']['width_from_banks'] and param['I']['riverbanks_filepath'] != "no layer":
        reduceXStobank(BV, param['XS']['distSearchMin'])

    if method != 'mesh':
        param['I']['XS_filepath'] = path.join(param.work_path, 'XS_lines.shp')
        save_XS_lines(BV, param['I']['XS_filepath'])
        BV.display('Number of cross-sections created: ' + str(n_sections))


def project_XS_lines(BV, params):
    """ Provide an elevation for each point of the sections. Each reach contains a list of all the sections linked to it.

    The elevation is obtained by 2D interpolation with the method specified.
    The projection is made DEM file by DEM file. A section can cover 2 different DEM files.
    The raster mode is not operational yet

    :param BV: Watershed
    :type BV: ModelCatchment
    :param param: Parameters requires for computation
    :type param: Parameters
    """
    BV.display("Starting cross-sections projection")
    reach = BV.reach
    list_of_point_tot = []
    list_of_z = []
    list_of_id_tot = []
    list_of_dist_point_tot = []

    if len(BV.DEM_stack['file_list']) == 1:

        DEMi = rasterio.open(BV.DEM_stack['file_list'][0], 'r', crs=BV.crs)
        nodata = DEMi.nodata
        for current_reach in reach:
            count = 0
            list_new = []
            for ids, section in enumerate(current_reach.section_list):
                list_of_point = [Point(coord[0], coord[1]) for coord in section.geom.coords]
                X = [pt.x for pt in list_of_point]
                Y = [pt.y for pt in list_of_point]
                Z = projOnDEM(X, Y, DEMi)
                Z = [float(z[0]) for z in Z]
                ind_nodata = [i for i, z in enumerate(Z) if z == nodata or z == float(-32768) or np.isnan(z)]


                if BV.DEM_stack['data_DEM'][0]['wse_file']:
                    wse = rasterio.open(BV.DEM_stack['data_DEM'][0]['wse_file'], 'r', crs=BV.crs)
                    WSE = projOnDEM([section.centre.x], [section.centre.y], wse)
                    section.WSE = WSE[0][0]

                #else:
                   # section.WSE = None
                for i in sorted(ind_nodata, reverse=True):
                    del X[i]
                    del Y[i]
                    del Z[i]
                if not all(np.isnan(z) for z in Z) and len(X)>1:

                    line = LineString([(x, y, z) for x, y, z in zip(X, Y, Z)])
                    S1 = Section(line, count)
                    count += 1
                    S1.dist_proj_axe = current_reach.Xinterp[ids]  # postion longitudinale de la section
                    S1.name = str('river_' + str(current_reach.geodata['River']) + '_Reach_' + str(
                        current_reach.geodata['Reach']) + '_section_' + str(count))
                    S1.centre = section.centre  # point du centre défini par la ligne centrale .shp
                    S1.originXSLocation(params['XS']['distSearchMin'])
                    S1.normal = section.normal

                    if section.WSE != 0:
                        S1.WSE =section.WSE
                    list_new.append(S1)
                    if params.verbose:
                        print('projection at section ' + str(ids) + ' of reach ' + current_reach.name)
                else:
                    if params.verbose:
                        print('remove section ' + str(ids) + ' of reach ' + current_reach.name)

            current_reach.section_list = list_new
            current_reach.Xinterp = [s.dist_proj_axe for s in current_reach.section_list]

    else:
        for idx, data_DEMi in enumerate(BV.DEM_stack['data_DEM']):
            if params.verbose:
                print('Cross section projection on  DEM n°' + str(idx + 1) + '/' + str(len(BV.DEM_stack['file_list'])))
            polygon = Polygon(data_DEMi['polygon_coords'])
            current_file = BV.DEM_stack['file_list'][idx]
            DEMi = rasterio.open(current_file, 'r', crs=BV.crs)
            nodata = DEMi.nodata
            list_of_point = []
            list_of_id = []
            list_of_dist_point = []

            for current_reach in reach:
                for ids, section in enumerate(current_reach.section_list):
                    points = [Point(coord[0], coord[1]) for coord in section.geom.coords]
                    for ip, pts in enumerate(points):
                        if polygon.contains(pts):
                            list_of_point.append(pts)
                            Id = [current_reach.geodata['River'], current_reach.geodata['Reach'], ids]
                            list_of_dist_point.append(ip)
                            list_of_id.append(Id)
            X = [pt.x for pt in list_of_point]
            Y = [pt.y for pt in list_of_point]
            Z = projOnDEM(X, Y, DEMi)
            Z = [z[0] for z in Z]
            ind_nodata = [i for i, z in enumerate(Z) if z == nodata or z == float(-32768) or z == np.nan]
            for i in sorted(ind_nodata, reverse=True):
                del list_of_point[i]
                del list_of_dist_point[i]
                del list_of_id[i]
                del Z[i]
            list_of_z = list_of_z + Z
            list_of_point_tot = list_of_point_tot + list_of_point
            list_of_id_tot = list_of_id_tot + list_of_id
            list_of_dist_point_tot = list_of_dist_point_tot + list_of_dist_point

            # reattribution des points et des cotes aux bonnes sections
        for idr, reach1 in enumerate(reach):
            if params.verbose:
                print('reattribution of  sections  for reach ' + reach1.name )
            count =0
            list_new = []
            for ids, section in enumerate(reach1.section_list):

                id_selected_point = [i for i, d in enumerate(list_of_id_tot) if
                                     d == [reach1.geodata['River'], reach1.geodata['Reach'], ids]]
                list_point_selected = []
                list_of_dist_point_selected = []
                list_z_selected = []
                if len(id_selected_point) > 1:
                    for i in range(len(id_selected_point)):
                        if not list_of_z[id_selected_point[i]] == data_DEMi['no_data'] or list_of_z[id_selected_point[i]] > -99:
                            list_point_selected.append(list_of_point_tot[id_selected_point[i]])
                            list_z_selected.append(list_of_z[id_selected_point[i]])
                            list_of_dist_point_selected.append(list_of_dist_point_tot[id_selected_point[i]])

                    list_x = [pt.x for pt in list_point_selected]
                    list_y = [pt.y for pt in list_point_selected]
                    list_indpoint = [npoint for npoint in list_of_dist_point_selected]
                    # trie des points dans le sens initial
                    sort_list_x = [i for _, i in sorted(zip(list_indpoint, list_x))]
                    sort_list_y = [i for _, i in sorted(zip(list_indpoint, list_y))]
                    sort_list_z_selected = [i for _, i in sorted(zip(list_indpoint, list_z_selected))]
                    new_line = LineString(
                        [(x, y, z) for x, y, z in zip(sort_list_x, sort_list_y, sort_list_z_selected)])
                    S1 = Section(new_line, count)
                    S1.dist_proj_axe = reach1.Xinterp[count]  # postion longitudinale de la section
                    S1.name = str('river_' + str(reach1.geodata['River']) + '_Reach_' + str(
                        reach1.geodata['Reach']) + '_section_' + str(count))
                    S1.centre = section.centre  # point du centre défini par la ligne centrale .shp
                    S1.originXSLocation(params['XS']['distSearchMin'])
                    S1.normal = section.normal
                    if section.WSE != 0:
                        S1.WSE =section.WSE
                    count += 1
                    list_new.append(S1)
            reach1.section_list = list_new
            reach1.Xinterp = [s.dist_proj_axe for s in reach1.section_list]

    if params['I']['riverbanks_filepath']:
        import_banks_from_layer(BV, params, params['I']['riverbanks_filepath'])



def interpolate_XS_lines(BV, param):
    """ Interpolate all cross-sections (calling Reach.interpolateXS function)

    :param BV: Watershed
    :type BV: ModelCatchment
    :param param: Parameters requires for computation
    :type param: Parameters
    """
    dxlong = param['XS']['interpolation_step']
    method = param['XS']['interpolation_method']
    reach = BV.reach

    if method == 'CEPHEE':
        BV.display("Starting cross-sections interpolation using CEPHEE method")

        for j, reach1 in enumerate(reach):
            new_section = []
            # Récupération des points de la boundary_line
            line = reach1.geodata['geometry']
            boundary_line = reach1.line_int
            list_of_subline = split_line_by_boundary(line, boundary_line)
            for i in range(len(reach1.section_list) - 1):  # boucle sur les portions entre sections
                subline = list_of_subline[i]
                if not subline.is_empty:  # il y a des sections interpolées entre les sections mesurées
                    # cote limite pour l'interpolation
                    z_up = [reach1.section_list[i].geom.coords[t][2] for t in
                            range(len(reach1.section_list[i].geom.coords))]
                    Zmaxi = np.max(z_up)
                    z_down = [reach1.section_list[i + 1].geom.coords[t][2] for t in
                              range(len(reach1.section_list[i + 1].geom.coords))]
                    Zmaxi1 = np.max(z_down)
                    maxZ = np.max([(Zmaxi - reach1.section_list[i].Zbed), (Zmaxi1 - reach1.section_list[i + 1].Zbed)])
                    num_vert = int(round(subline.length / dxlong))
                    if num_vert == 0:
                        num_vert = 1
                    line_int = LineString([subline.interpolate(float(n) / num_vert, normalized=True) for n in range(num_vert + 1)])
                    dxlong_real = subline.length / num_vert
                    c0 = np.transpose(np.array([subline.xy[0], subline.xy[1]]))  # point de la ligne originale entre section le plus proche du centre de la nouvelle section interpolée
                    c1 = np.transpose(np.array([line_int.xy[0], line_int.xy[1]]))
                    t0 = cKDTree(c0)
                    _, neighbours = t0.query(c1)

                    # on cherche la ligne définie au niveau de la section interp pour définir l'angle
                    for ii in range(1, len(line_int.coords) - 1):  # boucle sur les points de la ligne entre 2 sections
                        X = []  # coordonées X des points définissant la nouvelle section
                        Y = []
                        length_interp = ii * dxlong_real
                        zbed = np.interp(length_interp, [0, subline.length],
                                         [reach1.section_list[i].Zbed, reach1.section_list[i + 1].Zbed])
                        angle, neighbours_down, neighbours_up = angle_from_line(line_int, subline, neighbours, ii)

                        distance_amont = [d - reach1.section_list[i].d0 for d in
                                          reach1.section_list[i].coord.array['Xt']]
                        distance_aval = [d - reach1.section_list[i + 1].d0 for d in
                                         reach1.section_list[i + 1].coord.array['Xt']]
                        distancetot = distance_amont + distance_aval
                        distancetot = list(np.unique(distancetot))
                        distancetot.sort()
                        z_upinterp = np.interp(distancetot, reach1.section_list[i].coord.array['Xt'], z_up)
                        z_downinterp = np.interp(distancetot, reach1.section_list[i + 1].coord.array['Xt'], z_down)
                        z_interp = []

                        for jj in range(len(distancetot)):  # boucle sur les points de la section
                            z_interp.append(
                                np.interp(length_interp, [0, subline.length], [z_upinterp[jj], z_downinterp[jj]]))
                            X.append(line_int.coords[ii][0] + (distancetot[jj]) * np.cos(
                                angle - np.pi / 2))  # rive droite
                            Y.append(line_int.coords[ii][1] + (distancetot[jj]) * np.sin(angle - np.pi / 2))

                        sort_X = [x for _, x in sorted(zip(distancetot, X))]
                        sort_Y = [x for _, x in sorted(zip(distancetot, Y))]
                        sort_Z = [x for _, x in sorted(zip(distancetot, z_interp))]
                        sort_X, sort_Y, sort_Z = sort_X[::2], sort_Y[::2], sort_Z[::2]
                        line_newSection = LineString([(sort_X[i], sort_Y[i], sort_Z[i]) for i in range(len(sort_X))])

                        S1 = Section(line_newSection, i)
                        S1.centre = Point(line_int.coords[ii][0],
                                          line_int.coords[ii][1])
                        S1.normal = [np.sin(angle), np.cos(angle)]
                        S1.dist_proj_axe = reach1.Xinterp[i] + length_interp  # postion longitudinale de la sectio
                        S1.name = str('river_' + str(reach1.geodata['River']) + '_Reach_' + str(
                            reach1.geodata['Reach']) + '_section_' + str(i))
                        S1.coord.values['W'] = [10 for i in range(len(S1.geom.xy[0]))]
                        S1.Zbed = np.min(S1.coord.values['W'])
                        S1.originXSLocation(param['XS']['distSearchMin'])
                        for i2 in range(len(S1.coord.array['Xt'])):
                            S1.coord.array['Xt'][i2] = S1.coord.array['Xt'][i2] - S1.d0
                        new_section.append(S1)


            for i in range(len(new_section)):
                reach1.add_section(new_section[i])

            # remet les sections de l'amont vers l'aval
            Xdist = []
            for i in range(len(reach1.section_list)):
                Xdist.append(reach1.section_list[i].dist_proj_axe)
            list_section = reach1.section_list
            seen = set()
            Xdist_unique = []
            list_section_unique = []
            for x, section in zip(Xdist, list_section):
                if x not in seen:  # Ajoute uniquement si non encore rencontré
                    seen.add(x)
                    Xdist_unique.append(x)
                    list_section_unique.append(section)
            sort_section = [x for _, x in sorted(zip(Xdist_unique, list_section_unique))]
            Xdist_unique.sort()
            reach1.section_list = sort_section
            reach1.Xinterp = Xdist_unique

        if param['I']['riverbanks_filepath']:
            import_banks_from_layer(BV, param, param['I']['riverbanks_filepath'])


    elif method == 'TATOOINE':
        BV.display("Starting cross-sections interpolation using TATOOINE method")

        for reach1 in BV.reach:
            section_seq = reach1
            section_seq.compute_dist_proj_axe(reach1.geodata['geometry'], 10)
            section_seq.check_intersections()
            section_seq.sort_by_dist()
            constraint_lines = ConstraintLine.get_lines_and_set_limits_from_sections(section_seq, 'CARDINAL')
            mesh_constr = MeshConstructor(section_seq=section_seq, lat_step=None,
                                          nb_pts_lat=param['XS']['number_of_points'], interp_values='LINEAR')
            mesh_constr.build_interp(constraint_lines, param['XS']['interpolation_step'], True)
            param['I']['XS_filepath'] = os.path.join(param.work_path, 'XS_interp_' + reach1.name + '.shp')
            mesh_constr.export_sections(param['I']['XS_filepath'])
            import_XS_lines(BV, param)




def optimize_XS_lines(BV,param, method="angle", max_iter = 30, verbose=True):
    """ Modify angle of XS to avoid intersections

   :param BV: Watershed
   :type BV: ModelCatchment
   :param method: method to divided XS
   :type method: str
   :param resultType: computation for banks position (2nd method only)
   :type resultType: str

    """
    BV.display("Starting cross-sections optimization")
    if method == 'bank':
        if param['I']['riverbanks_filepath']:
            riverbanks_filepath = param['I']['riverbanks_filepath']

            if is_raster(param['I']['riverbanks_filepath']):
                riverbanks_filepath = polygonize_water_mask(riverbanks_filepath, param.work_path,
                                                            param['C']['DEM_CRS_ID'],
                                                            param['C']['mask_water_value'])
            bank_gdf = gpd.read_file(riverbanks_filepath)
            if type(bank_gdf .loc[0, 'geometry']) == LineString:
                distances = []
                for i, geom1 in enumerate(bank_gdf.geometry):
                    if geom1:
                        min_avg_dist = float("inf")
                        closest_index = None
                        for j, geom2 in enumerate(bank_gdf.geometry):
                            if i != j:  # Éviter la comparaison avec soi-même
                                if geom2:
                                    avg_dist = average_distance(geom1, geom2)
                                    if avg_dist < min_avg_dist:
                                        min_avg_dist = avg_dist
                                        closest_index = j
                    distances.append((i, closest_index, min_avg_dist))
                # Trier les résultats par distance moyenne croissante
                distances.sort(key=lambda x: x[2])

            else:
                BV.display('Bank file must be multilinestring file', 2)

        else:
            BV.display('Bank file is needed', 2)

    for i, reach1 in enumerate(BV.reach):
        angles = []
        xs_lines_lob = []
        xs_lines_rob = []
        xs_lines_channel = []
        length_lob = []
        length_rob = []
        length_channel = []
        point_inter1 = []
        point_inter2 = []
        init_npoint = []
        if method == "angle":
            if len(reach1.section_list) > 1:
                for i in range(len(reach1.section_list)):  # boucle sur les portions entre sections
                    S1 = reach1.section_list[i]
                    angles.append(np.arctan2(S1.normal[1], S1.normal[0]) - np.pi / 2)  # * 180.0 / np.pi)
                    xs_lines_channel.append(LineString([(S1.start.x, S1.start.y), (S1.end.x, S1.end.y)]))
                    length_channel.append(S1.geom.length)
                    point_inter1.append(S1.centre)
                    init_npoint.append(len(S1.coord.array['X']))
                angles_channel = angles.copy()
                # Flag intersecting lines
                intersect_flag = np.zeros(len(xs_lines_channel), dtype=int)
                for ix1 in range(0, len(xs_lines_channel)):
                    for ix2 in range(ix1 + 1, len(xs_lines_channel)):
                        if xs_lines_channel[ix1].intersects(xs_lines_channel[ix2]):
                            intersect_flag[ix1] += 1

                nintersect = np.sum(intersect_flag)
                if verbose:
                    print("Number of intersections:", nintersect)
                if nintersect > 0:
                    # --------------------------------------
                    xs_lines_channel, angles_channel = remove_intersection(xs_lines_channel, angles_channel,
                                                                           reach1.Xinterp,
                                                                           length_channel,
                                                                           point_inter1, max_iter)
                new_list_section = []
                for i in range(len(reach1.section_list)):  # boucle sur les portions entre sections
                    nx = len(reach1.section_list[i].geom.coords)
                    line = LineString([xs_lines_channel[i].coords[-1], xs_lines_channel[i].coords[0]])
                    nx = init_npoint[i]
                    line_int = LineString([line.interpolate(float(n) / nx, normalized=True) for n in range(nx + 1)])

                    S1 = Section(line_int, i)
                    S1.dist_proj_axe = reach1.Xinterp[i]  # postion longitudinale de la section
                    S1.name = str('river_' + str(reach1.geodata['River']) + '_Reach_' + str(
                        reach1.geodata['Reach']) + '_section_' + str(i))
                    S1.centre = reach1.section_list[i].centre  # point du centre défini par la ligne centrale .shp
                    S1.normal = [np.sin(angles_channel[i]), np.cos(angles_channel[i])]
                    S1.distanceBord()
                    new_list_section.append(S1)
                reach1.section_list = new_list_section

        elif method == "overbank":

            lines_channel =[]
            #traitement du lit mineur
            for i in range(len(reach1.section_list)):  # boucle sur les portions entre sections
                S1 = reach1.section_list[i]
                nx = len(reach1.section_list[i].geom.coords)
                # Compute normals angles
                #angles.append(np.arctan2(S1.normal[1], S1.normal[0]) - np.pi / 2)
                angles.append(np.arctan2(S1.normal[0], S1.normal[1]))
                xbank1 = S1.bank[1][0].x
                xbank2 = S1.bank[1][1].x
                ybank1 = S1.bank[1][0].y
                ybank2 = S1.bank[1][1].y
                inter1 = Point(xbank1, ybank1)
                inter2 = Point(xbank2, ybank2)
                lines_channel.append(LineString([(inter1.x, inter1.y),(inter2.x,inter2.y)]))
                xs_lines_rob.append(LineString([(S1.start.x, S1.start.y), (inter1.x, inter1.y)]))
                length_rob.append(S1.start.distance(inter1))
                point_inter1.append(inter1)
                xs_lines_lob.append(LineString([(inter2.x, inter2.y), (S1.end.x, S1.end.y)]))
                length_lob.append(S1.end.distance(inter2))
                point_inter2.append(inter2)

            xs_lines_rob_old =xs_lines_rob.copy()
            xs_lines_lob_old = xs_lines_lob.copy()
            xs_lines_channel_old = xs_lines_channel.copy()
            grouped_indices = group_intersecting_lines_with_indices(lines_channel)
            for indice_line in grouped_indices:
                if len(indice_line) > 1:
                    indice_line.sort()
                    start = indice_line[0]
                    end = indice_line[-1]

            angles_lob = angles.copy()
            angles_rob = angles.copy()

            # Flag intersecting lines
            intersect_flag_lob = np.zeros(len(xs_lines_lob), dtype=int)
            intersect_flag_rob = np.zeros(len(xs_lines_rob), dtype=int)
            for ix1 in range(0, len(xs_lines_lob)):
                for ix2 in range(ix1 + 1, len(xs_lines_lob)):
                    if xs_lines_lob[ix1].intersects(xs_lines_lob[ix2]):
                        intersect_flag_lob[ix1] += 1
                    if xs_lines_rob[ix1].intersects(xs_lines_rob[ix2]):
                        intersect_flag_rob[ix2] += 1
            nintersect = np.sum(intersect_flag_lob) + np.sum(intersect_flag_rob)
            if verbose:
                print("Number of intersections:", nintersect)

            if nintersect > 0:
                # traitement du lit majeur rive droite
                # --------------------------------------
                xs_lines_rob, angles_rob = edge_intersection(xs_lines_rob, angles_rob, max_iter, 'right',
                                                               int(nx/2),
                                                               point_inter1)
                # traitement du lit majeur rive gauche
                # --------------------------------------
                xs_lines_lob, angle_lob = edge_intersection(xs_lines_lob, angles_lob, max_iter, 'left',
                                                              int(nx/2), point_inter2)

                # Recompute normals
            new_list_section = []
            Xs = reach1.line_int.xy[0]
            Ys = reach1.line_int.xy[1]
            for i in range(len(reach1.section_list)):  # boucle sur les portions entre sections
                Stemp = reach1.section_list[i]
                nx = len(reach1.section_list[i].geom.coords)
                # Compute normals angles
                xbank1 = Stemp.bank[1][0].x
                xbank2 = Stemp.bank[1][1].x
                ybank1 = Stemp.bank[1][0].y
                ybank2 = Stemp.bank[1][1].y
                line = LineString([(xs_lines_rob[i].coords[0][0], xs_lines_rob[i].coords[0][1]),
                                   (xbank1,ybank1),
                                   (xbank2,ybank2),
                                  (xs_lines_lob[i].coords[-1][0], xs_lines_lob[i].coords[-1][1])])

                if nx == 0:
                    nx = 1
                line_int = LineString([line.interpolate(float(n) / nx, normalized=True) for n in range(nx + 1)])
                S1 = Section(line_int, i)
                S1.dist_proj_axe = reach1.Xinterp[i]  # postion longitudinale de la section
                S1.name = str('river_' + str(reach1.geodata['River']) + '_Reach_' + str(
                    reach1.geodata['Reach']) + '_section_' + str(i))
                S1.centre = Point(Xs[i], Ys[i])  # point du centre défini par la ligne centrale .shp
                S1.normal = Stemp.normal#[np.sin(angles[i]), np.cos(angles[i])]
                S1.distanceBord()
                new_list_section.append(S1)
            reach1.section_list = new_list_section

        elif method == 'bank':
            #trouver la centerline correspondant au reach
            # Calculer toutes les distances
            line_reach= reach1.geodata['geometry']
            liste_line =[]
            for i in range(len(distances)):
                liste_line.append(bank_gdf.loc[distances[i][0], 'geometry'])

            distance_centerline = [line_reach.distance(l) for l in liste_line]
            # Trouver l’indice du minimum
            min_index = min(range(len(distance_centerline)), key=lambda i: distance_centerline [i])

            if  min_index % 2:
                line1 = bank_gdf.loc[distances[min_index ][0], 'geometry']
                line2 = bank_gdf.loc[distances[min_index][1], 'geometry']
            else:
                line1 = bank_gdf.loc[distances[min_index][1], 'geometry']
                line2 = bank_gdf.loc[distances[min_index][0], 'geometry']

            #réechantillonnage des banks
            n = len(reach1.section_list)
            d_line1 = [i * line1.length / (n - 1) for i in range(n)]
            points1 = [line1.interpolate(d) for d in d_line1]
            d_line2 = [i * line2.length / (n - 1) for i in range(n)]
            points2 = [line2.interpolate(d) for d in d_line2]

            #creation line section
            new_list_section = []
            for i in range(len(reach1.section_list)):  # boucle sur les portions entre sections

                nx = len(reach1.section_list[i].geom.coords)
                line_section = LineString([points1[i],points2[i]])
                line_int = LineString([line_section.interpolate(float(n) / nx, normalized=True) for n in range(nx + 1)])
                S1 = Section(line_int, i)

                S1.dist_proj_axe = reach1.Xinterp[i]  # postion longitudinale de la section
                S1.name = str('river_' + str(reach1.geodata['River']) + '_Reach_' + str(
                    reach1.geodata['Reach']) + '_section_' + str(i))
                S1.centre = reach1.section_list[i].centre # point du centre défini par la ligne centrale .shp
                S1.normal = reach1.section_list[i].normal
                S1.distanceBord()
                new_list_section.append(S1)
            reach1.section_list = new_list_section





def save_XS_lines(BV, filepath):
    """ Save XS in a shapefile format

     :param BV: Watershed
     :type BV: ModelCatchment
     :param filepath: path for XS file
     :type filepath: str
      """
    gdf_to_save = GeoDataFrame(columns=['geometry'], crs=BV.crs)
    count = 0
    for i in range(len(BV.reach)):
        for j in range(0, len(BV.reach[i].section_list)):
            gdf_to_save.loc[count, 'geometry'] = BV.reach[i].section_list[j].geom
            count += 1
    gdf_to_save.to_file(filepath,driver="ESRI Shapefile")


def split_line_by_boundary(line, boundary_line):
    """
    Divise `line` en plusieurs LineStrings basées sur les projections des segments de `boundary_line`.

    :param line: Une LineString cible à diviser.
    :param boundary_line: Une LineString définissant les segments pour la projection.
    :return: Une liste de LineStrings correspondant aux intervalles.
    """
    if not isinstance(line, LineString) or not isinstance(boundary_line, LineString):
        raise TypeError("Les deux entrées doivent être des objets de type LineString.")

    # Récupérer les projections des points de `boundary_line` sur `line`
    projected_points = []
    for boundary_point in boundary_line.coords:
        point = Point(boundary_point)
        projected_point = line.interpolate(line.project(point))
        projected_points.append(projected_point)

    # Ordonner les projections le long de la `line`
    projected_points = sorted(projected_points, key=lambda p: line.project(p))

    # Construire des LineStrings pour chaque intervalle
    intervals = []
    for i in range(len(projected_points) - 1):
        start_projection = projected_points[i]
        end_projection = projected_points[i + 1]

        # Extraire les points de la `line` entre les projections
        segment_points = []
        for coord in line.coords:
            point = Point(coord)
            if line.project(start_projection) <= line.project(point) <= line.project(end_projection):
                segment_points.append(coord)

        # Ajouter les projections explicitement pour garantir qu'elles sont incluses
        if tuple(start_projection.coords[0]) not in segment_points:
            segment_points.insert(0, tuple(start_projection.coords[0]))
        if tuple(end_projection.coords[0]) not in segment_points:
            segment_points.append(tuple(end_projection.coords[0]))

        # Créer une LineString pour cet intervalle
        if len(segment_points) > 1:
            intervals.append(LineString(segment_points))
        else:
            intervals.append(LineString())
    return intervals


def create_all_levee_lines(BV, work_path):
    """Create a GeoPanda dataframe for all reaches with banks as geometry.

    :param BV: Watershed
    :type BV: ModelCatchment
    :param work_path: folder path for outpu levees_lines shapefile
    :type work_path: str
    """
    for reach1 in BV.reach:
        if len(reach1.section_list)>1:
            reach1.create_levee_lines()

    gdf_to_save = GeoDataFrame(columns=['geometry'], crs=BV.crs)
    count = 0
    for reach1 in BV.reach:
        if len(reach1.section_list) > 1:
            idRiver, idReach = reach1.geodata['River'], reach1.geodata['Reach']
            gdf_to_save.loc[count, 'geometry'] = reach1.left_levee_line
            gdf_to_save.loc[count, 'Name'] = 'LeftLevee' + str(idRiver) + '_reach' + str(idReach)
            count += 1
            gdf_to_save.loc[count, 'geometry'] = reach1.right_levee_line
            gdf_to_save.loc[count, 'Name'] = 'RightLevee' + str(idRiver) + '_reach' + str(idReach)
            count += 1
    gdf_to_save.to_file(os.path.join(work_path, 'levees_lines.shp'))


def create_all_centerlines(BV, path):
    """Create a geopanda dataframe for all reaches with centerline as geometry.

    :param BV: Watershed
    :type BV: ModelCatchment
    :param resultType: type of the computation for which the bank are calculated
    :type resultType: str
    """

    gdf_tosave = GeoDataFrame(columns=['geometry'], crs=BV.crs)
    count = 0
    for reach1 in BV.reach:
        if len(reach1.section_list)>1:
            Pt_centerline=[]
            for section in reach1.section_list:
                X = section.coord.array['X']
                Y = section.coord.array['Y']
                Z = section.coord.values['B']
                ind_min = np.argmin(Z)
                Pt_centerline.append(Point(X[ind_min], Y[ind_min], Z[ind_min]))

            centerline = LineString([(p.x, p.y, p.z) for p in Pt_centerline])

            idRiver, idReach = reach1.geodata['River'], reach1.geodata['Reach']
            gdf_tosave.loc[count, 'geometry'] = centerline
            gdf_tosave.loc[count, 'Name'] = 'centerline' + str(idRiver) + '_reach' + str(idReach)
            count += 1

    gdf_tosave.to_file(os.path.join(path, 'centerlines.shp'))


def XSbounds(line, outline, Xs, Ys, method='channel'):
    '''
    find the bank position on cross section and create a new line  between banks

    Parameters
    ----------
    line :Linestring
        corss section line
    contour : linestring or polygone
        line associated to banks
    Xs : float
        x position of XS centre
    Ys : float
        y position of XS centre

    Returns
    -------
    new_line : XS line between banks (2 points)
    banks : position of banks

    '''

    list_inter = []
    centre = Point(Xs, Ys)
    Pt = [Point(c) for c in line.coords]

    if type(outline) == list:
        outline = MultiLineString(outline)

    dist_centre_start = Pt[0].distance(centre)
    dist_centre_end = Pt[-1].distance(centre)
    dist_point_to_start = []
    dbank1 = [-dist_centre_start]
    dbank2 = [dist_centre_end]
    list_point1 = [Pt[0]]
    list_point2 = [Pt[-1]]
    linexy=LineString([(c[0],c[1]) for c in line.coords])
    if type(outline) == MultiLineString:
        for contour in outline.geoms:
            inter = contour.intersection(linexy)
            if type(inter) == MultiPoint:
                list_inter += [geom for geom in inter.geoms]
            elif type(inter) == Point:
                list_inter += [inter]
    elif type(outline) == LineString:
        inter = outline.intersection(linexy)
        if type(inter) == MultiPoint:
            list_inter += [geom for geom in inter.geoms]
        elif type(inter) == Point:
            list_inter += [inter]



    for point in list_inter:
        dist_point_to_centre = point.distance(centre)
        dist_point_to_start.append(point.distance(Pt[0]))
        if point.distance(Pt[0]) < dist_centre_start:  # rive gauche
            dbank1.append(-1 * dist_point_to_centre)
            list_point1.append(point)
        else:
            dbank2.append(dist_point_to_centre)
            list_point2.append(point)

    if method in ['channel','mesh','dike'] :
        ind1 = np.argmax(dbank1)
        ind2 = np.argmin(dbank2)
    elif method == 'overbank':
        if len(dbank1)>1:
            ind1 = np.argmin(dbank1[1:])
            ind1 += 1
        else:
            ind1 =0
        if len(dbank2) > 1:
            ind2 = np.argmax(dbank2[1:])
            ind2+=1
        else:
            ind2 = 0

    inter1 = Point(list_point1[ind1].x,list_point1[ind1].y)
    inter2 = Point(list_point2[ind2].x,list_point2[ind2].y)

    if type(inter1) != Point:
        inter1 = Pt[0]
    if type(inter2) != Point:
        inter2 = Pt[-1]

    proj1 = line.interpolate(line.project(inter1))
    proj2 = line.interpolate(line.project(inter2))
    # Convertir les projections en distances fractionnelles
    start_frac = line.project(proj1) / line.length
    end_frac = line.project(proj2) / line.length

    # Extraire la sous-partie de la ligne
    new_line = substring(line, start_frac * line.length, end_frac * line.length)
    bank = [[dbank1[ind1], dbank2[ind2]], [inter1, inter2]]


    return new_line, bank


def reduceReachtoBanks(reach1, crestline, outline, nb_trans_point):
    '''
    find the bank lines or foot lines of dykes from polygon from edge detection

    Parameters
    ----------
    BV :modelCatchment
        modelCatchment with reaches to discretize
    crestline : linestring
        line of the crest dike, iti is None if channel is considered
    line1 : linestring
        bank or foot line on one side
    line2 : linestring
        bank or foot line on the other side

    Returns
    -------
    list_of_points : list of all points of the structured mesh

    '''

    list_of_points = []
    list_crest = []
    list_new = []
    coord1 = []
    coord2 = []

    for ids, section in enumerate(reach1.section_list):
        list_inter = []
        lines_ext = []
        if isinstance(outline, LineString):
            lines_ext.append(outline)
        elif isinstance(outline, MultiLineString):
            lines_ext.extend(list(outline.geoms))


        for contour in lines_ext:
            inter = contour.intersection(section.geom)
            if type(inter) == MultiPoint:
                list_inter += [geom for geom in inter.geoms]
            elif type(inter) == Point:
                list_inter += [inter]

        dist_centre_start = section.start.distance(section.centre)
        dist_centre_end = section.end.distance(section.centre)
        dist_point_to_start = []
        dbank1 = [-dist_centre_start]
        dbank2 = [dist_centre_end]
        list_point1 = [section.start]
        list_point2 = [section.end]
        for point in list_inter:
            dist_point_to_centre = point.distance(section.centre)
            dist_point_to_start.append(point.distance(section.start))
            if point.distance(section.start) < dist_centre_start:  # rive gauche
                dbank1.append(-1 * dist_point_to_centre)
                list_point1.append(point)
            else:
                dbank2.append(dist_point_to_centre)
                list_point2.append(point)

        ind1 = np.argmax(dbank1)
        ind2 = np.argmin(dbank2)
        inter1 = list_point1[ind1]
        inter2 = list_point2[ind2]

        if type(inter1) != Point:
            inter1 = section.start
        if type(inter2) != Point:
            inter2 = section.end

        if crestline:
            xcenter = Point(crestline.coords[ids])
        else:
            xcenter = Point((inter1.x + inter2.x) / 2, (inter1.y + inter2.y) / 2)

        if type(inter1) == Point and type(inter2) == Point:
            xsline1 = LineString([(inter1.x, inter1.y), (xcenter.x, xcenter.y)])
            xsline2 = LineString([(xcenter.x, xcenter.y), (inter2.x, inter2.y)])
            num_vert = int(nb_trans_point / 2)
            if xsline1.length > 1 and xsline2.length > 1:
                xsline1_int = LineString(
                    [xsline1.interpolate(float(n) / num_vert, normalized=True) for n in range(num_vert + 1)])
                xsline2_int = LineString(
                    [xsline2.interpolate(float(n) / num_vert, normalized=True) for n in range(num_vert + 1)])
                new_point = [Point(coord) for coord in xsline1_int.coords] + [Point(coord) for coord in
                                                                              xsline2_int.coords]
                list_of_points = list_of_points + new_point
                if section.geom.has_z:
                    coord1.append((inter1.x, inter1.y, 0))
                    coord2.append((inter2.x, inter2.y, 0))
                else:
                    coord1.append((inter1.x, inter1.y))
                    coord2.append((inter2.x, inter2.y))

    return list_of_points, coord1, coord2


def reduceXStobank(BV, dsearch, method ='channel'):
    """ find the bank lines or foot lines of dykes from polygon from edge detection

    :param BV: Watershed
    :type BV: ModelCatchment
    :param param: Parameters requires for computation
    :type param: Parameters
    """

    for reach in BV.reach:
        list_new = []
        if  type(BV.outline) == list:
            BV.outline = MultiLineString(BV.outline)

        for ids, section in enumerate(reach.section_list):
            line, bank = XSbounds(section.geom, BV.outline, section.centre.x, section.centre.y,method)
            if bank[0][0]!= bank[0][1]:
                S1 = Section(line, ids)
                S1.dist_proj_axe = reach.Xinterp[ids]  # postion longitudinale de la section
                S1.name = str('river_' + str(reach.geodata['River']) + '_Reach_' + str(
                    reach.geodata['Reach']) + '_section_' + str(ids))
                S1.centre = section.centre  # point du centre défini par la ligne centrale .shp
                S1.WSE = section.WSE
                S1.Zbed = section.Zbed
                S1.bank = section.bank
                S1.normal = section.normal

                if not method =='dike':
                    S1.originXSLocation(dsearch)
                else:
                    S1.distanceBord()
                list_new.append(S1)
        reach.section_list = list_new


def remove_crossing_XS(BV):
    """ remove all leaving intersections from confluence or optimization stopped before convergence

    :param BV: Watershed
    :type BV: ModelCatchment
    """
    BV.display("Remove crossing XS")

    # 1. Collect geometries and IDs
    lines = []
    ids = []  # (reach_idx, section_idx)
    section_ids = []  # j = ordre amont/aval
    for i, reach in enumerate(BV.reach):
        for j, section in enumerate(reach.section_list):
            lines.append(section.geom)
            ids.append((i, j))
            section_ids.append(j)

    lines = np.array(lines, dtype=object)  # STRtree aime bien un array
    section_ids = np.array(section_ids)

    # 2. Build spatial index
    tree = STRtree(lines)

    ind_to_remove = set()
    seen = set()

    # 3. Compare only spatially close lines
    for idx, line in enumerate(lines):
        if idx in ind_to_remove:
            continue

        candidates = tree.query(line)  # retourne une liste d’indices
        for j in candidates:
            if j <= idx:
                continue
            if (idx, j) in seen:
                continue
            seen.add((idx, j))

            if line.intersects(lines[j]):
                # Keep the most upstream (smallest section_id)
                if section_ids[idx] <= section_ids[j]:
                    ind_to_remove.add(j)
                else:
                    ind_to_remove.add(idx)

    # 4. Fast mapping (reach_idx, section_idx) -> global idx
    id_map = {ids[k]: k for k in range(len(ids))}

    # 5. Rebuild filtered reach.section_list
    for i, reach in enumerate(BV.reach):
        new_list = []
        for j, section in enumerate(reach.section_list):
            idx = id_map[(i, j)]
            if idx not in ind_to_remove:
                new_list.append(section)
        reach.section_list = new_list
        reach.Xinterp = [s.dist_proj_axe for s in reach.section_list]


def bank_to_edge_end(BV,params):
    """ impose bank if BV.init_mode has no DEM

    :param BV: Watershed
    :type BV: ModelCatchment
    :param param: Parameters requires for computation
    :type param: Parameters
    """

    for reach in BV.reach:
        for ids, section in enumerate(reach.section_list):
            section.distance = list(section.coord.array['Xt'])
            if not  params['I']['XS_filepath']:
                widthL,widthR = params['XS']['width'],params['XS']['width']
                if type(params['B']['average_slope']) == float:
                    section.Zbed = params['B']['average_slope']*(BV.reach[0].Xinterp[-1]-BV.reach[0].Xinterp[ids])
                    section.slope = params['B']['average_slope']
                else:
                    print('type of average slope is not appropriated')
            else:
                widthL, widthR = section.start.distance(section.centre),\
                    section.end.distance(section.centre)

            section.WSE = section.Zbed + params['B']['depth']
            section.bank = [[-widthL, widthR], [section.start, section.end]]


def import_banks_from_layer(BV,param,layer):
    flag_oneDEM = False
    if is_raster(layer) and not param['XS']['method_banks'] == 'buffer':
        layer = polygonize_water_mask(layer, param.work_path,
                                              param['C']['DEM_CRS_ID'], param['C']['mask_water_value'])
    if len(BV.DEM_stack['file_list']) == 1:
        DEMi = rasterio.open(BV.DEM_stack['file_list'][0], 'r', crs=BV.crs)
        flag_oneDEM = True

    if  not param['XS']['method_banks'] == 'buffer':
        bank_gdf = gpd.read_file(layer)
        BV.outline = convert_geometry_to_multilinestring(bank_gdf)

        for reach1 in BV.reach:
            for section in reach1.section_list:
                _, bank = XSbounds(section.geom, BV.outline, section.centre.x, section.centre.y,method= param['XS']['method_banks'])
                section.bank = bank
                if not flag_oneDEM:
                    for idx, data_DEMi in enumerate(BV.DEM_stack['data_DEM']):
                        DEM_polygon = Polygon(data_DEMi['polygon_coords'])
                        current_file = BV.DEM_stack['file_list'][idx]
                        DEMi = rasterio.open(current_file, 'r', crs=BV.crs)
                        if DEM_polygon.contains(Point(bank[1][0].x, bank[1][0].y)):
                            section.WSE0 = projOnDEM([bank[1][0].x], [bank[1][0].y], DEMi)
                        if DEM_polygon.contains(Point(bank[1][1].x, bank[1][1].y)):
                            section.WSE1 = projOnDEM([bank[1][1].x], [bank[1][1].y], DEMi)

                else:# avoid open all rasters for each section
                    section.WSE0 = projOnDEM([bank[1][0].x], [bank[1][0].y], DEMi)
                    section.WSE1 = projOnDEM([bank[1][1].x], [bank[1][1].y], DEMi)
        if not BV.DEM_stack['data_DEM'][0]['wse_file']:
            for reach1 in BV.reach:
                for section in reach1.section_list:
                    section.WSE = np.nanmean([section.WSE0, section.WSE1])

    elif param['XS']['method_banks'] == 'buffer' and flag_oneDEM == True and is_raster(layer):
        with rasterio.open(layer) as src:
            data = src.read(1)  # première bande
            transform = src.transform
            crs = src.crs
            nodata = src.nodata

        cell_size = BV.DEM_stack['data_DEM'][0]['cell_size']
        dilate_d = 5*cell_size

        for reach1 in BV.reach:
            for section in reach1.section_list:

                buffer_geom = [mapping(section.geom.buffer(dilate_d ))]
                buffer_mask = features.rasterize(
                    [(geom, 1) for geom in buffer_geom],
                    out_shape=data.shape,
                    transform=transform,
                    fill=0,
                    dtype="uint8"
                )

                mask_final = (data == param['C']['mask_water_value']) & (buffer_mask == 1)
                rows, cols = np.where(mask_final)
                xs, ys = rasterio.transform.xy(transform, rows, cols, offset="center")

                Z = projOnDEM(xs, ys, DEMi)

                if not BV.DEM_stack['data_DEM'][0]['wse_file']:
                    section.WSE =  np.nanpercentile(Z, 90)

    else:
        BV.display('data non consistent with method')



def interpolate_points_from_multilinestring(multiline, step):
    """
    Retourne une liste de points régulièrement espacés le long d’un MultiLineString.
    """
    points = []

    # S’assurer que c’est un MultiLineString
    if isinstance(multiline, LineString):
        lines = [multiline]
    elif isinstance(multiline, MultiLineString):
        lines = list(multiline.geoms)
    else:
        raise TypeError("Entrée invalide : doit être un LineString ou MultiLineString.")

    for line in lines:
        length = line.length
        if length == 0:
            continue

        # positions d’interpolation (0 → longueur)
        num_steps = int(np.floor(length / step))
        for i in range(num_steps + 1):
            p = line.interpolate(i * step)
            points.append(p)

        # s'assurer d'inclure le dernier point exactement à la fin
        if (num_steps * step) < length:
            points.append(line.interpolate(length))

    return points
