######################################################################
#             / ____|  ____|  __ \| |  | |  ____|  ____|             #
#            | |    | |__  | |__) | |__| | |__  | |__                #
#            | |    |  __| |  ___/|  __  |  __| |  __|               #
#            | |____| |____| |    | |  | | |____| |____              #
#             \_____|______|_|    |_|  |_|______|______|             #
######################################################################
#
# CEPHEE
# Copyright (C) 2024 Toulouse INP
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License for more details :
# <http://www.gnu.org/licenses/>.
#
######################################################################

# global
import os
import numpy as np
import matplotlib.pyplot as plt
from shapely.geometry import LineString, Point,MultiLineString,Polygon
from pyproj import Transformer
import pandas as pd
import geopandas as gpd
import os.path
import rasterio

try:
    import pygeoops as pg
    find_centerline_avail = True
except ImportError:
    find_centerline_avail = False

def find_pixels_in_neighborhood(image,DEM_X,DEM_Y, X, Y, window_size):
    """ Find point at the vicinity of point (X,Y)

    :param image: Raster of searching value
    :type image: numpy.ndarray
    :param DEM_X:  grid of X position (np.meshgrid)
    :type DEM_X: numpy.ndarray
    :param DEM_Y:  grid of Y position (np.meshgrid)
    :type DEM_Y: numpy.ndarray
    :param X: X position of the reference point
    :type X: float
    :param Y: Y position of the reference point
    :type Y: float
    :param window_size: number of pixel for the search area
    :type window_size: int
    :return: TODO
    """
    neighborhood = []
    half_window = int(window_size / 2)
  
    if DEM_X.shape[1]>1:
       x=DEM_X[0,:]                    
       y=DEM_Y[:,0] 
    else:
       x=DEM_X
       y=DEM_Y
    indx=np.argmin(abs(x-X))
    indy=np.argmin(abs(y-Y))
    if indx<len(x) and indy<len(y): #le point se trouve sur cette dalle
        if indx<0:indx=0
        if indy<0:indy=0
    
    for i in range(indx - half_window, indx + half_window + 1):
        for j in range(indy - half_window, indy + half_window + 1):
            if 0 <= j < len(image) and 0 <= i < len(image[0]):
                neighborhood.append(image[j][i])

    return neighborhood


def find_LCA_triangle(X,section):
    """compare the measured section with a double triangular section.

    :param X: parameter to optimize
    :type X: list
    :param section: section with z data
    :type section: Section Class
    :return:
        - the residual between double triangle section and real section
    """
    w1=X[0]
    d1=X[1]
    i2=X[2]
    i1= w1/2/d1
    wetPerimeter= []
    wetPerimeter2 =[]
    surfaceWidth = []
    A =[]
    z = section.hydro['depth']

#calcul pour réseau de drainage
    for hw in z:
        if hw<=d1:
            wetPerimeter.append(2 * hw*((i1**2+1))**0.5)
            wetPerimeter2.append(0)
            surfaceWidth.append(2*hw *i1)
            A.append(np.square(hw)*w1/(2*d1))
        else:
            wetPerimeter.append(2* (d1**2+w1**2)**0.5  + (hw -d1) * (1+(1/i2)**2)**0.5)
            wetPerimeter2.append(2* (hw -d1)*np.sqrt(1+1/i2**2))
            A1a=w1*d1
            A1b= (hw -d1) *w1
            A2=(np.square(hw-d1)/i2)/2
            A.append(A1a+A1b+A2)
    rmse1 = np.sqrt(((np.array(wetPerimeter)-np.array(section.hydro['W'])) ** 2).mean())
    rmse2 = np.sqrt(((np.sqrt(A)-np.sqrt(section.hydro['A'])) ** 2).mean())
    res =abs(rmse1+rmse2)
    return res


def map_model(BV):
    """2D mapping of DEM and cross section. Value of water surface or width can be vizualise as colored point.

    :param BV: Watershed
    :type BV: ModelCatchment
    """
    fig, ax = plt.subplots(figsize=(10, 10))
    BV.ordered_network.plot(ax=ax, color='black')
#    BV.junction.plot(ax=ax, color='red')
    BV.globalDEM[BV.globalDEM==np.min(BV.globalDEM)]=0
    plot_extent = [BV.DEM_stack['global_extent'][0],BV.DEM_stack['global_extent'][2],
                   BV.DEM_stack['global_extent'][1],BV.DEM_stack['global_extent'][3]]
    plt.imshow(BV.globalDEM, extent=plot_extent)
    plt.colorbar(label='Elevation (m)')
    plt.grid(which='major', axis='both', linestyle=':',color='gray')

    if len(BV.list_of_outlet):
        for ii in range(len(BV.list_of_outlet)):
            plt.scatter(BV.list_of_outlet[ii][0].xy[0],BV.list_of_outlet[ii][0].xy[1], color ='green')

    for i in range(len(BV.reach)):
         for j in range(len(BV.reach[i].section_list)):
             if len(BV.reach[i].section_list)>0:
                Xs=BV.reach[i].section_list[j].geom.xy[0]
                Ys=BV.reach[i].section_list[j].geom.xy[1]
                plt.plot(Xs,Ys,color='green')
                plt.scatter(BV.reach[i].section_list[j].start.x,BV.reach[i].section_list[j].start.y, color ='red')
                plt.scatter(BV.reach[i].section_list[j].end.x, BV.reach[i].section_list[j].end.y, color='blue')

    plt.show()       


def plot_section(reach, Nsection):
    """plot of the section with the water surface associated
    :param reach: Reach
    :type reach: Reach
    :param Nsection: number of the sectin to plot in the reach
    :type reach: int
    """

    plt.figure()
    X=reach.section_list[Nsection].coord.array['Xt']
    Y=[reach.section_list[Nsection].geom.coords[i][2] for i in range(len(reach.section_list[Nsection].geom.coords))]
    plt.plot(X,Y,color='black')

    WS=reach.section_list[Nsection].WSE
    distancebank1 = reach.section_list[Nsection].bank[0][0]
    distancebank2 = reach.section_list[Nsection].bank[0][1]
    plt.scatter(distancebank1,WS,color='red')
    plt.scatter(distancebank2,WS,color='red')
    plt.plot([distancebank1,distancebank2],[WS,WS],color='blue')
    plt.show()


def plot_Morpho(reach):
    """plot of the inline profile of elevation and water surface (WS variable of the section object) for the entire reach

    :param reach: Reach
    :type reach: Reach
    """
    X=reach.Xinterp
    Zbed=[]
    Width=[]
    Acc =[]

    for j in range(len(reach.section_list)):
        Zbed.append(reach.section_list[j].Zbed)
        distancebank1 = reach.section_list[j].bank[0][0]
        distancebank2 = reach.section_list[j].bank[0][1]
        width = abs(distancebank1-distancebank2)
        Width.append(width)
        Acc.append(reach.section_list[j].acc)
    plt.scatter(Acc,Width)


def plot_longProfile(reach):
    """plot of the inline profile of elevation and water surface (WS variable of the section object) for the entire reach

    :param reach: Reach
    :type reach: Reach
    """
    X=reach.Xinterp
    Zbed=[]
    WS=[]

    for j in range(0, len(reach.section_list)):
        Zbed.append(reach.section_list[j].Zbed)
        WS.append(reach.section_list[j].WSE)

    plt.figure()
    plt.scatter(np.arange(len(X)),X)
    plt.figure()
    plt.plot(X,Zbed,color='black')
    plt.scatter(X, WS, color='red')
    plt.show()


def projOnDEM(X, Y, DEM):
    """ Projection of an X/Y list of points using DEM.

    :param X: X coordinate of line points
    :type X: float array
    :param Y: Y coordinate of line points
    :type Y: float array
    :param DEM: Digital Elevation Model
    :type DEM: XYZ dict-grid (interpolation) or rasterio object (raster)
    :return: list of elevations for (X,Y) points
    """
    list_of_point = [(X[i],Y[i]) for i in range(len(X)) ]
    Z = [x for x in DEM.sample(list_of_point)]
    return Z


def convert_projected_to_latlon(x, y, source_proj, dest_epsg=4326):
    transformer = Transformer.from_crs(source_proj,dest_epsg)
    lon, lat = transformer.transform(x, y)
    return lat, lon


def distance_along_line(line1,line2):
    """Give the distance of point in line2 considering curvilinear distance along line1.

    :param line1: reference line
    :type line1: LineString
    :param line2: Line containing point to evaluate
    :type line2: LineString
    :return:
        - list of distance for all point in line 2 (list)
    """
    list_of_distance = [line1.project(Point(coord)) for coord in line2.coords]
    return list_of_distance


def  angle_from_line(line_int, line_temp, neighbours, ii):
    """Compute an angle between line_int and line_temp

    :param line_int : line with a new section center as point
    :type line_int: LineString
    :param line_temp : line with the original point but between 2 sections
    :type line_temp: LineString
    :param  neighbours : list of nearest point of line_int to point of line_temp
    :type neighbours: list
    :param ii : index of the nearest point of intersection between the 2 lines
    :type neighbours: int
    """

    if neighbours[ii] == len(line_temp.xy[0])-1:  # tant que le point n'est pas une extrémité
        neighbours_down = neighbours[ii]
        neighbours_up = neighbours[ii] - 1
    elif ii == 0:
        neighbours_down = neighbours[ii] + 1
        neighbours_up = neighbours[ii]
    else:
        neighbours_down = neighbours[ii] + 1
        neighbours_up = neighbours[ii] - 1
        # traitement  de la ligne vertical
    if neighbours_down ==1: #2 points sur la ligne uniquement
        neighbours_down = 1
        neighbours_up = 0
    if (line_temp.xy[0][neighbours_down] - line_temp.xy[0][neighbours_up]) == 0:
        if neighbours_down ==neighbours_up: #meme point de référence, on prend la ligne interpolée pour l'angle
            dY = (line_int.xy[1][ii] - line_int.xy[1][ii+1])
            dX = (line_int.xy[0][ii] - line_int.xy[0][ii+1])  # trouver la direction normale
            angle = np.arctan2(dY, dX)
        else:
            if line_temp.xy[1][neighbours_down] > line_temp.xy[1][neighbours_up]:
                angle = np.pi/2
            else:
                angle = -np.pi/2
    else:
        if neighbours_down ==neighbours_up: #meme point de référence, on prend la ligne interpolée pour l'angle
            dY = (line_int.xy[1][ii] - line_int.xy[1][ii+1])
            dX = (line_int.xy[0][ii] - line_int.xy[0][ii+1])  # trouver la direction normale
            angle = np.arctan2(dY, dX)
        else:
            dY = (line_temp.xy[1][neighbours_down] - line_temp.xy[1][neighbours_up])
            dX = (line_temp.xy[0][neighbours_down] - line_temp.xy[0][neighbours_up])  # trouver la direction normale
            angle = np.arctan2(dY, dX)
    return angle, neighbours_down, neighbours_up


def remove_intersection(xs_lines,angles,Xinterp,length,point_inter,max_iter):
    """Modify XS directions to avoid intersection between XS.

    :param xs_lines: list of start and end point of all XS in the reach
    :type xs_lines: list
    :param angles: angles of XS sections
    :type angles: list
    :param Xinterp:curvilinear abscissa of XS
    :type Xinterp: list
    :param side: part of river considered (overbank or main channel)
    :type side: str
    :param length: half width of xs
    :type length: float
    :param point_inter: interection with banks
    :type point_inter: list
    :return:
        - list of new xslines (list)
        - list of new angles (list)
    """
    nintersect = 1
    n_iter0 =0
    while nintersect > 0 and n_iter0 < max_iter:
        # Flag intersecting lines
        intersect_flag = np.zeros(len(xs_lines), dtype=int)
        for ix1 in range(0, len(xs_lines)):
            for ix2 in range(ix1 + 1, len(xs_lines)):
                if xs_lines[ix1].intersects(xs_lines[ix2]):
                    intersect_flag[ix1] += 1
        nintersect = np.sum(intersect_flag)
        print("Number of intersections:", nintersect)
        if nintersect == 0:
            continue
        indices = np.argwhere(intersect_flag > 0).flatten()
        seq = indices[1:] - indices[:-1]
        seq = np.insert(seq, 0, 2)
        seq_start = np.argwhere(seq > 1).flatten()
        print("Number of intersecting ranges:", len(seq_start))

        for i in range(0, len(seq_start)):
            start = indices[seq_start[i]]
            if i < len(seq_start) - 1:
                end = indices[seq_start[i + 1] - 1]
            else:
                end = indices[-1]
            print("Processing intersecting range for sections: %i-%i" % (start, end))
            nintersect2 = 1
            n_iter=0
            while nintersect2 > 0 and n_iter < max_iter:
                new_angles = angles.copy()
                xr = Xinterp[start:end + 1]
                xb = [Xinterp[start], Xinterp[(start + end) // 2], Xinterp[end]]
                if xr[-1] < xr[0]:
                    xr = xr[0] - xr
                    xb = xb[0] - xb
                angles_selected = [angles[start], angles[(start + end) // 2], angles[end]]
                new_angles[start:end] = np.interp(xr, xb, angles_selected)
                # Compute new lines
                new_lines = xs_lines.copy()
                for ix in range(start, end + 1):
                        Xstart = point_inter[ix].x - length[ix]/2 * np.cos(new_angles[ix])
                        Ystart = point_inter[ix].y - length[ix]/2 * np.sin(new_angles[ix])
                        Xend = point_inter[ix].x - length[ix] /2* np.cos(new_angles[ix]+np.pi)
                        Yend = point_inter[ix].y - length[ix]/2 * np.sin(new_angles[ix]+np.pi)
                        new_lines[ix] = LineString([(Xstart, Ystart), (Xend, Yend)])

                nintersect2 = 0
                for ix1 in range(max(0, start - 2), min(end + 3, len(xs_lines))):
                    for ix2 in range(ix1 + 1, min(end + 3, len(xs_lines))):
                        if new_lines[ix1].intersects(new_lines[ix2]):
                            nintersect2 += 1

                if nintersect2 > 0:
                    start = max(0, start - 1)
                    end = min(end + 1, len(xs_lines) - 1)
                n_iter +=1

            print("Intersecting range filtered with range[%i-%i]" % (start, end))
            xs_lines =new_lines
            angles = new_angles
            del new_lines
        n_iter0 += 1

    return xs_lines , angles


def edge_intersection(xs_lines, angles, max_iter, side, nx, point_inter):
    """Modify XS directions to avoid intersection between XS.

    :param xs_lines: list of start and end point of all XS in the reach
    :type xs_lines: list
    :param angles: angles of XS sections
    :type angles: list
    :param max_iter:curvilinear abscissa of XS
    :type max_iter: list
    :param side: part of river considered (overbank or main channel)
    :type side: str
    :param nx: TODO
    :type nx: int
    :param point_inter: intersection with banks
    :type point_inter: list
    :return:
        - list of new xslines (list)
        - list of new angles (list)
    """

    grouped_indices = group_intersecting_lines_with_indices(xs_lines)
    nintersect = 0
    for indice_line in grouped_indices:
        if len(indice_line) > 1:
            nintersect += 1
    print("Number of intersecting ranges for " + side + " overbank:", nintersect)
    count_iter = 0
    while nintersect > 0 and count_iter < max_iter:
        #print("Number of intersections:", nintersect)
        count_iter += 1
        new_lines = xs_lines.copy()
        for indice_line in grouped_indices:
            if len(indice_line) > 1:
                indice_line.sort()
                start = indice_line[0]
                end = indice_line[-1]
                Xstart, Ystart = np.mean([new_lines[i].coords[0][0] for i in range(start,end)]),\
                    np.mean([new_lines[i].coords[0][1] for i in range(start,end)]) #
                Xend, Yend =  np.mean([new_lines[i].coords[1][0] for i in range(start,end)]),\
                    np.mean([new_lines[i].coords[1][1] for i in range(start,end)]) #

                for ix in range(start, end+1 ):
                    px, py = point_inter[ix].x, point_inter[ix].y

                    if side == 'right':
                        new_lines[ix] = LineString([(Xstart, Ystart), (px, py)])
                        interp_coords = [new_lines[ix].interpolate(float(n) / nx, normalized=True).coords[0] for n in
                                         range(nx + 1)]
                         # Récupérer la deuxième coordonnée
                        new_lines[ix] = LineString([(interp_coords[1]), (px, py)])
                    elif side == 'left':
                        new_lines[ix] = LineString([(px, py), (Xend, Yend)])
                        interp_coords = [new_lines[ix].interpolate(float(n) / nx, normalized=True).coords[0] for n in
                                         range(nx + 1)]
                        new_lines[ix] = LineString([(px, py), (interp_coords[-2])])


        nintersect=0
        grouped_indices = group_intersecting_lines_with_indices(new_lines)
        for indice_line in grouped_indices:
            if len(indice_line) > 1:
                nintersect += 1
        xs_lines = new_lines
        del new_lines
        print("Intersecting range filtered  for "+side+" overbank, remaining intersections sequence:" + str( nintersect))

    return xs_lines , angles


def group_intersecting_lines_with_indices(lines):
    """Regroupe les indices des LineString qui se croisent."""
    n = len(lines)
    groups = []
    assigned = set()  # Ensemble pour suivre les lignes déjà assignées

    for i in range(n):
        if i in assigned:
            continue  # Cette ligne est déjà dans un groupe

        # Créer un nouveau groupe avec la ligne actuelle
        group = {i}

        # Vérifier les intersections avec les autres lignes
        for j in range(i + 1, n):
            if j not in assigned and any(lines[j].intersects(lines[k]) for k in group):
                group.add(j)

        groups.append(list(group))
        assigned.update(group)  # Marquer les lignes comme assignées

    return groups


def Verify_type_geometry(geometry):
    """check format of geometry for import or dig channel

    :param geometry : input geometry to check
    :type geometry: shapely geometry
    :return:
    - list of LineString (list)
    """
    if isinstance(geometry, LineString):
        return geometry
    elif isinstance(geometry, MultiLineString):
        return list(geometry.geoms)[0]
    elif isinstance(geometry, Polygon):
        return geometry.exterior
    else:
        raise TypeError('Type de géométrie non pris en charge: {}'.format(type(geometry)))


def Convert_sections(XS_shp):
    """convert an imported section if needed

    :param XS_shp : imported section
    :type XS_shp: GeoDataFrame
    :return:
    - XS in the right format (MultiLineString)
    """
    linestrings = []

    if XS_shp.columns.to_list()[0] == 'sec_id': #precourlis format
        print('PreCourlis format' )

        for _, row in XS_shp.iterrows():
            line = row['geometry']
            distance = row['abs_lat'].split(',')
            Z = row['zfond'].split(',')
            distance = [float(d) for d in distance]
            Xsection = [coord[0] for coord in line.coords]
            Ysection = [coord[1] for coord in line.coords]
            distance_section = [line.project(point) for point in line.coords]
            X= np.interp(distance, distance_section, Xsection)
            Y= np.interp(distance, distance_section, Ysection)
            Z= [float(x) if x != "" else "0" for x in Z]
            line_z = LineString([(x,y,z) for x,y,z in zip(X,Y,Z)])
            linestrings.append(line_z)

    else:

        for geometry in XS_shp['geometry']:
            linestring = Verify_type_geometry(geometry)
            linestrings.append(linestring)
    return MultiLineString(linestrings)


def export_result_as_csv(BV, param):
    """export result in ascii format

    :param BV: Watershed
    :type BV: ModelCatchment
    :param param: Parameters requires for computation
    :type param: Parameters
    """
    geo = pd.DataFrame(columns= ['River Name','Reach Name','XS','X','Y','Z','width','discharge','area','WSE','X_curv','classe'])
    river_name=[]
    reach_name=[]
    section_name=[]
    Xcoord=[]
    Ycoord=[]
    Zcoord=[]
    abs_curv = []
    Q= []
    Width = []
    Acc = []
    WSE =[]
    classe = []
    Slope = []
    config_cal = param['H']['createBanksMethods']

    if param['H']['createBanksMethods'] == 'Normal':
        config_cal = config_cal + '_Q' + str(param['H']['outletDischarge'])
    elif param['H']['createBanksMethods'] == '1D':
        config_cal = config_cal + '_Q' + str(param['H']['outletDischarge']) + '_h' + str(param['H']['hWaterOutlet'])
    elif param['H']['createBanksMethods'] == 'Himposed':
        config_cal = config_cal + '_h' + str(param['H']['himposed'])
    ir = 0
    colors = plt.cm.hsv(np.linspace(0, 1, len(BV.reach)))  # Génère une gamme de couleurs
    #plt.figure()
    for reach in BV.reach:
        if len(reach.section_list)>1:
            id_river = reach.geodata['River']
            plotZ = []
            plotX = []


            for ids, section in enumerate(reach.section_list):
                WSE.append(section.WSE)
                distances = section.coord.array['Xt']
                distancebank1 = section.bank[0][0]
                distancebank2 = section.bank[0][1]

                river_name.append(str(id_river))
                reach_name.append(reach.name)
                section_name.append(reach.id_first_section + ids)
                Xcoord.append(section.centre.x)
                Ycoord.append(section.centre.y)
                Zcoord.append(section.Zbed)
                Slope.append(section.slope)
                abs_curv.append(reach.Xupstream - reach.Xinterp[ids])
                Q.append(reach.Q)
                width = abs(distancebank1 - distancebank2)
                Width.append(width)
                Acc.append(section.acc)
                if BV.network_type == 'BDCarthage':
                    classe.append(reach.geodata['Classe'])
                else:
                    classe.append(0)
                if section.Zbed>0:
                    #plotX.append(reach.Xupstream -reach.Xinterp[ids])
                    averaged = section.computeHydraulicGeometry(section.WSE, param['H']['dxlat'],param['H']['levee'],
                                                                             param['H']['frictionLaw'],section.slope)

                    Fr = averaged['Q']*width**0.5/averaged['A']**1.5/(9.81)**0.5
                    plotX.append(Fr)
                    plotZ.append(width)

    '''
            plt.scatter(plotX, plotZ, color=colors[ir], label=reach.name)
            ir += 1

    plt.xscale('log')
    plt.yscale('log')
    plt.legend()
    #plt.xlim(1e-4,20)
    #plt.ylim(1, 100)

    plt.xlabel("Froude number")
    plt.ylabel("Width (m)")
    plt.show()
    '''

    geo['River Name'] = np.array(river_name)
    geo['Reach Name'] = np.array(reach_name)
    geo['XS'] = np.array(section_name)
    geo['X'] = np.array(Xcoord)
    geo['Y'] = np.array(Ycoord)
    geo['Z'] = np.array(Zcoord)
    geo['width'] = np.array(Width)
    geo['discharge'] = np.array(Q)
    geo['area'] = np.array(Acc)
    geo['WSE'] = np.array(WSE)
    geo['X_curv'] = np.array(abs_curv)
    geo['classe'] = np.array(classe)
    geo['slope'] = np.array(Slope)

    geo.to_csv(os.path.join(param.work_path, 'last_result_' +config_cal +'.csv'), sep=',', index=False)


def export_as_csv_for_HECRASgeometry(BV, outfile_path):
    """export xs in a HEC RAS format

    :param BV: Watershed
    :type BV: ModelCatchment
    :param outfile_path: folder in which export
    :type outfile_path: str
    """
    
    geo = pd.DataFrame(columns= ['River Name','Reach Name','XS','X','Y','Z','d'])
    river_name=[]
    reach_name=[]
    section_name=[]
    Xcoord=[]
    Ycoord=[]
    Zcoord=[]
    Dist =[]
    
    for reach in BV.reach:
        id_river = reach.geodata['River'] 
        id_reach = reach.geodata['Reach'] 
        for ids,section in enumerate(reach.section_list):
            for ixs,coord in enumerate(section.geom.coords):
                river_name.append(str(id_river))
                #reach_name.append(str(id_reach))
                reach_name.append(reach.name)
                section_name.append(reach.id_first_section +ids)
                Xcoord.append(coord[0])
                Ycoord.append(coord[1])
                Zcoord.append(coord[2])
                Dist.append(section.coord.values['B'] [ixs])
                
    geo['River Name'] = np.array(river_name)
    geo['Reach Name']= np.array(reach_name)
    geo['XS']= np.array(section_name)
    geo['X']= np.array(Xcoord)
    geo['Y']= np.array(Ycoord)
    geo['Z']= np.array(Zcoord)
    geo['d'] = np.array(Dist)

    geo.to_csv(os.path.join(outfile_path,'geo_forHECRAS.csv'),sep = ',',index=False)


def compute_distance(coords):
    """compute curvilinear distance with list of coordinate

    :param coord: list of coord (tuple (x,y,z))
    :type coord: list
    """
    distance = [0]
    for id_point in range(1, len(coords)):
        distance.append(distance[id_point - 1] + ((coords[id_point][0] - coords[id_point - 1][0]) ** 2 +
                                                  (coords[id_point][1] - coords[id_point - 1][1]) ** 2) ** 0.5)
    return distance


def X_from_outlet(BV):
    """attribute to each reach the distance from outlet

    :param BV: Watershed
    :type BV: ModelCatchment
    """
    ind_river = BV.list_of_outlet[BV.id_outlet][1] #indice des rivières du BV
    ind1 = [BV.list_of_outlet[BV.id_outlet][2]] #tronçon aval [river,reach]
    ind_river_reach = []
    reach = BV.reach

    for i in range(len(reach)):
        ind_river_reach.append([reach[i].geodata['River'],reach[i].geodata['Reach']])
    ind_reach = ind_river_reach.index(ind1[0])
    reach[ind_reach].Xupstream = np.max(reach[ind_reach].Xinterp)
    n_reach_withX = 1

    while n_reach_withX < len(reach):
        for j in range(len(BV.junction)):

            if BV.junction.loc[j, 'River1'] in ind_river:
                ind_new_reach = ind_river_reach.index([BV.junction.loc[j, 'River1'], BV.junction.loc[j, 'Reach1']])
                if reach[ind_new_reach].Xupstream != 0:
                    ind_reach_upstream1 = ind_river_reach.index(
                        [BV.junction.loc[j, 'River2'], BV.junction.loc[j, 'Reach2']])
                    ind_reach_upstream2 = ind_river_reach.index(
                        [BV.junction.loc[j, 'River3'], BV.junction.loc[j, 'Reach3']])
                    if reach[ind_reach_upstream1].Xupstream ==0:
                        reach[ind_reach_upstream1].Xupstream= reach[ind_new_reach].Xupstream + np.max(reach[ind_reach_upstream1].Xinterp)
                        n_reach_withX += 1
                    if reach[ind_reach_upstream2].Xupstream == 0:
                        reach[ind_reach_upstream2].Xupstream = reach[ind_new_reach].Xupstream + np.max(reach[ind_reach_upstream2].Xinterp)
                        n_reach_withX += 1


def Limerinos(manning_distr, waterdepth_distr):
    """Limerinos formula for bed friction

    :param manning_distr: value of friciton coefficicent (m) for each subdomain
    :type manning_distr: float
    :param waterdepth_distr: value of wayterdepth (m) for each subdomain
    :type waterdepth_distr: list
    :return:
        - friction coefficient for the section (list)
    """
    new_manning_distr = []
    for ix, ks in enumerate(manning_distr):

        if waterdepth_distr[ix]/ks >1.2:
            new_manning_distr.append(waterdepth_distr[ix]**(1/6) *0.0926 /(1.16 + 2* np.log10(waterdepth_distr[ix]/ks)))
        else:
            new_manning_distr.append(1)

    return  new_manning_distr



def printProgressBar (iteration, total, prefix = '', suffix = '', decimals = 1, length = 100, fill = '█', printEnd = "\r"):
    """
    Call in a loop to create terminal progress bar
    @params:
        iteration   - Required  : current iteration (Int)
        total       - Required  : total iterations (Int)
        prefix      - Optional  : prefix string (Str)
        suffix      - Optional  : suffix string (Str)
        decimals    - Optional  : positive number of decimals in percent complete (Int)
        length      - Optional  : character length of bar (Int)
        fill        - Optional  : bar fill character (Str)
        printEnd    - Optional  : end character (e.g. "\r", "\r\n") (Str)
    """
    percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total)))
    filledLength = int(length * iteration // total)
    bar = fill * filledLength + '-' * (length - filledLength)
    print(f'\r{prefix} |{bar}| {percent}% {suffix}', end = printEnd)
    # Print New Line on Complete
    if iteration == total:
        print()



def is_raster(filename):
    """Test if the file is a raster

    :filename: raster to test
    :type filename : str
    :return:
        - True if raster, False otherwise (bool)
    """
    try:
        with rasterio.open(filename) as src:
            return True  # Si l'ouverture réussit, c'est un raster
    except rasterio.errors.RasterioIOError:
        return False  # Si une erreur se produit, ce n'est pas un raster


def sample_points(line, num_points=10):
    """Échantillonne `num_points` points uniformément espacés sur une ligne."""
    if not isinstance(line, LineString):  # Vérifier que c'est bien une ligne
        return []

    length = line.length
    distances = np.linspace(0, length, num_points)
    points = [line.interpolate(d) for d in distances]
    return points


def average_distance(line1, line2, num_samples=10):
    """Calcule la distance moyenne entre deux lignes en échantillonnant des points."""
    points1 = sample_points(line1, num_samples)
    points2 = sample_points(line2, num_samples)

    distances = [min(p1.distance(p2) for p2 in points2) for p1 in points1]
    return np.mean(distances)


def sort_lines_and_centerline(multishape, step_long):
    """associated a couple of bank or foot lines and the corresponding centerline.

    :param multishape: Ensemble of bank or foot lines
    :type multishape: gpd.GeoDataFrame
    :param step_long: Longitudinal distance between cross section
    :type step_long: float
    :return:
        - centerlines : list of linestring
        - outline : list of linestring
    """

    outline = []
    centerlines = []
    nlines = 0
    center = gpd.GeoDataFrame(geometry=[])
    if type(multishape.loc[0, 'geometry']) == Polygon:
        if find_centerline_avail: #find lines with pygeoops


            for i, geom1 in enumerate(multishape.geometry):
                geo = pg.simplify(geom1, tolerance=100, algorithm='vw')
                cline = pg.centerline(geo, densify_distance=-0.25, min_branch_length=-5, simplifytolerance=0)

                if type(cline) == MultiLineString:
                    for idl, line in enumerate(cline.geoms):
                        num_vert = int(max(round(line.length / step_long), 1))
                        line_int = LineString(
                            [line.interpolate(float(n) / num_vert, normalized=True) for n in range(num_vert + 1)])
                        center.loc[nlines, 'geometry'] = line_int
                        outline.append(geom1.exterior)
                        nlines += 1
                elif type(cline) == LineString:
                    num_vert = int(max(round(cline.length / step_long), 1))
                    line_int = LineString(
                        [cline.interpolate(float(n) / num_vert, normalized=True) for n in range(num_vert + 1)])

                    center.loc[nlines, 'geometry'] = line_int
                    outline.append(geom1.exterior)
                    nlines += 1

            for i, row in center.iterrows():
                centerlines.append(row['geometry'])

        else:

            print("module pygeoops not available")
            # Comparer chaque ligne avec toutes les autres et stocker la distance moyenne

    else:


        print("centerline found from banks")
        # Comparer chaque ligne avec toutes les autres et stocker la distance moyenne
        distances = []

        for i, geom1 in enumerate(multishape.geometry):
            if geom1:
                min_avg_dist = float("inf")
                closest_index = None
                for j, geom2 in enumerate(multishape.geometry):
                    if i != j:  # Éviter la comparaison avec soi-même
                        if geom2:
                            avg_dist = average_distance(geom1, geom2)
                            if avg_dist < min_avg_dist:
                                min_avg_dist = avg_dist
                                closest_index = j
            distances.append((i, closest_index, min_avg_dist))


        # Trier les résultats par distance moyenne croissante
        distances.sort(key=lambda x: x[2])

        for i in range(0,len(distances),2):
            line1 = multishape.loc[distances[i][0], 'geometry']
            line2 = multishape.loc[distances[i][1], 'geometry']
            coord1 = list(line1.coords)
            coord2 = list(line2.coords)
            coord_center = []
            for c1, c2 in zip(coord1, coord2):
                coord_center.append(Point((c1[0] + c2[0]) / 2, (c1[1] + c2[1]) / 2))

            line_int = LineString(coord_center)
            center.loc[nlines, 'geometry'] = line_int
            coord2 = coord2[::-1]
            coords = coord1 + coord2 + [coord1[0]]
            contour = LineString(coords)
            outline.append(contour)
            print(outline)
            nlines += 1

        for i, row in center.iterrows():
            centerlines.append(row['geometry'])

    return centerlines, outline


def convert_geometry_to_multilinestring(multishape):
    outline = []
    if type(multishape.loc[0, 'geometry']) == Polygon:
        for i, row in multishape.iterrows():
            outline.append(row['geometry'].exterior)
    elif type(multishape.loc[0, 'geometry']) == LineString:
        for i, row in multishape.iterrows():
            outline.append(row['geometry'])
    elif type(multishape.loc[0, 'geometry']) == MultiLineString:
        for i, row in multishape.iterrows():
            for j in range(len(row['geometry'].geoms)):
                outline.append(row['geometry'].geoms[j])
    return MultiLineString(outline)


def polygonize_water_mask(filename_mask,path,crs,water_value):
    """Transform water mask from raster to polygone

    :param filename_mask: file with mask data
    :type filename_mask: string
    :param path:
    :type path: str
    :param
    :type N_reach: int

    """


    mask = None
    with rasterio.Env():
        with rasterio.open(filename_mask) as src:
            image = src.read(1)  # first band
            image = image.astype('uint8')
            results = (
                {'properties': {'raster_val': v}, 'geometry': s}
                for i, (s, v) in enumerate(rasterio.features.shapes(image, mask=mask, transform=src.transform)) if v == water_value)

    geoms = list(results)
    gpd_polygonized_raster = gpd.GeoDataFrame.from_features(geoms)
    gpd_polygonized_raster['area'] = gpd_polygonized_raster['geometry'].area
    gdf_sort = gpd_polygonized_raster.sort_values(by='area',ascending=False)
    #gdf_sort.set_crs(crs)
    gdf_filtered = gpd.GeoDataFrame(geometry=[], crs=crs)
    nline = 0
    for i, row in gdf_sort.iterrows():
        if not row['geometry'].is_empty and i > 0:
            if row['area']>1e2: #aire>10ha
                gdf_filtered.loc[nline, 'geometry'] = row['geometry']
                nline += 1
    gdf_filtered.to_file(os.path.join(path, 'polygone_found.shp'))
    return os.path.join(path, 'polygone_found.shp')
