######################################################################
#             / ____|  ____|  __ \| |  | |  ____|  ____|             #
#            | |    | |__  | |__) | |__| | |__  | |__                #
#            | |    |  __| |  ___/|  __  |  __| |  __|               #
#            | |____| |____| |    | |  | | |____| |____              #
#             \_____|______|_|    |_|  |_|______|______|             #
######################################################################
#
# CEPHEE
# Copyright (C) 2024 Toulouse INP
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License for more details :
# <http://www.gnu.org/licenses/>.
#
######################################################################

import os
import numpy as np
import sys
from shapely.geometry import Point,Polygon,MultiPolygon, mapping
from shapely.ops import unary_union
from rasterio.features import geometry_mask
from rasterio.mask import mask
import rasterio
from rasterio.enums import Resampling
from rasterio.transform import from_origin
from rasterio.enums import MergeAlg
from scipy.interpolate import griddata
import pandas as pd
from rasterio.transform import Affine
from scipy.interpolate import LinearNDInterpolator
from .Tools import convert_projected_to_latlon

def create_poly_CEPHEE(BV,nx_trans = None, pixel_size=None,resultType ='Normal'):
    """transfrom  the data in XS hydraulic field into polygon and list of point for interpolation or rasterization

    :param BV: Watershed
    :type BV: ModelCatchment
    :param nx_trans: number of vertical lines in transverse direction
    :type BV: int
    :param pixel_size: resolution of output raster
    :type pixel_size: float
    :param resultType: type of computation to take into account for plot result
    :type pixel_size: string

    :return:
    - list of resyts (list)
    """

    if pixel_size ==None:
        print('pixel size have to be given if no raster is provided')
        sys.exit()
            
    # Créer les polygones entre les sections
    list_of_poly=[]
    list_of_WSE=[]
    valueh = []
    valueV = []

    #liste pou rles interpolation
    list_of_point=[]
    list_of_point_h=[]
    list_of_point_v=[]
    list_of_point_wse=[]
    list_of_point_z = []

    for reach in BV.reach:

        if resultType == 'Normal':
            df = reach.resNormal
        elif resultType == '1D':
            df = reach.res1D
        elif resultType == 'Himposed':
            df= reach.resHimposed
        elif resultType == 'Obs':
            df = reach.resObs

        if not df.empty:
            for j in range(len(reach.section)-1):
                df_filter = df[df['idSection'] == j]
                df_filter2 = df[df['idSection'] == j+1]

                dist_init = df_filter['distance'].iloc[0]
                h_init = df_filter['h'].iloc[0]
                v_init = df_filter['V'].iloc[0]
                v_init2 = df_filter2['V'].iloc[0]
                h_init2 = df_filter2['h'].iloc[0]
                WSE = df_filter['WSE'].iloc[0]
                dist_init2 = df_filter2['distance'].iloc[0]

                if nx_trans == 0:
                    distancebank1 = df_filter['bank'].iloc[0][0][0]
                    distancebank2 = df_filter['bank'].iloc[0][0][1]
                    nx_trans_min = int(abs(distancebank2-distancebank1)/pixel_size)
                    #print('n point trans =' +str(nx_trans_min))
                    if nx_trans_min ==0: nx_trans_min =1
                else:
                    nx_trans_min = nx_trans

                if len(dist_init)<2 : # un seul point latéral
                    if len(dist_init)==0:
                        dist_init,v_init,h_init = [0],[0],[0]
                    d_new1 = dist_init[0]-pixel_size,dist_init[0] + pixel_size
                    V_new1 = [v_init[0],v_init[0]]
                    h_new1 = [v_init[0],h_init[0]]
                else:
                    d_new1 =np.linspace(np.min(dist_init),np.max(dist_init),nx_trans_min,endpoint=True)
                    V_new1 = np.interp(d_new1, dist_init, v_init)
                    h_new1 = np.interp(d_new1, dist_init, h_init)

                if len(dist_init2) < 2:  # un seul point latéral
                    if len(dist_init2)==0:
                        dist_init2,v_init2,h_init2 = [0],[0],[0]
                    d_new2 = dist_init2[0] - pixel_size, dist_init2[0] + pixel_size
                    V_new2 = [v_init2[0], v_init2[0]]
                    h_new2 = [v_init2[0], h_init2[0]]
                    print('warning: number of point less than lateral space, width fixed to 2 * dx lat')
                else:
                    d_new2 =np.linspace(np.min(dist_init2),np.max(dist_init2),nx_trans_min,endpoint=True)
                    V_new2 = np.interp(d_new2, dist_init2, v_init2)
                    h_new2 = np.interp(d_new2, dist_init2, h_init2)

                    #calcul des X et Y des sous sections de calcul

                Xsection =[coord[0] for coord in reach.section[j].line.coords]
                Ysection =[coord[1] for coord in reach.section[j].line.coords]
                Zsection =[coord[2] for coord in reach.section[j].line.coords]
                distance_section = reach.section[j].distance
                Xsection2 =[coord[0] for coord in reach.section[j+1].line.coords]
                Ysection2 =[coord[1] for coord in reach.section[j+1].line.coords]
                Zsection2 = [coord[2] for coord in reach.section[j+1].line.coords]
                distance_section2 = reach.section[j+1].distance

                x_sub = np.interp(d_new1, distance_section, Xsection)
                y_sub = np.interp(d_new1, distance_section, Ysection)
                z_sub = np.interp(d_new1, distance_section, Zsection)
                x_sub2 = np.interp(d_new2, distance_section2, Xsection2)
                y_sub2 = np.interp(d_new2, distance_section2, Ysection2)
                z_sub2 = np.interp(d_new2, distance_section2, Zsection2)

                for x,y,h,v,z in zip(x_sub,y_sub,h_new1,V_new1,z_sub):
                    list_of_point.append(Point(x,y))
                    list_of_point_h.append(h)
                    list_of_point_v.append(v)
                    list_of_point_wse.append(WSE)
                    list_of_point_z.append(z)

                lenXSmin = np.min([len(x_sub),len(x_sub2)])
                for i in range(lenXSmin-1): #boucle sur les points de la ligne
                    poly = Polygon([Point(x_sub[i],y_sub[i]),
                                Point(x_sub[i+1],y_sub[i+1]),
                                Point(x_sub2[i+1],y_sub2[i+1]),
                                Point(x_sub2[i],y_sub2[i])]
                                )
                    list_of_poly.append(poly)
                    list_of_WSE.append(WSE)
                    valueh.append(h_new1[i])
                    valueV.append(V_new1[i])

            # add the last section
            for x, y, h, v, z in zip(x_sub2, y_sub2, h_new2, V_new2, z_sub2):
                list_of_point.append(Point(x, y))
                list_of_point_h.append(h)
                list_of_point_v.append(v)
                list_of_point_wse.append(WSE)
                list_of_point_z.append(z)

        else:
            raise ValueError("No result available for this type of computation")

    list_for_interp= (list_of_point,list_of_point_h,list_of_point_v,list_of_point_wse,list_of_point_z)
    return [list_of_poly, valueh, valueV, list_of_WSE,list_for_interp]


def interpolate_result_CEPHEE(data_poly,output_path,config_cal,pixel_size,crs=None):
    """interpolate the data list of polygon and point from create_poly_cephee

    :param data_poly: data from create_poly_cephee
    :type data_poly: list
    :param output_path: folder to save raster
    :type BV: string
    :param config_cal: additional name to raster
    :type BV: string
    :param pixel_size: resolution of output raster
    :type pixel_size: float
    :param crs: crs for the output raster
    :type crs: crs type

    :return:
        - dictionnary of name of raster result
    """

    nodata = np.nan
    [list_of_poly,_,_,_,list_for_interp] = data_poly
    list_of_point, list_of_point_h, list_of_point_v, list_of_point_wse,_ = list_for_interp
    multi_polygon = MultiPolygon(list_of_poly)
    # Calculer la bounding box du MultiPolygon
    minx, miny, maxx, maxy = multi_polygon.bounds
    height = int(abs(maxy - miny)/pixel_size[1])
    width = int(abs(maxx - minx)/pixel_size[0])
    transform = Affine.translation(minx, miny) * Affine.scale(pixel_size[0],  pixel_size[1])
    x_XS = [pt.x for pt in list_of_point]
    y_XS = [pt.y for pt in list_of_point]

    #xi = np.transpose(np.array([X, Y]))
    #nouveaux point de calcul
    Xres, Yres = np.meshgrid(np.arange(minx, maxx, int(pixel_size[0])),
                       np.arange(miny, maxy, int(pixel_size[1])),
                       indexing='xy')


    interp_h = LinearNDInterpolator(list(zip(x_XS, y_XS)), list_of_point_h)
    H = interp_h(Xres, Yres).astype(np.float32)
    interp_V = LinearNDInterpolator(list(zip(x_XS, y_XS)), list_of_point_v)
    V = interp_V(Xres, Yres).astype(np.float32)
    interp_wse = LinearNDInterpolator(list(zip(x_XS, y_XS)), list_of_point_wse)
    WSE = interp_wse(Xres, Yres).astype(np.float32)

    # Convertir MultiPolygon en GeoJSON (format lisible par rasterio)
    geom = [mapping(multi_polygon)]

    with rasterio.open(
            os.path.join(output_path, 'water_depth_' + config_cal + '.tif'),
            'w',
            driver='GTiff',
            height=height,
            width=width,
            count=1,
            dtype=rasterio.float32,  # Utiliser un type de données approprié pour les valeurs Z
            crs=crs,
            nodata = nodata,
            transform=transform
    ) as dst:
        dst.write(H, 1)

    with rasterio.open(os.path.join(output_path, 'water_depth_' + config_cal + '.tif')) as src:
        out_image, out_transform = mask(src, geom, crop=True)
        out_meta = src.meta.copy()
        # Mettre à jour les méta-données pour le raster croppé
        out_meta.update({ "driver": "GTiff","height": out_image.shape[1],"width": out_image.shape[2],"transform": out_transform
        })

        # Sauvegarder le raster croppé dans un nouveau fichier
        with rasterio.open(os.path.join(output_path, 'water_depth_' + config_cal + '.tif'), "w", **out_meta) as dest:
            dest.write(out_image)


    with rasterio.open(
            os.path.join(output_path, 'velocity_' + config_cal + '.tif'),
            'w',
            driver='GTiff',
            height=height,
            width=width,
            count=1,
            dtype=rasterio.float32,  # Utiliser un type de données approprié pour les valeurs Z
            crs=crs,
            nodata=nodata,
            transform=transform
    ) as dst:
        dst.write(V, 1)

    with rasterio.open(os.path.join(output_path, 'velocity_' + config_cal + '.tif')) as src:
        out_image, out_transform = mask(src, geom, crop=True)
        out_meta = src.meta.copy()
        # Mettre à jour les méta-données pour le raster croppé
        out_meta.update({ "driver": "GTiff","height": out_image.shape[1],"width": out_image.shape[2],"transform": out_transform
        })

        # Sauvegarder le raster croppé dans un nouveau fichier
        with rasterio.open(os.path.join(output_path, 'velocity_' + config_cal + '.tif'), "w", **out_meta) as dest:
            dest.write(out_image)

    with rasterio.open(
            os.path.join(output_path, 'WSE_' + config_cal + '.tif'),
            'w',
            driver='GTiff',
            height=height,
            width=width,
            count=1,
            dtype=rasterio.float32,  # Utiliser un type de données approprié pour les valeurs Z
            crs=crs,
            nodata=nodata,
            transform=transform
    ) as dst:
        dst.write(WSE, 1)

    with rasterio.open(os.path.join(output_path, 'WSE_' + config_cal + '.tif')) as src:
        out_image, out_transform = mask(src, geom, crop=True)
        out_meta = src.meta.copy()
        # Mettre à jour les méta-données pour le raster croppé
        out_meta.update({ "driver": "GTiff","height": out_image.shape[1],"width": out_image.shape[2],"transform": out_transform
        })

        # Sauvegarder le raster croppé dans un nouveau fichier
        with rasterio.open(os.path.join(output_path, 'WSE_' + config_cal + '.tif'), "w", **out_meta) as dest:
            dest.write(out_image)
# Rasterize vector using the shape and coordinate system of the raster
# Get list of geometries for all features in vector file
    raster_result_path = {'depth': os.path.join(output_path, 'water_depth_' + config_cal + '.tif'),
                      'velocity': os.path.join(output_path, 'velocity_' + config_cal + '.tif'),
                      'WSE': os.path.join(output_path, 'WSE_' + config_cal + '.tif')}

    return raster_result_path


def rasterize_poly_CEPHEE(data_poly,output_path,config_cal,pixel_size,DEM_file,crs=None):
    """rasterize  the data list of polygon and point from create_poly_cephee

    :param data_poly: data from create_poly_cephee
    :type data_poly: list
    :param output_path: folder to save raster
    :type BV: string
    :param config_cal: additional name to raster
    :type BV: string
    :param pixel_size: resolution of output raster
    :type pixel_size: float
    :param DEM_file: filename of the DEM raster to compute water depth from water elevation raster
    :type pixel_size: str
    :param crs: crs for the output raster
    :type crs: crs type


    :return:
        - dictionnary of name of raster result
    """
    [list_of_poly, valueh, valueV, list_of_WSE,_] = data_poly

    if DEM_file:
        src= rasterio.open(DEM_file,'r')
        upscale_factorx = src.transform[0] / pixel_size[0]
        upscale_factory = abs(src.transform[4]) / pixel_size[1]

        dst = src.read(
            out_shape=(
                src.count,
                int(src.height * upscale_factory),
                int(src.width * upscale_factorx)
            ),
            resampling=Resampling.bilinear
        )

        # scale image transform
        transform = src.transform * src.transform.scale(
            (src.width / dst.shape[-1]),
            (src.height / dst.shape[-2]))
        new_meta = src.meta.copy()
        new_meta.update({"driver": "GTiff",
                         "height": dst.shape[-2],
                         "width": dst.shape[-1],
                         "transform": transform,
                         "crs": src.crs,
                         })

        with rasterio.open(os.path.join(output_path,'DEM_reshape.tif'), 'w', **new_meta) as out_n:
            out_n.write(dst), 1
        src.close()

        rasterDEM = rasterio.open(os.path.join(output_path,'DEM_reshape.tif'))
        DEM = rasterDEM.read(1)
        height = rasterDEM.height
        width = rasterDEM.width
        transform = rasterDEM.transform
        crs = rasterDEM.crs

    else:
        multi_polygon = MultiPolygon(list_of_poly)
        # Calculer la bounding box du MultiPolygon
        minx, miny, maxx, maxy = multi_polygon.bounds
        height = abs(maxy-miny)
        width = abs(maxx-minx)
        transform = Affine.translation(minx, maxy) * Affine.scale(pixel_size[0], -1*pixel_size[1])

    # Créer une matrice d'image raster avec rasterio
    with rasterio.open(
            os.path.join(output_path,'water_depth_' +config_cal +'.tif'),
            'w',
            driver='GTiff',
            height=height,
            width=width,
            count=1,
            dtype=rasterio.float32,  # Utiliser un type de données approprié pour les valeurs Z
            crs=crs,
            transform=transform
    ) as dst:
        # Rasterize vector using the shape and coordinate system of the raster
        # Get list of geometries for all features in vector file

        geom_valuesV = ((poly,value) for poly, value in zip(list_of_poly, valueV))
        rasterized = rasterio.features.rasterize(geom_valuesV,
                                        out_shape = dst.shape,
                                        fill = 0,
                                        out = None,
                                        transform = dst.transform,
                                        all_touched = True,
                                        merge_alg = MergeAlg.replace,
                                        dtype = None)
        V_interp  = np.ma.masked_where(rasterized == 0, rasterized)


        geom_valuesWSE = ((poly,value) for poly, value in zip(list_of_poly, list_of_WSE))
        WSE_interp = rasterio.features.rasterize(geom_valuesWSE,
                                        out_shape = dst.shape,
                                        fill = 0,
                                        out = None,
                                        transform = dst.transform,
                                        all_touched = True,
                                        merge_alg = MergeAlg.replace,
                                        dtype = None)
        if DEM_file:
            WSE_interp = np.where(WSE_interp <= DEM, DEM, WSE_interp)
        # WSE_interp_mask= np.ma.masked_where(WSE_interp <DEM, WSE_interp)
        if DEM_file:
            h_interp = WSE_interp - DEM
            #h_interp = np.where(WSE_interp == DEM, 0, h_interp)
           # h_interp = np.where(h_interp == -rasterDEM.nodata, 0, h_interp)
        else:
            geom_valuesh = ((poly, value) for poly, value in zip(list_of_poly, valueh))
            rasterized = rasterio.features.rasterize(geom_valuesh,
                                                     out_shape=dst.shape,
                                                     fill=0,
                                                     out=None,
                                                     transform=dst.transform,
                                                     all_touched=True,
                                                     merge_alg=MergeAlg.replace,
                                                     dtype=None)
            h_interp = np.ma.masked_where(rasterized == 0, rasterized).astype(np.float32)

        dst.write(h_interp,1)

        with rasterio.open(os.path.join(output_path,'velocity_' +config_cal +'.tif'), 'w', driver='GTiff',height=height, width=width,
                    count=1, dtype=rasterio.float32, crs= crs , transform=transform) as dst:
            
            dst.write(V_interp.astype(np.float32), 1)
            
        with rasterio.open(os.path.join(output_path,'WSE_' +config_cal +'.tif'), 'w', driver='GTiff',height=height, width=width,
                    count=1, dtype=rasterio.float32, crs= crs , transform=transform) as dst:
            
            dst.write(WSE_interp.astype(np.float32), 1)
        
        raster_result_path={'depth': os.path.join(output_path,'water_depth_' +config_cal +'.tif'),
                            'velocity': os.path.join(output_path,'velocity_' +config_cal +'.tif'),
                            'WSE':  os.path.join(output_path,'WSE_' +config_cal +'.tif') }

       
        return  raster_result_path 

def export_raster(param,output_file, rasterh, rasterV,rasterWSE,data_poly):
    """write raster from cephhe result

    :param param: parameter for computation
    :type parameter: Parameter

    """
    [_, _, _, _, height,width,transform] = data_poly
        # Enregistrement du raster géoréférencé avec rasterio
    with rasterio.open(output_file+ '_h', 'w', driver='GTiff', height=height, width=width,
                    count=1, dtype=str(rasterh.dtype), crs=param['C']['crs'] , transform= transform) as dst:
        dst.write(rasterh, 1)
    # Enregistrement du raster géoréférencé avec rasterio
    with rasterio.open(output_file + '_V', 'w', driver='GTiff',height=height, width=width,
                    count=1, dtype=str(rasterV.dtype), crs= param['C']['crs'] , transform=transform) as dst:
        dst.write(rasterV, 1)

    with rasterio.open(output_file + '_WSE', 'w', driver='GTiff',height=height, width=width,
                count=1, dtype=str(rasterV.dtype), crs= param['C']['crs'] , transform=transform) as dst:
        dst.write(rasterWSE, 1)

    print('export raster finished')

        
def export_GISformat(BV,param,pathname,format):
    """create csv file for plotting cephee result in GIS format other than Qgis (only banks positions)

    :param BV: Watershed
    :type BV: ModelCatchment
    :param param: Parameters requires for computation
    :type param: Parameters
    :param pathname: folder to write cvs file
    :type pathname: str
    :param format: type of export
    :type format: str

    """
        
    if format=='GoogleEarth' :
        source_epsg = BV.crs
        filename = pathname +'/GISformat.csv'
    
        x=[]
        y=[]
        z=[]
        w=[]
        ind_river = BV.outlet[BV.id_outlet][1]
        for i in range(len(BV.reach)):
            if BV.reach[i].geodata['River'] in ind_river:
                for j in range(len(BV.reach[i].section)):
                    x.append(BV.reach[i].section[j].Wsextent[0][0])
                    x.append(BV.reach[i].section[j].Wsextent[1][0])
                    y.append(BV.reach[i].section[j].Wsextent[0][1])
                    y.append(BV.reach[i].section[j].Wsextent[1][1])
                    z.append(BV.reach[i].section[j].WS)
                    z.append(BV.reach[i].section[j].WS)
                    w.append(BV.reach[i].section[j].width)
                    w.append(BV.reach[i].section[j].width)
        
        x=np.array(x)
        y=np.array(y)
        z=np.array(z)
        w=np.array(w)
        Lat=[]
        Long=[]
        XY=np.zeros((len(x),4))
        for i in range(len(x)):
            lon, lat = convert_projected_to_latlon(x[i], y[i], source_epsg)
            XY[i,0]=lat
            XY[i,1] =lon
            XY[i,2] =z[i]
            XY[i,3] =w[i]
        
        df = pd.DataFrame(XY,columns=['Latitude','Longitude','Elevation','Width'])                        
        df.to_csv(filename)   
            
    if format =='raster':
        ind_river = BV.outlet[BV.id_outlet][1]

        for i in range(len(BV.reach)):
            if BV.reach[i].geodata['River'] in ind_river:
                data_poly = create_poly_CEPHEE( BV.reach[i], nx_trans =10,pixel_size=param['C']['resolution'])
                output_file= os.path.join(param['C']['pathname'] , 'output')
                print('creation polygons finished')
            
                rasterh, rasterV,rasterWSE = rasterize_poly_CEPHEE(data_poly, output_file,param['C']['crs'])
                print('rasterisation polygons finished')
                [_, _, _, _, height,width,transform] = data_poly
                print(np.max(rasterh),height,width,transform)

                export_raster(param,output_file, rasterh, rasterV,rasterWSE,data_poly)

