import progress
from decorators import cache_result
import time
import psycopg2
import binascii
from simplifier import *
import quadtree
import grid

@cache_result("geoms.json")
def get_geoms():
    result = []
    db = psycopg2.connect("dbname=gisdb host=localhost user=postgres password=postgres port=5432")
    cursor = db.cursor()

    cursor.execute("select pkid, location from nation where st_dwithin(location, st_geomfromtext('SRID=4326;POINT(-1.74 36.42)'), 50.0) order by pkid")
    for (pkid, wkb_str,) in cursor:
        result.append((pkid, wkb_str))
    return result


def test():
    def log(x):
        print x


    print "Load ChainDB"
    cdb = ChainDB()
    for (pkid, wkb_str) in progress.bar(get_geoms()[:40]):
        wkb = binascii.a2b_hex(wkb_str)
        cdb.add_geometry(pkid, wkb)

    def quadtest():
        # add segments
        get_segment = lambda (si, c,(i,j)): (cdb.chains[c][1][i][P_COORD], cdb.chains[c][1][j][P_COORD])
        segments = []
        for (c, (parent, chain)) in enumerate(cdb.chains):
            last = None
            for (i,p) in enumerate(chain):
                if last is not None:
                    si = len(segments) # index of tis segment in 'segments' array
                    seg_id = (si, c, (last, i))
                    s = get_segment(seg_id)
                    qs = quadtree.segment(seg_id, s[0], s[1])
                    segments.append(qs)
                last = i
        print "Load %s segments" % len(segments)

        t = time.time()
        qtree = quadtree.QuadTree(segments)
        print "Qtree load = %.3fs" % (time.time() - t)


        intersections = 0
        pairs = 0
        pairs_skip = 0
        for s1 in progress.bar(segments):
            # print "-------"
            seg_id1 = s1[0]
            line1 = get_segment(seg_id1)

            s2s = qtree.hit(s1)
            # print "result = %s" % s2s
            for seg_id2 in s2s:
                pairs += 1
                if seg_id1 > seg_id2:
                    line2 = get_segment(seg_id2)
                    if geotool.crosses(line1, line2):
                        intersections += 1
                else:
                    pairs_skip += 1
        print "Pairs = %s" % pairs
        print "Pairs skipped = %s" % pairs_skip
        print "Total intersections = %s" % intersections

    def gridtest():
        t = time.time()
        G = grid.Grid()

        # add segments
        get_segment = lambda (si, c,(i,j)): (cdb.chains[c][1][i][P_COORD], cdb.chains[c][1][j][P_COORD])
        segments = []
        for (c, (parent, chain)) in progress.bar(enumerate(cdb.chains), size=len(cdb.chains)):
            last = None
            for (i,p) in enumerate(chain):
                if last is not None:
                    si = len(segments) # index of tis segment in 'segments' array
                    seg_id = (si, c, (last, i))
                    s = get_segment(seg_id)
                    G.add(seg_id, s)
                    segments.append((seg_id, s))
                last = i
        print "Grid load = %.3fs" % (time.time() - t)
        print "Num boxes = %s" % len(G.boxes)


        intersections = 0
        pairs = 0
        pairs_skip = 0
        for s1 in progress.bar(segments):
            seg_id1 = s1[0]
            line1 = get_segment(seg_id1)

            s2s = G.hit(line1)
            for seg_id2 in s2s:
                pairs += 1
                if seg_id1 > seg_id2:
                    line2 = get_segment(seg_id2)
                    if geotool.crosses(line1, line2):
                        intersections += 1
                else:
                    pairs_skip += 1
        print "Pairs = %s" % pairs
        print "Pairs skipped = %s" % pairs_skip
        print "Total intersections = %s" % intersections

    # quadtest()
    gridtest()






if __name__ == "__main__":
    test()
