######################################################################
#             / ____|  ____|  __ \| |  | |  ____|  ____|             #
#            | |    | |__  | |__) | |__| | |__  | |__                #
#            | |    |  __| |  ___/|  __  |  __| |  __|               #
#            | |____| |____| |    | |  | | |____| |____              #
#             \_____|______|_|    |_|  |_|______|______|             #
######################################################################
#
# CEPHEE
# Copyright (C) 2024 Toulouse INP
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License for more details :
# <http://www.gnu.org/licenses/>.
#
######################################################################

# global
import glob
from os import mkdir
from shapely.geometry import multilinestring, linestring
from rasterio.merge import merge
from rasterio.transform import from_origin
from scipy.interpolate import griddata
from typing import List
from pyproj import CRS
from shapely.ops import unary_union, linemerge
import shutil
# local
from core.ModelCatchment import ModelCatchment
from core.CrossSection import *
from core.Parameters import Parameters
from core.Tools import is_raster, sort_lines_and_centerline

# optional
try:
    import cv2
    cv2_avail = True
except ImportError:
    cv2_avail = False

try:
    import rioxarray as rxr
    rioxarray_avail = True
except:
    rioxarray_avail = False


def read_DEM(DEM_filepath, crs, hydro_lib=None, res=None):
    """Read a DEM file given by DEM_filepath. the program can be used by various classes.

    :param DEM_filepath: file with DEM data
    :type DEM_filepath: string
    :param crs: Coordinate Reference System
    :type crs: pyproj.CRS
    :param hydro_lib: Name of the library used for the simulation. The default value is None.
    :type hydro_lib: string, optional
    :param res: TODO
    :type res:TODO
    :return:
        - dictionary with global data of DEM grid
        - pyproj Grid instance
        - dictionary with three 2D-arrays (X,Y,Z)
    """

    convert_epsg = False
    _, extension = path.splitext(DEM_filepath)
    _, filename = path.split(DEM_filepath)
    wse_file = None
    mask = None
    print('read file :' + DEM_filepath)

    if extension == '.csv':
        if res:
            df = pd.read_csv(DEM_filepath, sep=";")
            # Extraire les colonnes X, Y, Z
            x = df["x"].values
            y = df["y"].values
            z = df["z"].values

            wse_vect = df["e"].values
            # Définir la résolution et la grille régulière
            x_min, x_max = x.min(), x.max()
            y_min, y_max = y.min(), y.max()

            grid_x, grid_y = np.meshgrid(np.arange(x_min, x_max, res), np.arange(y_min, y_max, res))
            print(grid_x.shape)
            # Interpoler les valeurs Z sur la grille (méthode 'linear', 'nearest' ou 'cubic')
            dem = griddata((x, y), z, (grid_x, grid_y), method='cubic')
            dem_near = griddata((x, y), z, (grid_x, grid_y), method='nearest')
            dem[np.isnan(dem)] = dem_near[np.isnan(dem)]
            dem = np.flipud(dem)
            wse = griddata((x, y), wse_vect, (grid_x, grid_y), method='nearest')
            wse = np.flipud(wse)
            mask = np.zeros(dem.shape, dtype=np.uint8)
            for x1, y1 in zip(x, y):
                col = int((x1 - x_min) / res)
                row = int((y_max - y1) / res)  # Inverser Y car les rasters ont l'origine en haut à gauche
                mask[row, col] = 1  # Mettre le pixel à 1
            kernel = np.ones((3, 3), np.uint8)
            if cv2_avail:
                while True:
                    # Trouver les outline
                    outline, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
                    # Si un seul polygone est détecté, arrêter la dilatation
                    if len(outline) <= 1:
                        break
                    mask = cv2.dilate(mask, kernel, iterations=1)
                # Dilater le masque
                filled_raster = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=1)

            transform = from_origin(x_min, y_max, res, res)
            dem = np.float32(dem * filled_raster)
            wse = np.float32(wse * filled_raster)
            #dem[dem <= 0.] = np.nan
            #wse[wse <= 0.] = np.nan
            wse_file = path.join(path.dirname(DEM_filepath), 'wse.tif')
            with rasterio.open(
                    path.join(path.dirname(DEM_filepath), 'dem_'+str(0)+'.tif'), "w",
                    driver="GTiff",
                    height=dem.shape[0], width=dem.shape[1],
                    count=1, dtype=np.float32,
                    crs=crs,  # Modifier selon votre système de coordonnées
                    transform=transform, nodata = np.nan
            ) as dst:
                dst.write(dem, 1)

            with rasterio.open(
                    path.join(path.dirname(DEM_filepath), 'wse.tif'), "w",
                    driver="GTiff",
                    height=dem.shape[0], width=dem.shape[1],
                    count=1, dtype=np.float32,
                    crs=crs,  # Modifier selon votre système de coordonnées
                    transform=transform, nodata = np.nan
            ) as dst:
                dst.write(wse, 1)

            with rasterio.open(
                    path.join(path.dirname(DEM_filepath), 'mask.tif'), "w",
                    driver="GTiff",
                    height=dem.shape[0], width=dem.shape[1],
                    count=1, dtype=np.uint8,
                    crs=crs,  # Modifier selon votre système de coordonnées
                    transform=transform
            ) as dst:
                dst.write(filled_raster, 1)
            DEM_filepath = path.join(path.dirname(DEM_filepath), 'dem_'+str(0)+'.tif')
            extension = '.tif'

        else:
            raise Exception("A raster resolution must be provided")

    if hydro_lib == 'pysheds':

        try:
            from pysheds.grid import Grid

        except ImportError:
            raise Exception("pysheds not available")

        if extension =='.asc':
            grid1 = Grid.from_ascii(DEM_filepath, crs=crs, dtype=np.float32)
            dem = grid1.read_ascii(DEM_filepath, crs=crs, dtype=np.float32)
            n_cols = grid1.shape[1]
            n_rows = grid1.shape[0]
            a = grid1.affine * (0, 0)
            b = grid1.affine * (1, 0)
            transform = grid1.affine
            cell_size = float(b[0] - a[0])
            x_ll_corner = a[0]
            y_ll_corner = a[1] - n_rows * cell_size
            no_data = grid1.nodata

        elif extension == '.tiff' or extension == '.tif':
            grid1 = Grid.from_raster(DEM_filepath, crs=crs)
            dem = grid1.read_raster(DEM_filepath, crs=crs)
            n_cols = grid1.shape[1]
            n_rows = grid1.shape[0]
            a = grid1.affine * (0, 0)
            b = grid1.affine * (1, 0)
            cell_size = float(b[0] - a[0])
            transform = grid1.affine
            no_data = grid1.nodata
            x_ll_corner = a[0]
            y_ll_corner = a[1] - n_rows * cell_size

    elif not hydro_lib == 'pysheds' or extension == '.csv':

        with rasterio.open(DEM_filepath, "r") as src:
            dem = src.read(1)
            grid1 = src
            n_cols = dem.shape[1]
            n_rows = dem.shape[0]
            cell_size = src.res[0]
            transform = src.transform
            x_ll_corner = src.bounds[0]
            y_ll_corner = src.bounds[1]
            no_data = src.nodata

    else:
        raise Exception("DEM file format not supported" + extension)

    X, Y = np.meshgrid(np.arange(x_ll_corner, x_ll_corner + n_cols * cell_size, cell_size),
                       np.arange(y_ll_corner, y_ll_corner + n_rows * cell_size, cell_size),
                       indexing='xy')
    Y = np.flipud(Y)

    x_ur_corner = x_ll_corner + n_cols * cell_size
    y_ur_corner = y_ll_corner + n_rows * cell_size
    polygon = ((x_ll_corner, y_ll_corner), (x_ur_corner, y_ll_corner),
               (x_ur_corner, y_ur_corner), (x_ll_corner, y_ur_corner))

    data_DEM = {'ncols': n_cols, 'path': DEM_filepath, 'n_rows': n_rows, 'cell_size': cell_size,
                'll_corner': (x_ll_corner, y_ll_corner), 'ur_corner': (x_ur_corner, y_ur_corner),
                'polygon_coords': polygon, 'no_data': no_data, 'wse_file': wse_file, 'transform': transform}

    DEM = {'X': X, 'Y': Y, 'Z': dem, 'mask': mask}

    return data_DEM, grid1, DEM


def merge_rasters(rasters: List[str], save_name, crs, res):
    """Merge multiple rasters.

    :param rasters: Raster tiles paths list
    :type rasters: [string]
    :param save_name: name of the merged raster
    :type save_name: string
    :param crs: coordinate reference system
    :param res:Output resolution in units of coordinate reference system. If not set, a source resolution will be used.
        If a single value is passed, output pixels will be square.
    :type res: tuple, optional
    :return:
        - path to merged raster (str)
        - metadata of the merged raster (dict)
    """
    save_path = save_name + ".tif"
    first_raster = rasterio.open(rasters[0], crs=crs)
    out_meta = first_raster.meta.copy()

    # Ouvrir les rasters
    try:
        datasets = [rasterio.open(fp) for fp in rasters]
        merged_raster, transform = merge(datasets, res=res)
    except:
        merged_raster, transform = merge(rasters, res=res)

    out_meta.update(
        {
            "driver": "GTiff",
            "height": merged_raster.shape[1],
            "width": merged_raster.shape[2],
            "crs": crs,
            "transform": transform,
            "tiled": True,              # écrire en tuiles (meilleure gestion des gros fichiers)
            "compress": "deflate",      # compression sans perte
            "BIGTIFF": "YES"            # permet de dépasser 4 Go si besoin
        }
    )

    with rasterio.open(save_path, "w+", **out_meta) as dest:
        dest.write(merged_raster)

    return path.abspath(save_path), out_meta



def run_data_process(BV1, params, hydro_lib=None, nochannel=False):
    """ launch data analysis

    :param BV1: catchment model
    :type BV1: ModelCatchment
    :param params: Parameters required for computation
    :type params: Parameters
    :nochannel : for calculation without DEM data in main channel (water mask)
    :type nochannel: bool
    """
    # -----------------------------------------------------------------------------
    # Reading topographic data
    init_DEM_and_network(BV1, params,hydro_lib)
    BV1.display('DEM and Network initialized', 0)

    # -----------------------------------------------------------------------------
    # selection et projection du réseau dans la ROI
    if not params['I']['DEM_path'] or BV1.init_mode == 6:
        print('only channel mode needs 1)existing XS  or average slope 2) only one reach')
        if params['I']['riverbanks_filepath'] == None:
            BV1.display('No bank lines used', 0)
        else:
            BV1.display('The banks file is' + params['I']['riverbanks_filepath'], 0)
        if params['I']['XS_filepath']:
            BV1.display('Field bathymetry used :' + params['I']['XS_filepath'], 0)
        else:
            if BV1.init_mode in (4, 5, 11):
                average_slope = params['B']['average_slope']
                BV1.display('creation of new river with slope :' + str(average_slope), 0)
        if BV1.init_mode == 4:
            read_network(BV1, river_filename=params['I']['network_filepath'], nclassemax=3)
        BV1.ordered_network = BV1.hydro_network
        BV1.display('Network read', 0)

        BV1.ordered_network['River'] = BV1.ordered_network['gid']
        BV1.ordered_network.loc[0, 'Reach'] = 0
        BV1.display('only the first reach is considered', 1)
        BV1.list_of_outlet = [[BV1.ordered_network.loc[0, 'River'], [BV1.ordered_network.loc[0, 'River']],
                               [BV1.ordered_network.loc[0, 'River'], BV1.ordered_network.loc[0, 'Reach']]]]

    else:
        if nochannel:
            # Traitement du reséau pour le mettre au standard hydraulique 'utilisable par un code 1D)
            for i, row in BV1.hydro_network.iterrows():
                line = row['geometry']
                line_with_z = LineString([(x, y, 0) for x, y in zip(line.xy[0], line.xy[1])])
                cline_bound0 = line_with_z.boundary.geoms[0]
                cline_bound1 = line_with_z.boundary.geoms[1]
                if BV1.outlet_point:
                    if cline_bound0.distance(BV1.outlet_point) < cline_bound1.distance(BV1.outlet_point):
                        line_with_z = LineString(list(line_with_z.coords)[::-1])
                BV1.hydro_network.loc[i, 'geometry'] = line_with_z

            BV1.ordered_network = BV1.hydro_network
            BV1.ordered_network['River'] = BV1.ordered_network['gid']

        else:
            BV1.projectionOnDEM()
            BV1.order_network(params['N']['minDistJunction'])  # reunis et ordonne les lignes de l'amont vers l'aval

        if not nochannel:
            BV1.find_junctionAndOutlet(params['N']['minDistJunction'])  # trouve la position des confluences et les exutoires
            BV1.renameReachFromJunction(params)  # crée les biefs et renomme les biefs en fonction de la position des tronçons
        else:
            BV1.junction = GeoDataFrame(geometry =[])





def init_DEM_and_network(BV, param, hydro_lib=None):
    """ choose kind of study and data import

    :param BV: catchment model
    :type BV: ModelCatchment class
    :param param: Parameters requires for computation
    :type param: Parameters
    :param hydro_lib : librairy for hydrological computation
    :type hydro_lib: str
    """

    # Create a work folder if not present
    if not path.isdir(param.work_path):
        mkdir(param.work_path)

    # Init CRS method
    BV.crs = CRS.from_user_input(param['C']['DEM_CRS_ID'])
    BV.outline = MultiLineString()

    # Selecting DEM-network mode
    if param['I']['DEM_path']:
        if param['I']['network_filepath']:
            if param['I']['riverbanks_filepath']:
                # DEM file and river file and river bank file
                BV.init_mode = 10
            else:
                if param['I']['XS_filepath']:
                    # DEM file and river file and channel modification
                    BV.init_mode = 7
                else:
                    # DEM file and river file
                    BV.init_mode = 1
        else:
            if param['I']['XS_filepath']:
                if param['I']['riverbanks_filepath']:
                    # DEM file + river from banks and XS bathymetry
                    BV.init_mode = 8
                else:
                     # DEM file + river using pysheds
                    BV.init_mode = 9
            else:
                if param['I']['riverbanks_filepath']:
                    # DEM file + river from mask
                    BV.init_mode = 3
                    BV.network_type = 'fromMask'
                else:
                    if param['I']['outlet_point']:
                        # DEM file + river from outlet
                        BV.init_mode = 2
                    else:
                        # bathymetric survey : from survey .csv (only DEM but mask created from data)
                        BV.init_mode = 6
    else:
        if param['I']['network_filepath']:
            if param['I']['XS_filepath']:
                if not param['I']['riverbanks_filepath']:
                    # Channel only: river from XS bathymetry and river
                    BV.init_mode = 4
                else:
                    BV.init_mode = -1
            else:
                if param['I']['riverbanks_filepath']:
                    # Digging the main channel
                    BV.init_mode = 11
                else:
                    # Digging the main channel using centerline only
                    BV.init_mode = 12
        else:
            if param['I']['XS_filepath']:
                if param['I']['riverbanks_filepath']:
                    # Channel only: river from XS bathymetry and mask
                    BV.init_mode = 5
                else:
                    BV.init_mode = -1
            else:
                BV.init_mode = -1

    print("init mode is : "+str(BV.init_mode))

    # Reading DEM files or layer
    if param['I']['DEM_path']:
        if param['I']['boundary_filepath']:
            read_DEM_stack(BV, param['I']['DEM_path'], DEM_filter=param['I']['DEM_file_extension'],
                           hydro_lib=None, res=param['C']['resolution'])
            # filter DEM stack using boundary
            filter_dem_stack_using_boundary(BV, param, param['I']['boundary_filepath'], copy=False)
            read_DEM_stack(BV, BV.DEM_stack['file_list'] ,DEM_file_list= True,hydro_lib=None, res=param['C']['resolution'])


        else:
            read_DEM_stack(BV, param['I']['DEM_path'], DEM_filter=param['I']['DEM_file_extension'],
                          hydro_lib=None, res=param['C']['resolution'])

        if BV.init_mode == 6 and BV.DEM_stack['extension'] == '.csv':  # création du bankfile pou run nuage de point
            param['I']['riverbanks_filepath'] = path.join(BV.DEM_stack['data_DEM'][0]['path'], 'mask.tif')
            BV.network_type = 'fromMask'

    if param['I']['outlet_point']:
        BV.outlet_point = param['I']['outlet_point']

    # Hydro_network initialization
    if param['I']['riverbanks_filepath']  and not param['I']['network_filepath']:
        init_network_using_banks(BV, param)
    elif BV.init_mode == 2:
        build_globalDEM(BV, param)
        create_network_hydrolib(BV, param, hydro_lib)
    else:
        read_network(BV, param['I']['network_filepath'], nclassemax=param['N']['classeMax'])

    # Find the lowest Z point as the possible outlet if not provided
    if not param['I']['outlet_point'] and not BV.init_mode in (4, 5):
        minz = [pt.z for pt in BV.DEM_stack['minZ']]
        index_min = np.argmin(minz)
        BV.outlet_point = Point(BV.DEM_stack['minZ'][index_min].x, BV.DEM_stack['minZ'][index_min].y)
        BV.display('Detected outlet is : ' + str(BV.outlet_point), 0)

    if BV.init_mode in (1, 2, 6, 7, 8, 9, 10):
        BV.clipRiverFromDEM()
        if param['C']['computeGlobal']:
            build_globalDEM(BV, param)


def read_DEM_stack(BV, DEM_path, DEM_filter=None, DEM_file_list=False, hydro_lib=None, res=None):
    """Read all the DEM in the folder given by DEM_path. The user has to select the region of interest
    selecting the DEM file corresponding

    :param BV: catchment model
    :type BV: ModelCatchment class
    :param DEM_path: DEM path to one file or a folder or a list of files
    :type DEM_path: str or list(str)
    :param DEM_filter: specify DEM file extension if DEM_path is a folder
    :type DEM_filter: str, optional
    :param DEM_file_list: specify if DEM_path is a list
    :type DEM_file_list: bool, optional
    """
    # select one merged layer or DEM folder
    if not DEM_filter:
        if DEM_file_list:
            input_paths = DEM_path
        else:
            input_paths = [DEM_path]
        DEM_filter = path.splitext(input_paths[0])
    else:
        flist = sorted(glob.glob(path.join(DEM_path, '*' + DEM_filter)))
        input_paths = []
        for i in range(len(flist)):
            input_paths.append(flist[i])

    if not input_paths:
        raise Exception('No DEM data, folder is empty or QGIS layer not present')

    BV.DEM_stack = {
        'global_extent': [1e9, 1e9, 0.0, 0.0],
        'file_list': input_paths,
        'filtered_file_list': [],
        'data_DEM': [None for _ in range(len(input_paths))],
        'extension': DEM_filter,
        'minZ': []
    }

    resolution = []
    for idx, input_path in enumerate(input_paths):
        print('read DEM n°' + str(idx + 1) + '/' + str(len(input_paths)))
        data_DEMi, _, DEMi = read_DEM(input_path, BV.crs, hydro_lib=hydro_lib, res=res)
        if input_path.endswith(('.csv')):  #on bascule sur le raster DEM
            BV.DEM_stack['file_list'][idx] = path.join(path.dirname(input_path), 'dem_' + str(0) + '.tif')

        X, Y, dem = DEMi['X'], DEMi['Y'], DEMi['Z']
        BV.DEM_stack['data_DEM'][idx] = data_DEMi
        BV.DEM_stack['data_DEM'][idx]['path'] = path.dirname(input_path)
        resolution.append(BV.DEM_stack['data_DEM'][idx]['cell_size'])
        min_index = np.unravel_index(np.nanargmin(dem), dem.shape)
        BV.DEM_stack['minZ'].append(Point(X[min_index], Y[min_index], dem[min_index]))
        BV.DEM_stack['global_extent'][0] = min(BV.DEM_stack['global_extent'][0], data_DEMi['ll_corner'][0])
        BV.DEM_stack['global_extent'][1] = min(BV.DEM_stack['global_extent'][1], data_DEMi['ll_corner'][1])
        BV.DEM_stack['global_extent'][2] = max(BV.DEM_stack['global_extent'][2], data_DEMi['ur_corner'][0])
        BV.DEM_stack['global_extent'][3] = max(BV.DEM_stack['global_extent'][3], data_DEMi['ur_corner'][1])
        if idx == 0:
            BV.global_DEM_path = input_path


def filter_dem_stack_using_boundary(BV, param, boundary_file, copy=False):
    """ Select the DEM files corresponding to the polygon provided by user

    :param BV: catchment model
    :type BV: ModelCatchment
    :param param: parameters
    :type param: Parameter
    :param boundary_file: path to the file containing the boundary polygon
    :type boundary_file: str
    :param copy: copy the selected DEM into a dedicated folder
    :type copy: bool, optional
    """
    BV.display('Filtering DEM stack process starts with ' + str(len(BV.DEM_stack['file_list'])) + ' DEM', 0)
    list_of_files = []
    if path.exists(boundary_file):
        poly_shp = gpd.read_file(boundary_file)
        for i, row in poly_shp.iterrows():
            polygone = row['geometry']
            # Finding DEM files onc
            for dem_i, data_dem in enumerate(BV.DEM_stack['data_DEM']):
                coords = (data_dem['ll_corner'], (data_dem['ur_corner'][0], data_dem['ll_corner'][1]),
                          data_dem['ur_corner'], (data_dem['ll_corner'][0], data_dem['ur_corner'][1]))
                dem_polygon = Polygon(coords)
                if dem_polygon.within(polygone):
                    if copy:
                        if not path.isdir(path.join(param.work_path, 'DEM_selected')):
                            mkdir(path.join(param.work_path, 'DEM_selected'))
                        print(BV.DEM_stack['file_list'][dem_i])
                        source = BV.DEM_stack['file_list'][dem_i]
                        dir_name, filename = path.split(BV.DEM_stack['file_list'][dem_i])
                        destination = path.join(param.work_path, 'DEM_selected', filename)
                        shutil.copy2(source, destination)

                        list_of_files.append(path.join(param.work_path, 'DEM_selected', filename))
                    else:
                        list_of_files.append(BV.DEM_stack['file_list'][dem_i])
        BV.DEM_stack['file_list'] = list_of_files
    BV.display('Filtering DEM stack process finishes with ' + str(len(BV.DEM_stack['file_list'])) + ' DEM', 0)


def build_globalDEM(BV, param):
    """Compute the hydrological map using pysheds or pyflwdir library. The final map is a merge of all DEM

    :param BV: watershed
    :type BV: ModelCatchment
    :param param: parameters for computation
    :type param: Parameters
    """

    DEM_basename = path.splitext(path.basename(BV.DEM_stack['file_list'][0]))[0]
    full_save_path = path.join(param.work_path, DEM_basename + '_' + str(int(param['C']['resolution'])) + "m")

    BV.global_DEM_path, merge_meta = merge_rasters(BV.DEM_stack['file_list'], full_save_path,
                                                   BV.crs, param['C']['resolution'])

    data_DEMi, BV.globalGrid, DEMi = read_DEM(BV.global_DEM_path, BV.crs, hydro_lib=param['C']['hydro_lib'])
    BV.globalDEM = DEMi['Z']
    param['C']['resolution'] = data_DEMi['cell_size']

    if param['C']['hydro_lib'] == 'pysheds':
        from thirdparty.nested_watersheds.make_catchments import generate_catchments
        pit_filled_dem = BV.globalGrid.fill_pits(BV.globalDEM)
        flooded_dem = BV.globalGrid.fill_depressions(pit_filled_dem)
        inflated_dem = BV.globalGrid.resolve_flats(flooded_dem)
        dirmap = (64, 128, 1, 2, 4, 8, 16, 32)
        fdir = BV.globalGrid.flowdir(inflated_dem, dirmap=dirmap)
        acc = BV.globalGrid.accumulation(fdir, dirmap=dirmap)
        x_snap, y_snap = BV.globalGrid.snap_to_mask(acc >= param['N']['minAccumulativeArea'], (BV.outlet_point.x, BV.outlet_point.y))
        catch = BV.globalGrid.catchment(x=x_snap, y=y_snap, fdir=fdir, dirmap=dirmap, xytype='coordinate')
        x_min, y_min = np.min(data_DEMi['ll_corner'][0]), np.min(data_DEMi['ll_corner'][1])
        basins, branches = generate_catchments(BV.global_DEM_path, acc_thresh=param['N']['minAccumulativeArea'], so_filter=param['N']['classeMax'],crs = BV.crs)
        output_path = path.join(param.work_path, 'boundary_area.shp')
        basins.to_file(output_path, driver="ESRI Shapefile")


    elif param['C']['hydro_lib'] == 'pyflwdir':
        try:
            import pyflwdir
        except ImportError:
            raise Exception("pyflwdir not available")

        if not data_DEMi['no_data']:
            nodata = np.nan
        else:
            nodata =  data_DEMi['no_data']

        flw = pyflwdir.from_dem(
            data=DEMi['Z'],
            nodata=nodata,
            transform=data_DEMi['transform'],
            latlon=BV.crs.is_geographic
        )
        dummy_data = np.ones(flw.shape, dtype=np.float32)
        area = flw.accuflux(dummy_data)
        cell_area = abs(data_DEMi['transform'].a) * abs(data_DEMi['transform'].e)  # pixel width * pixel height
        acc = area * cell_area
        DEM_polygon = Polygon(data_DEMi['polygon_coords'])
        if BV.outlet_point.within(DEM_polygon):
            catch = flw.basins(xy=(BV.outlet_point.x, BV.outlet_point.y))
        else:
            catch = flw.basins(xy=(BV.DEM_stack['minZ'][0].x, BV.DEM_stack['minZ'][0].y))
        # first define streams based on an upstream area threshold, here 100 km2
        stream_mask = flw.upstream_area("m2") > param['N']['minAccumulativeArea']
        # calculate the stream orders for these streams
        strahler = flw.stream_order(type="strahler", mask=stream_mask)
        feats = flw.streams(stream_mask, strord=strahler)
        fdir = flw.to_array()
        x_min, y_min = np.min(data_DEMi['ll_corner'][0]), np.min(data_DEMi['ll_corner'][1])

        # calculate subbasins with a minimum stream order and its outlets
        subbas, idxs_out = flw.subbasins_streamorder(min_sto=param['N']['classeMax'], mask=None)
        # transfrom map and point locations to GeoDataFrames
        feats_gen = rasterio.features.shapes(
            subbas.astype(np.int32),
            mask=subbas.astype(np.int32) != nodata,
            transform=flw.transform,
            connectivity=8,
        )
        feats = [
            {"geometry": geom, "properties": {'basin': val}} for geom, val in list(feats_gen)
        ]

        # parse to geopandas for plotting / writing to file
        gdf_bas = gpd.GeoDataFrame.from_features(feats, crs=BV.crs)
        gdf_bas['basin'] = gdf_bas['basin'].astype(np.int32)

        # 1. Extraire toutes les frontières
        all_boundaries = gdf_bas.boundary

        # 2. Fusionner pour éliminer doublons internes
        merged = unary_union(all_boundaries)

        # 3. Linemerge pour créer les segments continus entre bifurcations
        merged_lines = linemerge(merged)

        # 4. Si résultat MultiLineString, éclater en lignes individuelles
        if isinstance(merged_lines, LineString):
            geoms = [merged_lines]
        elif isinstance(merged_lines, MultiLineString):
            geoms = list(merged_lines.geoms)
        else:
            geoms = []

        # 5. Créer le GeoDataFrame final
        gdf_lines_clean = gpd.GeoDataFrame(geometry=geoms, crs=gdf_bas.crs)

        # 6. Sauv

        # 5. Sauvegarder
        output_path = path.join(param.work_path, 'boundary_area.shp')
        gdf_lines_clean.to_file(output_path, driver="ESRI Shapefile")

    else:
        fdir, acc, catch = np.zeros(BV.globalDEM.shape), np.zeros(BV.globalDEM.shape), np.zeros(BV.globalDEM.shape)
        x_min, y_min = np.min(data_DEMi['ll_corner'][0]), np.min(data_DEMi['ll_corner'][1])
        print('No hydrological computation, pysheds/pyflwdir may not be installed or no raster format for input')

    n_rows, n_cols = BV.globalDEM.shape[0], BV.globalDEM.shape[1]
    X, Y = np.meshgrid(np.arange(x_min, x_min + n_cols * param['C']['resolution'], param['C']['resolution']),
                       np.arange(y_min, y_min + n_rows * param['C']['resolution'], param['C']['resolution']),
                       indexing='xy')
    Y = np.flipud(Y)
    # carte pour l'hydrologie
    BV.Map = {'X': X, 'Y': Y, 'fdir': fdir, 'acc': acc, 'mask': catch, 'w1': None, 'd1': None, 'i2': None}


    #create flow acc raster
    save_path = path.join(param.work_path, 'flow_acc.tif')
    merge_meta.update(dtype=acc.dtype)
    with rasterio.open(save_path, "w+", **merge_meta) as dest:
        dest.write(BV.Map['acc'], 1)

    # create watershed mask raster
    save_path = path.join(param.work_path, 'watershed_mask.tif')
    catch = catch.astype(np.int32)
    merge_meta.update(dtype=np.int32, nodata=0)
    with rasterio.open(save_path, "w+", **merge_meta) as dest:
        dest.write(catch, 1)


def read_network(BV, river_filename, nclassemax=1):
    """Read the hydrographic network in the .shp format from BD carthage

    :param BV: Watershed
    :type BV: ModelCatchment
    :param river_filename: River file path
    :type river_filename: str
    :param nclassemax: River class
    :type nclassemax: int, optional
    """
    shapefile = gpd.read_file(river_filename)

    if 'CdEntiteHy' in shapefile.columns.to_list():  #BD carthage
        try:
            classe = shapefile['Classe'].astype(int)
        except:
            classe = pd.Series(np.zeros((len(shapefile),)))

        full_network = shapefile[classe <= nclassemax]
        BV.network_type = 'BDCarthage'
    else:
        if 'CdOH' in shapefile.columns.to_list():  #BD topage
            full_network = pd.DataFrame(columns=['gid', 'Reach', 'CdEntiteHy', 'Classe', 'NomEntiteH'])
            full_network['gid'] = shapefile['gid']
            full_network['CdEntiteHy'] = shapefile['CdOH']
            full_network['NomEntiteH'] = shapefile['TopoOH']
            full_network['geometry'] = shapefile['geometry']
            full_network['Classe'] = pd.Series(np.zeros((len(shapefile),)))
            full_network['Reach'] = pd.Series(np.zeros((len(shapefile),)))
            BV.network_type = 'BDTopage'
        else:  # shapefile simple
            full_network = pd.DataFrame(columns=['gid', 'Reach', 'CdEntiteHy', 'Classe', 'NomEntiteH'])
            count_river = 0

            if type(shapefile['geometry']) == linestring.LineString:
                full_network['gid'] = str(0)
                full_network['CdEntiteHy'] = str(0)
                full_network['NomEntiteH'] = 'River'
                full_network['geometry'] = shapefile['geometry']
                full_network['Classe'] = str(0)
                full_network['Reach'] = 'Reach'
                BV.network_type = 'line'

            elif type(shapefile['geometry'][0]) == linestring.LineString:
                if 'swot_obs' in shapefile.columns.to_list():
                    BV.network_type = 'SWORD'
                else:
                    BV.network_type = 'line from series'
                for i in range(len(shapefile['geometry'])):
                    if 'swot_obs' in shapefile.columns.to_list():
                        wse = shapefile['wse'][i]
                    else:
                        wse = 0
                    full_network.loc[i, 'gid'] = count_river
                    full_network.loc[i, 'CdEntiteHy'] = i
                    full_network.loc[i, 'NomEntiteH'] = 'River' + str(i)
                    full_network.loc[i, 'geometry'] = shapefile['geometry'][i]
                    if 'swot_obs' in shapefile.columns.to_list():
                        full_network.loc[i, 'Classe'] = wse
                    else:
                        full_network.loc[i, 'Classe'] = 0
                    full_network.loc[i, 'Reach'] = 'Reach'
                    count_river += 1

            elif type(shapefile) == multilinestring.MultiLineString:

                if 'swot_obs' in shapefile.columns.to_list():
                    BV.network_type = 'SWORD'
                else:
                    BV.network_type = 'multiline'

                for i, row in shapefile.iterrows():
                    full_network.loc[i, 'gid'] = count_river
                    full_network.loc[i, 'CdEntiteHy'] = i
                    full_network.loc[i, 'NomEntiteH'] = 'River' + str(i)
                    full_network.loc[i, 'geometry'] = row['geometry']
                    if 'swot_obs' in shapefile.columns.to_list():
                        full_network.loc[i, 'Classe'] = row['wse']
                    else:
                        full_network.loc[i, 'Classe'] = 0
                    full_network.loc[i, 'Reach'] = 'Reach'
                    count_river += 1

    name_river = full_network.columns.to_list()[0]
    pd_temp = pd.DataFrame(columns=[name_river, 'Reach', 'CdEntiteHy', 'Classe', 'NomEntiteH', 'geometry'])
    filtered_network = gpd.GeoDataFrame(pd_temp, crs=BV.crs)

    #transformation des multilines en lines
    n_lines = 0  # len(full_network)
    for index, row in full_network.iterrows():
        filtered_network.loc[n_lines] = row
        if type(row['geometry']) == linestring.LineString:
            filtered_network.loc[n_lines, :] = row
            x_river = row['geometry'].xy[0]
            y_river = row['geometry'].xy[1]
            filtered_network.loc[n_lines, "geometry"] = LineString(
                [(x_river[j], y_river[j], 0) for j in range(len(x_river))])
            n_lines += 1

        elif type(row['geometry']) == multilinestring.MultiLineString:
            # Convert multiline into line
            line_list = []
            for j in range(len(row['geometry'].geoms)):
                x_river = row['geometry'].geoms[j].xy[0]
                y_river = row['geometry'].geoms[j].xy[1]
                line = LineString([(x_river[j], y_river[j], 0) for j in range(len(x_river))])
                line_list.append(line)

            for line in line_list:
                filtered_network.loc[n_lines, :] = row
                filtered_network.loc[n_lines, 'geometry'] = line
                n_lines += 1

    BV.hydro_network = filtered_network


def create_network_hydrolib(BV, param, hydro_lib):
    """ create hydrological network using external librairy

    :param BV: Watershed
    :type BV: ModelCatchment
    :param param: Parameters requires for computation
    :type param: Parameters
    :hydro_lib : librairy for hydrological computation
    :type hydro_lib: str
    """
    pd_temp = pd.DataFrame(columns=['gid', 'Reach', 'CdEntiteHy', 'Classe', 'NomEntiteH', 'geometry'])
    network = gpd.GeoDataFrame(pd_temp, crs=BV.crs)

    if hydro_lib == 'pysheds':
        minAcc_pixel = param['N']['minAccumulativeArea']
        #recherche du point du réseau le plus proche de l'exutoire donné
        dirmap = (64, 128, 1, 2, 4, 8, 16, 32)
        branches = BV.globalGrid.extract_river_network(BV.Map['fdir'], BV.Map['acc'] > minAcc_pixel, dirmap=dirmap)
        n_lines = 0
        for branch in branches['features']:
            lineS = LineString(branch['geometry']['coordinates'])
            if not lineS.is_empty:
                network.loc[n_lines, 'gid'] = n_lines
                network.loc[n_lines, 'Reach'] = n_lines
                network.loc[n_lines, 'CdEntiteHy'] = str(n_lines)
                network.loc[n_lines, 'Classe'] = 0
                network.loc[n_lines, 'NomEntiteH'] = str(n_lines)
                network.loc[n_lines, 'geometry'] = lineS
                n_lines += 1

    elif hydro_lib == 'pyflwdir':
        import pyflwdir

        data_DEMi, BV.globalGrid, DEMi = read_DEM(BV.global_DEM_path, BV.crs, hydro_lib=hydro_lib)
        dirmap = (64, 128, 1, 2, 4, 8, 16, 32)
        if not data_DEMi['no_data']:
            nodata = np.nan
        else:
            nodata = data_DEMi['no_data']

        flw = pyflwdir.from_dem(
            data=DEMi['Z'],
            nodata=nodata,
            transform=data_DEMi['transform'],
            latlon=BV.crs.is_geographic
        )
        # first define streams based on an upstream area threshold, here 100 km2
        stream_mask = flw.upstream_area("m2") > param['N']['minAccumulativeArea']
        # calculate the stream orders for these streams
        strahler = flw.stream_order(type="strahler", mask=stream_mask)
        # vectorize stream order for plotting
        feats = flw.streams(stream_mask, strord=strahler)
        feat = flw.streams()
        gdf = gpd.GeoDataFrame.from_features(feats, crs=BV.crs)
        count_river = 0
        for i, row in gdf.iterrows():
            network.loc[i, 'gid'] = count_river
            network.loc[i, 'CdEntiteHy'] = i
            network.loc[i, 'Classe'] = gdf.loc[i, 'strord']
            network.loc[i, 'NomEntiteH'] = 'River' + str(i)
            network.loc[i, 'geometry'] = row['geometry']
            count_river += 1
    else:
        raise Exception('Hydrological library selected is not available in create_network_hydrolib')

    BV.hydro_network = network
    BV.network_type = 'fromAccumulation'


def create_network_from_lines(BV, centerlines, outline,type = 'channel',DEMfile = None):
    """create a simple Modelcatchment class with only one reach from existing BV and networks. For use separately dikes.

     :param BV: Watershed
     :type BV: ModelCatchment
     :centerlines: list of all centerlines found or reach of BV
     :type centerlines: list
     :outline: list of all outline of polygones or banklines
     :type centerlines: list
     :type : if dike the centerlines is projected on DEM to find crest
     :type type: str

    """

    BV.hydro_network = gpd.GeoDataFrame()
    n_lines = 0
    BV.outline = []

    for i in range(len(centerlines)):
        if not centerlines[i].is_empty:
            BV.hydro_network.loc[n_lines, 'geometry'] = centerlines[i]
            BV.hydro_network.loc[n_lines, 'River'] = n_lines
            BV.hydro_network.loc[n_lines, 'Reach'] = 0
            BV.hydro_network.loc[n_lines, 'CdEntiteHy'] = str(n_lines)
            BV.hydro_network.loc[n_lines, 'Classe'] = 0
            BV.hydro_network.loc[n_lines, 'NomEntiteH'] = str(n_lines)
            BV.outline.append(outline[i])
            n_lines += 1

    if type == 'dike':
        read_DEM_stack(BV,DEMfile)
        BV.projectionOnDEM()
    else:
        for i, row in BV.hydro_network.iterrows():
            line = row['geometry']
            line_with_z = LineString([(x, y, 0) for x, y in zip(line.xy[0], line.xy[1])])
            BV.hydro_network.loc[i, 'geometry'] = line_with_z


def init_network_using_banks(BV, param, rasterDEM=None):
    """ launch data analysis

        :param BV: catchment model
        :type BV: ModelCatchment class
        :param param: Parameters required for computation
        :type param: Parameters
        :param rasterDEM : DEM used to assign elevatin to bank lines
        :type rasterDEM: str
    """

    riverbanks_filepath = param['I']['riverbanks_filepath']
    if is_raster(param['I']['riverbanks_filepath']):
        riverbanks_filepath = polygonize_water_mask(riverbanks_filepath, param.work_path,
                                              param['C']['DEM_CRS_ID'], param['C']['mask_water_value'])
    bank_gdf = gpd.read_file(riverbanks_filepath)

    if param['I']['network_filepath']:
        BV.outline = convert_geometry_to_multilinestring(bank_gdf)
        read_network(BV, param['I']['network_filepath'], nclassemax=param['N']['classeMax'])
    else:
        centerlines, outline = sort_lines_and_centerline(bank_gdf, param['XS']['creation_step'])
        create_network_from_lines(BV, centerlines, outline)
        BV.network_type = 'fromMask'
