######################################################################
#             / ____|  ____|  __ \| |  | |  ____|  ____|             #
#            | |    | |__  | |__) | |__| | |__  | |__                #
#            | |    |  __| |  ___/|  __  |  __| |  __|               #
#            | |____| |____| |    | |  | | |____| |____              #
#             \_____|______|_|    |_|  |_|______|______|             #
######################################################################
#
# CEPHEE
# Copyright (C) 2024 Toulouse INP
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License for more details :
# <http://www.gnu.org/licenses/>.
#
######################################################################

# global
from shapely.ops import nearest_points

# local
from core.Tools import *
from core.CrossSection import *


def addQtoSection(BV, param):
    """ Add discharge at the network linked to the chosen outlet. Discharge may be computed considering drained
    area ratio.

    :param BV: Watershed
    :type BV: ModelCatchment
    :param param: Parameters required for computation
    :type param: Parameters
    """
    # Select river
    reach = BV.reach
    Q = param['H']['outletDischarge']
    method = param['H']['dischargeMethod']
    crs = param['C']['DEM_CRS_ID']#
    window_size = param['C']['window_size']
    ind_river = BV.list_of_outlet[BV.id_outlet][1] #indice des rivières du BV
    ind1 = [BV.list_of_outlet[BV.id_outlet][2]] #tronçon aval [river,reach]
    ind_river_reach = []

    for i in range(len(reach)):
        ind_river_reach.append([reach[i].geodata['River'],reach[i].geodata['Reach']])

    for i in range(len(reach)):
        if reach[i].geodata['River'] in ind_river:
            for j in range(len(reach[i].section_list)):
                section = reach[i].section_list[j]
                if param['C']['computeGlobal']:
                    neighborhood = find_pixels_in_neighborhood(BV.Map['acc'], BV.Map['X'], BV.Map['Y'],
                                                            section.centre.coords[0][0],
                                                            section.centre.coords[0][1],
                                                            window_size)
                    section.acc = np.max(neighborhood)
                else:
                    section.acc = 0

    if method == 'Uniform':
        # Attribute user value to all reaches/sections
        for reach1 in BV.reach:
            reach1.Q = Q

    elif method == 'Uniform per reach using discharge shapefile':
        ind_reach = ind_river_reach.index(ind1[0])
        if os.path.exists(os.path.join(param.work_path,'discharge.shp')):
            discharge_df = gpd.read_file(os.path.join(param.work_path, 'discharge.shp'))
            if len(discharge_df) == len(reach):
                    for index, row in discharge_df.iterrows():
                        if float(row['Q imposed'])>=0:
                            if index == ind_reach:
                                reach[index].Q= float(row['Q imposed'])
                        else:
                            reach[index].Q =float(row['Q computed'])
                    discharge_df.to_file(os.path.join(param.work_path, 'discharge.shp'))
            else:
                BV.display('Discharge data are not consistent with hydro network', 2)
                return
        else:
            BV.display('discharge.shp file not present', 2)
            return

    elif method == 'Uniform per reach':
        if type(Q) is list:
            if len(Q) == len(reach):
                for i in range(len(reach)):
                    reach[i].Q = Q[i]
            else:
                BV.display('Number of discharge value and reaches are not equal, using the uniform max value', 1)
                Q = max(Q)
                for i in range(len(reach)):
                    reach[i].Q = Q
        else:
            BV.display('Number of discharge value and reaches are not equal, using the used value', 1)
            for i in range(len(reach)):
                reach[i].Q = Q

    elif method == 'Uniform per reach using accumulated area':
        ind_reach = ind_river_reach.index(ind1[0])
        reach[ind_reach].Q = Q  # set last reach flow rate using user value

        neighborhood = find_pixels_in_neighborhood(BV.Map['acc'], BV.Map['X'], BV.Map['Y'],
                                            reach[ind_reach].section_list[-1].centre.coords[0][0],
                                            reach[ind_reach].section_list[-1].centre.coords[0][1],
                                            window_size)

        reach[ind_reach].area = np.max(neighborhood)
        n_reach_withQ = 1

        while n_reach_withQ < len(reach):
            for j in range(len(BV.junction)):

                if BV.junction.loc[j, 'River1'] in ind_river:
                    ind_new_reach = ind_river_reach.index( [BV.junction.loc[j,'River1'],BV.junction.loc[j,'Reach1']])
                    if reach[ind_new_reach].Q != 0:
                        ind_reach_upstream1 = ind_river_reach.index([BV.junction.loc[j,'River2'],BV.junction.loc[j,'Reach2']])
                        ind_reach_upstream2 = ind_river_reach.index([BV.junction.loc[j, 'River3'], BV.junction.loc[j,'Reach3']])

                        ind_mid = int(len(reach[ind_reach_upstream1].geodata['geometry'].coords)/2)
                        neighborhood = find_pixels_in_neighborhood(BV.Map['acc'], BV.Map['X'], BV.Map['Y'],
                                                                reach[ind_reach_upstream1].geodata['geometry'].coords[ind_mid][0],
                                                                reach[ind_reach_upstream1].geodata['geometry'].coords[ind_mid][1],
                                                                window_size)

                        reach[ind_reach_upstream1].area = np.max(neighborhood)

                        ind_mid = int(len(reach[ind_reach_upstream2].geodata['geometry'].coords) / 2)
                        window_size = param['C']['window_size']
                        neighborhood = find_pixels_in_neighborhood(BV.Map['acc'], BV.Map['X'], BV.Map['Y'],
                                                            reach[ind_reach_upstream2].geodata[
                                                                'geometry'].coords[ind_mid][0],
                                                            reach[ind_reach_upstream2].geodata[
                                                                'geometry'].coords[ind_mid][1],
                                                            window_size)

                        reach[ind_reach_upstream2].area = np.max(neighborhood)

                        if reach[ind_reach_upstream1].Q == 0:
                            reach[ind_reach_upstream1].Q = reach[ind_reach_upstream1].area/(reach[ind_reach_upstream1].area+
                                                                                        reach[ind_reach_upstream2].area)*reach[ind_new_reach ].Q
                            n_reach_withQ += 1

                        if reach[ind_reach_upstream2].Q == 0:
                            reach[ind_reach_upstream2].Q = reach[ind_reach_upstream2].area /(reach[ind_reach_upstream1].area+
                                                                                        reach[ind_reach_upstream2].area) * reach[ind_new_reach].Q
                            n_reach_withQ += 1

    elif method == 'Accumulated area':
        # Compute from outlet discharge and accumulated area
        for i in range(len(reach)):
            for j in range(len(reach[i].section_list)):
                section = reach[i].section_list[j]
                neighborhood = find_pixels_in_neighborhood(BV.Map['acc'], BV.Map['X'], BV.Map['Y'],
                                                           section.centre.coords[0][0],
                                                           section.centre.coords[0][1], window_size)
                section.acc =  np.max(neighborhood)
                section.Q = section.acc/BV.list_of_outlet[BV.id_outlet][3] * Q
            reach[i].Q = reach[i].section_list[-1].Q


    # Transferring Q-reach to section for uniform methods
    if method[0:7] == "Uniform":    # Q-reach => section
        for reach1 in BV.reach:
            for section in reach1.section_list:
                section.Q = reach1.Q


    if param['H']['write_discharge']:
        BV.discharge = GeoDataFrame(columns=['idReach', 'River', 'Reach', 'Q computed', 'Q imposed', 'geometry'],
                                    crs=crs)
        count = 0
        for reach1 in reach:
            BV.discharge.loc[count, 'idReach'] = count
            BV.discharge.loc[count, 'River'] = reach1.geodata['River']
            BV.discharge.loc[count, 'Reach'] = reach1.geodata['Reach']
            BV.discharge.loc[count, 'Q computed'] = reach1.Q
            BV.discharge.loc[count, 'Q imposed'] = -1
            BV.discharge.loc[count, 'geometry'] = reach1.geodata['geometry']
            count +=1

        BV.discharge.to_file(os.path.join(param.work_path,'discharge.shp'))


def addWidthtoSection(BV, param):
    """ Compute for all sections the curve of wetted area, a width as a function of water depth.

    :param BV: Watershed
    :type BV: ModelCatchment
    :param param: Parameters required for computation
    :type param: Parameters
    """
    reach = BV.reach
    ind_river = BV.list_of_outlet[BV.id_outlet][1]
    dz = param['H']['dz']
    method = param['H']['levee'] #méthode des digues ou non
    dxlat = param['H']['dxlat']
    for i in range(len(reach)):
        if reach[i].geodata['River'] in ind_river:
            for j in range(len(reach[i].section_list)):

                section = reach[i].section_list[j]
                section.hydro = {'depth': [], 'A': [], 'W': [], 'Q': [], 'QM':[],
                                 'equivalent': {'w1': 0, 'd1': 0, 'i2': 0}}
                z = 0.01
                Qtemp =0
                while Qtemp <= 5 * reach[i].Q:

                    z =z +dz
                    dbank = section.findBankPoint(section.Zbed+z,method)
                    if len(dbank)>1:
                        averagedValue  =section.computeHydraulicGeometry(section.Zbed+z, dxlat,method)
                        Qtemp = averagedValue['Q']
                        section.hydro['A'].append(averagedValue['A'])
                        section.hydro['depth'].append(z)
                        section.hydro['W'].append(dbank[1]-dbank[0])
                        section.hydro['Q'].append(averagedValue['Q'])
                        section.hydro['QM'].append(reach[i].Q)
                    else:
                        print('No width at section ' +str(j)+ 'of reach ' + reach[i].name)


    '''
    try :
        with pd.ExcelWriter(os.path.join(param.work_path,"hydraulic_curve.xlsx")) as writer:
            # save result as Excel file
            for reach1 in BV.reach:
                for section in reach1.section:
                    df = pd.DataFrame(columns=['depth','Area','Width','Discharge'])
                    for i in range(len(section.hydro['depth'])):
                        df.loc[i, 'depth'] = section.hydro['depth'][i]
                        df.loc[i, 'Area'] = section.hydro['A'][i]
                        df.loc[i, 'Width'] = section.hydro['W'][i]
                        df.loc[i, 'Discharge'] = section.hydro['Q'][i]
                    # use to_excel function and specify the sheet_name and index
                    # to store the dataframe in specified sheet
                    df.to_excel(writer, sheet_name=section.name, index=False)
    except :
    '''
    df = pd.DataFrame(columns=['reach','section','depth', 'Area', 'Width','Discharge','MeanDischarge','MeanWidth'])
    count=0

    ir = 0
    colors = plt.cm.hsv(np.linspace(0, 1, len(BV.reach)))  # Génère une gamme de couleurs
    plt.figure()
    for reach1 in BV.reach:
        plotX = []
        plotZ = []
        for ids,section in enumerate(reach1.section_list):
            if len(section.hydro['Q'])>0:
                WM = np.interp(section.Q, section.hydro['Q'], section.hydro['W'])
                for i in range(len(section.hydro['depth'])):
                    df.loc[count, 'reach'] = reach1.name
                    df.loc[count, 'section'] = 'XS_' + str(reach1.id_first_section + ids)
                    df.loc[count, 'depth'] = section.hydro['depth'][i]
                    df.loc[count, 'Area'] = section.hydro['A'][i]
                    df.loc[count, 'Width'] = section.hydro['W'][i]
                    df.loc[count, 'Discharge'] = section.hydro['Q'][i]
                    df.loc[count, 'MeanDischarge'] = section.hydro['QM'][i]
                    df.loc[count, 'MeanWidth'] = WM

                    count +=1
                    
                    plotX.append(section.hydro['Q'][i] / section.hydro['QM'][i])
                    plotZ.append(section.hydro['W'][i] / WM)

        plt.scatter(plotX,plotZ,facecolors='none', edgecolors=colors[ir], label = reach1.name)
        ir+=1
    df.to_csv(os.path.join(param.work_path,"hydraulic_curve.csv"))


def computeNormalAndCriticalDepth(BV, param):
    """Compute the normal and critical depth at each section as a function of the variable Q of each section.
     Only one channel is possible presently.

    :param BV: Watershed
    :type BV: ModelCatchment
    :param param: Parameters requires for computation
    :type param: Parameters
    """
    for reach in BV.reach:
        reach.compute_normal_and_critical_depth(param['H'], param.verbose)


def backWaterProfile(BV,param):
    """Computation of the backwater profile. If the regime is supercritical, the water depth is limited to the
    critical one. Available only for one reach presently.

    :param BV: Watershed
    :type BV: ModelCatchment
    :param param: Parameters required for computation
    :type param: Parameters
    """
    reach = BV.reach
    ind1 = [BV.list_of_outlet[BV.id_outlet][2]]  # tronçon aval [river,reach]
    ind_river = [ind for ind in  BV.list_of_outlet[BV.id_outlet][1]]  # tronçon aval [river,reach]
    ind_river_reach = []

    #find inline_structure data
    if os.path.exists(os.path.join(param.work_path,'inline_structure.shp')):
        inline_df = gpd.read_file(os.path.join(param.work_path,'inline_structure.shp'))
        list_of_centre =[]
        list_ind_reach =[]
        list_ind_section = []
        for idr,reach1 in enumerate(BV.reach):
            for ids,section in enumerate(reach1.section_list):
                list_of_centre.append(section.centre)
                list_ind_reach.append(idr)
                list_ind_section.append(ids)

        for index, row in inline_df.iterrows():
            point_inline = row['geometry']
            ids = min(range(len(list_of_centre)), key=lambda i: list_of_centre[i].distance(point_inline))
            BV.reach[list_ind_reach[ids]].section_list[list_ind_section[ids]].label = 'inline_structure'
            BV.reach[list_ind_reach[ids]].section_list[list_ind_section[ids]].inlinedata = [float(row['Cd']),
                                                                                       float(row['L']),
                                                                                       float(row['Zs']),
                                                                                       row['structure']]

    header = ['ID', 'idSection', 'Q', 'WSE', 'CritDepth', 'bank', 'distance', 'h', 'A', 'P', 'Rh', 'V', 'Sf']
    for i in range(len(reach)):
        ind_river_reach.append([reach[i].geodata['River'], reach[i].geodata['Reach']])

    ind_reach_down = ind_river_reach.index(ind1[0])
    if param.verbose:
        print('1D computation for reach' + str(ind1[0]))


    WSE_out = param['H']['hWaterOutlet'] + reach[ind_reach_down].section_list[-1].Zbed
    reach[ind_reach_down].compute_back_water(WSE_out, param['H'], param.verbose)
    reach[ind_reach_down].WSE = reach[ind_reach_down].section_list[0].WSE
    reach[ind_reach_down].flag1D=1
    n_reach_with_comp = 1

    while n_reach_with_comp < len(reach):
        for j in range(len(BV.junction)):
            if BV.junction.loc[j, 'River1'] in ind_river and BV.junction.loc[j, 'River2'] in ind_river\
                    and BV.junction.loc[j, 'River3'] in ind_river:
                ind_new_reach = ind_river_reach.index([BV.junction.loc[j, 'River1'], BV.junction.loc[j, 'Reach1']])
                if reach[ind_new_reach].flag1D != 0:
                    h_junction = reach[ind_new_reach].WSE
                    ind_reach_upstream1 = ind_river_reach.index(
                        [BV.junction.loc[j, 'River2'], BV.junction.loc[j, 'Reach2']])
                    ind_reach_upstream2 = ind_river_reach.index(
                        [BV.junction.loc[j, 'River3'], BV.junction.loc[j, 'Reach3']])
                    if reach[ind_reach_upstream1].flag1D == 0:
                        if param.verbose:
                            print('1D computation for reach ' + str(BV.junction.loc[j, 'River2']) + ', ' + str(
                            BV.junction.loc[j, 'Reach2']))
                        if len(reach[ind_reach_upstream1].section_list)>1:
                            reach[ind_reach_upstream1].compute_back_water(h_junction, param['H'], param.verbose)
                            reach[ind_reach_upstream1].WSE = reach[ind_reach_upstream1].section_list[0].WSE
                        else:
                            reach[ind_reach_upstream1].WSE = h_junction

                        reach[ind_reach_upstream1].flag1D = 1
                        n_reach_with_comp += 1

                    if reach[ind_reach_upstream2].flag1D == 0:
                        if param.verbose:
                            print('1D computation for reach ' + str(BV.junction.loc[j, 'River3']) +
                          ', ' + str(BV.junction.loc[j, 'Reach3']))
                        if len(reach[ind_reach_upstream2].section_list)>1:
                            reach[ind_reach_upstream2].compute_back_water(h_junction, param['H'], param.verbose)
                            reach[ind_reach_upstream2].WSE = reach[ind_reach_upstream2].section_list[0].WSE
                        else:
                            reach[ind_reach_upstream2].WSE = h_junction

                        reach[ind_reach_upstream2].flag1D = 1
                        n_reach_with_comp += 1


def impose_water_depth(BV, param):
    """Computation of hydraulic parameter for an imposed water depth at each section. Water balance is not verified.
    Available only for one reach presently.

    :param BV: Watershed
    :type BV: ModelCatchment
    :param param: Parameters required for computation
    :type param: Parameters
    """
    for reach in BV.reach:
        reach.compute_imposed_water_depth(param['H'], param.verbose)


def impose_water_elevation(BV, param):
    """Computation of hydraulic parameter for an imposed water elevation at each section. Water balance is not verified.
    Available only for one reach presently.

    :param BV: Watershed
    :type BV: ModelCatchment
    :param param: Parameters required for computation
    :type param: Parameters
    """
    for reach in BV.reach:
        reach.compute_imposed_water_elevation(param['H'], param.verbose)


def projOnFrictionMap(BV, poly_filename):
    """Provide the friction coefficient to each point af all sections according a map value.

    :param BV: Watershed
    :type BV: ModelCatchment
    :param poly_filename: name of the shapefile containing the friction area (multipolygone type)
    :type poly_filename: str
    """
    poly_shp = gpd.read_file(poly_filename)
    list_header = poly_shp.columns.to_list()
    ManningName = list_header[1]
    for i, row in poly_shp.iterrows():
        for reach1 in BV.reach:
            for section in reach1.section_list:
                list_of_point = [Point(coord) for coord in section.geom.coords]
                for ip,pt in enumerate(list_of_point):
                    if row['geometry'].contains(pt):
                        manning_selected =row[ManningName]
                        if section.coord.values['W'][ip]>manning_selected:
                            section.coord.values['W'][ip] = manning_selected


def setConstantFriction(BV, manning_value):
    """Provide the friction coefficient to each point af all sections according a constant value.

    :param BV: Watershed
    :type BV: ModelCatchment
    :param manning_value: constant friction value
    :type manning_value: float
    """
    for reach1 in BV.reach:
        for section in reach1.section_list:
            for ip, coord in enumerate(section.coord.values['W']):
                section.coord.values['W'][ip] = manning_value


def createAllBankLines(BV):
    """Create a geopanda dataframe for all reaches with banks as geometry.

    :param BV: Watershed
    :type BV: ModelCatchment
    """
    for reach1 in BV.reach:
        if len(reach1.section_list)>1:
            reach1.create_bank_lines()


def save_bank_lines(BV, filepath, config_cal):
    """save the geopanda dataframe for all reaches with banks as geometry.

    :param BV: Watershed
    :type BV: ModelCatchment
    :param filepath: path of the repertory for the shapefile of all bank lines
    :type filepath: str
    """
    gdf_tosave = GeoDataFrame(columns=['geometry'],crs=BV.crs)
    count = 0
    BV.outline =[]
    for i,reach1 in enumerate(BV.reach):
        if len(reach1.section_list) >1:
            coord1 = list(reach1.left_bank_line.coords)
            coord2 = list(reach1.right_bank_line.coords)
            coord2 = coord2[::-1]
            coords = coord1 + coord2 + [coord1[0]]
            contour = LineString(coords)
            BV.outline.append(contour)
            idRiver, idReach = reach1.geodata['River'], reach1.geodata['Reach']
            gdf_tosave.loc[count,'geometry'] = reach1.left_bank_line
            gdf_tosave.loc[count,'Name'] = 'LeftBank_river' +str(idRiver) +'_reach' + str(idReach)
            count +=1
            gdf_tosave.loc[count,'geometry'] = reach1.right_bank_line
            gdf_tosave.loc[count,'Name'] = 'RightBank_river' +str(idRiver) +'_reach' + str(idReach)
            count +=1

    gdf_tosave.to_file(os.path.join(filepath,'banks_lines_'+config_cal+'.shp'))


def detect_inline_structure(BV, param):
    """Find the section with a slope superior to a threshold (inline structure)

    :param BV: Watershed
    :type BV: ModelCatchment
    :param param: Parameters requires for computation
    :type param: Parameters
    """
    crs = BV.crs.to_string()
    BV.inline_structure = GeoDataFrame(columns=['Cd','L','Zs','structure','geometry'],crs = crs)
    count =0
    if param['H']['createBanksMethods'] in ['Himposed','WSE']:

        for reach in BV.reach:
            for ids,section_u in enumerate(reach.section_list[:-1]):
                section_u = reach.section_list[ids]
                WS_d =  reach.section_list[ids+1].WSE
                WS_u = reach.section_list[ids].WSE
                dL = reach.Xinterp[ids+1]-reach.Xinterp[ids]
                slope = abs((WS_u-WS_d)/dL)
                if slope >  param['H']['slope_structure']:
                    BV.inline_structure.loc[count,'Cd']=0.4
                    width =  abs(section_u.bank[0][0] - section_u.bank[0][1])
                    BV.inline_structure.loc[count, 'L'] = width
                    BV.inline_structure.loc[count, 'Zs'] =WS_u
                    BV.inline_structure.loc[count, 'structure'] = 'weir'
                    BV.inline_structure.loc[count, 'geometry'] = section_u.centre
                    section_u.label = 'inline'
                    count +=1

        BV.inline_structure.to_file(os.path.join(param.work_path,'inline_structure.shp'))


def find_bank_from_poly(BV, file_multishape):
    """Find position of bank by intersection between XS and a bank line. Compute hydraulic geometry of the section
    considering intersection as observed water surface elevation

    :param BV: Watershed
    :type BV: ModelCatchment
    :param file_multishape: bank line obtained from computation or user defined
    :type file_multishape: str or shapely geometry
    """
    if type(file_multishape) == str:
        multishape = gpd.read_file(file_multishape)
    elif type(file_multishape) == MultiLineString:
        multishape = gpd.GeoDataFrame(geometry=[],crs = BV.crs)
        for idl, line in enumerate(file_multishape.geoms):
            multishape.loc[idl,'geometry'] = line
    elif type(file_multishape) == LineString:
        multishape = gpd.GeoDataFrame(geometry=[],crs = BV.crs)
        multishape.loc[0,'geometry'] = file_multishape

    for reach in BV.reach:
        for ids,section in enumerate(reach.section_list):
            list_inter =[]
            Z = section.coord.values['B']

            for i, shape in multishape.iterrows():
                if type(shape['geometry']) == LineString:
                   line_to_intersect = shape['geometry']
                elif type(shape['geometry']) ==Polygon:
                    line_to_intersect = LineString(shape['geometry'].exterior)

                inter = line_to_intersect.intersection(section.geom)

                if type(inter) == MultiPoint:
                    list_inter += [geom for geom in inter.geoms]
                elif type(inter) == Point:
                    list_inter += [inter]

            dist_centre_start = section.start.distance(section.centre)
            dist_centre_end = section.end.distance(section.centre)
            dist_point_to_start =[]
            dbank1 =[-dist_centre_start]
            dbank2 = [dist_centre_end]
            list_point1 =[section.start]
            list_point2 = [section.end]
            for point in list_inter:
                dist_point_to_centre= point.distance(section.centre)
                dist_point_to_start.append(point.distance(section.start))
                if point.distance(section.start)<dist_centre_start :#rive gauche
                    dbank1.append(-1*dist_point_to_centre)
                    list_point1.append(point)
                else:
                    dbank2.append(dist_point_to_centre)
                    list_point2.append(point)

            ind1 = np.argmax(dbank1)
            ind2 = np.argmin(dbank2)
            point_bank1=list_point1[ind1]
            point_bank2=list_point2[ind2]

            Z1 =np.interp(dbank1[ind1],section.coord.array['Xt'],Z)
            Z2 = np.interp(dbank2[ind2], section.coord.array['Xt'], Z)
            WS = np.min([Z1,Z2])
            section.Zbed = WS
            point_bank1 =Point(point_bank1.x,point_bank1.y,WS)
            point_bank2 = Point(point_bank2.x, point_bank2.y, WS)

            section.bank =[[dbank1[ind1],dbank2[ind2]],  [point_bank1 ,point_bank2]]


def mapForMarine(BV, param, shape):
    """create the map of drainage network parameter from DEM for hydrological computation in MARINE software

    :param BV: Watershed
    :type BV: ModelCatchment
    :param param: Parameters required for computation
    :type param: Parameters
    :param shape: ...
    :type shape: ...
    """
    list_w1, list_d1,list_i2, x_pos_obj, y_pos_obj  = [], [], [], [], []
    for i in range(len(BV.reach)):
        for j in range(0,len(BV.reach[i].section_list)):
            [w1,d1,i2] = BV.reach[i].section_list[j].equivalentChannel(shape)
            list_w1.append(w1)
            list_d1.append(d1)
            list_i2.append(i2)
            x_pos_obj.append(BV.reach[i].section_list[j].centre.x)
            y_pos_obj.append(BV.reach[i].section_list[j].centre.y)

    x = BV.Map['X'][0,:]
    y = BV.Map['Y'][:,0]
    resolution = param['C']['resolution']
    W1 = np.zeros(BV.globalDEM.shape)
    D1 = np.zeros(BV.globalDEM.shape)
    I2 = np.zeros(BV.globalDEM.shape)

    for idx,x_pos in enumerate(x):
        for idy,y_pos in enumerate(y):
            index = np.where((abs(x_pos_obj - x_pos) <= resolution/2) & (abs(y_pos_obj - y_pos) <= resolution/2))
            W1[idy,idx] = np.nanmean(np.array(list_w1)[index])
            D1[idy,idx] = np.nanmean(np.array(list_d1)[index])
            I2[idy,idx] = np.nanmean(np.array(list_i2)[index])

    BV.Map['w1'] = W1
    BV.Map['d1'] = D1
    BV.Map['i2'] = I2


def find_WaterdepthFromQ(Zbed, section, paramH):
    """ Solve Manning equation for the input section

    :param Zbed: lower point of the section used in reach bed
    :type Zbed: float
    :param section: considered section
    :type section: Section
    :param paramH: Parameters requires for hydraulics computation
    :type paramH: dict
    :return:
      - water surface elevation (float)
    """

    WSsup = Zbed + paramH['hsup']
    WSinf = Zbed + paramH['hinf']
    WS = (WSsup + WSinf) / 2
    ecart = WSsup - WSinf
    iter_i =0

    while ecart > paramH['eps'] and iter_i < paramH['MaxIter']:
        iter_i += 1
        # calcul du milieu
        WS = (WSsup + WSinf) / 2
        averaged = section.computeHydraulicGeometry(
            WS, paramH['dxlat'], paramH['levee'],paramH['frictionLaw'])

        Q1 = averaged['Q']

        if Q1 > section.Q:
            # la solution est inférieure à m
            WSsup = WS
        else:
            # la solution est supérieure à m
            WSinf = WS
        ecart = WSsup - WSinf

    return WS


def find_CriticaldepthFromQ(Zbed, section, paramH):
    """ Solve Froude = 1

    :param Zbed: the lowest point of the section used in reach bed
    :type Zbed: float
    :param section: the considered section
    :type section: Section
    :param paramH: Parameters requires for hydraulics computation
    :type paramH: dict
    :return:
        - critical elevation (float)
    """
    hsup = Zbed + paramH['hsup']
    hinf = Zbed + paramH['hinf']
    h = (hsup + hinf) / 2
    iter_i = 0
    ecart = hsup -hinf
    Fr=0

    while ecart > paramH['eps'] and iter_i < paramH['MaxIter']:
        iter_i += 1
        # calcul du milieu
        h = (hsup + hinf) / 2
        averagedValue = section.computeHydraulicGeometry(h, paramH['dxlat'] , paramH['levee'])
        section.bank = averagedValue['dbank']

        if averagedValue['A']>0:
            Fr = section.Q ** 2 * (section.bank[0][1] - section.bank[0][0]) / 9.81 / averagedValue['A'] ** 3

        if Fr < 1:
            # la solution est inférieure à m
            hsup = h
        else:
            # la solution est supérieure à m
            hinf = h
        ecart = hsup - hinf
    return h


def charge(Q, A, WS):
    """ compute head (m)

    :param Q: objective discharge
    :param A: wetted area
    :param WS: water elevation
    :return:
        - head (float)
    """
    H = WS + Q ** 2 / 2 / 9.81 / A ** 2
    return H


def find_hsup(Zbed, section1, WS2, section2, dx, paramH):
    """solve head loss equation between 2 sections , integral formulation

    :WS1: water elevation at upstream section (m)
    :A2,Rh2,WS2,: hydraulic parameter at downstream section (m2,m,m)
    :line: upstream crossecion line
    :centre : upstream crossection centre location
    :Q: discharge at the upstream section (m3/s)
    :dx: distance between the 2 sections (m)
    :return:
        - water elevation (float)
        - friction slope  (float)
    """
    WSsup1 = Zbed + paramH['hsup']
    WSinf1 = Zbed + paramH['hinf']

    WSsup = min(WS2 + dx*0.1,WSsup1) #pente max de 10%
    WSinf = max(WS2 - dx * 0.1,WSinf1)

    WS = (WSsup + WSinf) / 2
    Sf = 0
    ecart = WSsup - WSinf
    iter_i = 0

    while ecart > paramH['eps'] and iter_i < paramH['MaxIter']:
        iter_i += 1
        # calcul du milieu
        WS = (WSsup + WSinf) / 2

        #downstream section
        averagedValue = section2.computeHydraulicGeometry(
            WS2, paramH['dxlat'], paramH['levee'],paramH['frictionLaw'])

        neq2 = averagedValue['neq']
        Sf2 = (section2.Q ** 2 / averagedValue['A'] ** 2) * neq2 ** 2 / (averagedValue['Rh']) ** (
                    4 / 3)
        A2 = averagedValue['A']

        #upstream section
        averagedValue= section1.computeHydraulicGeometry(WS, paramH['dxlat'], paramH['levee'],paramH['frictionLaw'])
        neq1 =  averagedValue['neq']
        Sf1 = (section1.Q ** 2 / averagedValue['A'] ** 2) * neq1 ** 2 / (averagedValue['Rh']) ** (4 / 3)
        dh = (Sf1 + Sf2) / 2 * dx
        dh2 = charge(section1.Q, averagedValue['A'], WS) - charge(section2.Q, A2, WS2)

        if dh < dh2:
            # la solution est inférieure à m
            WSsup = WS
        else:
            # la solution est supérieure à m
            WSinf = WS

        ecart = WSsup - WSinf
        Sf = (Sf1 + Sf2) / 2

    return WS, Sf


def read_obs(BV, param, in_obs, time, var='Z', type='EnsObs'):
    """read in situ observation of water elevation

    :param BV: Watershed
    :type BV: ModelCatchment
    :param param: Parameters requires for computation
    :type param: Parameters
    :param in_obs: observation data
    :type param: csv file/ pickle
    :param time: considered time of observation
    :type time: datetime
    :param var: considered variable of observation
    :type param: str
    :param  type: data observation source
    :type type: str
    """
    paramH = param['H']
    if type == 'EnsObs':

        time_list = in_obs.times_list(var)
        indtime = np.argmin([abs(time - t) for t in time_list])
        near_time = int(time_list[indtime])

        for reach in BV.reach:
            X_posObs = []
            var_posObs = []
            distance_reach = [reach.geodata['geometry'].project(point) for point in reach.geodata['geometry']]
            idRiver, idReach = reach.geodata['River'], reach.geodata['Reach']
            for ip, p in enumerate(in_obs[var][near_time]['pos']):  # abscisse curviligne de stations de mesure
                if reach.geodata['geometry'].has_z:
                    centerline = LineString([(coord[0], coord[1]) for coord in reach.geodata['geometry'].coords])
                else:
                    centerline = reach.geodata['geometry']

                p_on_centerline = nearest_points(centerline, Point(p))[0]
                count_pts = 1
                while (distance_reach[count_pts] - distance_reach[count_pts - 1] < Point(
                        centerline.coords[count_pts]).distance(p_on_centerline)) \
                        and (count_pts < len(distance_reach) - 1):
                    count_pts += 1

                list_of_points = [coord for coord in
                                  centerline.coords[:count_pts - 1]] + [(p_on_centerline.x, p_on_centerline.y)]

                line_temp = LineString(list_of_points)
                X_posObs.append(line_temp.length)
                var_posObs.append(in_obs[var][near_time]['val'][ip])
            Z_posObs = var_posObs

            if var == 'H':
                point_of_centerline = [s.centre for s in reach.section_list]
                z_of_centerline = [s.Zbed for s in reach.section_list]
                for p in range(len(var_posObs)):
                    point_obs = Point(in_obs[var][near_time]['pos'][p][0], in_obs[var][near_time]['pos'][p][1])
                    # Z fond pris sur le MNT

                    Zneighborhood = np.min(
                        find_pixels_in_neighborhood(BV.globalDEM, BV.Map['X'], BV.Map['Y'], point_obs.x, point_obs.y,
                                                    1))
                    Z_posObs[p] = Z_posObs[p] + Zneighborhood
                    # Z fond pris sur les sections
                    # ind_section = np.argmin([point_obs.distance(pc) for pc in point_of_centerline])
                    # Z_posObs[p] =Z_posObs[p]+z_of_centerline[ind_section]

            # interpolation des cotes sur les sections
            Z_section = np.interp(reach.Xinterp, X_posObs, Z_posObs)

            slope = []
            for i in range(len(Z_section) - 1):
                slope.append(abs((Z_section[i + 1] - Z_section[i]) / (reach.Xinterp[i + 1] - reach.Xinterp[i])))
            slope.append(0)

            # calcul des grandeurs hydrauliques pour chaque section

            for j, section in enumerate(reach.section_list):
                section.WSE =Z_section[j]
                averagedValue = section.computeHydraulicGeometry(section.WSE, paramH['dxlat'], paramH['levee'],
                                                                 paramH['frictionLaw'], slope[j])


def equivalentChannel(section, shape):
    """launch hydraulic calculation

    :param section: section to fit
    :type section: string
    :param shape: equivalent shape
    :type shape: str
    """
    if shape=='triangle':
        X0 = [section.hydro['W'][0], 1, 1e-2]
        res = minimize(find_LCA_triangle, X0, method='BFGS', args=(section))
        return res.x


def select_computation(BV1, params):
    """select and launch hydraulic calculation

    :param BV1: Watershed
    :type BV1: ModelCatchment
    :param params: Parameters requires for computation
    :type params: dict
    """
    type_comp = params['H']['createBanksMethods']
    config_cal = type_comp
    if type_comp == 'Normal':
        computeNormalAndCriticalDepth(BV1, params)
        config_cal = config_cal + '_Q' + str(params['H']['outletDischarge'])
    elif type_comp == '1D':
        backWaterProfile(BV1, params)
        config_cal = config_cal + '_Q' + str(params['H']['outletDischarge']) + '_h' + str(params['H']['hWaterOutlet'])
    elif type_comp == 'Himposed':
        impose_water_depth(BV1, params)
        config_cal = config_cal + '_h' + str(params['H']['himposed'])
    elif type_comp == 'WSE':
        impose_water_elevation(BV1, params)
        config_cal = config_cal + '_WSE'
    return config_cal


def run_hydraulics(BV, param, assim=False):
    """launch hydraulic calculation and export result"

    :param BV: Watershed
    :type BV: ModelCatchment
    :param paramH: Parameters requires for computation
    :type param: dict
    :param assim: if script used in assimilation loop to avoid save banklines at each step
    :type assim: bool
    """
    config_cal = ' '
    setConstantFriction(BV, param['H']['frictionValue'])
    if param['H']['frictionMap']:
        projOnFrictionMap(BV, param['H']['friction_filename'])
    print("friction coefficient associated to sections")

    if not assim:
        export_as_csv_for_HECRASgeometry(BV, param.work_path)

    #- Attribute discharge to sections
    if param['H']['hydraulicCurve'] or param['H']['createBanks']:
        if param['H']['createBanksMethods'] != "Himposed" or param['H']['createBanksMethods'] != "WSE":
            addQtoSection(BV, param)
            BV.display("Discharge associated to sections")

    # ------Hydraulic curve computation--------------------------------------------
    if param['H']['hydraulicCurve']:
        addWidthtoSection(BV, param)
        BV.display("Hydraulic curves created")

    # ------Hydraulic map and banks computation--------------------------------------------
    if param['H']['createBanks']:
        BV.compute_slope(param['N']['npoly'])
        config_cal = select_computation(BV, param)
        if not assim:
            createAllBankLines(BV)
            save_bank_lines(BV, param.work_path, config_cal)
            BV.display("Bank lines created")
            export_result_as_csv(BV, param)

        if not param['I']['riverbanks_filepath'] and BV.init_mode== 12:  # creation channel only
            bank_to_edge_end(BV, param)
        if param['XS']['width_from_banks'] and not param['I']['riverbanks_filepath']:
            reduceXStobank(BV,param['XS']['distSearchMin'], param['XS']['method_banks'])

    return config_cal
