# third-party
import os
import numpy as np
import rasterio
from rasterio.merge import merge
from rasterio.warp import Resampling, calculate_default_transform, reproject
from pysheds.grid import Grid
from typing import List
import geopandas as gpd

def read_DEM(DEM_filepath, crs):
    """Read DEM file given by DEM_filepath. the program can be used by various classes.

    :param DEM_filepath: file with DEM data
    :type DEM_filepath: string
    :param crs: Coordinate Reference System
    :type crs: pyproj.CRS object
    :return:
        - dictionary with global data of DEM grid
        - pyproj Grid instance
        - dictionary with three 2D-arrays (X,Y,Z)
    """

    convert_epsg = False
    _, extension = os.path.splitext(DEM_filepath)
    _, filename = os.path.split(DEM_filepath)
    # Read file
    x_ll_corner = y_ll_corner= x_ur_corner = y_ur_corner = X = Y = None
    if extension == '.xyz':
        with open(DEM_filepath, 'r') as xyz_file:
            lines = xyz_file.readlines()

        raise Exception("Not implemented yet")

    elif extension =='.asc':
        grid1 = Grid.from_ascii(DEM_filepath, crs=crs)
        dem = grid1.read_ascii(DEM_filepath, crs=crs)

        n_cols = grid1.shape[1]
        n_rows = grid1.shape[0]
        a = grid1.affine * (0,0)
        b = grid1.affine * (1,0)
        cell_size = float(b[0]-a[0])
        x_ll_corner = a[0]
        y_ll_corner = a[1]-n_rows * cell_size
        no_data = grid1.nodata
      
    elif extension=='.tiff' or extension=='.tif':
        grid1 = Grid.from_raster(DEM_filepath)
        dem = grid1.read_raster(DEM_filepath)
        n_cols = grid1.shape[1]
        n_rows = grid1.shape[0]
        a = grid1.affine * (0,0)
        b = grid1.affine * (1,0)
        cell_size = float(b[0]-a[0])
        x_ll_corner = a[0]
        y_ll_corner = a[1]-n_rows * cell_size
        no_data = grid1.nodata

        # use of rasterio
       # with rasterio.open(DEM_filepath, "r") as src:
         #   dem = src.read(1)
          #  grid1 = src
        #n_cols = dem.shape[1]
         #   n_rows = dem.shape[0]
        #    cell_size = src.res[0]
        #    x_ll_corner = src.bounds[0]
        #    y_ll_corner = src.bounds[1]
         #   no_data = src.nodata

    else:
        raise Exception("DEM file format not supported" + extension)
     
    X, Y = np.meshgrid(np.arange(x_ll_corner, x_ll_corner + n_cols * cell_size, cell_size),
                       np.arange(y_ll_corner, y_ll_corner + n_rows * cell_size, cell_size),
                       indexing='xy')
    Y=np.flipud(Y)

    x_ur_corner = x_ll_corner + n_cols * cell_size
    y_ur_corner = y_ll_corner + n_rows * cell_size
    polygon = ((x_ll_corner, y_ll_corner), (x_ur_corner, y_ll_corner),
               (x_ur_corner, y_ur_corner), (x_ll_corner, y_ur_corner))
    data_DEM ={'ncols':n_cols, 'path':DEM_filepath, 'n_rows':n_rows , 'cell_size':cell_size,
               'll_corner':(x_ll_corner,y_ll_corner), 'ur_corner':(x_ur_corner, y_ur_corner),
               'polygon_coords': polygon, 'no_data':no_data}
    DEM ={'X':X, 'Y':Y, 'Z': dem}
        
    return data_DEM, grid1, DEM
   
    
def _single_raster_to_epsg_out(raster: str, save_name: str, _epsg_out) -> str:
     """
     Change single raster file so that it matches 'epsg_out' parameter.

     Parameters
     ----------
         raster : path to raster to reproject
         save_name : reprojected raster saving name

     Returns
     -------
         save_path : path to reprojected raster
     """
     foldername= os.path.dirname(raster)
     save_path = os.path.join(foldername, save_name + "_proj.tif")

     with rasterio.open(raster) as src:
         # Get the current CRS of the raster
         crs = src.crs

         # Update CRS only if needed
         if crs == rasterio.crs.CRS.from_epsg(_epsg_out):
             return raster

         # Define the new CRS using an EPSG code
         new_crs = rasterio.crs.CRS.from_epsg(_epsg_out)

         # Create a transform to convert from the current CRS to the new CRS
         transform, width, height = calculate_default_transform(crs, new_crs, src.width, src.height, *src.bounds)

         # Create a new dataset with the new CRS and transform
         profile = src.profile
         profile.update(crs=new_crs, transform=transform, width=width, height=height)

         with rasterio.open(save_path, "w", **profile) as dst:
             # Reproject the raster data to the new CRS and write to the output file
             for i in range(1, src.count + 1):
                 reproject(
                     rasterio.band(src, i),
                     rasterio.band(dst, i),
                     src_transform=src.transform,
                     src_crs=src.crs,
                     dst_transform=transform,
                     dst_crs=new_crs,
                     resampling=Resampling.bilinear,
                 )

     return save_path


def merge_rasters(rasters: List[str], save_name, crs,res):
    """
    Merge multiple rasters.

    Parmeters
    ---------
       rasters : raster tiles paths list
       save_name : merged data saving name

    Returns
    -------
       result : path to merged raster
    """

    #foldername= os.path.dirname(rasters[0])
    # save_path = os.path.join(foldername, save_name + ".tif")
    save_path = save_name + ".tif"
    first_raster = rasterio.open(rasters[0],crs= crs)
    out_meta = first_raster.meta.copy()
    # Ouvrir les rasters
    try:
        datasets = [rasterio.open(fp) for fp in rasters]
        merged_raster, transform = merge(datasets, res=res)
    except:
        merged_raster, transform = merge(rasters, res=res)


    out_meta.update(
       {
           "driver": "GTiff",
           "height": merged_raster.shape[1],
           "width": merged_raster.shape[2],
           "crs" : crs,
           "transform": transform
       }
    )

    with rasterio.open(save_path, "w+", **out_meta) as dest:
       dest.write(merged_raster)

    return os.path.abspath(save_path), out_meta

def polygonize_water_mask(filename_mask,param,N_reach):
    """transform water mask from raster to polygone

    :param filename_mask: file with mask data
    :type filename_mask: string
    :param param: Parameters requires for computation
    :type param: Parameters
    :param N_reach: number of polygon to keep (from the largest one in decreasing order)
    :type N_reach: int

    """
    mask = None
    crs = param['C']['DEM_CRS_ID']
    with rasterio.Env():
        with rasterio.open(filename_mask) as src:
            image = src.read(1)  # first band
            image = image.astype('uint8')
            results = (
                {'properties': {'raster_val': v}, 'geometry': s}
                for i, (s, v) in enumerate(rasterio.features.shapes(image, mask=mask, transform=src.transform)) if v == 0)

        geoms = list(results)
        gpd_polygonized_raster = gpd.GeoDataFrame.from_features(geoms)
        gpd_polygonized_raster.to_file(os.path.join(param.work_path,'polygone_found.shp'),crs = crs)


    gpd_polygonized_raster['area'] = gpd_polygonized_raster['geometry'].area
    gdf_sort = gpd_polygonized_raster.sort_values(by='area',ascending=False)
    index_max_area = gdf_sort['area'].idxmax()
    nreach=np.min([len(gdf_sort),N_reach])
    gdf_sort = gdf_sort.head(nreach)

    '''
    #centerline detection
    tolerance=1
    minDistance=500
    lines=[]
    #attributes = {"id": 1, "name": "polygon", "valid": True}
    for i, row in gdf_sort.iterrows():
        polygon_river = row['geometry']#.simplify(tolerance=tolerance)
        attributes = {"id": i, "name": "polygon", "valid": True}
        centerline = Centerline(polygon_river,interpolation_distance=100, **attributes)
        lines.append(centerline)
        gdf_sort.loc[i,'geometry']=polygon_river
        
    '''


    gdf_sort.to_file(os.path.join(param.work_path,'polygone_smooth.shp'),crs=crs)

