

from os import getcwd, path
import time
import argparse
import numpy as np
import glob
import os
from pyproj import CRS
import importlib
import rasterio
from rasterio.merge import merge
from rasterio.warp import Resampling, calculate_default_transform, reproject
from typing import List
import geopandas as gpd




def merge_rasters(rasters: List[str], save_name, crs,res):
    """
    Merge multiple rasters.

    Parmeters
    ---------
       rasters : raster tiles paths list
       save_name : merged data saving name

    Returns
    -------
       result : path to merged raster
    """

    #foldername= os.path.dirname(rasters[0])
    # save_path = os.path.join(foldername, save_name + ".tif")
    save_path = save_name + ".tif"
    first_raster = rasterio.open(rasters[0],crs= crs)
    out_meta = first_raster.meta.copy()

    for raster in rasters:
        fraster = rasterio.open(raster,crs= crs)
        print(fraster.nodata)
        if fraster.nodata==0 or fraster.nodata==None :
            data = fraster.read(1)
            data[data == 0] = np.nan
            with rasterio.open(raster, 'w', driver='GTiff',
                   height=fraster.height, width=fraster.width,
                   count=1, dtype=rasterio.float32, crs=fraster.crs, transform=fraster.transform, nodata= np.nan) as dst:

                dst.write(data, 1)
            

    # Ouvrir les rasters
    try:
        datasets = [rasterio.open(fp) for fp in rasters]
        merged_raster, transform = merge(datasets, res=res, method='min')
    except:
        merged_raster, transform = merge(rasters, res=res,method='min')


    out_meta.update(
       {
           "driver": "GTiff",
           "height": merged_raster.shape[1],
           "width": merged_raster.shape[2],
           "crs" : crs,
           "transform": transform
       }
    )

    with rasterio.open(save_path, "w+", **out_meta) as dest:
       dest.write(merged_raster)

    return os.path.abspath(save_path), out_meta



DEM_path =  '/Users/cassan/Documents/SWIFT/chinon/raster_to_merge'
resolution = 2
crs = CRS.from_user_input( 'EPSG:2154')
extension = '.tif'
flist = sorted(glob.glob((os.path.join(DEM_path , '*' + extension))))

input_paths = []
for i in range(len(flist)):
    input_paths.append(flist[i])

print(input_paths)
full_save_path = os.path.join(DEM_path,'_' + str(int(resolution)) + "m_merged")

savepath, merge_meta = merge_rasters(input_paths, full_save_path,crs, resolution )
