######################################################################
#             / ____|  ____|  __ \| |  | |  ____|  ____|             #
#            | |    | |__  | |__) | |__| | |__  | |__                #
#            | |    |  __| |  ___/|  __  |  __| |  __|               #
#            | |____| |____| |    | |  | | |____| |____              #
#             \_____|______|_|    |_|  |_|______|______|             #
######################################################################
#
# CEPHEE
# Copyright (C) 2024 Toulouse INP
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License for more details :
# <http://www.gnu.org/licenses/>.
#
######################################################################

import os
#import cv2
import numpy as np
import rasterio.features
from shapely.geometry import LineString,MultiLineString,Polygon,MultiPolygon
from geopandas import GeoDataFrame


def find_dikes(param, DEMfile, areaMin=100, output_name=None, method='avg', width_range=None, height_range=None):
    # ======================
    # read MNT
    # =======================
    raster = rasterio.open(DEMfile)
    DTM = raster.read(1)
    DTM = np.ma.masked_where(DTM == raster.nodatavals, DTM)
    # keep only selected contour for rastering
    Data_raster_all = np.zeros(DTM.shape, np.uint8)
    areaMin = areaMin * (raster.transform[0] ** 2)  # conversion en px2

    for width in width_range:
        for height in height_range:
            print(width, height)

            if method == 'avg':
                size_kernel_avg = int(3 * width / raster.transform[0])  # conversion in pixel and enlarge windows
                kernel_avg = np.ones((size_kernel_avg, size_kernel_avg), np.uint8)
                DTM_avg = cv2.filter2D(DTM, -1, kernel_avg) / size_kernel_avg ** 2
                DTM_avg[DTM_avg < 0] = 0
                DTM_avg_bin = np.zeros(DTM.shape, np.uint8)
                DTM_avg_bin = np.where(DTM > DTM_avg + height, 255, 0).astype(np.uint8)
                Data_raster = DTM_avg_bin

            elif method == 'gradient':
                # conversion parameter
                treshold_slope_min = height / raster.transform[0]  # en m/m
                treshold_slope_max = 6  # en m/m to remove contour slope
                print('find slope between ' + str(treshold_slope_min) + ' and ' + str(treshold_slope_max))
                # gradient computation
                try:
                    gradDTM = cv2.Laplacian(DTM, cv2.CV_32F)
                except:
                    gradDTM = cv2.Laplacian(DTM, cv2.CV_64F)

                # transform gradient in m/m
                gradDTM = gradDTM / raster.transform[0]
                Data_raster = abs(gradDTM)
                Data_raster[Data_raster <= treshold_slope_min] = 0
                Data_raster[Data_raster > treshold_slope_max] = 0
                Data_raster[Data_raster > 0] = 1
                Data_raster = (Data_raster * 255).astype(np.uint8)

            # contour detection------------------------------

            contours, hierarchy = cv2.findContours(Data_raster, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
            selected_contours = []
            areas = []

            for contour in contours:
                area = cv2.contourArea(contour)
                if area > areaMin:
                    areas.append(area)
                    selected_contours.append(contour)

            if len(areas) > 1:
                # remove larger area (external contour)
                indArea = np.argmax(areas)
                selected_contours.pop(indArea)
                areas.pop(indArea)
                indArea = np.argmax(areas)
                selected_contours.pop(indArea)

            mask = np.zeros(Data_raster.shape, np.uint8)
            Data_raster = cv2.fillPoly(mask, pts=selected_contours, color=(255, 255, 255))
            Data_raster_all = cv2.bitwise_or(Data_raster_all, Data_raster)


    # filtrage du raster total
    contours, hierarchy = cv2.findContours(Data_raster_all, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    selected_contours = []
    areas = []

    for contour in contours:
        area = cv2.contourArea(contour)
        if area > areaMin:
            areas.append(area)
            selected_contours.append(contour)

    print(len(contours), len(selected_contours))

    if len(areas) > 1:
        # remove larger area (external contour)
        indArea = np.argmax(areas)
        selected_contours.pop(indArea)
        areas.pop(indArea)
        indArea = np.argmax(areas)
        selected_contours.pop(indArea)

    mask_all = np.zeros(Data_raster_all.shape, np.uint8)
    Data_raster_all = cv2.fillPoly(mask_all, pts=selected_contours, color=(255, 255, 255))

    multi_lines = GeoDataFrame()
    multi_polygones = GeoDataFrame()

    ncount = 0
    # list_of_lines =[]
    # list_of_polygons =[]
    for contour in selected_contours:
        cont = np.squeeze(contour)
        # ouverture de la ligne

        X = [x * raster.transform[0] + raster.transform[2] for x in cont[:, 0]]
        Y = [y * raster.transform[4] + raster.transform[5] for y in cont[:, 1]]

        line = LineString([(x, y) for x, y in zip(X, Y)])
        open_line = LineString([(x, y) for x, y in zip(X[:-1], Y[:-1])])

        # Vérifie si la LineString est fermée
        if not line.is_ring:
            # Ferme la LineString en ajoutant le premier point à la fin
            closed_line = LineString(list(line.coords) + [line.coords[0]])
        else:
            closed_line = line

        # list_of_lines.append(open_line)
        # Convertir la LineString fermée en un Polygon
        # list_of_polygons.append(Polygon(closed_line))

        multi_polygones.loc[ncount, 'geometry'] = Polygon(closed_line)
        multi_lines.loc[ncount, 'geometry'] = open_line
        ncount += 1

    if not output_name:
        output_name = 'detected_structure'

    multi_lines.set_crs(crs=raster.crs)
    multi_polygones.set_crs(crs=raster.crs)
    multi_lines.to_file(os.path.join(output_name + '.shp'))
    multi_polygones.to_file(os.path.join(output_name + 'poly.shp'))

    return Data_raster, len(contours)
    
    
def density_zone(DEMfile,param,min_meshsize=1, max_meshsize=500,size_kernel_filter=3):


    raster = rasterio.open(DEMfile)
    DTM =raster.read(1)
    DTM =np.ma.masked_where(DTM == raster.nodatavals,DTM)

    treshold_slope_max = param['P']['overtopping']/param['P']['averageSize']

    try: 
        gradDTM = cv2.Laplacian(DTM,cv2.CV_32F)
    except:
        gradDTM = cv2.Laplacian(DTM,cv2.CV_64F)

    #transform gradient in m/m      
    gradDTM  = abs(gradDTM) /raster.transform[0]
    gradDTM[gradDTM >  treshold_slope_max/raster.transform[0] ] = treshold_slope_max/raster.transform[0]
            # en m/m to remove contour slope]    
    sigma = int(param['P']['averageSize']/raster.transform[0]) #conversion in pixel
    kernel_filter_gauss = cv2.getGaussianKernel(size_kernel_filter, sigma)
    H = cv2.filter2D(gradDTM ,-1, kernel_filter_gauss)
    #xedges =np.linspace(raster.transform[2], DTM.shape[1]*raster.transform[0]+raster.transform[2],int(DTM.shape[1]))
    #yedges= np.linspace(DTM.shape[0]*raster.transform[4]+raster.transform[5],raster.transform[5],int(DTM.shape[0]))
    
          #mise à l'échelle de tailles spécifiées
    NormH =1-(H-np.min(H))/(np.max(H)-np.min(H))
    density_map= NormH *(max_meshsize-min_meshsize)+ min_meshsize
    
    
    out_name = os.path.join(param.work_path, 'density_map')
    with rasterio.open(out_name, 'w', raster.profile) as dst:
        dst.write(density_map,1)
    
    return out_name


def flat_zone(BV,param,DEMfile):
    
    #======================
    #read MNT
    #=======================
    raster = rasterio.open(DEMfile)
    DTM =raster.read(1)
    DTM =np.ma.masked_where(DTM== raster.nodatavals,DTM) 
    
    try: 
        water_mask= cv2.Laplacian(DTM,cv2.CV_32F)
        water_mask  = water_mask/raster.transform[0]
        water_mask= cv2.Laplacian(water_mask,cv2.CV_32F)
        water_mask  = water_mask/raster.transform[0]
    except:
        water_mask= cv2.Laplacian(DTM,cv2.CV_64F)
        water_mask  = water_mask /raster.transform[0]
        water_mask= cv2.Laplacian(water_mask,cv2.CV_64F)
        water_mask = water_mask/raster.transform[0]
  
    water_mask[abs(water_mask)<=eps]=1
    water_mask[water_mask!=1]=0
    
    out_name = os.path.join(param.work_path, 'water_mask')
    with rasterio.open(out_name, 'w', raster.profile) as dst:
        dst.write(water_mask,1)
    
    return outname
