######################################################################
#             / ____|  ____|  __ \| |  | |  ____|  ____|             #
#            | |    | |__  | |__) | |__| | |__  | |__                #
#            | |    |  __| |  ___/|  __  |  __| |  __|               #
#            | |____| |____| |    | |  | | |____| |____              #
#             \_____|______|_|    |_|  |_|______|______|             #
######################################################################
#
# CEPHEE
# Copyright (C) 2024 Toulouse INP
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License for more details :
# <http://www.gnu.org/licenses/>.
#
######################################################################

# global
import os
from shapely.geometry import Point,Polygon,MultiPolygon, mapping
import rasterio
from rasterio.features import geometry_mask
from rasterio.mask import mask
from rasterio.enums import Resampling, MergeAlg
from rasterio.transform import Affine
import pandas as pd
from scipy.interpolate import LinearNDInterpolator
from tatooinemesher.constraint_line import ConstraintLine
from tatooinemesher.mesh_constructor import MeshConstructor
from tatooinemesher.utils import TatooineException
import numpy as np
import sys
# local
from core.Tools import convert_projected_to_latlon
try:
    sys.path.append(os.path.abspath('../../telemac-mascaret/scripts/python3'))
    from data_manip.extraction.telemac_file import TelemacFile
    from scipy.spatial import KDTree
    import gmsh
except ImportError:
   print('open telemac librairy not available')

def result_slf(BV, params, work_path, config_cal,crs):

    nb_point = int(params['XS']['number_of_points'] / 2)  # int(width_av/ params['H']['dxlat'])
    pixel_size = params['XS']['width'] / nb_point
    nreach=0

    for reach in BV.reach:
        if len(reach.section_list) > 1:
            nreach+=1
            section_seq = reach
            section_seq.compute_dist_proj_axe(reach.line_int, 10)
            section_seq.check_intersections()
            section_seq.sort_by_dist()
            constraint_lines = ConstraintLine.get_lines_and_set_limits_from_sections(section_seq, 'LINEAR')

            if nb_point is not None and len(constraint_lines) != 2:
                raise TatooineException("Argument `--nb_pts_lat` is only compatible with 2 constraint lines!")
            mesh_constr = MeshConstructor(section_seq=section_seq, lat_step=None,
                                          nb_pts_lat=nb_point, interp_values='LINEAR')
            mesh_constr.build_interp(constraint_lines, pixel_size, True)
            mesh_constr.export_points(os.path.join(work_path, 'result_'+config_cal+'.shp'))
            mesh_constr.build_mesh()
            mesh_constr.export_mesh(os.path.join(work_path, 'result_'+config_cal+'_'+str(nreach) +'.slf'), lang="fr")

def result_slf_gmsh(BV, params, data_poly,work_path, config_cal, crs):

    [list_of_poly, valueh, valueV, list_of_WSE, list_for_interp,list_of_bed] = data_poly
    list_of_point, list_of_point_h, list_of_point_v, list_of_point_wse, _ = list_for_interp

    union_poly = unary_union(list_of_poly)

    step_long = params['XS']['creation_step']
    if params['XS']['interpolate_XS']:
        step_long = params['XS']['interpolation_step']


    mesh_name = os.path.join(params.work_path, 'mesh_init')
    slf_file = os.path.join(params.work_path,  'result_'+config_cal+'.slf')
    compute_mesh_cephee(union_poly, [], [], list_of_point, size_mesh,step_long, mesh_name, density_line=False,
                 boundary_distance_null=True)
    gmsh2slf(mesh_name, slf_file)

def create_poly_CEPHEE(BV, params, output_res = None):
    """transfrom  the data in XS hydraulic field into polygon and list of point for interpolation or rasterization

    :param BV: Watershed
    :type BV: ModelCatchment
    :return:
    - list of resyts (list)
    """

    if not output_res:
        output_res = params['H']['output_resolution']
    # Créer les polygones entre les sections
    list_of_poly = []
    list_of_WSE = []
    valueh = []
    valueV = []

    # liste pou rles interpolation
    list_of_point = []
    list_of_point_h = []
    list_of_point_v = []
    list_of_point_wse = []
    list_of_point_z = []
    list_of_bed = []

    for reach in BV.reach:
        if len(reach.section_list) > 1:
            #calcul largeur moyenne
            pixel_size = output_res

            for j in range(len(reach.section_list) - 1):
                section = reach.section_list[j]
                section2 = reach.section_list[j+1]
                if np.isscalar(section.WSE):
                    WSE = section.WSE
                else:
                    WSE = section.WSE[0]

                if np.isscalar(section2.WSE):
                    WSE2 = section2.WSE
                else:
                    WSE2 = section2.WSE[0]

                ds1 = process_section(section, WSE, pixel_size)
                ds2 = process_section(section2, WSE2, pixel_size)

                for x, y, h, v, z in  zip(ds1['x_sub'], ds1['y_sub'], ds1['h_new'], ds1['V_new'], ds1['z_sub']):
                    list_of_point.append(Point(x, y))
                    if h > 0 :
                        list_of_point_h.append(h)
                        list_of_point_v.append(v)
                    else:
                        list_of_point_h.append(0)
                        list_of_point_v.append(0)

                    list_of_point_wse.append(WSE)
                    list_of_point_z.append(z)

                lenXSmin =np.min([len(ds1['x_sub']), len(ds2['x_sub'])])
                for i in range(lenXSmin - 1):  # boucle sur les points de la ligne
                    poly = Polygon([Point(ds1['x_sub'][i], ds1['y_sub'][i]),
                                    Point(ds1['x_sub'][i + 1], ds1['y_sub'][i + 1]),
                                    Point(ds2['x_sub'][i + 1], ds2['y_sub'][i + 1]),
                                    Point(ds2['x_sub'][i], ds2['y_sub'][i])]
                                   )
                    list_of_poly.append(poly)
                    list_of_WSE.append(WSE)
                    list_of_bed.append(ds1['z_sub'][i])
                    valueh.append(ds1['h_new'][i])
                    valueV.append(ds1['V_new'][i])

            # add the last section
            for x, y, h, v, z in zip(ds2['x_sub'], ds2['y_sub'], ds2['h_new'], ds2['V_new'], ds2['z_sub']):
                if h>0:
                    list_of_point_h.append(h)
                    list_of_point_v.append(v)
                else:
                    list_of_point_h.append(0)
                    list_of_point_v.append(0)

                list_of_point.append(Point(x, y))
                list_of_point_wse.append(WSE)
                list_of_point_z.append(z)

    list_for_interp = (list_of_point, list_of_point_h, list_of_point_v, list_of_point_wse, list_of_point_z)
    return [list_of_poly, valueh, valueV, list_of_WSE, list_for_interp,list_of_bed]


def interpolate_result_CEPHEE(data_poly, output_path, config_cal, params, crs=None ,DEM_file =None, qgstask=None):
    """interpolate the data list of polygon and point from create_poly_cephee

    :param data_poly: data from create_poly_cephee
    :type data_poly: list
    :param output_path: folder to save raster
    :type BV: string
    :param config_cal: additional name to raster
    :type BV: string
    :param pixel_size: resolution of output raster
    :type pixel_size: float
    :param crs: crs for the output raster
    :type crs: crs type

    :return:
        - dictionnary of name of raster result
    """

    pixel_size = params['H']['output_resolution']
    pixel_size = [pixel_size, pixel_size]
    [list_of_poly, _, _, _, list_for_interp,_] = data_poly
    list_of_point, list_of_point_h, list_of_point_v, list_of_point_wse, _ = list_for_interp
    multi_polygon = MultiPolygon(list_of_poly)

    if DEM_file:
        rasterDEM = resize_DEM(DEM_file, pixel_size, output_path)
        DEM = rasterDEM.read(1)
        height = rasterDEM.height
        width = rasterDEM.width
        transform = rasterDEM.transform
        crs = rasterDEM.crs
        minx, miny, maxx, maxy = rasterDEM.bounds
    else:
        # Calculer la bounding box du MultiPolygon
        minx, miny, maxx, maxy = multi_polygon.bounds
        height = int(abs(maxy - miny)/pixel_size[1])
        width = int(abs(maxx - minx)/pixel_size[0])
        transform = Affine.translation(minx, miny) * Affine.scale(pixel_size[0],  pixel_size[1])

    nodata = np.nan
    x_XS = [pt.x for pt in list_of_point]
    y_XS = [pt.y for pt in list_of_point]
    Xres, Yres = np.meshgrid(np.linspace(minx, maxx, width),
                             np.linspace(miny, maxy, height),
                             indexing='xy')

    Yres = np.flipud(Yres)
    interp_wse = LinearNDInterpolator(list(zip(x_XS, y_XS)), list_of_point_wse)
    WSE = interp_wse(Xres, Yres).astype(np.float32)
    if qgstask:
        qgstask.setProgress(66)

    if DEM_file:
        H=WSE-DEM
    else:
        interp_h = LinearNDInterpolator(list(zip(x_XS, y_XS)), list_of_point_h)
        H = interp_h(Xres, Yres).astype(np.float32)
    if qgstask:
        qgstask.setProgress(83)
    interp_V = LinearNDInterpolator(list(zip(x_XS, y_XS)), list_of_point_v)
    V = interp_V(Xres, Yres).astype(np.float32)
    V = np.where(H<=0,nodata,V)
    H[H <= 0] = nodata

    # Convertir MultiPolygon en GeoJSON (format lisible par rasterio)
    geom = [mapping(multi_polygon)]

    with rasterio.open(
            os.path.join(output_path, 'water_depth_' + config_cal + '.tif'),
            'w',
            driver='GTiff',
            height=height,
            width=width,
            count=1,
            dtype=rasterio.float32,  # Utiliser un type de données approprié pour les valeurs Z
            crs=crs,
            nodata=nodata,
            transform=transform
    ) as dst:
        dst.write(H, 1)

    with rasterio.open(os.path.join(output_path, 'water_depth_' + config_cal + '.tif')) as src:
        out_image, out_transform = mask(src, geom, crop=True)
        out_meta = src.meta.copy()
        # Mettre à jour les méta-données pour le raster croppé
        out_meta.update(
            {"driver": "GTiff", "height": out_image.shape[1], "width": out_image.shape[2], "transform": out_transform
             })

        # Sauvegarder le raster croppé dans un nouveau fichier
        with rasterio.open(os.path.join(output_path, 'water_depth_' + config_cal + '.tif'), "w", **out_meta) as dest:
            dest.write(out_image)

    with rasterio.open(
            os.path.join(output_path, 'velocity_' + config_cal + '.tif'),
            'w',
            driver='GTiff',
            height=height,
            width=width,
            count=1,
            dtype=rasterio.float32,  # Utiliser un type de données approprié pour les valeurs Z
            crs=crs,
            nodata=nodata,
            transform=transform
    ) as dst:
        dst.write(V, 1)

    with rasterio.open(os.path.join(output_path, 'velocity_' + config_cal + '.tif')) as src:
        out_image, out_transform = mask(src, geom, crop=True)
        out_meta = src.meta.copy()
        # Mettre à jour les méta-données pour le raster croppé
        out_meta.update(
            {"driver": "GTiff", "height": out_image.shape[1], "width": out_image.shape[2], "transform": out_transform
             })

        # Sauvegarder le raster croppé dans un nouveau fichier
        with rasterio.open(os.path.join(output_path, 'velocity_' + config_cal + '.tif'), "w", **out_meta) as dest:
            dest.write(out_image)

    with rasterio.open(
            os.path.join(output_path, 'WSE_' + config_cal + '.tif'),
            'w',
            driver='GTiff',
            height=height,
            width=width,
            count=1,
            dtype=rasterio.float32,  # Utiliser un type de données approprié pour les valeurs Z
            crs=crs,
            nodata=nodata,
            transform=transform
    ) as dst:
        dst.write(WSE, 1)

    with rasterio.open(os.path.join(output_path, 'WSE_' + config_cal + '.tif')) as src:
        out_image, out_transform = mask(src, geom, crop=True)
        out_meta = src.meta.copy()
        # Mettre à jour les méta-données pour le raster croppé
        out_meta.update(
            {"driver": "GTiff", "height": out_image.shape[1], "width": out_image.shape[2], "transform": out_transform
             })

        # Sauvegarder le raster croppé dans un nouveau fichier
        with rasterio.open(os.path.join(output_path, 'WSE_' + config_cal + '.tif'), "w", **out_meta) as dest:
            dest.write(out_image)
    # Rasterize vector using the shape and coordinate system of the raster
    # Get list of geometries for all features in vector file
    raster_result_path = {'depth': os.path.join(output_path, 'water_depth_' + config_cal + '.tif'),
                          'velocity': os.path.join(output_path, 'velocity_' + config_cal + '.tif'),
                          'WSE': os.path.join(output_path, 'WSE_' + config_cal + '.tif')}

    return raster_result_path


def rasterize_poly_CEPHEE(data_poly, output_path, config_cal, params, DEM_file, crs=None, qgstask=None):
    """rasterize  the data list of polygon and point from create_poly_cephee

    :param data_poly: data from create_poly_cephee
    :type data_poly: list
    :param output_path: folder to save raster
    :type BV: string
    :param config_cal: additional name to raster
    :type BV: string
    :param pixel_size: resolution of output raster
    :type pixel_size: float
    :param DEM_file: filename of the DEM raster to compute water depth from water elevation raster
    :type pixel_size: str
    :param crs: crs for the output raster
    :type crs: crs type


    :return:
        - dictionnary of name of raster result
    """
    [list_of_poly, valueh, valueV, list_of_WSE, _,list_of_bed] = data_poly

    pixel_size = params['H']['output_resolution']
    pixel_size =[pixel_size,pixel_size]
    if DEM_file:

        rasterDEM = resize_DEM(DEM_file,pixel_size,output_path)
        DEM = rasterDEM.read(1)
        height = rasterDEM.height
        width = rasterDEM.width
        transform = rasterDEM.transform
        crs = rasterDEM.crs

    else:
        multi_polygon = MultiPolygon(list_of_poly)
        # Calculer la bounding box du MultiPolygon
        minx, miny, maxx, maxy = multi_polygon.bounds
        height = abs(maxy - miny)
        width = abs(maxx - minx)
        transform = Affine.translation(minx, maxy) * Affine.scale(pixel_size[0], -1 * pixel_size[1])

    # Créer une matrice d'image raster avec rasterio
    with rasterio.open(
            os.path.join(output_path, 'water_depth_' + config_cal + '.tif'),
            'w',
            driver='GTiff',
            height=height,
            width=width,
            count=1,
            dtype=rasterio.float32,  # Utiliser un type de données approprié pour les valeurs Z
            crs=crs,
            transform=transform
    ) as dst:
        # Rasterize vector using the shape and coordinate system of the raster
        # Get list of geometries for all features in vector file

        geom_valuesV = ((poly, value) for poly, value in zip(list_of_poly, valueV))
        V_interp= rasterio.features.rasterize(geom_valuesV,
                                                 out_shape=dst.shape,
                                                 fill=np.nan,
                                                 out=None,
                                                 transform=dst.transform,
                                                 all_touched=True,
                                                 merge_alg=MergeAlg.replace,
                                                 dtype=None)
        #V_interp = np.ma.masked_where(np.isnan(rasterized)==1, rasterized)
        if qgstask:
            qgstask.setProgress(66)
        geom_valuesWSE = ((poly, value) for poly, value in zip(list_of_poly, list_of_WSE))
        WSE_interp = rasterio.features.rasterize(geom_valuesWSE,
                                                 out_shape=dst.shape,
                                                 fill=np.nan,
                                                 out=None,
                                                 transform=dst.transform,
                                                 all_touched=True,
                                                 merge_alg=MergeAlg.replace,
                                                 dtype=None)
        if qgstask:
            qgstask.setProgress(83)
        if DEM_file:
            WSE_interp = np.where(WSE_interp <= DEM, DEM, WSE_interp)
            h_interp = WSE_interp - DEM

        else:
            geom_valuesh = ((poly, value) for poly, value in zip(list_of_poly, valueh))
            rasterized = rasterio.features.rasterize(geom_valuesh,
                                                     out_shape=dst.shape,
                                                     fill=np.nan,
                                                     out=None,
                                                     transform=dst.transform,
                                                     all_touched=True,
                                                     merge_alg=MergeAlg.replace,
                                                     dtype=None)
            h_interp = rasterized.astype(np.float32)

        dst.write(h_interp, 1)

        with rasterio.open(os.path.join(output_path, 'velocity_' + config_cal + '.tif'), 'w', driver='GTiff',
                           height=height, width=width,
                           count=1, dtype=rasterio.float32, crs=crs, transform=transform) as dst:

            dst.write(V_interp.astype(np.float32), 1)

        with rasterio.open(os.path.join(output_path, 'WSE_' + config_cal + '.tif'), 'w', driver='GTiff', height=height,
                           width=width,
                           count=1, dtype=rasterio.float32, crs=crs, transform=transform) as dst:

            dst.write(WSE_interp.astype(np.float32), 1)

        raster_result_path = {'depth': os.path.join(output_path, 'water_depth_' + config_cal + '.tif'),
                              'velocity': os.path.join(output_path, 'velocity_' + config_cal + '.tif'),
                              'WSE': os.path.join(output_path, 'WSE_' + config_cal + '.tif')}

        return raster_result_path
        
def export_GISformat(BV,param,pathname,format):
    """create csv file for plotting cephee result in GIS format other than Qgis (only banks positions)

    :param BV: Watershed
    :type BV: ModelCatchment
    :param param: Parameters requires for computation
    :type param: Parameters
    :param pathname: folder to write cvs file
    :type pathname: str
    :param format: type of export
    :type format: str

    """
        
    if format=='GoogleEarth' :
        source_epsg = BV.crs
        filename = pathname +'/GISformat.csv'
    
        x=[]
        y=[]
        z=[]
        w=[]
        ind_river = BV.outlet[BV.id_outlet][1]
        for i in range(len(BV.reach)):
            if BV.reach[i].geodata['River'] in ind_river:
                for j in range(len(BV.reach[i].section_list)):
                    x.append(BV.reach[i].section_list[j].Wsextent[0][0])
                    x.append(BV.reach[i].section_list[j].Wsextent[1][0])
                    y.append(BV.reach[i].section_list[j].Wsextent[0][1])
                    y.append(BV.reach[i].section_list[j].Wsextent[1][1])
                    z.append(BV.reach[i].section_list[j].WS)
                    z.append(BV.reach[i].section_list[j].WS)
                    w.append(BV.reach[i].section_list[j].width)
                    w.append(BV.reach[i].section_list[j].width)
        
        x=np.array(x)
        y=np.array(y)
        z=np.array(z)
        w=np.array(w)
        XY=np.zeros((len(x),4))
        for i in range(len(x)):
            lon, lat = convert_projected_to_latlon(x[i], y[i], source_epsg)
            XY[i,0]=lat
            XY[i,1] =lon
            XY[i,2] =z[i]
            XY[i,3] =w[i]
        
        df = pd.DataFrame(XY,columns=['Latitude','Longitude','Elevation','Width'])                        
        df.to_csv(filename)   


def export_map(BV1,params,config_cal):
    """create the result as raster map
    :param BV: Watershed
    :type BV: ModelCatchment
    :param param: Parameters requires for computation
    :type param: Parameters
    :param config_cal: name of computation with parameter value
    :type pathname: str


    """
    res = params['H']['output_resolution']
    if params['H']['mapping_method'] == 'interpolation':
        data_poly = create_poly_CEPHEE(BV1, params,res)
        raster_result_interp = interpolate_result_CEPHEE(data_poly, params.work_path, config_cal, params,
                                                         params['C']['DEM_CRS_ID'],BV1.global_DEM_path)
    elif params['H']['mapping_method'] == 'by_polygon':
        data_poly = create_poly_CEPHEE(BV1, params,res)
        raster_result_path = rasterize_poly_CEPHEE(data_poly, params.work_path, config_cal,
                                                   params,
                                                   BV1.global_DEM_path)
    elif params['H']['mapping_method'] == 'selafin':
        data_poly = create_poly_CEPHEE(BV1, params, res)
        result_slf(BV1, params, params.work_path, config_cal, params['C']['DEM_CRS_ID'])


def resize_DEM(DEM_file,pixel_size,output_path):
    """ change resolution of dem to fit result raster

        :param DEM_file: initial raster
        :type DEM_file: str, pickle,modelCatchment
        :param params: Parameters requires for computation
        :type params: Parameters
        :hydro_lib : librairy for hydrological computation
        :type hydro_lib: str
        :nochannel : for calcultation without DEM data in main channel (water mask)
        :type nochannel: bool
        """

    src = rasterio.open(DEM_file, 'r')
    upscale_factorx = src.transform[0] / pixel_size[0]
    upscale_factory = abs(src.transform[4]) / pixel_size[1]

    dst = src.read(
        out_shape=(
            src.count,
            int(src.height * upscale_factory),
            int(src.width * upscale_factorx)
        ),
        resampling=Resampling.bilinear
    )

    # scale image transform
    transform = src.transform * src.transform.scale(
        (src.width / dst.shape[-1]),
        (src.height / dst.shape[-2]))
    new_meta = src.meta.copy()
    new_meta.update({"driver": "GTiff",
                     "height": dst.shape[-2],
                     "width": dst.shape[-1],
                     "transform": transform,
                     "crs": src.crs,
                     })

    with rasterio.open(os.path.join(output_path, 'DEM_reshape.tif'), 'w', **new_meta) as out_n:
        out_n.write(dst), 1
    src.close()
    rasterDEM = rasterio.open(os.path.join(output_path, 'DEM_reshape.tif'))


    return rasterDEM

def process_section(section, WSE, pixel_size):
    Xt = section.coord.array['Xt']
    B = section.coord.values['B']
    M = section.coord.values['M']
    X = section.coord.array['X']
    Y = section.coord.array['Y']

    results = {}

    for side, mask in {'left': Xt < 0, 'right': Xt >= 0}.items():
        dist = Xt[mask]#+section.d0
        v = M[mask]
        h = WSE - B[mask]
        x = X[mask]
        y = Y[mask]
        z = B[mask]

        if len(dist) < 2:
            if len(dist) == 0:
                dist, v, h, x, y, z = [0], [0], [0], [0], [0], [0]
            d_new = [dist[0] - pixel_size, dist[0] + pixel_size]
            V_new = [v[0], v[0]]
            h_new = [v[0], h[0]]
            x_sub = [x[0], x[0]]
            y_sub = [y[0], y[0]]
            z_sub = [z[0], z[0]]
        else:
            d_new = np.arange(np.min(dist), np.max(dist), pixel_size)
            V_new = np.interp(d_new, dist, v)
            h_new = np.interp(d_new, dist, h)
            x_sub = np.interp(d_new, dist, x)
            y_sub = np.interp(d_new, dist, y)
            z_sub = np.interp(d_new, dist, z)


        results[side] = {
        'd_new': d_new,
        'V_new': V_new,
        'h_new': h_new,
        'x_sub': x_sub,
        'y_sub': y_sub,
        'z_sub': z_sub
        }

    combined = {key: np.hstack((results['left'][key], results['right'][key])) for key in results['left']}


    return combined


