"""
Linear Referencing plugin
(c) Copyright 2008 Martin Dobias

This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.       
"""

from geom import nearestPoint, Point

import math

class XYM(object):
	""" a point with measure (None if unknown) """
	def __init__(self, x,y,m=None):
		self.x, self.y, self.m = x,y,m
		
	def equalXY(self, pnt2):
		""" check whether they have equal X,Y coordinates """
		return self.x == pnt2.x and self.y == pnt2.y
	
	def __repr__(self):
		return "(%.1f,%.1f,%.1f) " % (self.x, self.y, -1 if self.m is None else self.m)
		

def XYM2XY(pt):
	return (pt.x,pt.y)
	
class RoadArc(object):
	
	def __init__(self, points, nodes):
		self.points = points
		self._add_node(nodes, points[0])
		self._add_node(nodes, points[-1])
		
	def _add_node(self, nodes, pt):
		""" add a node to dict of nodes """
		pt = XYM2XY(pt)
		if not nodes.has_key(pt):
			nodes[pt] = []
		nodes[pt].append(self)

	def nodeBegin(self, nodes):
		return nodes[XYM2XY(self.points[0])]
	def nodeEnd(self, nodes):
		return nodes[XYM2XY(self.points[-1])]
	def takeFromTopology(self, nodes):
		self.nodeBegin(nodes).remove(self)
		self.nodeEnd(nodes).remove(self)


def _polyline_len(points):
	""" return length of polyline (euclidean distance) """
	dist = 0
	lastPt = points[0]
	for pt in points[1:]:
		dist += math.sqrt( (pt.x-lastPt.x)**2 + (pt.y-lastPt.y)**2 )
	return dist

class Road(object):
	""" road for calibration """
	def __init__(self, name, points):
		self.name = name

		self.nodes = {}  # k: (x,y) v: [ arcs ]
		self.arcs = []   # v: RoadArc
		self.add_part(points)


	def add_part(self, points):
		""" append another part of the road """
		arc = RoadArc(points, self.nodes)
		self.arcs.append(arc)
		
	def merge_parts(self, log=None):
		""" merge all parts to one array of points """
		
		# dump parts:
		#print "MERGE", self.name
		#for pt,arcs in self.nodes.iteritems():
		#	print " n", pt, arcs
		
		# join arcs
		i = 0
		while i < len(self.arcs):
			arc = self.arcs[i]
			nodeBegin = arc.nodeBegin(self.nodes)
			nodeEnd = arc.nodeEnd(self.nodes)
			if len(nodeBegin) > 1:
				res = self.join_arcs(arc, nodeBegin, log)
			elif len(nodeEnd) > 1:
				res = self.join_arcs(arc, nodeEnd, log)
			else:
				i += 1 # move further
				continue
			
			# we've hit a cycle. let's continue with another part
			if not res:
				i += 1
			
		# now select the largest arc
		max_dist = -1
		max_arc = None
		for arc in self.arcs:
			dist = _polyline_len(arc.points)
			if dist > max_dist:
				max_dist, max_arc = dist, arc
		self.points = max_arc.points
		
		# warn if there are more parts
		if log and len(self.arcs) > 1:
			log.write("line '%s' has %d parts! using only the longest part." % (self.name, len(self.arcs)) )
			
	def join_arcs(self,arc, node, log=None):
		""" joins two arcs into one """
		
		# select other arc
		for otherArc in node:
			if otherArc is not arc:
				break
		if otherArc is arc:
			if log: log.write("we've come to a cycle (road '%s'), ignoring this part" % self.name)
			return False
			
		# join
		if arc.nodeEnd(self.nodes) is otherArc.nodeBegin(self.nodes):
			new_points = arc.points + otherArc.points[1:]
		elif arc.nodeBegin(self.nodes) is otherArc.nodeEnd(self.nodes):
			new_points = otherArc.points + arc.points[1:]
		elif arc.nodeEnd(self.nodes) is otherArc.nodeEnd(self.nodes):
			new_points = arc.points[:-1] + list(reversed(otherArc.points))
		elif arc.nodeBegin(self.nodes) is otherArc.nodeBegin(self.nodes):
			new_points = list(reversed(arc.points)) + otherArc.points[1:]
		else:
			raise ValueError, "NOOOOO!"
		
		# remove arcs, records from node
		#print "---", arc, otherArc
		#for pt,arcs in self.nodes.iteritems():
		#	print " n", pt, arcs
		arc.takeFromTopology(self.nodes)
		self.arcs.remove(arc)
		otherArc.takeFromTopology(self.nodes)
		self.arcs.remove(otherArc)
		
		# save as new arc
		new_arc = RoadArc(new_points, self.nodes)
		self.arcs.append(new_arc)
		
		return True
		

class CalPoint(XYM):
	""" calibration point """
	def __init__(self, road, x,y,m):
		XYM.__init__(self,x,y,m)
		self.road = road


def calibration(roads, cal_pts, progdlg=None, log=None):
	"""
		steps:
		0. fetch and join roads with the same name (and flip where are in reverse order)
		1. fetch calibration points and align them to road's line
		2. intrapolate (and extrapolate) measure values to all road's vertices
	"""
	
	if progdlg:
		progdlg.reset()
		progdlg.setLabelText("(4/5) Aligning calibration points...")
		progdlg.setMaximum(len(cal_pts))
		if log: log.write_header("(4/5) Aligning calibration points...")
	
	# 1. align calibration points to roads
	align_cal_points(cal_pts, roads, progdlg, log)
	
	remove_roads = []
	
	if progdlg:
		progdlg.reset()
		progdlg.setLabelText("(5/5) Calculating M values...")
		progdlg.setMaximum(len(roads))
		if log: log.write_header("(5/5) Calculating M values...")
	
	# 2. calculate M values for all vertices
	for road in roads:
		res = interpolate_measure(roads[road], log)
		if res != True:
			remove_roads.append(road)
			
		if progdlg: progdlg.setValue(progdlg.value()+1)
			
	# remove roads that didn't have any calibration points
	for r in remove_roads:
		del roads[r]

	return roads


def align_cal_points(cal_pts, roads, progdlg=None, log=None):
	""" align points to road's linestring together with their M value """
	
	for pnt in cal_pts:
		
		if progdlg: progdlg.setValue(progdlg.value()+1)
		
		if not roads.has_key(pnt.road):
			if log: log.write("skipping unknown road '%s'" % pnt.road)
			continue
		
		# get road for the calibration point
		road = roads[pnt.road]
		
		# find out where it belongs on the road
		line_pnt, line_index = nearestPoint(Point(pnt.x, pnt.y), road.points)
		
		# TODO: possibly save this connection (pnt - line_pnt) to some layer
		
		# insert the point where it belongs
		# BUT first check whether it doesn't have the same position as one of the line segments's endpoints
		new_pnt = XYM(line_pnt.x, line_pnt.y, pnt.m)
		if new_pnt.x == road.points[line_index].x and new_pnt.y == road.points[line_index].y:
			road.points[line_index].m = new_pnt.m
		elif new_pnt.x == road.points[line_index+1].x and new_pnt.y == road.points[line_index+1].y:
			road.points[line_index+1].m = new_pnt.m
		else:
			road.points.insert(line_index+1, new_pnt)
			
	
def _next_pnt_with_m(points, start):
	""" find next point with non-null M value and distance from previous point with M value"""
	
	dist = 0
	pnt_old = points[start-1] if start != 0 else points[start]
	
	for i, pnt in enumerate(points[start:]):
		dist += _dist(pnt_old, pnt)

		if pnt.m is not None:
			#print "NEXT",start,"RES",start+i, dist
			return (start+i, dist)
		
		pnt_old = pnt
		
	#print "NEXT",start,"RES",None
	return None, None


def _dist(p1, p2):
	""" calculate Euclidean distance between 2 points """
	return math.sqrt( (p1.x - p2.x)*(p1.x - p2.x) + (p1.y - p2.y)*(p1.y - p2.y) )
		
def _get_dist(points, index1, index2):
	""" calculate distance between 2 points on road """
	
	# get negative distance (for extrapolation at begin)
	if index2 < index1:
		return - _get_dist(points, index2, index1)
	
	dist = 0
	p_old = points[index1]
	for p in points[index1+1:index2+1]:
		dist += _dist(p_old, p)
		p_old = p
	return dist
		
def _interpolate_point(points, index_pt, index_m1, index_m2, d_12):
	""" interpolate M between two measures. TODO: speed up """
	
	# calculate distance m1->index_pt
	d_1pt = _get_dist(points, index_m1, index_pt)
	#print "i", index_pt, index_m1, index_m2, "d",d_12, d_1pt
	m1, m2 = points[index_m1].m, points[index_m2].m
	
	m = m1 + (m2-m1) * (d_1pt / d_12)
	return m


def interpolate_measure(road, log=None):
	
	# first find two values with M
	# they will be used also for extrapolation at begin (if needed)
	(m_last, d_xxx) = _next_pnt_with_m(road.points, 0)
	if m_last is None:
		if log: log.write("road '%s' doesn't have any measure points!" % road.name)
		return False
	
	(m_current, d_last_current) = _next_pnt_with_m(road.points, m_last+1)
	if m_current is None:
		if log: log.write("road '%s' has only one measure point!" % road.name)
		return False
	
	for i,pnt in enumerate(road.points):
		
		if i == m_current:
			# we've reached our current point with M - let's get new one
			# if there are no more points M, we will continue doing extrapolation with current pair
			next_pnt_m, next_dist = _next_pnt_with_m(road.points, m_current+1)
			if next_pnt_m is not None:
				m_last = m_current
				m_current, d_last_current = next_pnt_m, next_dist
			
		elif pnt.m is not None:
			continue
		else:
			# calculate the interpolation / extrapolation
			pnt.m = _interpolate_point(road.points, i, m_last, m_current, d_last_current)
		
		if pnt.m is None:
			continue

	# interpolation went fine
	return True


##########################################
# testing

def get_test_roads():
	""" return fetched, joined roads as list of Road instances """
	road = Road("D1", [ XYM(10,10), XYM(20,20), XYM(30,20), XYM(40,10), XYM(25,10) ] )
	return { "D1" : road }


def get_test_cal_points():
	cpt1 = CalPoint("D1", 15,16, 10)
	cpt2 = CalPoint("D1", 31,21, 50)
	cpt3 = CalPoint("D1", 35, 8, 70)
	return [ cpt1, cpt2, cpt3 ]

def do_test():
	
	roads = calibration(get_test_roads(), get_test_cal_points())
	for road in roads.values():
		print road.name, road.points

def do_merge_test():
	
	road = Road("D1", [ XYM(10,10), XYM(20,20), XYM(30,20), XYM(40,10), XYM(25,10) ] )
	road.add_part( [ XYM(10,10), XYM(2,2) ] )
	road.add_part( [ XYM(15,15), XYM(25,10) ] )
	road.merge_parts()
	print road.points

if __name__ == '__main__':
	#do_test()
	do_merge_test()
