######################################################################
#             / ____|  ____|  __ \| |  | |  ____|  ____|             #
#            | |    | |__  | |__) | |__| | |__  | |__                #
#            | |    |  __| |  ___/|  __  |  __| |  __|               #
#            | |____| |____| |    | |  | | |____| |____              #
#             \_____|______|_|    |_|  |_|______|______|             #
######################################################################
#
# CEPHEE
# Copyright (C) 2024 Toulouse INP
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License for more details :
# <http://www.gnu.org/licenses/>.
#
######################################################################
import numpy as np
# global
from shapely.geometry import Polygon,MultiPolygon,mapping
from shapely.ops import nearest_points,linemerge
from rasterio.transform import Affine
from rasterio.enums import MergeAlg
import rasterio.mask
from rasterio.mask import mask
from rasterio.merge import merge
from rasterio.fill import fillnodata
# local
from .Tools import *
from .Section import Section
from .Hydraulics import find_bank_from_poly
from .postCEPHEE import create_poly_CEPHEE
from scipy.interpolate import LinearNDInterpolator
import geopandas as gpd

def interpolate_MNTchannel(BV, param,res):
    """ rasterize the elevation of main channel by interpolation

    :param BV: Watershed
    :type BV: ModelCatchment
    :param param: Parameters requires for computation
    :type param: Parameters
    :param res:  resolution of tue output DEM raster
    :type res: float
    """

    # res = param['XS']['width'] * 2 / param['XS']['numberOfPoints']

    data_poly = create_poly_CEPHEE(BV, nx_trans=0, pixel_size=res,
                                   resultType=param['B']['interpolationMethod'])

    [list_of_poly, _ , _ , _ , list_for_interp] =data_poly
    list_of_point, _, _, _,list_of_point_z = list_for_interp

    x_XS = [pt.x for pt in list_of_point]
    y_XS = [pt.y for pt in list_of_point]

    multi_polygon = MultiPolygon(list_of_poly)
    # Calculer la bounding box du MultiPolygon
    minx, miny, maxx, maxy = multi_polygon.bounds
    # Calculer la bounding box du MultiPolygon
    height = int(abs(maxy - miny) / res)
    width = int(abs(maxx - minx) / res)
    transform = Affine.translation(minx, maxy) * Affine.scale(res, -1* res)

    Xres, Yres = np.meshgrid(np.arange(minx, maxx,res),
                             np.arange(miny, maxy, res),
                             indexing='xy')

    interp_z = LinearNDInterpolator(list(zip(x_XS, y_XS)), list_of_point_z)
    z_interp = interp_z(Xres, Yres).astype(np.float32)
    z_interp = np.flipud(z_interp)
    z_interp[z_interp  == np.nan] =0

    # Convertir MultiPolygon en GeoJSON (format lisible par rasterio)
    geom = [mapping(multi_polygon)]

    with rasterio.open(
            os.path.join(param.work_path, 'global_channel_DEM.tif'),
            'w',
            driver='GTiff',
            height=height,
            width=width,
            count=1,
            nodata = np.nan,
            dtype=rasterio.float32,  # Utiliser un type de données approprié pour les valeurs Z
            crs=param['C']['DEM_CRS_ID'],
            transform=transform
    ) as dst:
        # Rasterize vector using the shape and coordinate system of the raster
        # Get list of geometries for all features in vector file

        #z_interp = np.ma.masked_where(z_interp == 0, z_interp)
        dst.write(z_interp, 1)

    dst.close()

    with rasterio.open(os.path.join(param.work_path, 'global_channel_DEM.tif')) as src:
        out_image, out_transform = mask(src, geom, crop=True)

        out_meta = src.meta.copy()
        out_image = np.squeeze(out_image).astype(np.float32)
        out_image = np.where(out_image == 0, 1e4, out_image)
        out_image  = fillnodata(out_image , mask=out_image  == src.nodata,
          max_search_distance=3 * param['H']['dxlat'],
           smoothing_iterations=1)
        # Mettre à jour les méta-données pour le raster croppé
        out_meta.update(
            {"driver": "GTiff", "height": out_image.shape[-2], "width": out_image.shape[-1], "transform": out_transform
             })
            # Sauvegarder le raster croppé dans un nouveau fichier
        with rasterio.open(os.path.join(param.work_path, 'global_channel_DEM.tif'), "w", **out_meta) as dest:
            dest.write(out_image ,1)

    print("Main Channel Raster created")
def rasterize_MNTchannel(BV, param,res):
    """ rasterize the elevation of main channel by polygone between sections

    :param BV: Watershed
    :type BV: ModelCatchment
    :param param: Parameters requires for computation
    :type param: Parameters
    :param res:  resolution of tue output DEM raster
    :type res: float
    """

    reach = BV.reach
    ind_river = BV.list_of_outlet[BV.id_outlet][1]
    #res = param['XS']['width'] * 2 / param['XS']['numberOfPoints']

    list_of_poly = []
    list_of_z = []

    for i in range(len(reach)):

        for section1, section2 in zip(reach[i].section[:-1], reach[i].section[1:]):
            
            line1 = section1.line_channel
            line2 =section2.line_channel
            
            distance1 = section1.distance_channel
            distance2 = section2.distance_channel
            
            z_up = [line1.coords[t][2] for t in range(len(line1.coords))]
            z_down = [line2.coords[t][2] for t in range(len(line2.coords))]
            x_up = [line1.coords[t][0] for t in range(len(line1.coords))]
            x_down = [line2.coords[t][0] for t in range(len(line2.coords))]
            y_up = [line1.coords[t][1] for t in range(len(line1.coords))]
            y_down = [line2.coords[t][1] for t in range(len(line2.coords))]

            distancetot = distance1 + distance2
            distancetot = list(np.unique(distancetot))
            distancetot.sort()
            z_upinterp = np.interp(distancetot, distance1, z_up)
            z_downinterp = np.interp(distancetot, distance2, z_down)
            x_upinterp = np.interp(distancetot, distance1, x_up)
            x_downinterp = np.interp(distancetot, distance2, x_down)
            y_upinterp = np.interp(distancetot, distance1, y_up)
            y_downinterp = np.interp(distancetot, distance2, y_down)

            for ip in range(len(x_downinterp)-1):
                poly = Polygon([Point(x_upinterp[ip], y_upinterp[ip]),
                                Point(x_upinterp[ip + 1], y_upinterp[ip + 1]),
                                Point(x_downinterp[ip + 1], y_downinterp[ip + 1]),
                                Point(x_downinterp[ip], y_downinterp[ip])]
                                )
                list_of_poly.append(poly)
                list_of_z.append(np.mean([z_upinterp[ip], z_upinterp[ip + 1],
                                            z_downinterp[ip], z_downinterp[ip + 1]]))

    multi_polygon = MultiPolygon(list_of_poly)
    # Calculer la bounding box du MultiPolygon
    minx, miny, maxx, maxy = multi_polygon.bounds
    height = int(abs(maxy - miny)/res)
    width = int(abs(maxx - minx)/res)
    transform = Affine.translation(minx, maxy) * Affine.scale(res, -1 * res)

    with rasterio.open(
            os.path.join(param.work_path, 'global_channel_DEM.tif'),
            'w',
            driver='GTiff',
            height=height,
            width=width,
            count=1,
            dtype=rasterio.float32,  # Utiliser un type de données approprié pour les valeurs Z
            crs=param['C']['DEM_CRS_ID'],
            transform=transform
    ) as dst:
        # Rasterize vector using the shape and coordinate system of the raster
        # Get list of geometries for all features in vector file

        geom_valuesh = ((poly, value) for poly, value in zip(list_of_poly, list_of_z))
        z_interp = rasterio.features.rasterize(geom_valuesh,
                                                    out_shape=dst.shape,
                                                    fill=10000,
                                                    out=None,
                                                    transform=dst.transform,
                                                    all_touched=True,
                                                    merge_alg=MergeAlg.replace,
                                                    dtype=None)

        #z_interp = np.ma.masked_where(z_interp == 0, z_interp)

        dst.write(z_interp, 1)

    dst.close()
    print("Main Channel Raster created")


def drag_all_XS(BV, param, type_comp = None, bank_filename =None):
    """ modify line section to drag initial elevation between banks

    :param BV: Watershed
    :type BV: ModelCatchment
    :param param: Parameters requires for computation
    :type param: Parameters
    :param type_comp:  computuation to consider for banks position
    :type type_comp: str
    :param bank_filename: name of bank shapefile
    :type bank_filename: str
    """

    if bank_filename:
        multishape = gpd.read_file(bank_filename)
        all_banklines = Convert_sections(multishape)
        merged_line = linemerge(all_banklines)
        find_bank_from_poly(BV, param['H']['dxlat'], merged_line )
        type_comp = 'Obs'

    reach = BV.reach
    for idr,reach1 in enumerate(reach):

        if type_comp=='Normal':
            df= reach1.resNormal
        elif type_comp == '1D':
            df= reach1.res1D
        elif type_comp == 'Himposed':
            df= reach1.resHimposed
        elif type_comp == 'Obs':
            df= reach1.resObs
        for ids, section in enumerate(reach1.section):
            df.loc[ids,'Sf']= reach1.slope[ids]


        if param['B']['smoothSlope']:
            reach1.compute_slope(1)
            WSstart = df['WSE'].iloc[0]
            WSend = df['WSE'].iloc[-1]
            for ids, section in enumerate(reach1.section):
                df.loc[ids,'Sf']= reach1.slope[ids]
                newWS = WSstart + ((WSend-WSstart)/(reach1.Xinterp[-1]-reach1.Xinterp[0])) * reach1.Xinterp[ids]
                df.loc[ids,'WSE']= newWS

        param['N']['distSearchMin'] = 0
        for ids, section in enumerate(reach1.section):
            #if section.type == 'XS':
            df_filter = df[df['idSection'] == ids]

            if param['B']['useImportedSections'] and param['B']['depth'] <=0:
                param['B']['fromQ'] = False
                section.linechannel(param, df_filter)

            else:

                if not param['B']['fromQ']:
                    section.linechannel(param, df_filter)
                    section.modif_line(param, df_filter)

                else:
                    WS= df_filter.loc[df_filter.index[0],'WSE']+0.01

                    param['B']['depth'] = find_BfromQ(WS,section,param,df_filter)
                    section.linechannel(param, df_filter)
                    section.modif_line(param, df_filter)

    print("Dug channel end")



def find_BfromQ(WS,section,param,df_filter):
    """ return the depth to remove to get the target discharge

    :param WS:  water surface
    :type section: float
    :param section: considered section
    :type section: Section
    :param param: Parameters requires for computation
    :type param: Parameters
    :df_filter: hydraulic data for the considered section
    :type df_filter: panda dataframe
    """

    new_depth_inf = 0.05
    new_depth_sup  = 20
    new_depth = (new_depth_sup + new_depth_inf) / 2
    ecart = section.Q
    iter_i =0

    while ecart > param['H']['eps'] and iter_i < param['H']['MaxIter']:
        
        section_temp = section.copy()
        param['B']['depth'] =new_depth
        section_temp.linechannel(param, df_filter)
        section_temp.modif_line(param, df_filter)
        iter_i += 1
        # calcul du milieu
        new_depth = abs(new_depth_sup + new_depth_inf) / 2
        _, _, hydro_distr, manning_distr = section_temp.computeHydraulicGeometry(
            WS, param['H']['dxlat'], param['H']['levee'],param['H']['frictionLaw'])
        dist_distr = hydro_distr[:, 0]
        A_distr = hydro_distr[:, 2]
        Rh_distr = hydro_distr[:, 4]
        V = np.zeros((len(dist_distr),))
        Sf = df_filter.loc[df_filter.index[0],'Sf']

        if Sf < 0:
            Sf = 0.00001

        for idx, Rh in enumerate(Rh_distr):
            V[idx] = 1 / manning_distr[idx] * Rh ** (2 / 3) * Sf** 0.5
        Q1 = np.sum(V * A_distr)

        if Q1 > section.Q:
            # la solution est inférieure à m
            new_depth_sup = new_depth
        else:
            # la solution est supérieure à m
           new_depth_inf = new_depth
        ecart =abs(Q1-section.Q)

    return param['B']['depth']


def MergeRaster_channel_DEM(BV, work_path):
    """ merge the raster channel with initial DEM. Create folder and file of raster

    :param BV: Watershed
    :type BV: ModelCatchment
    :param work_path: folder to put new DEM raster
    :type work_path: string
    """

    with rasterio.open(os.path.join(work_path, 'global_channel_DEM.tif')) as src:
        data = BV.DEM_stack['data_DEM']

        #folder for channel_DEM
        if not os.path.exists(os.path.join(work_path,'channel_DEM')):
            os.makedirs(os.path.join(work_path,'channel_DEM'))
        if not os.path.exists(os.path.join(work_path,'Merged_DEM')):
            os.makedirs(os.path.join(work_path,'Merged_DEM'))



        for idx in range(len(data)):
            res = BV.DEM_stack['data_DEM'][idx]['cell_size']
            polygon_coords = data[idx]['polygon_coords']
            polygon = Polygon(polygon_coords)

            if not polygon.is_valid:
                polygon = polygon.buffer(0)

            out_image, out_transform = rasterio.mask.mask(src, [polygon], crop=True)
            out_meta = src.meta

            out_meta.update({"driver": "GTiff",
            "height": out_image.shape[-2],
            "width": out_image.shape[-1],
            "transform": out_transform})

            with rasterio.open(os.path.join(work_path,'channel_DEM', "Divided_channel_{}.tif".format(idx + 1)), "w", **out_meta) as dest:
                dest.write(out_image)


            dem_file = BV.DEM_stack['file_list'][idx]

            with rasterio.open(dem_file) as dem_src, \
                    rasterio.open(os.path.join(work_path,'channel_DEM', "Divided_channel_{}.tif".format(idx + 1))) as channel_src:
                out_meta = dem_src.meta.copy()
                try :
                    merged_raster, transform = merge([dem_src, channel_src],method = 'min', res =res)
                except:
                    merged_raster, transform = merge([dem_file,os.path.join(work_path,'channel_DEM', "Divided_channel_{}.tif".format(idx + 1))],
                                                     method='min', res=res) #old merge version


                out_meta.update(
                    {
                        "driver": "GTiff",
                        "height": merged_raster.shape[1],
                        "width": merged_raster.shape[2],
                        "nodata": out_meta['nodata'],
                        "crs" : BV.crs,
                        "transform": transform
                    }
                )
            dest.close()

            with rasterio.open(os.path.join(work_path,'Merged_DEM',"Merged_DEM_{}.tif".format(idx + 1)), "w", **out_meta) as dest:
                dest.write(merged_raster)

    print(f"Merged raster created")




