######################################################################
#             / ____|  ____|  __ \| |  | |  ____|  ____|             #
#            | |    | |__  | |__) | |__| | |__  | |__                #
#            | |    |  __| |  ___/|  __  |  __| |  __|               #
#            | |____| |____| |    | |  | | |____| |____              #
#             \_____|______|_|    |_|  |_|______|______|             #
######################################################################
#
# CEPHEE
# Copyright (C) 2024 Toulouse INP
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License for more details :
# <http://www.gnu.org/licenses/>.
#
######################################################################

# global
from shapely.geometry import point, multipoint

# third-party
from tatooinemesher.section import *
from tatooinemesher.coord import Coord
from tatooinemesher.utils import float_vars

# local
from core.Tools import *
from core.Hydraulics import Limerinos


class Section(CrossSection):
    """ TODO

    """

    def __init__(self, line, ID=0, name=""):
        coord = np.array(list(zip(line.xy[0], line.xy[1])), dtype=[('X', float), ('Y', float)])
        super().__init__(ID, coord)
        self.var = [('B', 'f8'), ('W', 'f8'), ('H', 'f8'), ('S', 'f8'), ('U', 'f8'), ('V', 'f8'), ('M', 'f8'),
                    ('TAU', 'f8'), ('QS', 'f8')]

        if line.has_z:
            z_array = np.array([c[2] for c in line.coords], dtype=float_vars(['Z']))
            list_z = [c[2] for c in line.coords]
            self.geom = LineString([(x, y, z) for x, y, z in zip(line.xy[0], line.xy[1], list_z)])
        else:
            z_array = np.array([0 for _ in line.coords], dtype=float_vars(['Z']))

        self.coord.values = np.zeros(len(line.coords), dtype=self.var)
        self.coord.values['B'] = z_array['Z']
        self.centre = Point()
        self.d0 = 0.
        self.start = Point()
        self.end = Point()
        self.name = name
        self.Zbed = []
        self.hydro = {'depth': [], 'A': [], 'W': [], 'Q': [], 'equivalent': {'w1': 0, 'd1': 0, 'i2': 0}}
        self.WSE = 0.
        self.Q = 0.
        self.acc = 0.
        self.slope = []
        self.flagProj = 0
        self.flag_1D = False
        self.bank = []
        self.normal = []
        self.line_channel = LineString()
        self.distance_channel = []

    def distanceBord(self):
        """ Compute the distance on (X,Y) plane between each point of the self and the one defined as the reference
        (self.start).

        """
        self.coord.compute_Xt()
        self.coord.compute_xt()
        self.start = Point(self.coord.array['X'][0], self.coord.array['Y'][0])
        self.end = Point(self.coord.array['X'][-1], self.coord.array['Y'][-1])

    def originXSLocation(self, d_search=0):
        """ Choose the position of the transversal origin for each section. It could be defined by the point with
        minimum elevation or the shp file of the streamline.

        :param d_search: Distance for searching the position
        :type d_search: float
        """
        # Position du 0 latéral
        self.distanceBord()
        # on cherche le plus proche point de la position du lit dans le shapefile
        distance_centre = [((coord[0] - self.centre.x) ** 2 + (coord[1] - self.centre.y) ** 2) ** 0.5 for coord in
                           self.geom.coords]

        ind_min = np.nanargmin(distance_centre)
        ind_min_centre = ind_min
        # on prend le point le plus autour du minimum

        npoint_search_sup = np.nanargmin(abs(self.coord.array['Xt'] - (self.coord.array['Xt'][ind_min] + d_search)))
        npoint_search_inf = np.nanargmin(abs(self.coord.array['Xt'] - (self.coord.array['Xt'][ind_min] - d_search)))

        if self.geom.has_z:
            Z = [self.geom.coords[i][2] for i in range(len(self.geom.coords))]
            if npoint_search_inf != npoint_search_sup:
                ind_min = np.nanargmin(Z[npoint_search_inf:npoint_search_sup])
                min_distance_centre = self.coord.array['Xt'][ind_min + npoint_search_inf]
                self.Zbed = Z[ind_min + npoint_search_inf]
                self.d0 = self.coord.array['Xt'][ind_min_centre] - min_distance_centre

            else:
                min_distance_centre = self.coord.array['Xt'][ind_min_centre]
                self.Zbed = Z[ind_min_centre]
                self.d0 = 0

        else:  # cas de la création de bathymetrie par creusement
            min_distance_centre = (self.coord.array['Xt'][0] + self.coord.array['Xt'][-1]) / 2
            self.d0 = self.coord.array['Xt'][ind_min_centre] - min_distance_centre

        self.coord.array['Xt'] = [self.coord.array['Xt'][i] - min_distance_centre for i in
                                  range(len(self.coord.array['Xt']))]

    def interpDistributionLat(self, dbank, dxlat):
        """ Re-sample section in transversal direction with a constant distance step. Lateral boundaries are defined
        by dbank.

        :param dbank: Point of intersection between water surface and bed channel.
        :type dbank : list
        :param dxlat: lateral discretization step.
        :type dxlat : float
        :return:
            - dist_distr: lateral distance of each center of subdomain. (list)
            - manning_distr: friction Coeffcient of each center of subdomain (list)
            - Wsextent: Position of intersection between bank and XS List (Point)
            - dist_point: transversal position of subdomain limit (list)
            - z_point: elevation interpolated at each subdomain limit (list)
        """
        Z = self.coord.values['B']
        dcentre = (dbank[0] + dbank[-1]) / 2  # les valeurs non égale à dx sont réparties sur les 2 berges
        if dxlat > 0:
            # liste des distance transversale en eau
            dist1 = list(np.arange(dcentre, dbank[-1], dxlat))
            dist2 = list(np.arange(-dcentre, -dbank[0], dxlat))
            dist2 = dist2[::-1]
            dist2 = [-i for i in dist2]
            disttot = dist2 + dist1

            for d in dbank:
                disttot.append(d)
            dist_point = list(np.unique(disttot))
            if len(dist_point) < 2:

                dist_point = [dcentre - dxlat, dcentre + dxlat]

            dist_point.sort()  # contient les points définissant les extrémtités des segments distribués
            dist_distr = np.zeros((len(dist_point) - 1))  # contient la position des centres des segments distribués

            for i in range(len(dist_distr)):
                dist_distr[i] = (dist_point[i + 1] + dist_point[
                    i]) / 2  # position de la verticale au milieu de la sous section
            dist_distr.sort()

            Zcorr = Z
            #Zcorr[Zcorr == 0] = np.max(Z)  # corrige l'interpolation hors des banks
            z_point = np.interp(dist_point, self.coord.array['Xt'], Z)
            manning_distr = list(np.interp(dist_distr, self.coord.array['Xt'], self.coord.values['W']))

        else:
            # Computation for unique section
            dist_point = np.array(dbank)
            dist_distr = [dcentre]
            manning_distr = [np.nanmean(self.coord.values['W'])]
            z_point = np.interp(dist_point, self.coord.array['Xt'], Z)

        X0 = np.interp(dbank[0], self.coord.array['Xt'], self.coord.array['X'])
        Y0 = np.interp(dbank[0], self.coord.array['Xt'], self.coord.array['Y'])
        X1 = np.interp(dbank[-1], self.coord.array['Xt'], self.coord.array['X'])
        Y1 = np.interp(dbank[-1], self.coord.array['Xt'], self.coord.array['Y'])
        Wsextent = [Point(X0, Y0), Point(X1, Y1)]
        return dist_distr, manning_distr, Wsextent, dist_point, z_point

    def findBankPoint(self, WSE, method):
        """

        :param WSE: water surface elevation.
        :type WSE: float
        :param method: Define method to manage overflowing on floodplain (levee taken into account or not).
        :type method: string
        :return:
            - dtemp : point of intersection between the water surface and the bed channel (list)
        """
        Z = self.coord.values['B']
        lineXS = LineString((self.coord.array['Xt'][i], Z[i]) for i in range(len(Z)))
        linesurf = LineString([(np.min(self.coord.array['Xt']), WSE), (lineXS.length, WSE)])
        dtemp = []
        inter = lineXS.intersection(linesurf)  # intersection surface-fond

        if type(inter) is multipoint.MultiPoint:
            for ii in range(len(inter.geoms)):
                dtemp.append(inter.geoms[ii].x)
        elif type(inter) is point.Point:
            dtemp.append(inter.x)
        dtemp.sort()

        # cas de points tous du même coté
        if len(dtemp) >= 2:
            if np.min(dtemp) >= 0:  # tous les points sur la berge droite
                dtemp[1] = np.min(dtemp)
                dtemp[0] = np.min(self.coord.array['Xt'])
            if np.max(dtemp) <= 0:
                dtemp[0] = np.max(dtemp)
                dtemp[1] = np.max(self.coord.array['Xt'])
        dtemp = list(np.unique(dtemp))
        dtemp.sort()

        if method:
            # cas de deux point autour d'un maximum , on garde celui du coté du centre
            if len(dtemp) >= 2:
                dcorr = []
                for i in range(len(dtemp) - 1):
                    cote = [Z[ii] for ii in range(len(Z)) if
                            dtemp[i] <= self.coord.array['Xt'][ii] <= dtemp[i + 1]]
                    coteMax = np.max(cote)

                    if coteMax > WSE:
                        if dtemp[i + 1] < 0:  # on est du meme coté gauche de la rivière
                            dcorr.append(dtemp[i + 1])
                        elif dtemp[i] > 0:  # on est du meme coté droit de la rivière
                            dcorr.append(dtemp[i])
                    else:
                        dcorr.append(dtemp[i])
                        dcorr.append(dtemp[i + 1])
                dcorr.sort()
                dtemp = list(np.unique(dcorr))
            if len(dtemp) == 0:
                dbank1 = self.coord.array['Xt'][0]
                dbank2 = self.coord.array['Xt'][-1]
            elif len(dtemp) == 1:  # si la surface libre est plus haute que les berges
                if WSE >= Z[0]:
                    dbank1 = self.coord.array['Xt'][0]
                    dbank2 = dtemp[0]
                elif WSE >= Z[-1]:
                    dbank1 = dtemp[0]
                    dbank2 = self.coord.array['Xt'][-1]
            elif len(dtemp) == 2:
                dbank1 = dtemp[0]
                dbank2 = dtemp[1]
            elif len(dtemp) > 2:  # plus de 2 points ont la même cote, on cherche les plus proche du point le plus bas
                dleft = np.array(dtemp.copy())  # bank à gauche
                dleft[dleft > 0] = -np.inf
                ind_bank1 = np.argmax(dleft)
                dbank1 = dtemp[ind_bank1]
                dright = np.array(dtemp.copy())  # bank à droite
                dright[dright < 0] = np.inf
                ind_bank2 = np.argmin(dright)
                dbank2 = dtemp[ind_bank2]

            dtemp = [dbank1, dbank2]

        else:  # tous les poits d'intersections sont pris en compte
            if len(dtemp) == 0:
                dtemp.append(self.coord.array['Xt'][0])
                dtemp.append(self.coord.array['Xt'][-1])
            if WSE >= Z[0]:
                dtemp.append(self.coord.array['Xt'][0])
            if WSE >= Z[-1]:
                dtemp.append(self.coord.array['Xt'][-1])
            dtemp = list(np.unique(dtemp))
            dtemp.sort()

        if WSE <= np.min(Z):
            #  detection using inconsistent mask
            dtemp = [-1, 1]
        return dtemp

    def computeHydraulicGeometry(self, WSE, dxlat, method, friction="Manning", Sf=0.0001):
        """ Compute the hydraulic parameters for all subdomain defined by the water surface elevation and the
        transversal step

        :param WSE: water surface elevation.
        :type WSE: float
        :param dxlat: lateral discretization step
        :type dxlat : float
        :param method: define method to manage overflowing on floodplain (levee taken into account or not).
        :type method: string
        :return:
            - averagedValue: Averaged hydraulic characteristic of the section : area, wetted perimeter,
            surface width (dict)
            - bank: point of intersection between water surface and bed channel (list)
            - hydro_distr: hydraulic characteristic of the sub-section : area, wetted perimeter, surface width (dict)
            - manning_distr: manning for each subsection (list)
        """
        ztopo = self.coord.values['B']

        # point de berge avec méthode sans ou avec levee (voir methode HEC RAS)
        dbank = self.findBankPoint(WSE, method)


        # interpoler la section pour la redistribution
        dist_distr, manning_distr, Wsextent, dist_point, z_point = self.interpDistributionLat(dbank, dxlat)

        # -------------------------------------------------------------------------
        waterdepth = np.zeros(len(z_point))
        for i, w in enumerate(waterdepth):
            if z_point[i] >= 0:
                waterdepth[i] = np.array(WSE - np.array(z_point[i]))

        waterdepth[waterdepth < 0] = 0
        dist_temp, waterdepth_distr, A_distr, Pm_distr, Rh_distr, n_distr, V = [dbank[0]], [0], [0], [0], [0], [
            self.coord.values['W'][0]], [0]
        if not isinstance(WSE, float) and not isinstance(WSE, int) and not isinstance(WSE, np.float32):
            WSE = WSE[0]

        for ii in range(len(dist_distr)):
            # section des points de topo entre les 2 points de redistribution
            dist_temp.append(dist_distr[ii])
            n_distr.append(manning_distr[ii])
            dtemp = [self.coord.array['Xt'][i] for i in range(len(self.coord.array['Xt'])) if
                     dist_point[ii] < self.coord.array['Xt'][i] < dist_point[ii + 1]]
            htemp = [WSE - ztopo[i] for i in range(len(self.coord.array['Xt'])) if
                     dist_point[ii] < self.coord.array['Xt'][i] < dist_point[ii + 1]]
            ztemp = [ztopo[i] for i in range(len(self.coord.array['Xt'])) if
                     dist_point[ii] < self.coord.array['Xt'][i] < dist_point[ii + 1]]

            if not len(dtemp):  # pas de points topo entre 2 verticales interpolées
                dtemp = [dist_point[ii], dist_point[ii + 1]]
                htemp = [WSE - z_point[ii], WSE - z_point[ii + 1]]
                ztemp = [z_point[ii], z_point[ii + 1]]
            else:

                dtemp = [dist_point[ii]] + dtemp + [dist_point[ii + 1]]
                htemp = [WSE - z_point[ii]] + htemp + [WSE - z_point[ii + 1]]
                ztemp = [z_point[ii]] + ztemp + [z_point[ii + 1]]

            htemp = np.array(htemp)
            htemp[htemp < 0] = 0
            waterdepth_distr.append((waterdepth[ii + 1] + waterdepth[ii]) / 2)
            Pm_temp = 0
            for j in range(len(ztemp) - 1):
                Pm_temp += ((dtemp[j + 1] - dtemp[j]) ** 2 + (ztemp[j + 1] - ztemp[j]) ** 2) ** 0.5
            Pm_distr.append(Pm_temp)
            try:
                A_distr.append(np.trapezoid(htemp, dtemp))
            except:
                A_distr.append(np.trapz(htemp, dtemp))

            if Pm_temp > 0:
                Rh_distr.append(A_distr[-1] / Pm_distr[-1])
            else:
                Rh_distr.append(0)
        # ------------------------------------------------------------------------
        dist_temp.append(dbank[1])
        Pm_distr.append(0)  # point de berge à 0
        Rh_distr.append(0)
        A_distr.append(0)
        waterdepth_distr.append(0)

        n_distr.append(manning_distr[-1])
        A_distr[waterdepth_distr == 0] = 0
        A = np.sum(A_distr)
        P = np.sum(Pm_distr)
        if P > 0:
            Rh = A / P
        else:
            Rh = 0

        if friction == 'Limerinos':
            nd = np.interp(self.coord.array['Xt'], dist_temp, n_distr)
            waterd = np.interp(self.coord.array['Xt'], dist_temp, waterdepth_distr)
            self.coord.values['W'] = Limerinos(nd, waterd)  # Limerinos(self.coord.values['W'], list(waterd))

        for idx, Rh in enumerate(Rh_distr[1:-1]):

            if self.label == 'Cross-section':
                if Sf >= 0 and (n_distr[idx])>0:
                    V.append(1 / n_distr[idx] * Rh ** (2 / 3) * Sf ** 0.5)
                else:
                    V.append(0)
            else:
                # V.append(self.Q / (WS - self.inlinedata[2]) / self.inlinedata[1])
                V.append(0)
        V.append(0)
        V[waterdepth_distr == 0] = 0
        tau2 = [0 for _ in range(len(Pm_distr))]
        deb = [A_distr[i] * Rh_distr[i] ** (2 / 3) / n_distr[i] for i in range(len(n_distr)) if  n_distr[i]>0 ]
        Q = np.sum([V[i] * A_distr[i] for i in range(len(V))])
        angle = np.arctan2(self.normal[0], self.normal[1])
        u = [v * np.cos(angle) for v in V]
        v = [v * np.sin(angle) for v in V]
        wse = [WSE for _ in range(len(V))]
        if sum(deb)>0:
            neq = (A * Rh ** (2 / 3)) / (sum(deb))
        else:
            neq= self.coord.values['W'][0]
        averagedValue = {'Q': Q, 'A': A, 'P': P, 'Rh': Rh, 'dbank': [dbank, Wsextent], 'neq': neq}
        if dxlat <= 0:  # redistribution des hauteurs d'eau
            for ID, d in enumerate(self.coord.array['Xt']):
                if d>= dbank[0] and  d<= dbank[1]:
                    self.coord.values['H'][ID] = WSE - self.coord.values['B'][ID]
                else:
                    self.coord.values['H'][ID] = 0
                self.coord.values['S'][ID] = WSE
                if n_distr[1]>0:
                    self.coord.values['M'][ID] = 1 / n_distr[1] * self.coord.values['H'][ID] ** (2 / 3) * Sf ** 0.5
                else:
                    self.coord.values['M'][ID] = 0
                self.coord.values['U'][ID] = self.coord.values['M'][ID] * np.cos(angle)
                self.coord.values['V'][ID] = self.coord.values['M'][ID] * np.sin(angle)
                self.coord.values['TAU'][ID] = tau2[0]

        else:
            self.coord.values['H'] = np.interp(self.coord.array['Xt'], dist_temp, waterdepth_distr)
            self.coord.values['S'] = np.interp(self.coord.array['Xt'], dist_temp, wse)
            self.coord.values['U'] = np.interp(self.coord.array['Xt'], dist_temp, u)
            self.coord.values['V'] = np.interp(self.coord.array['Xt'], dist_temp, v)
            self.coord.values['M'] = np.interp(self.coord.array['Xt'], dist_temp, V)
            self.coord.values['TAU'] = np.interp(self.coord.array['Xt'], dist_temp, tau2)

        return averagedValue

    def modif_line(self, param, imposedWSE=False):
        """ Change the bathymetry of the line section between the 2 banks

        :param param: Parameters requires for computation
        :type param: Parameters
        :param imposedWSE:  WSE from bank file and not from calculation
        :type imposedWSE: bool
        """
        paramB = param['B']
        warning = False
        WSE = self.WSE
        distances = list(self.coord.array['Xt'])
        distancebank1 = self.bank[0][0]
        distancebank2 = self.bank[0][1]

        old_distance = [d for d in self.coord.array['Xt']]
        old_manning = [n for n in self.coord.values['W']]

        # Lists of coordinates of line
        list_x = [coord[0] for coord in self.geom.coords]
        list_y = [coord[1] for coord in self.geom.coords]
        list_z = [coord[2] for coord in self.geom.coords]
        if imposedWSE:
            if isinstance(imposedWSE, list):
                WS1, WS2 = imposedWSE[0], imposedWSE[1]
                WSE = (WS1+WS2)/2
            elif isinstance(imposedWSE, bool):
                WS1, WS2 = WSE, WSE
        else:
            WS1 = np.interp(distancebank1, old_distance, list_z)
            WS2 = np.interp(distancebank2, old_distance, list_z)

        '''
        # adding the coordinates of bank 1 to the correct index
                '''
        flag1,flag2 = False,False
        if not distancebank1 in distances:
            distances.append(distancebank1)
            flag1=True
        if not distancebank2 in distances:
            distances.append(distancebank2)
            flag2 =True
        distances.sort()
        index1 = distances.index(distancebank1)
        index2 = distances.index(distancebank2)

        if flag1:
            list_x.insert(index1, self.bank[1][0].x)
            list_y.insert(index1, self.bank[1][0].y)
            list_z.insert(index1, WS1)
        # adding the coordinates of bank 2 to the correct index
        if flag2:
            list_x.insert(index2 + 1, self.bank[1][1].x)
            list_y.insert(index2 + 1, self.bank[1][1].y)
            list_z.insert(index2 + 1, WS2)


        # modification of the z coordinate
        distance_mid = (distancebank1 + distancebank2) / 2
        i_mid = np.abs(np.array(distances) - distance_mid).argmin()
        i_1 = np.abs(np.array(distances) - distancebank1).argmin()
        i_2 = np.abs(np.array(distances) - distancebank2).argmin()

        if paramB['bathymetricSections'] == "Rectangular":
            for i in range(len(list_z)):
                if i_1 <= i <= i_2:
                    list_z[i] = WSE - paramB['depth']

        elif paramB['bathymetricSections'] == "Triangular":

            if i_1 != i_mid and i_2 != i_mid:
                coeff_poly1 = np.polyfit([distancebank1, distances[i_mid]], [WS1, WSE - paramB['depth']], 1)
                f_inter1 = np.poly1d(coeff_poly1)
                coeff_poly2 = np.polyfit([distances[i_mid], distancebank2], [WSE - paramB['depth'], WS2], 1)
                f_inter2 = np.poly1d(coeff_poly2)
                for i in range(len(list_z)):
                    if i_1 <= i <= i_mid:
                        list_z[i] = f_inter1(distances[i])
                    if i_mid < i <= i_2:
                        list_z[i] = f_inter2(distances[i])


            else:
                warning = True
                for i in range(len(list_z)):
                    if i_1 <= i <= i_2:

                        list_z[i] = WSE - paramB['depth']

        elif paramB['bathymetricSections'] == "Parabolic":
            if i_1 != i_mid and i_2 != i_mid:
                coeff_poly = np.polyfit([distancebank1, distances[i_mid], distancebank2],
                                        [WS1, WSE - paramB['depth'], WS2], 2)
                f_inter = np.poly1d(coeff_poly)
                for i in range(len(list_z)):
                    if i_1 <= i <= i_2:
                        list_z[i] = f_inter(distances[i])
            else:
                warning = True
                for i in range(len(list_z)):
                    if i_1 <= i <= i_2:
                        list_z[i] = WSE - paramB['depth']

        if warning:
            print("There are not enough points to make a parabolic section on certain sections. "
                  "The section is rectangular.")

        # create lists of coordinates between banks
        list_x_channel = list_x[i_1:i_2]
        list_y_channel = list_y[i_1:i_2]
        list_z_channel = list_z[i_1:i_2]

        # update
        self.geom = LineString([(x, y, z) for x, y, z in zip(list_x, list_y, list_z)])
        coord = np.array(list(zip(self.geom.xy[0], self.geom.xy[1])), dtype=[('X', float), ('Y', float)])
        self.coord = Coord(np.array(coord, dtype=float_vars(['X', 'Y'])), ['Xt', 'xt'])
        self.coord.array['Xt'] = distances  # [d + distance_mid for d in distances]
        # self.coord.array['Xt'] = [d + self.d0 for d in self.coord.array['Xt']]
        self.coord.values = np.zeros(len(self.geom.coords), dtype=self.var)
        self.coord.values['B'] = list_z
        self.coord.values['W'] = np.interp(self.coord.array['Xt'], old_distance, old_manning)

        if len(list_x_channel) > 1:
            self.line_channel = LineString(
                [(x, y, z) for x, y, z in zip(list_x_channel, list_y_channel, list_z_channel)])
            self.distance_channel = self.coord.array['Xt'][i_1:i_2 + 1]
        else:
            self.line_channel = LineString()
            self.distance_channel = []

    def weir_discharge(self, paramH):
        """ compute the upstream water depth for inline structure section

        :param paramH: Hydraulics Parameters requires for computation
        :type param: Parameters
        :return WS: water surface (float)
        """
        Zs = self.inlinedata[2]
        # free flow
        WS = (self.Q / (self.inlinedata[0] * self.inlinedata[1] * (2 * 9.81) ** 0.5)) ** (3 / 2) + Zs

        self.computeHydraulicGeometry(WS, paramH['dxlat'], paramH['levee'])
        return WS
