######################################################################
#             / ____|  ____|  __ \| |  | |  ____|  ____|             #
#            | |    | |__  | |__) | |__| | |__  | |__                #
#            | |    |  __| |  ___/|  __  |  __| |  __|               #
#            | |____| |____| |    | |  | | |____| |____              #
#             \_____|______|_|    |_|  |_|______|______|             #
######################################################################
#
# CEPHEE
# Copyright (C) 2024 Toulouse INP
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License for more details :
# <http://www.gnu.org/licenses/>.
#
######################################################################

# global
from scipy.optimize import minimize
from shapely.geometry import point,multipoint
# local
from .Tools import *
from .Hydraulics import Limerinos
import copy

class Section:
    """ TODO

    """
    def __init__(self,line,name=""):
        X=line.xy[0]
        Y=line.xy[1]    
        if len(line.coords[0])>2:
            Z=[line.coords[i][2] for i in range(len(line.coords))]
            self.line = LineString([(X[i],Y[i],Z[i]) for i in range(len(X))])
        else:
            self.line = LineString([(X[i],Y[i],0) for i in range(len(X))])
        self.origin_line=self.line       
        self.centre = Point()
        self.d0=0.
        self.start =Point()
        self.end = Point()
        self.name=name
        self.Zbed= []
        self.distance =[]
        self.res=None
        
        self.hydro = {'depth':[],'A' : [], 'W' : [],'Q': [],'equivalent' : { 'w1' : 0, 'd1' : 0, 'i2' : 0}}
        self.Q =0.
        self.acc = 0.
        self.manning =[]
        self.slope =[]
        self.flagProj=0
        self.type ='XS'


    def distanceBord(self):
        '''
        compute the distance on (X,Y) plane between each point of the section and the one defined as the reference (self.start) 

        Returns
        -------
        None.

        '''
        
        self.distance =[]
        self.distance = X_abscissa(self.line)
        self.start = Point(self.line.coords[0][0], self.line.coords[0][1])
        self.end = Point(self.line.coords[-1][0], self.line.coords[-1][1])


    def originXSLocation(self,centre=None,d_search=0):
        '''
        Choose the position of the transversal origin for each section. It could be defined by the point with minimum elevation or the shp file of the streamline.

        Parameters
        ----------
        centre : string
            DESCRIPTION. The default is None.

        Returns
        -------
        None.

        '''

        #Position du 0 latéral   
        self.distanceBord()
        
        if centre=='minZ':
            Z = [self.line.coords[i][2] for i in range(len(self.line.coords))]
            ind_min=np.argmin(Z)
            min_distance_centre=self.distance[ind_min]
            distance_centre=self.line.distance(self.centre)
            ind_min_centre=np.argmin(distance_centre)
            self.Zbed = Z[ind_min]

        elif centre == 'riverline':

            #on cherche le plus proche point de la position du lit dans le shapefile
            distance_centre=[((coord[0]-self.centre.x)**2 +(coord[1]-self.centre.y)**2)**0.5 for coord in self.line.coords]

            ind_min=np.argmin(distance_centre)
            ind_min_centre =ind_min
            # on prend le point le plus autour du minimum

            npoint_search_sup = np.argmin(abs(self.distance-(self.distance[ind_min]+d_search)))
            npoint_search_inf = np.argmin(abs(self.distance-(self.distance[ind_min]-d_search)))

            if self.line.has_z:
                Z = [self.line.coords[i][2] for i in range(len(self.line.coords))]

                if npoint_search_inf != npoint_search_sup:
                    ind_min = np.argmin(Z[npoint_search_inf:npoint_search_sup])
                    min_distance_centre = self.distance[ind_min + npoint_search_inf]
                    self.Zbed = Z[ind_min + npoint_search_inf]
                else:

                    if d_search==0: # cas d'utilisation de masque
                        min_distance_centre = self.distance[ind_min]
                        self.Zbed = Z[ind_min]
                    else:# cas où le pas d'espace latéral est plus grand que l'espace entre berge
                        print('Warning : lateral point space superior to water surface width, increase number of point')
                        ind_min = np.argmin(Z)
                        min_distance_centre = self.distance[ind_min]
                        self.Zbed = Z[ind_min]

            else:
                min_distance_centre = self.distance[ind_min]
                self.Zbed = 0
       
        self.distance =[self.distance[i]-min_distance_centre for i in range(len(self.distance))]                               

        if ind_min_centre<ind_min:
            self.d0 = ((self.centre.x-self.line.xy[0][ind_min])**2+(self.centre.y-self.line.xy[1][ind_min])**2)**0.5
        else:
            self.d0 = -((self.centre.x-self.line.xy[0][ind_min])**2+(self.centre.y-self.line.xy[1][ind_min])**2)**0.5


    def interpDistributionLat(self, dbank, dxlat):
        '''
        Resample section in transversal direction with a constant distance step. Lateral boundaries are defined by dbank. 

        Parameters
        ----------
        dbank : List
            point of intersection between water surface and bed channel.
        dxlat : float
            lateral discretization step.

        Returns
        -------
        dist_distr : List
            Lateral distance of each center of subdomain.
        manning_distr : List
            friction Coeffcient of each center of subdomain.
        Wsextent : List (Point)
            Position of intersection between bank and XS
        dist_point : List
            transversal position of subdomain limit
        z_point : List
             elevation interpolated at each subdomain limit

        '''
        X = [self.line.coords[i][0] for i in range(len(self.line.coords))]
        Y = [self.line.coords[i][1] for i in range(len(self.line.coords))]
        Z = [self.line.coords[i][2] for i in range(len(self.line.coords))]
        distance = self.distance

        if dxlat == 0:  # pas de redistribution

            dist_point = [self.distance[i] for i in range(len(self.distance)) if
                          self.distance[i] >= dbank[0] and self.distance[i] <= dbank[1]]
            

        else:

            # liste des distance transversale en eau
            dcentre = (dbank[0] + dbank[-1]) / 2  # les valeurs non égale à dx sont réparties sur les 2 berges
            dist1 = list(np.arange(dcentre, dbank[-1], dxlat))
            
            dist2 = list(np.arange(-dcentre, -dbank[0], dxlat))
            dist2 = dist2[::-1]
            dist2 = [-i for i in dist2]
            disttot=dist2 + dist1
            for d in dbank:
                disttot.append(d)
            dist_point = list(np.unique(disttot))

        dist_point.sort() #contient les points définissant les extrémtités des segments distribués
        dist_distr=np.zeros((len(dist_point)-1)) #contient la position des centres des segments distribués

        for i in range(len(dist_distr)):
            dist_distr[i]=(dist_point[i+1]+dist_point[i])/2 #position de la verticale au milieu de la sous section

        dist_distr.sort()
        manning_distr = np.interp(dist_distr, distance, self.manning)
        if len(dist_distr)==1: #une seule sous section

            z_point = np.interp(dist_point, distance, Z)
            manning_distr = self.manning

        else:

            z_point = np.interp(dist_point, distance, Z)
            manning_distr = list(np.interp(dist_distr, distance, self.manning))

        X0 = np.interp(dbank[0], distance, X)
        Y0 = np.interp(dbank[0], distance, Y)
        X1 = np.interp(dbank[-1], distance, X)
        Y1 = np.interp(dbank[-1], distance, Y)
        Wsextent = [Point(X0,Y0),Point(X1,Y1)]
        return dist_distr, manning_distr, Wsextent, dist_point, z_point


    def findBankPoint(self, WS, method):
        '''
        

        Parameters
        ----------
        WS : float
            water surface elevation.
        method : string
            define methode to manage overflowing on floodplain (levee taken into account or not).

        Returns
        -------
        dtemp : list
            point of intersection between water surface and bed channel.

        '''
        Z = [self.line.coords[i][2] for i in range(len(self.line.coords))]
        lineXS = LineString((self.distance[i], Z[i]) for i in range(len(Z)))
        linesurf = LineString([(np.min(self.distance), WS), (lineXS.length, WS)])
        ztemp = list(Z)
        dtemp = []
        inter = lineXS.intersection(linesurf)  # intersection surface-fond

        if type(inter) == multipoint.MultiPoint:
            for ii in range(len(inter.geoms)):
                dtemp.append(inter.geoms[ii].x)

        elif type(inter) == point.Point:
            dtemp.append(inter.x)
        dtemp.sort()


        # cas de points tous du même coté
        if len(dtemp) >= 2:
            if np.min(dtemp) >= 0: #tous les points sur la berge droite
                dtemp[1] = np.min(dtemp)
                dtemp[0] = np.min(self.distance)
            if np.max(dtemp) <= 0:
                dtemp[0] =  np.max(dtemp)
                dtemp[1] = np.max(self.distance)
        dtemp =list(np.unique(dtemp))
        dtemp.sort()


        if method:

            # cas de deux point autour d'un maximum , on garde celui du coté du centre
            if len(dtemp) >= 2:
                dcorr=[]
                for i in range(len(dtemp)-1):
                    cote=[Z[ii] for ii in range(len(Z)) if self.distance[ii]>=dtemp[i] and self.distance[ii]<=dtemp[i+1] ]
                    dtemp_select = [self.distance[ii] for ii in range(len(Z)) if self.distance[ii]>=dtemp[i] and self.distance[ii]<=dtemp[i+1] ]

                    coteMax=np.max(cote)
                    
                    if coteMax>WS:
                        if dtemp[i+1]<0: #on est du meme coté gauche de la rivière
                            dcorr.append(dtemp[i+1])
                        elif dtemp[i]>0 : #on est du meme coté droit de la rivière
                            dcorr.append(dtemp[i])
                    else:
                        dcorr.append(dtemp[i])
                        dcorr.append(dtemp[i+1])
                
                dcorr.sort()            
                dtemp=list(np.unique(dcorr))


            if len(dtemp) == 0:
                dbank1 = self.distance[0]
                dbank2 = self.distance[-1]
            elif len(dtemp) == 1:  # si la surface libre est plus haute que les berges
                if WS >= Z[0]:
                    dbank1 = self.distance[0]
                    dbank2 = dtemp[0]
                elif WS >= Z[-1]:
                    dbank1 = dtemp[0]
                    dbank2 = self.distance[-1]
            elif len(dtemp) == 2:
                dbank1 = dtemp[0]
                dbank2 = dtemp[1]
            elif len(dtemp) > 2:  # plus de 2 points ont la même cote, on cherche les plus proche du point le plus bas
                # bank à gauche

                dleft = np.array(dtemp.copy())
                dleft[dleft > 0] = -np.inf
                ind_bank1 = np.argmax(dleft)
                dbank1 = dtemp[ind_bank1]
                # bank à droite
                dright = np.array(dtemp.copy())
                dright[dright < 0] = np.inf
                ind_bank2 = np.argmin(dright)
                dbank2 = dtemp[ind_bank2]

               # if dbank1 == dbank2:
                    #print('warning: only one bank found')
                    #dbank2=dbank1+1


            dtemp=[dbank1, dbank2]


        else: #tous les poits d'intersections sont pris en compte
    
            if len(dtemp) == 0:
                dtemp.append(self.distance[0])
                dtemp.append(self.distance[-1])
            
            if WS >= Z[0]:
                dtemp.append(self.distance[0])
              
            if WS >= Z[-1]:           
                dtemp.append(self.distance[-1])
            
            dtemp=list(np.unique(dtemp))  
            dtemp.sort()

        return dtemp


    def computeHydraulicGeometry(self, WS, dxlat, method,friction = "Manning"):
        ''' Compute the hydraulic parameters for all subdomain defined by the water surface eleavtion and the trnasversal step

        Parameters
        ----------
    
        WS : float
            water surface elevation.
        dxlat : float
            lateral discretization step.
        method : string
            define method to manage overflowing on floodplain (levee taken into account or not).

        Returns
        -------
        averagedValue : dict
            Averaged hydraulic characteristic of the section : area, wetted perimeter, surface width
        bank : list
           point of intersection between water surface and bed channel.
        hydro_distr : dict
            hydraulic characteristic of the sub-section : area, wetted perimeter, surface width
        manning_distr : lsit
            manning for each sub-section

        '''


        ztopo = [self.line.coords[i][2] for i in range(len(self.line.coords))]
        #point de berge avec méthode sans ou avec levee (voir methode HEC RAS)
        dbank=self.findBankPoint(WS,method)

        # interpoler la section pour la redistribution
        dist_distr, manning_distr, Wsextent,dist_point,z_point = self.interpDistributionLat(dbank, dxlat)

        # -------------------------------------------------------------------------
        waterdepth = np.array(WS - np.array(z_point))
        waterdepth[waterdepth < 0] = 0
        waterdepth_distr, A_distr, Pm_distr, Rh_distr = np.zeros((len(dist_distr),)), np.zeros(
            (len(dist_distr),)), np.zeros((len(dist_distr),)), np.zeros((len(dist_distr),))
        hydro_distr = np.zeros((len(dist_distr), 5))

        for ii in range(len(dist_distr)):
            # section des points de topo entre les 2 points de redistribution
            dtemp = [self.distance[i] for i in range(len(self.distance)) if
                     self.distance[i] > dist_distr[ii] and self.distance[i] < dist_distr[ii]]
            htemp = [WS - ztopo[i] for i in range(len(self.distance)) if
                     self.distance[i] > dist_distr[ii] and self.distance[i] < dist_distr[ii]]
            ztemp = [ztopo[i] for i in range(len(self.distance)) if
                     self.distance[i] > dist_distr[ii] and self.distance[i] < dist_distr[ii]]
            if not isinstance(WS, float) and not isinstance(WS, int):
                WS = WS[0]
            if not len(dtemp):  # pas de points topo entre 2 verticales interpolées
                dtemp = [dist_point[ii], dist_point[ii + 1]]
                htemp = [WS - z_point[ii], WS - z_point[ii + 1]]
                ztemp = [z_point[ii], z_point[ii + 1]]
            else:
                dtemp = dist_point[ii] + dtemp + dist_point[ii + 1]
                htemp = WS - z_point[ii] + htemp + WS - z_point[ii + 1]
                ztemp = z_point[ii] + ztemp + z_point[ii + 1]
            htemp=np.array(htemp)
            htemp[htemp<0]=0
            waterdepth_distr[ii] = (waterdepth[ii + 1] + waterdepth[ii]) / 2
            Pm_distr[ii] = 0
            for j in range(len(ztemp) - 1):
                Pm_distr[ii] += ((dtemp[j + 1] - dtemp[j]) ** 2 + (ztemp[j + 1] - ztemp[j]) ** 2) ** 0.5
            A_distr[ii] = np.trapz(htemp, dtemp)

            if Pm_distr[ii] > 0:
                Rh_distr[ii] = A_distr[ii] / Pm_distr[ii]
            else:
                Rh_distr[ii] = 0

            hydro_distr[ii, 0] = dist_distr[ii]
            hydro_distr[ii, 1] = waterdepth_distr[ii]
            hydro_distr[ii, 2] = A_distr[ii]
            hydro_distr[ii, 3] = Pm_distr[ii]
            hydro_distr[ii, 4] = Rh_distr[ii]

        # ------------------------------------------------------------------------
        A = np.sum(A_distr)
        P = np.sum(Pm_distr)
        if P>0:
            Rh = A/P
        else:
            Rh=0

        if friction == 'Limerinos':
            manning_distr = Limerinos(manning_distr, list(hydro_distr[:, 1]))

        averagedValue = {'A' : A, 'P' : P, 'Rh': Rh}
        bank =[dbank, Wsextent]
        return averagedValue, bank, hydro_distr, manning_distr


    def equivalentChannel(self, shape):
        '''
        Compute the equivalent channel with a given shape considering similar weeted area and perimeter

        Parameters
        ----------
        shape : string
            shape of the equivalent section

        Returns
        -------
        list
            geometrical parameters of the equivalent section.

        '''
        if shape == 'triangle':
            X0 = [self.hydro['W'][0],1,1e-2]
            res = minimize(find_LCA_triangle,X0, method ='BFGS', args =(self))
            return res.x
        

    def linechannel(self, param, df_filter):
        '''
        Create the line in Section corresponding to only the line between two banks

        Parameters
        ----------
        param : string
            parmaeter for computation

        Returns
        -------
        list
            geometrical parameters of the equivalent section.

        '''


        WS=df_filter['WSE'].iloc[0]

        distancebank1 = df_filter['bank'].iloc[0][0][0]
        distancebank2 = df_filter['bank'].iloc[0][0][1]
        xbank1 = df_filter['bank'].iloc[0][1][0].x
        xbank2 = df_filter['bank'].iloc[0][1][1].x
        ybank1 = df_filter['bank'].iloc[0][1][0].y
        ybank2 = df_filter['bank'].iloc[0][1][1].y
        #self.distanceBord()
        #self.originXSLocation(param['XS']['originSection'], param['XS']['distSearchMin'])
        old_distance=[d for  d in self.distance]
        self.distance.append(distancebank1)
        self.distance.append(distancebank2)
        self.distance.sort()
        self.manning = np.interp(self.distance,old_distance,self.manning)
        #Search for bank index
        index1 = self.distance.index(distancebank1)
        index2 = self.distance.index(distancebank2)
        #creating coordinates lists
        list_x = [coord[0] for coord in self.line.coords]
        list_y = [coord[1] for coord in self.line.coords]
        list_z = [coord[2] for coord in self.line.coords]
        #adding the coordinates of bank 1 to the correct index
        list_x.insert(index1, xbank1)
        list_y.insert(index1, ybank1)
        list_z.insert(index1, WS)
        #adding the coordinates of bank 2 to the correct index
        list_x.insert(index2, xbank2)
        list_y.insert(index2, ybank2)
        list_z.insert(index2, WS)
        #update
        self.line = LineString([(x, y, z) for x, y, z in zip(list_x, list_y, list_z)])
        #create line of points between banks
        sub_list_x = list_x[index1:index2+1]
        sub_list_y = list_y[index1:index2+1]
        sub_list_z = list_z[index1:index2+1]

        if len(sub_list_x) > 1:
            self.line_channel = LineString([(x, y, z) for x, y, z in zip(sub_list_x, sub_list_y, sub_list_z)])
            self.distance_channel = self.distance[index1:index2 + 1]

        else:
            self.line_channel = LineString()
            self.distance_channel = []


    def modif_line(self, param, df_filter):
        """ Change the bathymetry of the line section between the 2 banks

        :param param: Parameters requires for computation
        :type param: Parameters
        :param df_filter:  result from computation with banks position
        :type res: dataframe
        """
        paramB = param['B']
        distancebank1 = df_filter['bank'].iloc[0][0][0]
        distancebank2 = df_filter['bank'].iloc[0][0][1]
        WSE = df_filter['WSE'].iloc[0]
        distances = self.distance
        
        #Search for bank index
        index1 = distances.index(distancebank1)
        index2 = distances.index(distancebank2)
        #Lists of coordinates of line
        list_x = [coord[0] for coord in self.line.coords]
        list_y = [coord[1] for coord in self.line.coords]
        list_z = [coord[2] for coord in self.line.coords]
        #modification of the z coordinate

        distance_mid = (distancebank1+distancebank2)/2
        i_mid = np.abs(np.array(distances) - distance_mid).argmin()

        #i_mid = (index1 + index2)//2
        #zbank =np.min(list_z[index1],)

        if not paramB['useImportedSections']:
            if paramB['bathymetricSections'] == "Rectangular":
                for i in range(len(list_z)):
                    if i >= index1 and i <= index2 :
                        list_z[i] =  WSE - paramB['depth']

            elif paramB['bathymetricSections'] == "Triangular":
                slope1 = ((WSE - paramB['depth']) - WSE) / (distances[i_mid] - distances[index1])
                slope2 = (WSE - (WSE- paramB['depth'])) / (distances[index2] - distances[i_mid])
                for i in range(len(list_z)):
                    if i >= index1 and i <= i_mid:
                        list_z[i] = list_z[index1] + slope1 * (distances[i] - distances[index1])
                    if i > i_mid and i <= index2:
                        list_z[i] = (list_z[index1] - paramB['depth']) + slope2 * (distances[i] - distances[i_mid])

            elif paramB['bathymetricSections'] == "Parabolic":
                if index2 - index1 + 1 < 5:

                    print("Il n'y a pas assez de point pour faire une section parabolique. La section est triangulaire.")
                    if i_mid==index1:
                        list_z[i_mid] = list_z[index1] - paramB['depth']
                    else:
                        slope1 = ((list_z[index1] - paramB['depth']) - list_z[index1]) / (distances[i_mid] - distances[index1])
                        for i in range(len(list_z)):
                            if i >= index1 and i <= i_mid:
                                list_z[i] = WSE + slope1 * (distances[i] - distances[index1])
                    if i_mid == index2:
                        list_z[i_mid] = list_z[index2] - paramB['depth']
                    else:
                        slope2 = (list_z[index2] - (list_z[index1] - paramB['depth'])) / (distances[index2] - distances[i_mid])
                        for i in range(len(list_z)):
                            if i > i_mid and i <= index2:
                                list_z[i] = (WSE- paramB['depth']) + slope2 * (distances[i] - distances[i_mid])
                else:
                    x1, z1 = distances[index1], WSE
                    x2, z2 = distances[i_mid], WSE - paramB['depth']
                    x3, z3 = distances[index2], WSE
                    A = np.array([[x1**2, x1, 1],
                                [x2**2, x2, 1],
                                [x3**2, x3, 1]])
                    B = np.array([z1, z2, z3])
                    a, b, c = np.linalg.solve(A, B)
                    for i in range(len(list_z)):
                        if i >= index1 and i <= index2:
                            list_z[i] = a * distances[i]**2 + b * distances[i] + c

        #create lists of coordinates between banks
        list_x_channel = list_x[index1:index2+1]
        list_y_channel = list_y[index1:index2+1]
        list_z_channel = list_z[index1:index2+1]

        #update
        self.line = LineString([(x, y, z) for x, y, z in zip(list_x, list_y, list_z)])
        #self.distanceBord()
        #self.originXSLocation(param['XS']['originSection'], param['XS']['distSearchMin'])
        if len(list_x_channel) > 1:
            self.line_channel = LineString([(x, y, z) for x, y, z in zip(list_x_channel, list_y_channel, list_z_channel)])
            self.distance_channel = self.distance[index1:index2 + 1]

        else:
            self.line_channel = LineString()
            self.distance_channel = []


    def copy(self):
        """ copy section

        """

        section_temp = copy.deepcopy(self)
        '''
        section_temp = Section(self.line)
        section_temp.start = self.start
        section_temp.end=self.end 
        section_temp.name =self.name
        section_temp.Zbed =self.Zbed
        section_temp.distance =self.distance
        section_temp.centre = self.centre
        section_temp.manning = self.manning
        section_temp.slope = self.slope
        '''
        
        return section_temp
       

    def weir_discharge(self,paramH,ws_downstream=0):
        """ compute the upstream water depth for inline structure section

        :param paramH: Hydraulics Parameters requires for computation
        :type param: Parameters
        :param ws_downstream:  water surface at dowsntream (only for submerged condition not implemented yet)
        :type res: flaot
        """

        Zs = self.inlinedata[2]
        #free flow

        WS = (self.Q / (self.inlinedata[0] * self.inlinedata[1]* (2 * 9.81) ** 0.5)) ** (3 / 2) +Zs
        Sf = 0
        averagedValue, bank, hydro_distr, manning_distr = self.computeHydraulicGeometry(WS,  paramH[ 'dxlat'],paramH['levee'])
        V = np.zeros((len(hydro_distr[:, -1]),))
        for idx, Rh in enumerate(hydro_distr[:, -1]):
            V[idx] = self.Q / (WS - Zs) / \
                     self.inlinedata[1]

        return WS,Sf,V, hydro_distr
