######################################################################
#             / ____|  ____|  __ \| |  | |  ____|  ____|             #
#            | |    | |__  | |__) | |__| | |__  | |__                #
#            | |    |  __| |  ___/|  __  |  __| |  __|               #
#            | |____| |____| |    | |  | | |____| |____              #
#             \_____|______|_|    |_|  |_|______|______|             #
######################################################################
#
# CEPHEE
# Copyright (C) 2024 Toulouse INP
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License for more details :
# <http://www.gnu.org/licenses/>.
#
######################################################################

# global
import rasterio.mask
from rasterio.merge import merge
from rasterio.fill import fillnodata
from rasterio.features import rasterize
import copy

# local
from core.Section import Section
from core.postCEPHEE import *
from core.Hydraulics import *


def interpolate_DEM_channel(BV, param):
    """ Rasterize the elevation of main channel by interpolation

    :param BV: Watershed
    :type BV: ModelCatchment
    :param param: Parameters required for computation
    :type param: Parameters
    """
    output_res = param['B']['output_resolution']
    data_poly = create_poly_CEPHEE(BV, param,output_res)
    [list_of_poly, _ , _ , _ , list_for_interp,_] =data_poly
    list_of_point, _, _, _,list_of_point_z = list_for_interp
    x_XS = [pt.x for pt in list_of_point]
    y_XS = [pt.y for pt in list_of_point]
    multi_polygon = MultiPolygon(list_of_poly)

    # Calculer la bounding box du MultiPolygon
    minx, miny, maxx, maxy = multi_polygon.bounds

    # Calculer la bounding box du MultiPolygon
    height = int(abs(maxy - miny) / output_res)
    width = int(abs(maxx - minx) / output_res)
    transform = Affine.translation(minx, maxy) * Affine.scale(output_res, -1* output_res)

    Xres, Yres = np.meshgrid(np.arange(minx, maxx, output_res),
                             np.arange(miny, maxy, output_res),
                             indexing='xy')
    interp_z = LinearNDInterpolator(list(zip(x_XS, y_XS)), list_of_point_z)
    z_interp = interp_z(Xres, Yres).astype(np.float32)
    z_interp = np.flipud(z_interp)
    z_interp[z_interp  == np.nan] =0

    # Convertir MultiPolygon en GeoJSON (format lisible par rasterio)
    geom = [mapping(multi_polygon)]

    with rasterio.open(
            os.path.join(param.work_path, 'global_channel_DEM.tif'),
            'w',
            driver='GTiff',
            height=height,
            width=width,
            count=1,
            nodata = np.nan,
            dtype=rasterio.float32,  # Utiliser un type de données approprié pour les valeurs Z
            crs=param['C']['DEM_CRS_ID'],
            transform=transform,
    ) as dst:
        dst.write(z_interp, 1)
    dst.close()

    with rasterio.open(os.path.join(param.work_path, 'global_channel_DEM.tif')) as src:
        out_image, out_transform = mask(src, geom, crop=True)
        out_meta = src.meta.copy()
        out_image = np.squeeze(out_image).astype(np.float32)
        out_image = np.where(out_image == 0, 1e4, out_image)
        out_image  = fillnodata(out_image , mask=out_image  == src.nodata,
          max_search_distance=3 * param['H']['dxlat'],
           smoothing_iterations=1)
        # Mettre à jour les méta-données pour le raster croppé
        out_meta.update(
            {"driver": "GTiff", "height": out_image.shape[-2], "width": out_image.shape[-1], "transform": out_transform,
             "tiled": True,  # écrire en tuiles (meilleure gestion des gros fichiers)
             "compress": "deflate",  # compression sans perte
             "BIGTIFF": "YES"  # permet de dépasser 4 Go si besoin
             })
            # Sauvegarder le raster croppé dans un nouveau fichier
        with rasterio.open(os.path.join(param.work_path, 'global_channel_DEM.tif'), "w", **out_meta) as dest:
            dest.write(out_image ,1)
    BV.display("Main Channel Raster created")

def rasterize_DEM_channel(BV, param):
    """ rasterize the elevation of main channel by polygone between sections

    :param BV: Watershed
    :type BV: ModelCatchment
    :param param: Parameters requires for computation
    :type param: Parameters
    """
    output_res = param['B']['output_resolution']
    data_poly = create_poly_CEPHEE(BV, param,output_res)
    [list_of_poly, _, _, _, list_for_interp,list_of_bed] = data_poly
    list_of_point, _, _, _, list_of_point_z = list_for_interp
    multi_polygon = MultiPolygon(list_of_poly)
    # Calculer la bounding box du MultiPolygon
    minx, miny, maxx, maxy = multi_polygon.bounds
    height = int(abs(maxy - miny) / output_res)
    width = int(abs(maxx - minx) / output_res)
    transform = Affine.translation(minx, maxy) * Affine.scale(output_res, -1 * output_res)

    with rasterio.open(
            os.path.join(param.work_path, 'global_channel_DEM.tif'),
            'w',
            driver='GTiff',
            height=height,
            width=width,
            count=1,
            dtype=rasterio.float32,  # Utiliser un type de données approprié pour les valeurs Z
            crs=param['C']['DEM_CRS_ID'],
            transform=transform
    ) as dst:
        # Rasterize vector using the shape and coordinate system of the raster
        # Get list of geometries for all features in the vector file

        geom_valuesh = ((poly, value) for poly, value in zip(list_of_poly, list_of_bed))
        z_interp = rasterio.features.rasterize(geom_valuesh,
                                                    out_shape=dst.shape,
                                                    fill=np.max(list_of_point_z)+20,
                                                    out=None,
                                                    transform=dst.transform,
                                                    all_touched=True,
                                                    merge_alg=MergeAlg.replace,
                                                    dtype=None)
        dst.write(z_interp, 1)
    dst.close()
    BV.display("Main Channel Raster created")

def drag_all_XS(BV, param):
    """ modify line section to drag initial elevation between banks

    :param BV: Watershed
    :type BV: ModelCatchment
    :param param: Parameters requires for computation
    :type param: Parameters
    :param type_comp:  computuation to consider for banks position
    :type type_comp: str
    :param riverbanks_filepath: name of bank shapefile
    :type riverbanks_filepath: str
    """

    reach = BV.reach
    imposedWSE = False
    if not param['I']['riverbanks_filepath'] and not param['H']['createBanks']:
        bank_to_edge_end(BV, param)

    for idr,reach1 in enumerate(reach):
        if param['B']['smoothSlope']:
            WSstart = reach1.section_list[0].WSE
            lineXinterp = []
            lineWSE=[]
            for i, section in enumerate(reach1.section_list):
                if not np.isnan(section.WSE):
                    lineXinterp.append(reach1.Xinterp[i])
                    lineWSE.append(section.WSE)

            for ids, section in enumerate(reach1.section_list):
                if type(param['B']['average_slope']) == str:
                    if param['B']['average_slope'] == 'fromMask':
                        imposedWSE = True
#AP                        npoly = np.min([param['N']['npoly'], len(reach1.section_list) - 2])
                        npoly = np.min([param['N']['npoly'], len(lineWSE) - 2])
#AP                        coeff_poly = np.polyfit(reach1.Xinterp, lineWSE, npoly)
                        coeff_poly = np.polyfit(lineXinterp, lineWSE, npoly)
                        WSEinterpolator = np.poly1d(coeff_poly)
                        newWS = WSEinterpolator(reach1.Xinterp[ids])
                        section.WSE= newWS

                elif type(param['B']['average_slope']) == float or type(param['B']['average_slope']) == np.float32\
                        or type(param['B']['average_slope']) == np.float64:
                    section.WSE = WSstart - param['B']['average_slope'] * (reach1.Xinterp[ids]-reach1.Xinterp[0])

        reach1.compute_slope(param['N']['npoly'],'WSE') #degrés de l'ajustement  des pentes

        if param['B']['fromQ']:
            xdig,hdig = [],[]
            for ids, section in enumerate(reach1.section_list):

                if section.slope > 0:

                    WS = section.WSE + 0.001
                    hd = find_BfromQ(WS, section, param, imposedWSE)
                    hdig.append(hd)
                    xdig.append(reach1.Xinterp[ids])
            if len(xdig)>1:
                hdig_interp = np.interp(reach1.Xinterp,xdig,hdig)

        for ids, section in enumerate(reach1.section_list):
            if param['B']['fromQ']:
                if len(xdig) >1:
                    param['B']['depth'] = hdig_interp[ids]
                else:
                    param['B']['depth'] = 0

            if not param['B']['bathymetricSections'] =='Original':
                section.modif_line(param,imposedWSE)

    print("Dug channel end")


def find_BfromQ(WS,section,param,imposedWSE = False):
    """ return the depth to remove to get the target discharge

    :param WS:  water surface
    :type section: float
    :param section: considered section
    :type section: Section
    :param param: Parameters requires for computation
    :type param: Parameters
    :df_filter: hydraulic data for the considered section
    :type df_filter: panda dataframe
    """

    new_depth_inf = 0.01
    new_depth_sup  = 40
    new_depth = (new_depth_sup + new_depth_inf) / 2
    ecart = section.Q
    iter_i =0

    while ecart > param['H']['eps'] and iter_i < param['H']['MaxIter']:
        section_temp = copy.deepcopy(section)
        iter_i += 1
        # calcul du milieu
        new_depth = abs(new_depth_sup + new_depth_inf) / 2
        param['B']['depth'] =new_depth
        section_temp.modif_line(param,imposedWSE)
        averaged= section_temp.computeHydraulicGeometry(
            WS, param['H']['dxlat'], param['H']['levee'],param['H']['frictionLaw'],section_temp.slope)
        Q1 = averaged['Q']

        if Q1 >= section.Q:
            # la solution est inférieure à m
            new_depth_sup = new_depth
        else:
            # la solution est supérieure à m
           new_depth_inf = new_depth

        ecart =abs(Q1-section.Q)

    return param['B']['depth']

def MergeRaster_channel_DEM(BV, work_path):
    """ merge the raster channel with initial DEM. Create folder and file of raster

    :param BV: Watershed
    :type BV: ModelCatchment
    :param work_path: folder to put new DEM raster
    :type work_path: string
    """

    if  BV.DEM_stack:
        with rasterio.open(os.path.join(work_path, 'global_channel_DEM.tif'),'r') as src:
            data = BV.DEM_stack['data_DEM']

            #folder for channel_DEM
            if not os.path.exists(os.path.join(work_path,'channel_DEM')):
                os.makedirs(os.path.join(work_path,'channel_DEM'))
            if not os.path.exists(os.path.join(work_path,'Merged_DEM')):
                os.makedirs(os.path.join(work_path,'Merged_DEM'))

            for idx in range(len(data)):
                res = BV.DEM_stack['data_DEM'][idx]['cell_size']
                polygon_coords = data[idx]['polygon_coords']
                polygon = Polygon(polygon_coords)

                if not polygon.is_valid:
                    polygon = polygon.buffer(0)

                out_image, out_transform = rasterio.mask.mask(src, [polygon], crop=True)
                out_meta = src.meta
                out_meta.update({"driver": "GTiff",
                "height": out_image.shape[-2],
                "width": out_image.shape[-1],
                "transform": out_transform,
                "crs": BV.crs,
                 "tiled": True,  # écrire en tuiles (meilleure gestion des gros fichiers)
                 "compress": "deflate",  # compression sans perte
                 "BIGTIFF": "YES"  # permet de dépasser 4 Go si besoin
                 })

                with rasterio.open(os.path.join(work_path,'channel_DEM', "Divided_channel_{}.tif".format(idx + 1)), "w", **out_meta) as dest:
                    dest.write(out_image)
                dem_file = BV.DEM_stack['file_list'][idx]


                with rasterio.open(dem_file) as dem_src, \
                        rasterio.open(os.path.join(work_path,'channel_DEM', "Divided_channel_{}.tif".format(idx + 1))) as channel_src:
                    out_meta = dem_src.meta.copy()
                    try :

                        merged_raster, transform = merge([dem_src, channel_src],method = 'min', res =res)
                    except:
                        merged_raster, transform = merge([dem_file,os.path.join(work_path,'channel_DEM', "Divided_channel_{}.tif".format(idx + 1))],
                                                         method='min', res=res) #old merge version

                    out_meta.update(
                        {
                            "driver": "GTiff",
                            "height": merged_raster.shape[1],
                            "width": merged_raster.shape[2],
                            "nodata": out_meta['nodata'],
                            "crs" : BV.crs,
                            "transform": transform,
                            "tiled": True,  # écrire en tuiles (meilleure gestion des gros fichiers)
                            "compress": "deflate",  # compression sans perte
                            "BIGTIFF": "YES"  # permet de dépasser 4 Go si besoin
                        }
                    )

                with rasterio.open(os.path.join(work_path,'Merged_DEM',"Merged_DEM_{}.tif".format(idx + 1)), "w", **out_meta) as dest:
                    dest.write(merged_raster)
            dest.close()
    else:
        print('No DEM to merge')

    print(f"Merged raster created")


def run_bathymetric(BV1, params):
    if params['B']['fromQ'] and not params['H']['createBanks']:
        addQtoSection(BV1, params)
    drag_all_XS(BV1, params)  # create the line of each section with the points of the banks
    if params['B']['mapping_method'] == 'interpolation':
        interpolate_DEM_channel(BV1, params)
        MergeRaster_channel_DEM(BV1, params.work_path)
    elif params['B']['mapping_method'] == 'by_polygon':
        rasterize_DEM_channel(BV1, params,)
        MergeRaster_channel_DEM(BV1, params.work_path)


def dig_init_DEM_with_channel(BV, param):
    """ rasterize the elevation of main channel by interpolation

    :param BV: Watershed
    :type BV: ModelCatchment
    :param param: Parameters requires for computation
    :type param: Parameters
    :param res:  resolution of tue output DEM raster
    :type res: float
    """
    output_res = param['B']['output_resolution']
    data_poly = create_poly_CEPHEE(BV, param,output_res)

    [list_of_poly, _, _, _, list_for_interp, list_of_bed] = data_poly
    list_of_point, _, _, _, list_of_point_z = list_for_interp
    multi_polygon = MultiPolygon(list_of_poly)

    if len(BV.DEM_stack['file_list']) == 1:
        print('only one DEM considered')

        with rasterio.open(BV.DEM_stack['file_list'][0]) as src:
            meta = src.meta.copy()
            out_shape = (src.height, src.width)
            transform = src.transform
            nodata = src.nodata

        # Préparer les (geom, value) pour rasterize
        shapes = [(geom, val) for geom, val in zip(list_of_poly, list_of_bed)]

        # Rasterisation : chaque polygone peint avec sa valeur
        rasterized = rasterize(
            shapes=shapes,
            out_shape=out_shape,
            transform=transform,
            fill=nodata,
            dtype=meta["dtype"],
            all_touched=False  # ou True si tu veux inclure tous les pixels que le polygone touche
        )

        # Écriture dans un nouveau raster
        meta.update({"count": 1, "nodata": nodata,
            "tiled": True,              # écrire en tuiles (meilleure gestion des gros fichiers)
            "compress": "deflate",      # compression sans perte
            "BIGTIFF": "YES"            # permet de dépasser 4 Go si besoin
         })

        with rasterio.open(os.path.join(param.work_path, 'global_channel_DEM.tif'), "w", **meta) as dest:
            dest.write(rasterized, 1)
    BV.display("Main Channel Raster created")

