######################################################################
#             / ____|  ____|  __ \| |  | |  ____|  ____|             #
#            | |    | |__  | |__) | |__| | |__  | |__                #
#            | |    |  __| |  ___/|  __  |  __| |  __|               #
#            | |____| |____| |    | |  | | |____| |____              #
#             \_____|______|_|    |_|  |_|______|______|             #
######################################################################
#
# CEPHEE
# Copyright (C) 2024 Toulouse INP
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License for more details :
# <http://www.gnu.org/licenses/>.
#
######################################################################

# global
from shapely.geometry import LineString,Point
from scipy.spatial import cKDTree
import pandas as pd
# local
from .Hydraulics import *
from shapely.ops import nearest_points

class Reach:
    """ TODO

    """
    def __init__(self,name=""):
        self.section = []
        self.id_first_section = 0
        self.slope = [] # smoothing slope used for hydraulic calculation
        self.name = name
        self.geodata =[]
        self.area =0.
        self.line_int=LineString()
        self.Q =0.
        self.flag1D =0
        self.Xinterp = [0]
        header =['ID', 'idSection','Q', 'WSE', 'CritDepth', 'bank', 'distance', 'h', 'A', 'P', 'Rh', 'V','Sf']
        self.resNormal = pd.DataFrame(columns=header)
        self.res1D = pd.DataFrame(columns=header)
        self.resHimposed = pd.DataFrame(columns=header)
        self.resObs = pd.DataFrame(columns=header)

    def reachNormalAndCriticalDepth(self,paramH):
        '''
        Compute the normal and critical depth for each section of the reach

        Parameters
        ----------
        paramH : dict
            hydraulic computation parameter

        Returns
        -------
        None.

        '''

        idRiver, idReach = self.geodata['River'], self.geodata['Reach']


        for j in range(len(self.section)):
            section = self.section[j]
            section.slope = self.slope[j]

            WS = find_WaterdepthFromQ(self.Zbed[j],section, paramH)
            averagedValue, bank, hydro_distr, manning_distr = section.computeHydraulicGeometry(WS,paramH['dxlat'],paramH['levee'])

            dist_distr = hydro_distr[:, 0]
            if section.slope <= 0:
                section.slope = 0.0001
            V = np.zeros((len(dist_distr),))
            for idx, Rh in enumerate(hydro_distr[:,-1]):
                V[idx] = 1 / manning_distr[idx] * Rh ** (2 / 3) * section.slope ** 0.5
            #calcul hauteur critique
            section.CritDepth = 0#find_CriticaldepthFromQ(self.Zbed[j], section,paramH)
            #Store data in dataframe
            ID = [idRiver, idReach]

            self.resNormal.loc[j] = [ID, j, section.Q, WS, section.CritDepth, bank, hydro_distr[:, 0],
                                               hydro_distr[:, 1],
                                               hydro_distr[:, 2], hydro_distr[:, 3], hydro_distr[:, 4], V,
                                               section.slope]

    def compute_slope(self,npoly=10):
        """ approximate longitudinal slope of the reach

        :param npoly: order of the polynom for approximation
        :type npoly : int
        """

        self.slope = []
        self.Zbed = [section.Zbed for section in self.section]
        npoly = np.min([npoly, len(self.section)-2])
        if len(self.Zbed)>1:
            coeff_poly = np.polyfit(self.Xinterp, self.Zbed, npoly)
            Zinter = np.poly1d(coeff_poly)

            for i in range(len(self.Zbed)-1):
                self.slope.append(-1*(Zinter(self.Xinterp[i+1])-Zinter(self.Xinterp[i]))/(self.Xinterp[i+1]-self.Xinterp[i]))

            self.slope.append(-1*(Zinter(self.Xinterp[-1])-Zinter(self.Xinterp[-2]))/(self.Xinterp[-1]-self.Xinterp[-2]))
        else:
            self.slope = np.zeros(len(self.Zbed))

    def reachBackWater(self,WS2,paramH):
        '''
        Compute the water surface for each section of the reach

        Parameters
        ----------
        WS2 : float
            downstream boundary condition
        paramH : dict
            hydraulic computation parameter

        Returns
        -------
        None.

        '''
        #self.compute_slope()
        self.Zbed = [section.Zbed for section in self.section]


        indDown = len(self.section) - 1
        idRiver, idReach, idSection = self.geodata['River'], self.geodata['Reach'], indDown
        ID = [idRiver, idReach, idSection]
        # initialisation depuis l'aval
        averagedValue, bank, hydro_distr, manning_distr = self.section[indDown].computeHydraulicGeometry(WS2, paramH['dxlat'],paramH['levee'])
        CritDepth = find_CriticaldepthFromQ(self.Zbed[indDown] , self.section[indDown],paramH)

        if WS2 < CritDepth:
            WS2 = CritDepth + 0.1
            print('condition torrentiel à l aval')

        V = np.zeros((len(hydro_distr[:, -1]),))

        # CritDepth = find_CriticaldepthFromQ(self.section[indDown], MinZ+0.01,MinZ+10)
        self.res1D.loc[0] = [ID, indDown, self.section[indDown].Q, WS2, CritDepth, bank,
                                           hydro_distr[:, 0], hydro_distr[:, 1], hydro_distr[:, 2], hydro_distr[:, 3],
                                           hydro_distr[:, 4], V, 0.]

        for j in range(1, len(self.section)):

            dx = self.Xinterp[-j] - self.Xinterp[-(j + 1)]
            # recherche cote de la section amont
            CritDepth = find_CriticaldepthFromQ(self.Zbed[-(j + 1)], self.section[-(j + 1)], paramH)
            if self.section[-(j + 1)].type == 'XS':
                (WS, Sf) = find_hsup(self.Zbed[-(j+1)],self.section[-(j+1)],WS2,self.section[-(j)], dx,paramH)
                averagedValue, bank, hydro_distr, manning_distr = self.section[-(j + 1)].computeHydraulicGeometry(WS,
                                                                paramH['dxlat'],paramH['levee'],paramH['frictionLaw'])
                V = np.zeros((len(hydro_distr[:, -1]),))

                for idx, Rh in enumerate(hydro_distr[:, -1]):
                    if Sf >=0:
                        V[idx] = 1 / manning_distr[idx] * Rh **(2 / 3) * Sf**(1/2)
                    else:
                        V[idx] = 0

                ID = [idRiver, idReach]

            elif self.section[-(j + 1)].type == 'inline_structure':
                print('weir at:' +str(self.section[-(j + 1)].centre))
                if self.section[-(j + 1)].inlinedata[3] == 'weir':
                    WS,Sf,V,hydro_distr = self.section[-(j + 1)].weir_discharge(paramH,0)

            if WS < CritDepth:
                WS2 = CritDepth + 0.05
                averagedValue, bank, hydro_distr, manning_distr = self.section[-(j + 1)].computeHydraulicGeometry(
                    WS2,paramH['dxlat'],paramH['levee'],paramH['frictionLaw'])

                V = np.zeros((len(hydro_distr[:, -1]),))
                for idx, Rh in enumerate(hydro_distr[:, -1]):
                    if self.section[-(j + 1)].type == 'XS':
                        if Sf >= 0:
                            V[idx] = 1 / manning_distr[idx] * Rh ** (2 / 3) * Sf ** (1 / 2)
                        else:
                            V[idx] = 0
                    else:
                        V[idx] = self.section[-(j + 1)].Q / (WS - self.section[-(j + 1)].inlinedata[2]) / \
                                 self.section[-(j + 1)].inlinedata[1]
            else:
                WS2 = WS

            self.res1D.loc[j] = [ID, len(self.section) - (j + 1), self.section[-(j + 1)].Q, WS2,
                                               CritDepth, bank, hydro_distr[:, 0], hydro_distr[:, 1],
                                               hydro_distr[:, 2], hydro_distr[:, 3], hydro_distr[:, 4], V, Sf]
            #add slope at dowstream
            self.res1D.loc[0,'Sf'] = (self.res1D.loc[1,'WSE']- self.res1D.loc[0,'WSE'])/(self.Xinterp[-1]-self.Xinterp[-2])

    def reachImposedWaterDepth(self,paramH):
        '''
        Compute the water surface for each section of the reach with a constant water depth

        Parameters
        ----------
        WS2 : float
            downstream boundary condition
        paramH : dict
            hydraulic computation parameter

        Returns
        -------
        None.

        '''
        
        idRiver, idReach = self.geodata['River'], self.geodata['Reach']
        self.compute_slope()
        self.Zbed = [section.Zbed for section in self.section]
     

        for j in range(len(self.section)):
            section = self.section[j]
            
            section.slope = np.max([0,self.slope[j]])

            WS = self.Zbed[j]+ paramH['himposed']
            averagedValue, bank, hydro_distr, manning_distr = section.computeHydraulicGeometry(WS,paramH['dxlat'],paramH['levee'])

            dist_distr = hydro_distr[:, 0]
            V = np.zeros((len(dist_distr),))
            for idx, Rh in enumerate(hydro_distr[:,-1]):
                V[idx] = 1 / manning_distr[idx] * Rh ** (2 / 3) * section.slope ** 0.5

            #Store data in dataframe
            ID = [idRiver, idReach]
            self.resHimposed.loc[j] = [ID, j,section.Q, WS, 0, bank, hydro_distr[:, 0], hydro_distr[:, 1],
                                               hydro_distr[:, 2], hydro_distr[:, 3], hydro_distr[:, 4],V,section.slope]
        


    def createBankLines(self, resultType='Normal'):
        '''
        Compute the bank line of each reach. They can come from Normal computation, 1D

        Parameters
        ----------
        param : dict
        resultType  : string
            origin of bank position
        Returns
        -------
        None.

        '''


        list_rightbank = []
        list_leftbank = []
        if resultType == 'Normal':
            df = self.resNormal
        elif resultType == '1D':
            df = self.res1D
        elif resultType == 'Himposed':
            df = self.resHimposed
        elif resultType == 'Obs':
            df = self.resObs


        for j in range(len(self.section)):
            df_filter = df[df['idSection'] == j]
            xbank1 = df_filter['bank'].iloc[0][1][0].x
            xbank2 = df_filter['bank'].iloc[0][1][1].x
            ybank1 = df_filter['bank'].iloc[0][1][0].y
            ybank2 = df_filter['bank'].iloc[0][1][1].y
            WSE =  df_filter['WSE'].iloc[0]

            list_leftbank.append(Point(xbank1,ybank1))
            list_rightbank.append(Point(xbank2, ybank2))

        self.leftLine=LineString([(p.x,p.y,WSE) for p in list_leftbank])
        self.rightLine=LineString([(p.x,p.y,WSE) for p in list_rightbank])

