
import geopandas as gpd
import numpy as np
from pathlib import Path
from .paras_2 import diameterFromHeight,decay_tree_potential
from .geotools2 import pointDelineate,optimized_treesegmentation,clipBigVRTrasterByFeature,sampleValues,valuesByDistance
from .ecoindices import diversityIndices
from .geostats import getisord,diversityIndicator
from qgis.core import QgsProcessingFeedback
import tempfile
from datetime import datetime
from osgeo import gdal
import pandas as pd

def crsValidation(layer,ref_layer):
    "this convert crs to wanted crs"
    crs = ref_layer.crs
    converted_layer = layer.to_crs(crs)

    return converted_layer
 
def calcArea(layer):
    layer['area_ha'] = np.round(layer.geometry.area / 10000,3)
    return layer

def mapTrees(chm,layer,minheight):
    
    feedback = QgsProcessingFeedback()
    
    # Define output
    output_fold = tempfile.TemporaryFile()
    output_fold = output_fold.name+"\\"

    output = tempfile.TemporaryFile(prefix='raster_',suffix='.vrt')
    output = output.name

    if 'id' not in layer.columns:
        layer['id'] = range(1, len(layer) + 1)
    clipBigVRTrasterByFeature(chm,"id",layer,output_fold,output)
    now = datetime.now().strftime("%H:%M:%S")
    feedback.pushInfo(now+"\tDetecting individual trees from CHM and segmenting tree canopies")
    
    folder_path = Path(output_fold)
    #print (folder_path)
    treelist = []
    standcount = len(layer)
    for c,file in enumerate(folder_path.iterdir()):
        #print(file)
        if file.is_file():
            ts = optimized_treesegmentation(str(file),minheight,None)
            feedback.pushInfo("Mapping trees: "+str(c)+"/"+str(standcount))
            #treepoints = ts[1]
            treelist.append(ts[1])

    combined_trees = gpd.GeoDataFrame(pd.concat(treelist, ignore_index=True))
    #drop spatial duplicates
    #print (layer)
    layer = calcArea(layer)
    layer = layer[["area_ha","geometry"]]
    combined_trees = gpd.sjoin(combined_trees, layer, how="inner", predicate="intersects")
    combined_trees = combined_trees.rename({"index_right":"site_id"},axis=1)
    
    return combined_trees

def joinVectorData(trees,vegezone,grid):
    
    trees_joined = gpd.sjoin(trees, vegezone, how="left", predicate="intersects")
    trees_joined = trees_joined.drop(["index_right"],axis=1)
    trees_joined = gpd.sjoin(trees_joined, grid, how="left", predicate="intersects")
    trees_joined = trees_joined.drop(["index_right"],axis=1)

    return trees_joined

def assignTreespecies(layer,h_field,tsh_fields):
    # Find the column with the closest value to 'd'
    #fields = tsh_fields.append(h_field)
    layer['treespecie'] = layer[tsh_fields].apply(lambda row: row.index[(row - layer.loc[row.name, h_field]).abs().argmin()], axis=1)
    #layer['treespecie'] = layer[tsh_fields].apply(lambda row:[i f] , axis=1)
    return layer

def assignHeightGroup(h_field,h_mean):
    #h_mean = np.mean(layer[h_field])
    if h_field< h_mean*0.9:
        h_group = 1
    elif h_field> h_mean*1.1:
        h_group = 3
    else:
        h_group = 2
    
    return h_group

def calcDWP(zone,dbh,fertilityclass,treespecies):

    zone = 'zone'+str(zone)
    decay_params = decay_tree_potential(zone)
    if fertilityclass>6:
        fertilityclass=6
    
    if treespecies==29:
        treespecies=3
    if (fertilityclass <= 0) or (treespecies <=0):
        dwp = 0.0
    else:
        dwp = np.poly1d(decay_params[fertilityclass][treespecies])(dbh)
        
    if dwp > 2:
        dwp = 2.0

    return dwp

def forestType(treelist):
    count = len(treelist)
    pine_count = len([t for t in treelist if t==1])
    spruce_count = len([t for t in treelist if t==2])
    deci_count = count - pine_count - spruce_count
    probs = [pine_count/count,spruce_count/count,deci_count/count]
    species = ['Pine','Spruce','Deciduous']
    rank_1 = max(probs)
    if rank_1>=0.50:
        name_id = probs.index(max(probs))
        ftype = species[name_id]
    else:
        ftype = 'Mixed'
    
    return ftype



def treeDeduct(sites,vege_gdb,grid_gpd,chm,dtw):
    layer = crsValidation(sites,sites)
    layer = calcArea(layer)
    trees = mapTrees(chm,layer,2)
    trees = joinVectorData(trees,vege_gdb,grid_gpd)
    trees.columns = trees.columns.str.lower()
    trees = trees[trees.columns][trees['fertilityclass']>0]
    trees = trees[~trees.duplicated(subset='geometry')]
    trees = trees.rename({"meanheightpine":1,
                        "meanheightspruce":2,
                        "meanheightdeciduous":29},axis=1)
    trees = trees.fillna(-9999)
    trees = assignTreespecies(trees,"height",[1,2,29])
    trees = trees.drop([1,2,29],axis=1)

    trees = valuesByDistance(trees,"treespecie",20)
    trees['foresttype'] = trees['list_treespecie'].apply(lambda x: forestType(x))
    trees = trees.drop(['list_treespecie','distances'],axis=1)
    trees['dbh'] = trees.apply(lambda x: diameterFromHeight(x['treespecie'],x['height'],3),axis=1)
    trees['sitehmean'] = trees.groupby("site_id")['height'].transform('mean')
    #trees['h_group'] = trees.apply(lambda x: assignHeightGroup(x['height'],x['sitehmean']),axis=1)
    #trees['hgroup_species'] = trees.apply(lambda x: int(str(x['treespecie'])+str(x['h_group'])),axis=1)

    dtw_data = gdal.Open(dtw)
    trees = sampleValues(trees,dtw_data,"dtw")
    dtw_data = None
    trees['dtw'] = trees['dtw'] / 100

    return trees

   
def ecosystemIndicators(trees,ds_radius,gini_radius):
    # Ecological indicators
    trees["edtw"] = np.log10((np.max(trees['dtw']) -trees['dtw']) +1)
    trees["dwp"] = trees.apply(lambda x:calcDWP(x.paajakonro,x.dbh,x.fertilityclass,x.treespecie),axis=1)

    trees = diversityIndices(trees,"treespecie",ds_radius,["simpson"])
    trees = trees.rename({"simpson":"ds"},axis=1)
    trees['deci_prop'] = trees['list_treespecie'].apply(lambda x: len([s for s in x if s in [29]]) / len(x))
    trees = trees.drop(["list_treespecie","nearindices"],axis=1)
    trees = diversityIndicator(trees,"height",gini_radius,'gini',1)
    trees = trees.rename({'gini'+str(gini_radius)+'height':'gini'},axis=1)

    return trees

def spatialPlanning(trees,treecount,hs_radius,weights):
    #calculating hotspot values
    collect = []
    indices = ['edtw','dwp','ds','gini']
    fields = trees.columns
    groupvalues = sorted(trees['site_id'].unique())
    for g in groupvalues:
        # selecting rows based on condition 
        #uniques = combined_point.loc[combined_point['site']==g]
        uniques = trees[fields][trees['site_id']==g]
        for i in indices:
            uniques = getisord(uniques,i,hs_radius,"gaussian")
        collect.append(uniques)

    combined_point = gpd.GeoDataFrame(pd.concat(collect, ignore_index=True))


    #GIS-automated multidecision analysis (gama)


    #first phase: weighting P = primary, S = secondary, and A = additional parameters
    weightlist = ['P','S','A']
    counts = [list(weights.values()).count(i) for i in weightlist]
    indicators = {}
    cols = [col for col in combined_point.columns]
    cols.append('gama')
    #lists for 2nd phase
    weights2nd = [0.5,0.3,0.2]
    sum_keys = []
    for c,i in enumerate(counts):
        if i >0:
            indic = [k+"_GiZ" for k,v in list(weights.items()) if v==weightlist[c]]
            #indicators[weightlist[c]] = indic
            sumname = weightlist[c]+"_sum"
            sum_keys.append(sumname)
            combined_point[sumname] = combined_point.apply(lambda x: weights2nd[c]*np.sum([x[iv]*(1/i) for iv in indic]),axis=1)
    combined_point['gama'] = combined_point[sum_keys].sum(axis=1)
    combined_point = combined_point[cols]

    #select and delineate trees
    combined_point["rtreecount"] = np.ceil(combined_point["area_ha"] * treecount).astype("int")

    agg = {'site_id':'first',
        'height':'mean',
        'dbh':'mean',
        'deci_prop':'mean',
        'ds':'mean',
        'gini':'mean',
        'dwp':'mean',
        'edtw':'mean',
        'ds_GiZ':'mean',
        'gini_GiZ':'mean',
        'dwp_GiZ':'mean',
        'edtw_GiZ':'mean',
        'treespecie':list}
    selection_df = combined_point
    selection_df['proposed'] = (combined_point.groupby('site_id')['gama'].rank(method="first",ascending=False)<=combined_point["rtreecount"]).astype("int")
    selection_df = selection_df[selection_df.columns][selection_df['proposed']==1]

    groups = sorted(combined_point['site_id'].unique())
    retentiotrees = []
    for g in groups:
        group_gdf = selection_df[selection_df.columns][selection_df['site_id'] == g]
        ret = pointDelineate(group_gdf,5,agg)
        retentiotrees.append(ret)

    retention = gpd.GeoDataFrame(pd.concat(retentiotrees, ignore_index=True))

    return combined_point,retention