######################################################################
#             / ____|  ____|  __ \| |  | |  ____|  ____|             #
#            | |    | |__  | |__) | |__| | |__  | |__                #
#            | |    |  __| |  ___/|  __  |  __| |  __|               #
#            | |____| |____| |    | |  | | |____| |____              #
#             \_____|______|_|    |_|  |_|______|______|             #
######################################################################
#
# CEPHEE
# Copyright (C) 2024 Toulouse INP
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License for more details :
# <http://www.gnu.org/licenses/>.
#
######################################################################

# global
import glob
import geopandas as gpd
from geopandas import GeoDataFrame
from shapely.geometry import multilinestring, linestring, Polygon,MultiPoint,GeometryCollection
from shapely.ops import nearest_points
from shapely.ops import split
from scipy.spatial import cKDTree
from os import path, mkdir
import rasterio.mask
# local
#
from .Tools import *
from .Reach import Reach
from .Data import *

class ModelCatchment:
    """ TODO

    """
    def __init__(self):
        self.reach = []
        self.ordered_network = GeoDataFrame()
        self.HydroNetwork = None
        self.id_outlet = 0
        self.list_of_outlet = []
        self.outlet_point = None
        self.junction = None # geodataframe
        self.DEM_method = None
        self.DEM_stack = None
        self.DEM = []
        self.crs = None
        self.grid = []
        self.Zmin = []
        self.Zmax = []
        self.cellsize = []
        self.globalGrid = None
        self.globalDEM = None
        self.global_DEM_path = None
        self.Map = None
        self.BDOE =None

    def read_data(self, param):

        # Create work folder if not present
        if not path.isdir(param.work_path):
            mkdir(param.work_path)

        '''
        if os.path.exists(os.path.join(param.work_path, 'discharge.shp')):
            for file in os.listdir(param.work_path):
            # On sépare le nom du fichier de son extension
                filename, extension = os.path.splitext(file)
                if filename == 'discharge':
                    # Création du chemin complet du fichier
                    chemin_complet = os.path.join(os.path.join(param.work_path,file))
                    os.remove(chemin_complet)


        '''


        # Init CRS method
        self.crs = CRS.from_user_input(param['C']['DEM_CRS_ID'])


        if param.qgis_river_layer and param['N']['riverPath']:
            self.read_network(qgis_river_layer=param.qgis_river_layer,
                                                  nclassemax=param['N']['classeMax'])
        elif param['N']['riverPath']:
            self.read_network(river_filename=param['N']['riverPath'],
                                                  nclassemax=param['N']['classeMax'])

        else :
            self.HydroNetwork =GeoDataFrame()

        # donne le outlet
        self.outlet_point = param.user_outlet_point


        # lire le globale extent avec find catchement
        if param.qgis_DEM_layer:
            self.read_DEM_stack(qgis_DEM_layer=param.qgis_DEM_layer)
        else:
            self.read_DEM_stack(DEM_path_ext=(param['C']['DEMpath'], param['C']['DEM_file_extension']))

        #recherche des bassin versant si plusieurs dalles MNT
        if param['C']['findCatchment']:
            self.DEM_stack = self.findCatchment(param)  # selectionne dalles MNT concernées par l'exutoire


        # Hydro-network generation if not present
        if not self.HydroNetwork.empty:
            self.clipRiverFromDEM()
            self.build_globalDEM(param)
        else:
            self.build_globalDEM(param)
            self.create_network(param)

        # select network in DEM from hydro-network
        #self.clipRiverFromDEM()


    def read_DEM_stack(self, DEM_path_ext = None, qgis_DEM_layer = None, DEM_file_list = None):
        """Read all the DEM in the folder given by DEMpath. The user has to select the region of interest
        selecting the DEM file corresponding

        :param DEM_path_ext: DEM path and file extension
        :type DEM_path_ext: tuple(str, str), optional
        :param qgis_DEM_layer: QGIS DEM layer
        :type qgis_DEM_layer: QGIS DEM layer, optional
        :param DEM_file_list: DEM file list
        :type DEM_file_list: list(str), optional
        """


        # select one merged layer or DEM folder
        if qgis_DEM_layer:
            input_paths = [qgis_DEM_layer.dataProvider().dataSourceUri()]
        elif DEM_file_list:
            input_paths = DEM_file_list
        elif DEM_path_ext:
            flist = sorted(glob.glob((os.path.join(DEM_path_ext[0], '*' + DEM_path_ext[1]))))

            input_paths = []
            for i in range(len(flist)):
                #if not flist[i][-10:-4] == 'merged':
                input_paths.append(flist[i])
        else:
            input_paths = []

        if not input_paths:
            raise Exception('No DEM data, folder is empty or QGIS layer not present')

        self.DEM_stack ={
            'global_extent': [1e9, 1e9, 0.0, 0.0],
            'file_list': input_paths,
            'filtered_file_list': [],
            'data_DEM': [None for _ in range(len(input_paths))],
        }

        resolution =[]
        for idx, input_path in enumerate(input_paths):
            print('read DEM n°'+str(idx+1)+'/' + str(len(input_paths)))
            data_DEMi, _, _ = read_DEM(input_path, self.crs)
            self.DEM_stack['data_DEM'][idx] = data_DEMi
            resolution.append(self.DEM_stack['data_DEM'][idx]['cell_size'])
            self.DEM_stack['global_extent'][0] = min(self.DEM_stack['global_extent'][0], data_DEMi['ll_corner'][0])
            self.DEM_stack['global_extent'][1] = min(self.DEM_stack['global_extent'][1], data_DEMi['ll_corner'][1])
            self.DEM_stack['global_extent'][2] = max(self.DEM_stack['global_extent'][2], data_DEMi['ur_corner'][0])
            self.DEM_stack['global_extent'][3] = max(self.DEM_stack['global_extent'][3], data_DEMi['ur_corner'][1])
        self.resmin = np.min(resolution)

    def findCatchment(self, param):
        """select the DEM corresponding to the polygon provided by users
        :param parameter: parameter for ciomputation
        :type parameter: Parameter

        """
        print('findCatchment starts with '+ str(len(self.DEM_stack['file_list'])) + ' DEM')

        list_of_files =[]

        if os.path.exists(os.path.join(param.work_path, 'boundary_catchment.shp')):
            poly_shp = gpd.read_file(os.path.join(param.work_path, 'boundary_catchment.shp'))
            if not path.isdir(os.path.join(param.work_path, 'DEM_selected')):
                mkdir(os.path.join(param.work_path, 'DEM_selected'))

            for i, row in poly_shp.iterrows():
                polygone = row['geometry']
                # Finding DEM file containing outlet
                ind_outlet = None
                for dem_i, data_dem in enumerate(self.DEM_stack['data_DEM']):
                    coords = (data_dem['ll_corner'], (data_dem['ur_corner'][0], data_dem['ll_corner'][1]),
                              data_dem['ur_corner'], (data_dem['ll_corner'][0], data_dem['ur_corner'][1]))
                    polygon_i = Polygon(coords)

                    if polygon_i.within(polygone):
                        with open(self.DEM_stack['file_list'][dem_i], 'rb') as source:
                            dir_name,filename = os.path.split(self.DEM_stack['file_list'][dem_i])
                            with open(os.path.join(param.work_path, 'DEM_selected',filename), 'wb') as destination:
                                # Lire le contenu du fichier source et l'écrire dans le fichier destination
                                destination.write(source.read())
                        list_of_files.append(os.path.join(param.work_path, 'DEM_selected',filename))

            self.read_DEM_stack(DEM_file_list=list_of_files)
        print('findCatchment finishes with ' + str(len(self.DEM_stack['file_list'])) + ' DEM')

        return self.DEM_stack


    def build_globalDEM(self, param, user_list=None):
        """Compute the hydrological map using pyshed library. The final map isa merged of all DEM


        :param parameter: parameter for ciomputation
        :type parameter: Parameter
        :param user_list:  list of file for all DEM to consider
        :type user_list: string:
        """

        if param.qgis_DEM_layer:
            raster_layer_path = param.qgis_DEM_layer.dataProvider().dataSourceUri()
            full_save_path = os.path.join(param.work_path,
                                          os.path.splitext(os.path.basename(raster_layer_path))[0]
                                          + '_' + str(int(param['C']['resolution'])) + "m")

            self.global_DEM_path, merge_meta = merge_rasters([raster_layer_path], full_save_path, self.crs, param['C']['resolution'] )

        elif user_list:
            full_save_path = os.path.join(param.work_path,
                                        'findCatchment_' + str(int(param['C']['resolution'])) + "m_merged")
            self.global_DEM_path, merge_meta = merge_rasters(user_list, full_save_path,
                                               self.crs, param['C']['resolution'] )

        else:
            _, dir_name = os.path.split(param['C']['DEMpath'])
            
            full_save_path = os.path.join(param.work_path,
                                       '_' + str(int(param['C']['resolution'])) + "m_merged")
            self.global_DEM_path, merge_meta = merge_rasters(self.DEM_stack['file_list'], full_save_path,
                                               self.crs, param['C']['resolution'] )

        data_DEMi, self.globalGrid, DEMi = read_DEM(self.global_DEM_path, self.crs)
        self.globalDEM = DEMi['Z']
        param['C']['resolution'] = self.globalGrid.affine[0]

        # Condition DEM
        # ----------------------
        # Fill pits in DEM
        pit_filled_dem = self.globalGrid.fill_pits(self.globalDEM)
        # Fill depressions in DEM
        flooded_dem = self.globalGrid.fill_depressions(pit_filled_dem)
        # Resolve flats in DEM
        inflated_dem = self.globalGrid.resolve_flats(flooded_dem)
        # Determine D8 flow directions from DEM
        # ----------------------
        # Specify directional mapping
        dirmap = (64, 128, 1, 2, 4, 8, 16, 32)
        # Compute flow directions
        # -------------------------------------
        fdir= self.globalGrid.flowdir(inflated_dem, dirmap=dirmap)
        # Calculate flow accumulation
        # --------------------------
        acc = self.globalGrid.accumulation(fdir, dirmap=dirmap)
        x_min,y_min = self.globalDEM.extent[0],self.globalDEM.extent[2]
        # Delineate the catchment
        # Snap pour point to high accumulation cell
        x_snap, y_snap = self.globalGrid.snap_to_mask(acc >=  np.max(acc)/2, (self.outlet_point.x, self.outlet_point.y))
        catch = self.globalGrid.catchment(x=x_snap, y=y_snap, fdir=fdir, dirmap=dirmap,
                               xytype='coordinate')

        n_rows, n_cols =self.globalDEM.shape[0], self.globalDEM.shape[1]

        X, Y = np.meshgrid(np.arange(x_min, x_min+n_cols*param['C']['resolution'],param['C']['resolution']),
                          np.arange(y_min, y_min+n_rows*param['C']['resolution'],param['C']['resolution']),
                           indexing='xy')
        Y=np.flipud(Y)

        # carte pour l'hydrologie
        self.Map ={'X':X,'Y':Y,'fdir' : fdir, 'acc':acc, 'mask' : catch, 'w1' : None,'d1' : None,'i2' : None}

        #create flow acc raster
        save_path = os.path.join(param.work_path,'flow_acc.tif')
        merge_meta.update(dtype=acc.dtype)

        with rasterio.open(save_path, "w+", **merge_meta) as dest:
            dest.write(self.Map['acc'],1)

        # create watershed mask raster
        save_path = os.path.join(param.work_path, 'watershed_mask.tif')
        catch = catch.astype(np.int32)
        merge_meta.update(dtype=np.int32)
        with rasterio.open(save_path, "w+", **merge_meta) as dest:
            dest.write(catch,1)


    def read_network(self, qgis_river_layer=None, river_filename=None, nclassemax=1, BDOEfilename=None):
        """Read the hydrographic network in the .shp format from BD carthage

        :filename: file path of the BD carthage file
        :rtype : pandas geodataframe : hydrographic network
        :rtype: int : number of river reach
        """
        if qgis_river_layer:
            river_filename = qgis_river_layer.dataProvider().dataSourceUri()
        if not river_filename:
            raise Exception('river layer or filename not provided')

        shapefile = gpd.read_file(river_filename)

        # read BDOE if present
        if BDOEfilename:
            shapefileBDOE = gpd.read_file(BDOEfilename)
            self.BDOE = shapefileBDOE

        if 'Classe' in shapefile.columns.to_list(): #BD carthage
            classe = shapefile['Classe'].astype(int)
            full_network = shapefile[classe<=nclassemax]
            self.network_type = 'BDCarthage'
        else:
            if 'CdOH' in shapefile.columns.to_list(): #BD topage
                full_network = pd.DataFrame(columns=['gid', 'Reach', 'CdEntiteHy', 'Classe', 'NomEntiteH'])
                full_network['gid'] = shapefile['gid']
                full_network['CdEntiteHy'] = shapefile['CdOH']
                full_network['NomEntiteH'] = shapefile['TopoOH']
                full_network['geometry'] = shapefile['geometry']
                full_network['Classe'] = pd.Series(np.zeros((len(shapefile), )))
                full_network['Reach'] = pd.Series(np.zeros((len(shapefile), )))
                self.network_type = 'BDTopage'

            else: # shapefile simple

                full_network = pd.DataFrame(columns=['gid', 'Reach', 'CdEntiteHy', 'Classe', 'NomEntiteH'])
                count_river = 0

                if type(shapefile['geometry']) == linestring.LineString:
                    full_network['gid'] = str(0)
                    full_network['CdEntiteHy'] =str(0)
                    full_network['NomEntiteH'] = 'River'
                    full_network['geometry'] = shapefile['geometry']
                    full_network['Classe'] = str(0)
                    full_network['Reach'] = 'Reach'
                    self.network_type = 'line'

                elif type(shapefile['geometry'][0]) == linestring.LineString:
                    for i in range(len(shapefile['geometry'])):
                        full_network.loc[i, 'gid'] = count_river
                        full_network.loc[i, 'CdEntiteHy'] = i
                        full_network.loc[i, 'NomEntiteH'] = 'River' + str(i)
                        full_network.loc[i, 'geometry'] = shapefile['geometry'][i]
                        full_network.loc[i, 'Classe'] = 0
                        full_network.loc[i, 'Reach'] = 'Reach'
                        count_river += 1
                        self.network_type = 'line from series'

                elif type (shapefile) == multilinestring.MultiLineString:
                    self.network_type = 'multiline'
                    for i, row in shapefile.iterrows():
                        full_network[i,'gid'] = count_river
                        full_network[i,'CdEntiteHy'] = i
                        full_network[i,'NomEntiteH'] = 'River'+str(i)
                        full_network[i,'geometry'] = row['geometry']
                        full_network[i,'Classe'] = 0
                        full_network[i,'Reach'] = 'Reach'
                        count_river+=1

        pd_temp = pd.DataFrame(columns=['gid', 'Reach', 'CdEntiteHy', 'Classe', 'NomEntiteH','geometry'])
        filtered_network = GeoDataFrame(pd_temp,crs =self.crs)

        #transformation des multilines en lines
        n_lines = 0 # len(full_network)
        for index, row in full_network.iterrows():

            filtered_network.loc[n_lines] = row

            if type(row['geometry']) == linestring.LineString:
                filtered_network.loc[n_lines, :] = row
                x_river = row['geometry'].xy[0]
                y_river = row['geometry'].xy[1]
                filtered_network.loc[n_lines, "geometry"] = (
                    LineString([(x_river[j], y_river[j], 0) for j in range(len(x_river))]))
                n_lines += 1

            elif type(row['geometry']) == multilinestring.MultiLineString:
                # Convert multiline into line
                line_list = []
                for j in range(len(row['geometry'].geoms)):

                    x_river = row['geometry'].geoms[j].xy[0]
                    y_river = row['geometry'].geoms[j].xy[1]
                    line = LineString([(x_river[j], y_river[j],0) for j in range(len(x_river))])
                    line_list.append(line)

                for line in line_list:
                    filtered_network.loc[n_lines, :] = row
                    filtered_network.loc[n_lines, 'geometry'] = line
                    n_lines += 1


        self.HydroNetwork = filtered_network



    def create_network(self,param):
        """Create the hydrographic network from hydrological map (directino , flow acc) using pysheds

        :param parameter: parameter for ciomputation
        :type parameter: Parameter

        """
        pd_temp = pd.DataFrame(columns=['gid','Reach','CdEntiteHy','Classe','NomEntiteH','geometry'])
        network = GeoDataFrame(pd_temp, crs =self.crs)
        minAcc =param['N']['minAccumulativeArea']
        #recherche du point du réseau le plus proche de l'exutoire donné
        dirmap = (64, 128, 1, 2, 4, 8, 16, 32)
        branches = self.globalGrid.extract_river_network(self.Map['fdir'], self.Map['acc']> minAcc, dirmap=dirmap)
        n_lines = 0
        for branch in branches['features']:
            lineS=LineString(branch['geometry']['coordinates'])

            network.loc[n_lines,'gid'] = n_lines
            network.loc[n_lines,'Reach'] = n_lines
            network.loc[n_lines,'CdEntiteHy'] = str(n_lines)
            network.loc[n_lines,'Classe'] = 0
            network.loc[n_lines,'NomEntiteH'] = str(n_lines)
            network.loc[n_lines,'geometry'] = lineS
            n_lines += 1

        self.HydroNetwork = network


    def clipRiverFromDEM(self):
        """Clip the hydrographic network as a function of boundary limit obtained from DEM file.
        Multilines are converted in linestring and added as a new reach
        """
        network = self.HydroNetwork
        network.reset_index()

        remove_index = []
        Xmin, Ymin, Xmax, Ymax =self.DEM_stack['global_extent']
        for index, row in network.iterrows():

            if type(row['geometry']) == linestring.LineString:
                x_river = row['geometry'].xy[0]
                y_river = row['geometry'].xy[1]
                xclip, yclip =[],[]
                # Keeping points inside the DEM
                for xr, yr in zip(x_river, y_river):
                    if Xmin < xr < Xmax and Ymin < yr < Ymax:
                        xclip.append(xr)
                        yclip.append(yr)
                if len(xclip) > 1:
                    network.loc[index, 'geometry'] = LineString([(xclip[j], yclip[j],0) for j in range(len(xclip))])
                else:
                    # Removing line outside the DEM
                    remove_index.append(network.index[index])

            else:
                remove_index.append(network.index[index])

        network.drop(remove_index, inplace=True)
        network.reset_index()


    def projectionOnDEM(self, type_projection, interpolation_method):
        """Project the river reach (in the hydrographic network) on the DEM.
        The projection is made sequentially for each DEM file.

        The variable type_projection gives the method for projection:
            interpolation : use of interpolation 2D from scipy for each point on the DEM file
            raster : transform line in binary map (fast but only one measurement by DEM cell size)
        Metadata from BD carthage are saved and adapted

        :param type_projection: interpolation or raster
        :type type_projection: string
        :param interpolation_method:  method for 2D interpolation (nearest, cubic, linear)
        :type interpolation_method: string:
        """
        network = self.HydroNetwork
        network.reset_index()

        pd_temp = pd.DataFrame(columns=['River', 'Reach', 'CdEntiteHy', 'Classe', 'NomEntiteH','geometry'])
        projected_network = GeoDataFrame(pd_temp,crs =self.crs)
        n_lines = 0
        allX,allY,allid,allnpoint =[],[],[],[]
        for i, row in network.iterrows():
            line = row['geometry']
            allX += [coord[0] for coord in line.coords]
            allY += [coord[1] for coord in line.coords]
            allid += [i for _ in line.coords]
            allnpoint += [j for j in range(len(line.coords))]

        list_of_point_tot = []
        list_of_z = []
        list_of_id_tot = []
        list_of_dist_point_tot = []

        for idx, data_DEMi in enumerate(self.DEM_stack['data_DEM']):
            print('HydroNetwork projection on  DEM n°'+str(idx+1)+'/' + str(len(self.DEM_stack['file_list'])))
            DEM_polygon = Polygon(data_DEMi['polygon_coords'])
            current_file = self.DEM_stack['file_list'][idx]

            if type_projection == 'interpolation':
                _, _, DEMi = read_DEM(current_file, self.crs)
            else:
                DEMi = rasterio.open(current_file, 'r', crs=self.crs)

            list_of_point = []
            list_of_id = []
            list_of_dist_point = []

            for x,y,id,npoint in zip(allX, allY, allid, allnpoint):
                if DEM_polygon.contains(Point(x, y)):
                    list_of_point.append(Point(x, y))
                    list_of_dist_point.append(npoint)
                    list_of_id.append(id)

            X = [pt.x for pt in list_of_point]
            Y = [pt.y for pt in list_of_point]
            Z = projOnDEM(X, Y, DEMi, type_projection, interpolation_method)

            list_of_z = list_of_z + Z
            list_of_point_tot = list_of_point_tot + list_of_point
            list_of_id_tot = list_of_id_tot + list_of_id
            list_of_dist_point_tot = list_of_dist_point_tot + list_of_dist_point

        for n_lines, row in network.iterrows():

            id_selected_point = [ii for ii, d in enumerate(list_of_id_tot) if d == n_lines]
            list_point_selected = []
            list_of_dist_point_selected = []
            list_z_selected = []

            if len(id_selected_point) > 1:
                for id_point in id_selected_point:
                    if not list_of_z[id_point] == data_DEMi['no_data'] or list_of_z[id_point] > -99:
                        list_point_selected.append(list_of_point_tot[id_point])
                        list_z_selected.append(list_of_z[id_point])
                        list_of_dist_point_selected.append(list_of_dist_point_tot[id_point])
                list_x = [pt.x for pt in list_point_selected]
                list_y = [pt.y for pt in list_point_selected]
                list_indpoint = [npoint for npoint in list_of_dist_point_selected]
                # trie des points dans le sens initial
                sort_list_x = [i for _, i in sorted(zip(list_indpoint, list_x))]
                sort_list_y = [i for _, i in sorted(zip(list_indpoint, list_y))]
                sort_list_z_selected = [i for _, i in sorted(zip(list_indpoint, list_z_selected))]

                line_proj = LineString(
                    [(x, y, z) for x, y, z in zip(sort_list_x, sort_list_y, sort_list_z_selected)])

                projected_network.loc[n_lines] = row
                projected_network.loc[n_lines, 'River'] = row['gid']
                projected_network.loc[n_lines, 'Reach'] = 0
                projected_network.loc[n_lines, 'geometry'] = line_proj

        self.HydroNetwork = projected_network


    def order_network(self,min_dist_junction):
        """sort river reach from upstream to downstream considering averaged slope. Consecutive reaches without
        confluence are gathered (a reach can belong to only one river)

        :param min_dist_junction: Minimal distance between junctions
        :type min_dist_junction: float
        """
        #variable pour la centerline de la rivière
        base_network = self.HydroNetwork

        n_river = 0
        pd_temp=pd.DataFrame(columns=['River','Reach','CdEntiteHy','Classe','NomEntiteH'])
        self.ordered_network = GeoDataFrame(pd_temp, geometry=[],crs =self.crs)

        # Construct coords_min_max
        x_coords = []
        y_coords = []
        z_max = []
        indexes = []

        for i, row in base_network.iterrows():
            # invert LineString if z_end > z_start
            if not row['geometry'].is_empty:
                if row['geometry'].coords[-1][2] > row['geometry'].coords[0][2]:
                    row['geometry'] = LineString(row['geometry'].coords[::-1])
                    base_network.loc[i,'geometry']= row['geometry']

                x_coords.append((row['geometry'].coords[0][0], row['geometry'].coords[-1][0]))
                y_coords.append((row['geometry'].coords[0][1], row['geometry'].coords[-1][1]))
                z_max.append(row['geometry'].coords[0][2])
                indexes.append(i)

        # Connecting reaches with connection upstream/downstream smaller min_dist_junction
        dist = np.zeros(len(z_max))
        while np.max(z_max) > 0:
            # Looking for highest point not yet processed
            current_ind = np.argmax(z_max)
            self.ordered_network.loc[n_river] = base_network.loc[indexes[current_ind]].copy()
            current_linestring = base_network.loc[indexes[current_ind], 'geometry']
            z_max[current_ind] = - np.Inf # z_max set to -inf when processed

            # Connecting all successive reaches
            while current_ind > -1:
                # computing distance between downstream of current_ind and other available reaches
                for j in range(len(z_max)):
                    if z_max[j] > 0:
                        dist[j] = ((x_coords[current_ind][1] - x_coords[j][0])**2
                            + (y_coords[current_ind][1] - y_coords[j][0])**2 ) **0.5
                    else:
                        dist[j] = np.inf

                closest_ind = np.argmin(dist)
                if dist[closest_ind] < min_dist_junction:
                    total_coords = (list(current_linestring.coords) +
                                    list(base_network.loc[indexes[closest_ind], 'geometry'].coords))
                    current_linestring = LineString(total_coords)
                    z_max[closest_ind] = - np.inf
                    current_ind = closest_ind
                else:
                    current_ind = -1

            # Re-setting line string of the current reach of the ordered network
            self.ordered_network.loc[n_river, 'geometry'] = current_linestring
            n_river += 1


    def find_junctionAndOutlet(self, dist_min):
        """Find the location of junctions between 2 reaches
        find outlet on the DEM boundary corresponding to various catchments.

        :param dist_min: minimum distance between 2 reaches to consider a junction
        :type dist_min: float
        """

        gdf = self.ordered_network
        junction = GeoDataFrame(columns=['River1','River2','geometry'],crs =self.crs)
        outlet = []
        dist = []
        #suppression de la ligne à l'aval de l'éxutoire
        new_index = [i for i in range(len(gdf))]
        for i in range(len(gdf)):
            dist.append(self.outlet_point.distance(gdf['geometry'].iloc[i]))

        id_outlet = np.argmin(dist)
        gdf.index = new_index
        PointNearOutlet = nearest_points( gdf.loc[id_outlet,'geometry'], self.outlet_point)[0]
        c0 = np.transpose(np.array([ gdf.loc[id_outlet,'geometry'].xy[:][0], gdf['geometry'].iloc[id_outlet].xy[:][1]]))
        t0 = cKDTree(c0)
        c1 = np.transpose(np.array([PointNearOutlet.x,PointNearOutlet.y]))
        distance, neighbours = t0.query(c1)
        z_outlet = gdf.loc[id_outlet,'geometry'].coords[neighbours][2]

        if z_outlet <= gdf.loc[id_outlet,'geometry'].coords[0][2]: # on garde du point 1 jusqu'à l'exutoire
            if neighbours>1:
                gdf.loc[id_outlet, 'geometry'] = LineString([ gdf.loc[id_outlet,'geometry'].coords[jj] for jj in range(neighbours)])

        else: #on garde de l'exutoire jusqu'au dernier point
            if len(gdf['geometry'].iloc[id_outlet].coords)-neighbours>1:
                gdf.loc[id_outlet,'geometry'] = LineString([gdf.loc[id_outlet,'geometry'].coords[jj] for jj in range(neighbours,len(gdf.loc[id_outlet,'geometry'].coords))])

        self.outlet_point = PointNearOutlet

        # Suppression des reach à l'aval de l'exutoire
        remove_index = []
        for i in range(len(gdf)):  # boucle sur les reaches
            Zjj = [gdf.loc[i,'geometry'].coords[ii][2] for ii in range(len(gdf.loc[i,'geometry'].coords))]
            if np.max(Zjj) < z_outlet:
                remove_index.append(i)
        gdf.drop(remove_index, inplace=True)
        gdf.reset_index()

        minLinei = []
        ind_junction = 0
        for i in range(len(gdf)): #boucle sur les reaches
            minLinei.append(gdf.loc[i,'geometry'].coords[-1]) #stockage du point le plus à l'aval

            # calcul des distances entre les points de 2 reaches
            for j in range(i): #boucle sur les reaches avant celui considéré (symétrie des distances entre points)

                line1 = LineString([(x, y) for (x, y, z) in gdf.loc[i,'geometry'].coords])
                line2 = LineString([(x, y) for (x, y, z) in gdf.loc[j,'geometry'].coords])
                c0 = np.transpose(np.array([line1.xy[:][0],line1.xy[:][1]]))
                t0 = cKDTree(c0)
                c1 = np.transpose(np.array([line2.xy[:][0],line2.xy[:][1]]))
                distance, neighbours = t0.query(c1)
                ind_c1 = int(np.argmin(distance))
                ind_c0 = int(neighbours[ind_c1])

                if distance[ind_c1]<=dist_min:

                    Zjunction = gdf.loc[i,'geometry'].coords[ind_c0][2]
                    junction.loc[ind_junction, 'River1'] = gdf['River'].loc[i]#[[i, 0], [j, 0]]
                    junction.loc[ind_junction, 'River2'] = gdf['River'].loc[j]
                    junction.loc[ind_junction, 'geometry'] = Point(c0[ind_c0][0],c0[ind_c0][1],Zjunction)
                    ind_junction +=1

        self.junction = junction

        zmin= [minLinei[jj][2] for jj in range(len(minLinei))] # list des cotes de tous les points à l'aval des reachs
        ind_river = [gdf.loc[i,'River'] for i in range(len(gdf))]
        ind_reach = []

        ind_min = np.argmin(zmin)  #exutoire le plus bas
        ind_reach.append(ind_river[ind_min])
        for coord_min in minLinei:
            outlet.append([Point(coord_min), [], [], []])

        self.list_of_outlet += outlet


    def renameReachFromJunction(self,params):
        """ Rename reach considering junction.

        River is partitioned to get reach with a consistent discharge for each reach. For each river , reaches have
        a increasing id from the downstream to upstream. Id river is also corrected from 0 to Nriver. The name of reach
        refers to river and number of reach from downstream.

        :param compute_global: whether to compute global reach considering junction.
        :type compute_global: boolean
        """
        junction = self.junction
        gdf = self.ordered_network.copy()

        #on cherche la jonction la plus basse
        zjunction = []
        n_reach = 0
        junction_temp = GeoDataFrame(columns=['geometry', 'River1','Reach1','River2','Reach2','River3','Reach3','area'],crs =self.crs)
        pd_temp = pd.DataFrame(columns=['River', 'Reach', 'CdEntiteHy', 'Classe', 'NomEntiteH'])
        gdf_temp = GeoDataFrame(pd_temp, geometry=[],crs =self.crs)

        #decoupage des rivières en reach
        for i in range(len(gdf)):
            line = LineString([(x, y, z) for (x, y, z) in gdf.loc[i,'geometry'].coords])
            ind_river = gdf.loc[i,'River']
            points_to_cut = []

            for ind_junction in range(len(self.junction)):
                if self.junction.loc[ind_junction,'River1'] == ind_river or \
                        self.junction.loc[ind_junction,'River2'] == ind_river:

                    if self.junction.loc[ind_junction,'geometry'].distance(line) < params['N']['minDistJunction']:
                        points_to_cut.append(self.junction.loc[ind_junction,'geometry'])
                        zjunction.append(self.junction.loc[ind_junction,'geometry'].z)

            points_to_cut.append(Point(gdf['geometry'].loc[i].coords[0][0],gdf['geometry'].loc[i].coords[0][1]))
            points_to_cut.append(Point(gdf['geometry'].loc[i].coords[-1][0],gdf['geometry'].loc[i].coords[-1][1]))
            zjunction.append(gdf['geometry'].loc[i].coords[0][2])
            zjunction.append(gdf['geometry'].loc[i].coords[-1][2])
            line = LineString([(x, y,z) for (x, y, z) in gdf['geometry'].loc[i].coords])
            # Obtenir les coordonnées de la LineString
            coords = list(line.coords)
            # Projeter les points sur la ligne pour trouver les points les plus proches
            projected_coords = []
            for point,z in zip (points_to_cut,zjunction):
                # Calculer la position projetée le long de la ligne
                projected_distance = line.project(point)
                # Trouver les coordonnées réelles sur la ligne à cette distance projetée
                projected_point = line.interpolate(projected_distance)
                projected_coords.append((projected_point.x, projected_point.y,z))
            # Ajouter les points projetés à la liste des coordonnées de la ligne
            for projected_point in projected_coords:
                coords.append(projected_point)
            # Trier les points (coordonnées) dans l'ordre de la ligne en fonction de leur distance sur la ligne
            sorted_coords = sorted(coords, key=lambda coord: line.project(Point(coord)))
            segments = []
            current_segment = []

            for coord in sorted_coords:
                current_segment.append(coord)
                # Si le point actuel est un point projeté, terminer le segment ici
                if coord in projected_coords:
                    if len(current_segment) > 1 and current_segment[0][0] != current_segment[-1][0] \
                            and current_segment[0][1] != current_segment[-1][1]:
                        segments.append(LineString(current_segment))
                    # Démarrer un nouveau segment à partir de ce point
                    current_segment = [coord]

            for ind_reach, segment in enumerate(segments):
                gdf_temp.loc[n_reach,'River'] = ind_river
                gdf_temp.loc[n_reach,'Reach'] = ind_reach
                gdf_temp.loc[n_reach,'CdEntiteHy'] = gdf.loc[i,'CdEntiteHy']
                gdf_temp.loc[n_reach,'Classe'] = gdf.loc[i,'Classe']
                gdf_temp.loc[n_reach,'NomEntiteH'] = gdf.loc[i,'NomEntiteH']
                gdf_temp.loc[n_reach, 'geometry'] = segment
                n_reach +=1

        #modification des noms des reachs de l'aval vers l'amont
        ind_river_all = []
        for j in range(len(gdf_temp)): #cherche les index des rivières
            ind_river_all.append(gdf_temp['River'].loc[j])
        ind_river_all = np.unique(ind_river_all)
        for i in ind_river_all:
            Nreach = 0
            for j in range(len(gdf_temp)):
                if gdf_temp['River'].loc[j] == i:
                    Nreach+=1
            for j in range(len(gdf_temp)):
                if gdf_temp.loc[j,'River'] == i:
                    gdf_temp.loc[j,'Reach'] = Nreach-gdf_temp.loc[j,'Reach']-1

        #modification du nom des reachs dans la variable junction  découlant de la présence des jonctions
        for i in range(len(junction)):
            dist_end = []

            for j in range(len(gdf_temp)): #pour chaque jonction on cherche les 3 biefs les plus proches
                line1 = LineString([(gdf_temp['geometry'].loc[j].coords[jj][0], gdf_temp['geometry'].loc[j].coords[jj][1])
                                    for jj  in range(len(gdf_temp['geometry'].loc[j].coords))])
                p = junction.loc[i, 'geometry']
                dist_end.append(p.distance(line1))

            #on suppose uniquement 3 tronçons sur une jonction
            list_of_reach =[]
            list_of_river =[]
            list_of_z =[]

            for _ in range(3):
                ind_min=np.argmin(dist_end)

                z = [coord[2] for coord in gdf_temp.loc[ind_min,'geometry'].coords]
                list_of_reach.append(gdf_temp['Reach'].loc[ind_min])
                list_of_river.append(gdf_temp['River'].loc[ind_min])
                list_of_z.append(np.min(z))
                dist_end[ind_min] = np.Inf

            sorted_reach = [[i, j] for _, i, j in sorted(zip(list_of_z,list_of_river, list_of_reach))]

            if params['C']['computeGlobal']:
            #utilisation de l'aire drainée et ajout de la valeur pour chaque jonction
                neighborhood = find_pixels_in_neighborhood(self.Map['acc'],self.Map['X'],self.Map['Y'],
                                                         junction.loc[i, 'geometry'].x, junction.loc[i, 'geometry'].y,
                                                         params['C']['window_size'])
                Area = np.max(neighborhood)
            else:
                Area =0

            junction_temp.loc[i, 'geometry'] = junction.loc[i, 'geometry']
            junction_temp.loc[i, 'River1'] = sorted_reach[0][0]
            junction_temp.loc[i, 'Reach1'] = sorted_reach[0][1]
            junction_temp.loc[i, 'River2'] = sorted_reach[1][0]
            junction_temp.loc[i, 'Reach2'] = sorted_reach[1][1]
            junction_temp.loc[i, 'River3'] = sorted_reach[2][0]
            junction_temp.loc[i, 'Reach3'] = sorted_reach[2][1]
            junction_temp.loc[i, 'area'] = Area

        junction_temp.sort_values(by='area')
        self.junction = junction_temp

        #ajout des noms de rivières dans outlet
        o = self.id_outlet
        p = self.list_of_outlet[self.id_outlet][0]
        #recherche de la rivière exutoire
        dist_end = []
        for j in range(len(gdf_temp)):
            line1 = LineString([(x, y) for (x, y, z) in gdf_temp['geometry'].loc[j].coords])
            dist_line = p.distance(line1)
            if not np.isnan(dist_line):
                dist_end.append(p.distance(line1))
            else:
                dist_end.append(np.Inf)

        ind_min = np.argmin(dist_end)
        river1 = gdf_temp['River'].loc[ind_min] #indice de la rivière connecté à l'exutoire
        reach1 = gdf_temp['Reach'].loc[ind_min] #indice du tronçon connecté à l'exutoire

        #recherche de toutes les rivières connectées à la rivière exutoire
        ind_river = [river1]
        Nind_prev = 0
        Nind = len(ind_river)

        while Nind > Nind_prev: #tant que l'on trouve une nouvelle rivière connecté au précédent
            Nind_prev = len(ind_river)
            #recherche des rivières liées à cet exutoire
            for j in range(len(self.junction)):
                nriver = [self.junction.loc[j,'River1'],
                        self.junction.loc[j, 'River2'],
                        self.junction.loc[j, 'River3']
                        ]
                for t in range(3):
                    if nriver[t] in ind_river:
                       #on ajoute les nouvelles rivières et on supprime celle en double
                       for ii in range(len(nriver)):
                           ind_river.append(nriver[ii])
                       ind_river = list(np.unique(ind_river))

            Nind=len(ind_river)

        # ajout de l'aire drainée par exutoire

        # s'il ya eu calcul de l'accumulation, il y a un seul DEM d'où l'indice 0
        X = self.list_of_outlet[o][0].xy[0][0]
        Y = self.list_of_outlet[o][0].xy[1][0]

        if params['C']['computeGlobal']:
            neighborhood = find_pixels_in_neighborhood(self.Map['acc'], self.Map['X'], self.Map['Y'], X, Y, params['C']['window_size'])
            area = np.max(neighborhood)
        else:
            area = 0

        self.list_of_outlet[o][3] = area
        self.list_of_outlet[o][1] = ind_river
        self.list_of_outlet[o][2] = [river1, reach1]
        self.ordered_network = gdf_temp

    def setOutlet(self, arg, outlet=Point()):
        """ set the outlet of the considered netwok in the list of all outlets detected

        :param arg: TODO
        :type arg: TODO
        """
        if isinstance(arg,int):
            self.id_outlet = arg
            self.list_of_outlet.append([outlet, [], [], []])
            self.list_of_outlet[0][1] = [0]
            self.list_of_outlet[0][2] = [0, 0]
        elif type(arg) == Point:
            dist = []
            for i in range(len(self.list_of_outlet)):
                dist.append(arg.distance(self.list_of_outlet[i][0]))
            ind_min = np.argmin(dist)
            self.id_outlet = ind_min


    def interpolateReach(self, step):
        """ Change the distance between point of the center line of the reach. These points are those where sections
         will be plotted

        :param step: distance in meter between two floowing point on centerline of reach
        :type step: int
        """
        reach = self.reach
        ind_river = self.list_of_outlet[self.id_outlet][1]
        # loop over all model reaches

        for j in range(len(reach)):
            # Select reach linked to the current outlet
            if reach[j].geodata['River'] in ind_river:
                geom = reach[j].geodata['geometry']
                num_vert = int(max(round(geom.length / step), 1))
                if geom.length > 1:
                    line_int = LineString([geom.interpolate(float(n) / num_vert, normalized=True) for n in range(num_vert + 1)])
                    reach[j].line_int=line_int
                    reach[j].Xinterp=distance_along_line(geom,line_int)
          
                else:
                    reach[j].line_int=LineString()


    def createReach(self):
        """Create the reach variable from data in the hydrographic network

        """
        # Create ordered_network if necessary
        if self.ordered_network.empty:
            self.ordered_network = GeoDataFrame(
                {'River': 0, 'Reach': 0, 'CdEntiteHy': 0, 'Classe': 1,
                 'NomEntiteH': 'river0', 'geometry': self.HydroNetwork.loc[0,'geometry']},crs =self.crs
            )

        # Create reaches connected to outlet
        network = self.ordered_network
        ind_river = self.list_of_outlet[self.id_outlet][1]

        self.reach = []
        for _, row in network.iterrows():
            if row['River'] in ind_river:
                reach = Reach(str(row['NomEntiteH']) + "_" + str(row['Reach']))
                reach.geodata = row.copy()
                self.reach.append(reach)

        return len(self.reach)
