import boto3
import sqlalchemy
import geopandas as gpd
from lxml import etree
import pandas as pd
from copy import deepcopy

from domain.helpers.db_secret_string_manager import build_connection_string
from utilities.dcmgeometrysdk.dcmgeometry.geometry import Geometries
from utilities.dcmgeometrysdk.dcmgeometry.polygons import PolygonGeom
from utilities.dcmgeometrysdk.dcmgeometry.points import PointGeom
import json
from utilities.dcmgeometrysdk.geometryfunctions.transformationfunctions import helmert_trans_unknown, rubber_sheet, \
    calc_distance
from utilities.dcmgeometrysdk.iuf.iuf_full import parseString
from pathlib import Path
import shapely.geometry as sg
from shapely.ops import nearest_points
from shapely import force_2d
import shapely.affinity as sa
import shapely.ops as so

def build_polygons(data):
    insert_polygons = {}
    for insert in data.INSERT + data.REPLACE:

        for poly in insert.POLYGON:
            for pn in poly.POLYGON_NAME:
                for pi in pn.POLYGON_ITEM:
                    polygon = {}

                    for ph in pi.HEADER:
                        polygon['PFI'] = ph.PFI
                        polygon['table'] = ph.TABLE

                    exist_poly = insert_polygons.get(polygon.get('PFI'))
                    for pa in pi.ATT:
                        att, val = pa.split('=')
                        polygon[att] = val.strip('"')

                    if not exist_poly:
                        exist_poly = polygon
                    connects = exist_poly.get('connects', [])
                    for connect in pi.CONNECT:
                        att, val = connect.split('=')

                        connects.append((att, val.strip('"')))
                    exist_poly['connects'] = connects
                    outers = []
                    for out in pi.OUTER:
                        outer = []
                        for coord in out.XY:
                            x, y = coord.split(',')
                            x = float(x)
                            y = float(y)
                            outer.append(round_tuple((x,y)))
                        outers.append(outer)
                    inners = []
                    for inners in pi.INNER:
                        inner = []
                        for coord in inners.XY:
                            print(coord)
                            x, y = coord.split(',')
                            x = float(x)
                            y = float(y)
                            inner.append(round_tuple((x,y)))
                        if len(inner) == 0:
                            inners.append(inner)
                    pss = []
                    inner_pols = [sg.Polygon(inner) for inner in inners]
                    for outer in outers:
                        outer_pol = sg.Polygon(shell=outer)
                        holes = []
                        for inner_pol in inner_pols:
                            if inner_pol.within(outer_pol):
                                inner = inner_pol.coords[:]
                                holes.append(inner)
                        if len(holes) > 0:
                            ps = sg.Polygon(shell=outer, holes=holes)
                        else:
                            ps = outer_pol

                        pss.append(ps)

                    epss = exist_poly.get('geometry')
                    if epss is not None:
                        pss.extend([g for g in epss.geoms])
                    exist_poly['geometry'] = sg.MultiPolygon(pss)
                    exist_poly['type_'] = insert.original_tagname_
                    insert_polygons[exist_poly['PFI']] = exist_poly
    return insert_polygons

def build_a_spatial(data):
    aspatial = {}
    for insert in data.INSERT + data.REPLACE:
        for poly in insert.ASPATIAL:
            for pn in poly.ASPATIAL_NAME:
                for pi in pn.ASPATIAL_ITEM:
                    polygon = {}
                    for ph in pi.HEADER:
                        polygon['PFI'] = ph.PFI
                        polygon['table'] = ph.TABLE

                    exist_poly = aspatial.get(polygon.get('PFI'))
                    for pa in pi.ATT:
                        att, val = pa.split('=')
                        polygon[att] = val.strip('"')
                    if not exist_poly:
                        exist_poly = polygon
                    connects = exist_poly.get('connects', [])
                    for connect in pi.CONNECT:
                        att, val = connect.split('=')
                        connects.append((att, val.strip('"')))
                    exist_poly['connects'] = connects
                    aspatial[exist_poly['PFI']] = exist_poly

    return aspatial

def create_dcm_point(x):
    p = PointGeom()
    # if x.coords_right is None or pd.isna(x.coords_right):
    p.geometry = x.geometry

    p.original_geom = sg.Point(x.original_coords)
    p.original_crs = 7844
    p.name = str(x.id)
    p.point_oid = x.id
    p.point_type = 'boundary'
    p.crs = 7855
    return p

def create_dcm_polygon(x, points):
    test = PolygonGeom()
    #geom = x.geometry.simplify(1, preserve_topology=True)
    test.create_polygon(x.geometry, name=x.poly_name,
                        crs=7855, points=points, coord_decimals=3)
    test.crs = 7855
    return test

def round_tuple(x):
    return tuple(round(float(i), 9) for i in x)


def explode_round_clean_gdf(gdf):
    base_gdf = gdf
    base_gdf['coords'] = base_gdf.geometry.apply(lambda x: sg.mapping(force_2d(x))['coordinates'])
    base_gdf = base_gdf.explode('coords').explode('coords').reset_index()
    base_gdf['coords'] = base_gdf.coords.apply(round_tuple)
    base_gdf = gpd.GeoDataFrame(base_gdf, geometry=base_gdf.coords.apply(lambda x: sg.Point(x)), crs=7855)
    base_gdf['wkt'] = base_gdf.geometry.apply(lambda x: x.wkt)
    base_gdf = base_gdf.reset_index()[['index', 'coords', 'geometry', 'wkt']]
    base_gdf.drop_duplicates(['coords'], keep='first', inplace=True)
    base_gdf.rename(columns={'index': 'id'}, inplace=True)
    return base_gdf


def create_target(x):
    tg = PolygonGeom()
    tg.create_polygon(x.geometry, crs=7855)
    return tg


local = True

if local in [True, 'True']:
    session = boto3.Session(profile_name='work-uat')
    connection = build_connection_string(session, 'dcdb_cad_user', local=True, local_port=5433)
    event_connection = build_connection_string(session, 'dcdb_event_user', local=True, local_port=5433)
else:
    session = boto3.Session()
    connection = build_connection_string(session, 'dcdb_cad_user', return_engine=True)
    event_connection = build_connection_string(session, 'dcdb_event_user', return_engine=True)

infile_str = '/Users/jamesleversha/Downloads/work_files/IUF/inputs/bayside_iuf_data_gda2020_15-jan-2023.xml'
infile = Path(infile_str)
root = etree.parse(infile)
outpath = '/Users/jamesleversha/Downloads/work_files/IUF/outputs'

gdfs = []
inshapes = ['/Users/jamesleversha/Downloads/work_files/IUF/inputs/parcel_view.gpkg']

for in_shape in inshapes:
    gdf = gpd.read_file(in_shape)[['PFI', 'geometry']]
    gdfs.append(gdf)

gdf = pd.concat(gdfs)
gdf['PFI'] = gdf.PFI.astype('Int64')
gdf = gdf.explode(index_parts=True)
all_polys = gdf.copy()
all_polys['geometry'] = gdf.geometry.apply(lambda x: force_2d(x))
all_polys.to_crs(7855, inplace=True)

data = parseString(etree.tostring(root), silence=True, print_warnings=False)
information = build_polygons(data)

f = [{'coords': tuple(float(a) for a in i.text.split(',')), 'text_coords': i.text} for i in root.findall('.//XY')]
df = pd.DataFrame(f)
df = gpd.GeoDataFrame(df, geometry=df.coords.apply(lambda x: sg.Point(x)), crs=7844)
df.to_crs(7855, inplace=True)
df['original_coords'] = df['coords']
df = df.reset_index()[['index', 'coords', 'original_coords', 'text_coords', 'geometry']]
df.drop_duplicates(['text_coords'], keep='first', inplace=True)
df.rename(columns={'index': 'id'}, inplace=True)
df['distance'] = -1
df['dcm_geom'] = df.apply(create_dcm_point, axis=1)
df['original_coords_wkt'] = df['original_coords'].apply(lambda x: sg.Point(x).wkt)
all_geoms = Geometries()
all_geoms.points = {i.name: i for i in df.dcm_geom.to_list()}
all_geoms.crs = 7855
all_geoms.survey_year = 2022

# data to dataframe, exploding multis to individual polygons
explode_df = pd.DataFrame([item for item in information.values()])
explode_df = gpd.GeoDataFrame(explode_df, geometry=explode_df['geometry'], crs=7844).explode(index_parts=True)
explode_df.to_crs(7855, inplace=True)
explode_df['PFI'] = explode_df.PFI.astype('Int64')
explode_df.drop_duplicates('PFI', inplace=True, keep=False)
explode_df['poly_name'] = explode_df['PFI']
explode_df = explode_df[['PFI', 'poly_name', 'geometry']].reset_index(drop=True)
explode_df['dcm_polygon'] = explode_df.apply(lambda x: create_dcm_polygon(x, all_geoms.points), axis=1)

# generate polygons within the

all_geoms.polygons = {i.name: i for i in explode_df.dcm_polygon.to_list()}
all_geoms.add_lines_from_polygons()
all_geoms.gen_graph_from_points_lines(add_unconnected=True, join_branches=True, add_gen_lines=True)

# match to existing in db
ng = deepcopy(all_geoms)
ng.transform_geometries(7844, use_xml_data=False)
points = sg.MultiPoint([i.geometry for i in ng.points.values()]).wkt

existing_points = []

sql = f"""SELECT cp.point_id, cp.geom from cadastre.cad_point cp 
      where ST_DWithin(cp.geom, 'SRID=7844;{points}', 0.000000045)"""
edf = gpd.read_postgis(sql, connection)
edf.to_crs(7855, inplace=True)
dbgdf = gpd.sjoin_nearest(edf, df, max_distance=.005)

dbpoints = tuple(dbgdf.point_id.tolist())

sql = f"""SELECT cpe.plan_point_id, cpe.geom from dcm_event.cad_point_event cpe 
            where cpe.plan_point_id in  %(a)s and cpe.dataset_type_id = 50"""
params = {'a': dbpoints}
original_geometry = gpd.read_postgis(sql, event_connection, params=params)
original_geometry.to_crs(7855, inplace=True)

dbgdf = pd.merge(dbgdf, original_geometry, left_on='point_id', right_on='plan_point_id')

for p, pvalue in all_geoms.points.items():
    pvalue: PointGeom
    data = dbgdf.loc[dbgdf['original_coords_wkt'] == pvalue.original_geom.wkt]
    if data.empty is False:
        pvalue.geometry = data.geom_y.tolist()[0]
        pvalue.point_oid = data.point_id.tolist()[0]
        pvalue.ccc = True

all_geoms.recalc_geometries(leave_constrained=True, use_branches=True)
all_geoms.survey_number = 'base'
all_geoms.write_geom_to_file(location=outpath)

for p, pvalue in all_geoms.points.items():
    pvalue.ccc = False

all_polys['poly_exterior'] = all_polys.geometry.apply(lambda x: x.exterior)
all_lines = all_polys['poly_exterior'].tolist()
all_lines = sg.MultiLineString([i for i in all_lines])
#
exist_geoms = set(v.geometry for v in all_geoms.points.values())

for k, v in all_geoms.points.items():
    if v.associated_point_oid is None:
        np = so.snap(v.geometry, all_lines, 1)
        if np != v.geometry:
            v.geometry = np
            exist_geoms.add(np)
        else:
            point = all_lines.interpolate(all_lines.project(v.geometry))
            distance = calc_distance(v.geometry, point)
            if distance < .5:
                v.geometry = point

all_geoms.update_geometries(xml=False, loops=False)
all_geoms.transform_geometries(7844)
og_lookup = {}

for k, v in all_geoms.points.items():
    og_lookup[(v.original_geom.x, v.original_geom.y)] = (v.geometry.x, v.geometry.y)

for i in root.findall('.//XY'):
    xy = i.text.split(',')
    xy = tuple(float(x) for x in xy)
    new_coord = og_lookup.get(xy)
    if new_coord is not None:
        new_xy = ','.join([f'{n:.9f}' for n in new_coord])
        i.text = new_xy

for node in root.iter():
    if node.text is None:
        node.text = ''

all_geoms.survey_number = 'tester'
all_geoms.write_geom_to_file(location=outpath)

outfile = infile_str.replace('.xml', '-output.xml').replace('inputs', 'outputs')
root.write(outfile, pretty_print=True)