######################################################################
#             / ____|  ____|  __ \| |  | |  ____|  ____|             #
#            | |    | |__  | |__) | |__| | |__  | |__                #
#            | |    |  __| |  ___/|  __  |  __| |  __|               #
#            | |____| |____| |    | |  | | |____| |____              #
#             \_____|______|_|    |_|  |_|______|______|             #
######################################################################
#
# CEPHEE
# Copyright (C) 2024 Toulouse INP
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License for more details :
# <http://www.gnu.org/licenses/>.
#
######################################################################

# third-party
from tatooinemesher.section import CrossSectionSequence

# local
from core.Hydraulics import *


class Reach(CrossSectionSequence):
    """ TODO

    """
    def __init__(self, name=""):
        super().__init__()
        self.id_first_section = 0
        self.slope = []  # smoothing the slope used for hydraulic calculation
        self.name = name
        self.geodata = []
        self.area = 0.0
        self.line_int = LineString()
        self.Q = 0.0
        self.flag1D = 0
        self.Xinterp = [0]
        self.Xupstream = 0
        self.Zbed = []
        self.left_bank_line, self.right_bank_line = [], []
        self.left_levee_line, self.right_levee_line = None, None

    def compute_normal_and_critical_depth(self, paramH, verbose):
        """ Compute the normal and critical depth for each section of the reach

        :param paramH: hydraulic computation parameter
        :type paramH: dict
        :param verbose: To display information
        :type verbose: boolean
        """
        for j in range(len(self.section_list)):
            section = self.section_list[j]
            section.slope = self.slope[j]
            if section.slope <= 0:
                section.slope = 0.0001
            WS = find_WaterdepthFromQ(self.Zbed[j], section, paramH)
            averagedValue = section.computeHydraulicGeometry(WS, paramH['dxlat'], paramH['levee'],
                                                             paramH['frictionLaw'], section.slope)
            section.WSE = WS
            section.bank = averagedValue['dbank']
            if verbose:
                print('Water elevation at section ' + str(j) + ' is ' + str(WS))

    def compute_slope(self, n_poly, data):
        """ Approximate a longitudinal slope of the reach

        :param n_poly: order of the polynom for approximation
        :type n_poly : int
        :param data: Data used to approximate the slope (bed or WSE)
        :type data: string
        """
        Y_to_smooth = []
        X_to_smooth = []
        X_smooth = []
        if data == 'bed':
            self.Zbed = [section.Zbed for section in self.section_list]
            Y_to_smooth = self.Zbed
            X_to_smooth = self.Xinterp
        elif data == 'WSE':
            for i, section in enumerate(self.section_list):
                if not np.isnan(section.WSE) and section.WSE is not None:
                    X_to_smooth.append(self.Xinterp[i])
                    Y_to_smooth.append(section.WSE)

        if len(Y_to_smooth) > 1:
            if n_poly:
                n_poly = np.min([n_poly, len(Y_to_smooth) - 2])
                coeff_poly = np.polyfit(X_to_smooth, Y_to_smooth, n_poly)
                Z_inter = np.poly1d(coeff_poly)
                for i in range(len(Y_to_smooth) - 1):
                    X_smooth.append(-1 * (Z_inter(self.Xinterp[i + 1]) - Z_inter(self.Xinterp[i])) / (
                            self.Xinterp[i + 1] - self.Xinterp[i]))
                X_smooth.append(
                    -1 * (Z_inter(self.Xinterp[-1]) - Z_inter(self.Xinterp[-2])) / (self.Xinterp[-1] - self.Xinterp[-2]))
        else:
            X_smooth = np.zeros(len(Y_to_smooth))

        if data == 'bed':
            self.slope = X_smooth
            for i, section in enumerate(self.section_list):
                section.slope = X_smooth[i]
        elif data == 'WSE':
            if len(X_smooth)<len(self.section_list):
                X_smooth = np.interp(self.Xinterp,X_to_smooth,X_smooth)
            for i, section in enumerate(self.section_list):
                section.slope = X_smooth[i]

    def compute_back_water(self, WSE_out, paramH, verbose):
        """ Compute the water surface for each section of the reach

        Parameters
        ----------
        :param WSE_out: Water Surface Elevation downstream boundary condition
        :type WSE_out: float
        :param paramH: hydraulic computation parameters
        :type paramH : dict
        """
        self.Zbed = [section.Zbed for section in self.section_list]
        # initialisation depuis l'aval

        averagedValue = self.section_list[-1].computeHydraulicGeometry(WSE_out, paramH['dxlat'], paramH['levee'],
                                                                            paramH['frictionLaw'],
                                                                            self.section_list[-1].slope)
        self.section_list[-1].WSE = WSE_out
        self.section_list[-1].bank = averagedValue['dbank']
        self.section_list[-1].flag_1D= True
        CritDepth = find_CriticaldepthFromQ(self.Zbed[-1], self.section_list[-1], paramH)

        if WSE_out < CritDepth:
            WSE_out = CritDepth + 0.1
            print('condition torrentiel à l aval')

        for j in range(1, len(self.section_list), paramH['step_section']):

            dx = self.Xinterp[-j] - self.Xinterp[-(j + 1)]
            # recherche cote de la section amont
            CritDepth = find_CriticaldepthFromQ(self.Zbed[-(j + 1)], self.section_list[-(j + 1)], paramH)


            if self.section_list[-(j + 1)].label == 'Cross-section':
                (WS, Sf) = find_hsup(self.Zbed[-(j + 1)], self.section_list[-(j + 1)], WSE_out, self.section_list[-j], dx,
                                     paramH)

                averagedValue = self.section_list[-(j + 1)].computeHydraulicGeometry(WS,
                                                                                     paramH['dxlat'], paramH['levee'],
                                                                                     paramH['frictionLaw'], Sf=Sf)

                self.section_list[-(j + 1)].slope = Sf
                self.section_list[-(j + 1)].bank = averagedValue['dbank']
            elif self.section_list[-(j + 1)].label == 'inline_structure':
                if verbose:
                    print('weir at:' + str(self.section_list[-(j + 1)].centre))
                if self.section_list[-(j + 1)].inlinedata[3] == 'weir':
                    WS = self.section_list[-(j + 1)].weir_discharge(paramH)

            if WS <= CritDepth:
                WSE_out = CritDepth + 0.05
                averagedValue = self.section_list[-(j + 1)].computeHydraulicGeometry(
                    WSE_out, paramH['dxlat'], paramH['levee'], paramH['frictionLaw'], 0)
                self.section_list[-(j + 1)].bank = averagedValue['dbank']
            else:
                WSE_out = WS

            if verbose:
                print('Water elevation at section ' + str(len(self.section_list) - (j + 1)) + ' is ' + str(WSE_out))
            self.section_list[-(j + 1)].WSE = WSE_out
            self.section_list[-(j + 1)].flag_1D = True

        #add slope at dowstream
        if len(self.section_list) > 1:
            #self.section_list[-1].WSE = paramH['hWaterOutlet']+ self.section_list[-1].Zbed
            self.slope[-1] = (paramH['hWaterOutlet'] - self.section_list[-2].WSE) / (
                    self.Xinterp[-1] - self.Xinterp[-2])

        if paramH['step_section'] !=1:
            self.interp_result_1D(paramH)

    def interp_result_1D(self,paramH):
        Y_computed = []
        X_computed = []

        for i, section in enumerate(self.section_list):
            if not np.isnan(section.WSE):
                if section.flag_1D:
                    X_computed.append(self.Xinterp[i])
                    Y_computed.append(section.WSE)

        for i, section in enumerate(self.section_list):
            if not section.flag_1D:
                section.WSE = np.interp(self.Xinterp[i],X_computed,Y_computed)
                averagedValue = section.computeHydraulicGeometry(section.WSE, paramH['dxlat'], paramH['levee'],
                                                                 paramH['frictionLaw'], section.slope)
                section.bank = averagedValue['dbank']

    def compute_imposed_water_depth(self, paramH, verbose=False):
        """ Compute the water surface for each section of the reach with a constant water depth

        :param paramH: hydraulic computation parameters
        :type paramH: dict
        :param verbose: To display debugging information
        :type verbose: bool (optional)
        """
        self.Zbed = [section.Zbed for section in self.section_list]
        if len(self.section_list) > 1:
            for j, section in enumerate(self.section_list):
                section.slope = np.max([0, self.slope[j]])
                section.WSE = self.Zbed[j] + paramH['himposed']
                averagedValue = section.computeHydraulicGeometry(section.WSE, paramH['dxlat'], paramH['levee'],
                                                                 paramH['frictionLaw'], section.slope)
                section.bank = averagedValue['dbank']
                if verbose:
                    print('Water elevation at section ' + str(j) + ' is ' + str(section.WSE))

        if len(self.section_list) > 1:
            section = self.section_list[-1]
            section.slope = (self.Zbed[-1] + paramH['himposed'] - self.section_list[-2].WSE) / (
                    self.Xinterp[-1] - self.Xinterp[-2])

    def compute_imposed_water_elevation(self, paramH, verbose=False):
        """ Compute the water surface for each section of the reach with a constant water elevation

        :param paramH: hydraulic computation parameters
        :type paramH: dict
        :param verbose: To display debugging information
        :type verbose: bool (optional)
        """
        for j, section in enumerate(self.section_list[:-1]):
            section.slope = np.max([1e-6, self.slope[j]])
            averagedValue = section.computeHydraulicGeometry(section.WSE, paramH['dxlat'], paramH['levee'],
                                                             paramH['frictionLaw'], section.slope)
            section.bank = averagedValue['dbank']
            if verbose:
                print('Water elevation at section ' + str(j) + ' is ' + str(section.WSE))

        # add slope at dowstream
        if len(self.section_list) > 1:
            section = self.section_list[-1]
            Sf = (self.section_list[-1].WSE - self.section_list[-2].WSE) / (self.Xinterp[-1] - self.Xinterp[-2])
            averagedValue = section.computeHydraulicGeometry(section.WSE, paramH['dxlat'], paramH['levee'],
                                                             paramH['frictionLaw'], Sf)
            section.bank = averagedValue['dbank']

    def create_bank_lines(self):
        """ Compute the bank lines (left and right) of each reach. They can come from normal or 1D computation.

        """
        list_right_bank = []
        list_left_bank = []
        for section in self.section_list:
            x_bank1 = section.bank[1][0].x
            x_bank2 = section.bank[1][1].x
            y_bank1 = section.bank[1][0].y
            y_bank2 = section.bank[1][1].y
            list_left_bank.append(Point(x_bank1, y_bank1))
            list_right_bank.append(Point(x_bank2, y_bank2))
        self.left_bank_line = LineString([(p.x, p.y, self.section_list[i].WSE) for i, p in enumerate(list_left_bank)])
        self.right_bank_line = LineString([(p.x, p.y, self.section_list[i].WSE) for i, p in enumerate(list_right_bank)])

    def create_levee_lines(self):
        """ Create and save levee lines defined as the maximum between minimum elevation and section end

        """
        list_left_levee = []
        list_right_levee = []

        for section in self.section_list:
            X = section.coord.array['X']
            Y = section.coord.array['Y']
            Z = section.coord.values['B']
            distance = section.coord.array['Xt']
            ind_min = np.argmin(np.abs(distance))
            if ind_min > 0:
                ind_levee1 = np.argmax(Z[:ind_min])
            else:
                ind_levee1 = 0
            if ind_min != len(Z) - 1:
                ind_levee2 = np.argmax(Z[ind_min:])
            else:
                ind_levee2 = len(Z) - 1
                ind_min = 0

            z_levee1 = Z[ind_levee1]
            z_levee2 = Z[ind_min + ind_levee2]
            d_levee1 = distance[ind_levee1]
            d_levee2 = distance[ind_min + ind_levee2]
            X_levee1 = np.interp(d_levee1, distance, X)
            X_levee2 = np.interp(d_levee2, distance, X)
            Y_levee1 = np.interp(d_levee1, distance, Y)
            Y_levee2 = np.interp(d_levee2, distance, Y)

            list_left_levee.append(Point(X_levee1, Y_levee1, z_levee1))
            list_right_levee.append(Point(X_levee2, Y_levee2, z_levee2))

        self.left_levee_line = LineString([(p.x, p.y, p.z) for p in list_left_levee])
        self.right_levee_line = LineString([(p.x, p.y, p.z) for p in list_right_levee])
