"""
Linear Referencing plugin
(c) Copyright 2008 Martin Dobias

This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.       
"""

import psycopg2

# use unicode!
import psycopg2.extensions
psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)

class DatabaseError(Exception):
	def __init__(self, message, query=None):
		self.message = message
		self.query = query
	def __str__(self):
		return "Message: %s\nQuery: %s" % (self.message, self.query)


class Route:
	def __init__(self, table, id_column, geom_column):
		self.table = table
		self.id_column = id_column
		self.geom_column = geom_column


class PostGIS:
	
	def __init__(self, host=None, port=None, dbname=None, user=None, passwd=None):
		self.host = host
		self.port = port
		self.dbname = dbname
		self.user = user
		self.passwd = passwd
		
		try:
			self.con = psycopg2.connect(self.con_info())
		except psycopg2.OperationalError, e:
			raise DatabaseError(e.message)

		self.geom_col_is_view = self.is_view_geometry_columns()

	def is_view_geometry_columns(self):
		c = self.exec_sql("SELECT relkind = 'v' FROM pg_class WHERE relname = 'geometry_columns' AND relkind IN ('v', 'r')")
		res = c.fetchone()
		return res != None and len(res) != 0 and res[0]

	def con_info(self):
		con_str = ''
		if self.host:   con_str += "host='%s' "     % self.host
		if self.port:   con_str += "port=%d "       % self.port
		if self.dbname: con_str += "dbname='%s' "   % self.dbname
		if self.user:   con_str += "user='%s' "     % self.user
		if self.passwd: con_str += "password='%s' " % self.passwd
		return con_str
		
		
	def get_layer_geometry_type(self, layer):
		c = self.exec_sql("SELECT type FROM geometry_columns WHERE f_table_name = '%s'" % self._quote_str(layer))
		return c.fetchone()[0]
	
	def get_layer_geometry_column(self, layer):
		""" return column with geometry. let's suppose there's only one geom.column """
		c = self.exec_sql("SELECT f_geometry_column FROM geometry_columns WHERE f_table_name = '%s'" % self._quote_str(layer))
		return c.fetchone()[0]

	def list_routes(self):
		# deprecated by usage of lin_ref_routes table
		c = self.exec_sql("SELECT f_table_name FROM geometry_columns WHERE type = 'LINESTRINGM'")
		return map(lambda x: x[0], c.fetchall()) # get first column
			
	
	def list_aspatial_tables(self):
		""" return list of all aspatial tables in database (minus postgis metadata tables) """
		
		# , nspname, relkind
		c = self.exec_sql("""SELECT relname FROM pg_class, pg_namespace
		   WHERE relkind IN ('v', 'r') AND
			       pg_namespace.oid = pg_class.relnamespace AND
						 nspname NOT IN ('information_schema','pg_catalog') AND
						 relname NOT IN ('geometry_columns','spatial_ref_sys') AND
						 relname NOT IN (SELECT f_table_name FROM geometry_columns)""")
		return map(lambda x: x[0], c.fetchall()) # get first column
	
	def list_spatial_tables(self, geom_type=None):
		""" return list of spatial tables in database (possibly with a filter to a specific type """
		where_clause = '' if geom_type is None else "WHERE type = '%s'" % geom_type
		c = self.exec_sql("SELECT f_table_name FROM geometry_columns %s" % where_clause)
		return map(lambda x: x[0], c.fetchall()) # get first column
	
	def table_exists(self, table):
		c = self.exec_sql("SELECT COUNT(*) FROM pg_class WHERE relkind IN ('v','r') AND relname = '%s'" % self._quote_str(table))
		return c.fetchone()[0] != 0

	def list_table_columns(self, table):
		""" return list of table's attributes """
		c = self.exec_sql("""SELECT attname FROM pg_class, pg_attribute, pg_type
			WHERE relname = '%s' AND
						attrelid = pg_class.oid AND
						atttypid = pg_type.oid AND
						attnum >= 1 AND
						typname != 'geometry'
			ORDER BY attnum""" % self._quote_str(table))
		return map(lambda x: x[0], c.fetchall()) # get first column

	def get_route_geom_column(self, route_table):
		""" find out geometry column of the route """
		# deprecated by usage of lin_ref_routes table
		c = self.exec_sql("""SELECT f_geometry_column FROM geometry_columns WHERE f_table_name = '%s'""" % self._quote_str(route_table))
		return c.fetchone()[0]
	
	def create_routes_table(self):
		""" create metadata table with list of routes """
		self.exec_sql("""CREATE TABLE lin_ref_routes (
				table_name varchar(256) NOT NULL,
				route_column varchar(256) NOT NULL,
				geom_column varchar(256) NOT NULL,
				CONSTRAINT lin_ref_routes_pkey PRIMARY KEY (table_name, route_column) )""")
		self.con.commit()
		
	def list_routes_2(self):
		c = self.exec_sql("SELECT table_name FROM lin_ref_routes")
		return map(lambda x: x[0], c.fetchall()) # get first column
	
	def add_route(self, route):
		self.exec_sql("INSERT INTO lin_ref_routes VALUES ('%s', '%s', '%s')" % (self._quote_str(route.table), self._quote_str(route.id_column), self._quote_str(route.geom_column)))
		self.con.commit()
	
	
	def delete_route(self, route):
		self.exec_sql("DELETE FROM lin_ref_routes WHERE table_name = '%s'" % self._quote_str(route))
		self.con.commit()
		try:
			self.exec_sql("SELECT DropGeometryTable('%s')" % self._quote_str(route))
			self.con.commit()
			return True
		except DatabaseError, e:
			self.con.rollback()
			return False
	
	
	def get_route_info(self, table_name):
		""" gets information about route and returns Route() instance """
		c = self.exec_sql("SELECT * FROM lin_ref_routes WHERE table_name = '%s'" % self._quote_str(table_name))
		row = c.fetchone()
		if row is None:
			return None
		return Route(row[0], row[1], row[2]) # table name, route column, geometry column
	
	def create_route(self, table_name):
		""" create a new route """
		self.exec_sql("CREATE TABLE %s (fid serial not null primary key, name varchar(20) not null)" % self._quote(table_name))
		self.con.commit()
		self.exec_sql("SELECT AddGeometryColumn('%s','geom', -1, 'LINESTRINGM', 3)" % self._quote_str(table_name))
		self.con.commit()
		
		self.add_route(Route(table_name, 'name', 'geom'))

	
	def create_table_with_located_points(self, route, events_table, events_route_id, events_measure, output):
		""" create a new table from events table with added geometry with lin. referenced measures as points """
		
		# create new column with ID only if orig. table doesn't have suitable pkey (int4)
		events_pkey = self.get_table_primary_key(events_table)
		if events_pkey:
			new_pkey = ""
		else:
			new_pkey = "nextval('%s')::int4 AS qgis_id, " % self._quote_str(u"%s_seq" % output)
		
		if events_pkey is None:
			# create a temporary sequence for a new column to ensure that we will have a primary key
			c = self.exec_sql("CREATE TEMPORARY SEQUENCE %s" %  self._quote(u"%s_seq" % output))
		
		data = { 'out'   : self._quote(output),
						 'pkey'  : new_pkey,
						 'g'     : self._quote(route.geom_column),
						 'm'     : self._quote(events_measure),
						 'events': self._quote(events_table),
						 'route' : self._quote(route.table),
						 'events_col': self._quote(events_route_id),
						 'route_col' : self._quote(route.id_column) }
		
		sql = """CREATE TABLE %(out)s WITHOUT OIDS AS
		  SELECT %(pkey)s p.*, ST_locate_along_measure(r.%(g)s, p.%(m)s) AS geom_event
			  FROM %(events)s p, %(route)s r  WHERE lower(p.%(events_col)s) = lower(r.%(route_col)s)""" % data
		c = self.exec_sql(sql)
		
		if events_pkey is None:
			self.exec_sql("DROP SEQUENCE %s" % self._quote(u"%s_seq" % output), c)
			self.con.commit()
		
		if not self.geom_col_is_view:
			# add to geometry columns table
			self.exec_sql(u"INSERT INTO geometry_columns VALUES ('','public','%s','geom_event',2,-1,'POINT')" % self._quote_str(output), c)
		else:
			self.exec_sql(u"ALTER TABLE %s ALTER COLUMN geom_event TYPE geometry(PointM,-1)" % self._quote(output), c)
		self.con.commit()
		
		# add primary key
		pkey = events_pkey if events_pkey is not None else 'qgis_id'
		self.exec_sql("ALTER TABLE %s ADD PRIMARY KEY (%s)" % (self._quote(output), pkey), c)
		self.con.commit()		
		
	
	def create_table_with_located_lines(self, route, events_table, events_route_id, events_measure, events_measure2, output):
		""" """
		
		# create new column with ID only if orig. table doesn't have suitable pkey (int4)
		events_pkey = self.get_table_primary_key(events_table)
		if events_pkey:
			new_pkey = ""
		else:
			new_pkey = "nextval('%s')::int4 AS qgis_id, " % self._quote_str(u"%s_seq" % output)
		
		if events_pkey is None:
			# create a temporary sequence for a new column to ensure that we will have a primary key
			c = self.exec_sql("CREATE TEMPORARY SEQUENCE %s" %  self._quote(u"%s_seq" % output))
		
		data = { 'out'  : self._quote(output),
						 'pkey' : new_pkey,
						 'ini'  : "l.%s" % self._quote(events_measure),
						 'fin'  : "l.%s" % self._quote(events_measure2),
						 'g'    : "r.%s" % self._quote(route.geom_column),
						 'route'  : self._quote(route.table),
						 'events' : self._quote(events_table),
		         'events_col' : self._quote(events_route_id),
						 'route_col'  : self._quote(route.id_column) }
		
		sql = """CREATE TABLE %(out)s WITHOUT OIDS AS SELECT %(pkey)s l.*,
		      ST_locate_between_measures(%(g)s, LEAST( %(ini)s,%(fin)s ), GREATEST( %(fin)s,%(ini)s )) AS geom_event
					FROM %(events)s l, %(route)s r WHERE lower(l.%(events_col)s) = lower(r.%(route_col)s)""" % data
		c = self.exec_sql(sql)
		
		if events_pkey is None:
			self.exec_sql("DROP SEQUENCE %s" %  self._quote(u"%s_seq" % output), c)
			self.con.commit()
		
		if not self.geom_col_is_view:
			self.exec_sql(u"INSERT INTO geometry_columns VALUES ('','public','%s','geom_event',2,-1,'LINESTRING')" % self._quote_str(output), c)
		else:
			#FIXME use a better solution than deleting points 
			# to allow to change the column type then geometry_columns is a view
			self.exec_sql(u"DELETE FROM %s WHERE GeometryType(geom_event) LIKE 'POINTM'" % self._quote(output), c)
			self.exec_sql(u"ALTER TABLE %s ALTER COLUMN geom_event TYPE geometry(LineStringM,-1)" % self._quote(output), c)
		self.con.commit()
		
		# add primary key
		pkey = events_pkey if events_pkey is not None else 'qgis_id'
		self.exec_sql("ALTER TABLE %s ADD PRIMARY KEY (%s)" % (self._quote(output), pkey), c)
		self.con.commit()		
	
	
	def get_route_m(self, route, route_name):
		""" return route's coordinates with M value """
		c = self.exec_sql("SELECT ST_asEWKT(%s) FROM %s WHERE %s = '%s'" % (self._quote(route.geom_column), self._quote(route.table), self._quote(route.id_column), self._quote_str(route_name)))
		ewkt = c.fetchone()[0]
		ewkt = ewkt.replace("LINESTRINGM(","").replace(")","")
		toFloat = lambda x: float(x)
		parsePoint = lambda x: map(toFloat, x.strip().split(' '))
		return map(parsePoint, ewkt.split(','))
	
	def get_route_point_m(self, route, route_name, vertex):
		""" return M value of a route at specific vertex (zero-indexed) """
		vertex += 1 # vertices in PostGIS are indexed from 1, in QGIS from 0
		sql = "SELECT ST_m(ST_PointN(%s,%d)) FROM %s WHERE %s = '%s'" % (self._quote(route.geom_column), vertex, self._quote(route.table), self._quote(route.id_column), self._quote_str(route_name))
		c = self.exec_sql(sql)
		return c.fetchone()[0]
	
	def get_route_m_interpolation(self, route, route_name, pnt, beforeVertex):
		geom_col = route.geom_column
		v1 = "ST_PointN(%s,%d)" % (self._quote(route.geom_column), beforeVertex)
		v2 = "ST_PointN(%s,%d)" % (self._quote(route.geom_column), beforeVertex+1)
		pt = "ST_MakePoint(%f,%f)" % (pnt.x(), pnt.y())
		koef = "ST_line_locate_point(ST_MakeLine(%s,%s), %s)" % (v1,v2,pt)
		m = " ST_m(%s) + (ST_m(%s)-ST_m(%s)) * %s" % (v1, v2, v1, koef)
		sql = "select %s from %s where %s = '%s'" % (m, self._quote(route.table), self._quote(route.id_column), self._quote_str(route_name))
		c = self.exec_sql(sql)
		return c.fetchone()[0]
		
	def get_table_primary_key(self, table):
		# FIXME: added cast attnum to text let's see whether it will help
		sql = """SELECT attname FROM pg_constraint
				LEFT JOIN pg_class ON conrelid = pg_class.oid
 				LEFT JOIN pg_attribute ON attrelid = pg_class.oid
				WHERE relname = '%s' and contype='p' and attnum::text = array_to_string(conkey, ' ')""" % self._quote_str(table)
		c = self.exec_sql(sql)
		row = c.fetchone()
		if row is None: # no primary key :-(
			return None
		else:
			return row[0]
		
	def update_row(self, input_layer, input_pkey, pkey_value, values_dict):
		""" update row in a table """
		values = ""
		for col,val in values_dict.iteritems():
			if len(values) > 0:
				values += ", "
			values += "%s = '%s'" % (self._quote(col), self._quote_str(val))
		sql = "UPDATE %s SET %s WHERE %s = '%s'" % (self._quote(input_layer), values, self._quote(input_pkey), self._quote_str(pkey_value))
		self.exec_sql(sql)
		self.con.commit()
	
	def get_calibration_points(self, layer, column_measure):
		c = self.exec_sql("SELECT * FROM ")
		pass
	def get_calibration_lines(self, layer, column_name):
		pass
	
	def exec_sql(self, sql, cursor=None):
		try:
			if not cursor:
				cursor = self.con.cursor()
			cursor.execute(sql)
			return cursor
		except psycopg2.Error, e:
			# we have to do rollback, otherwise everythin called then will be aborted
			self.con.rollback()
			raise DatabaseError(e.message, e.cursor.query)


	def _quote(self, identifier):
		""" quote identifier if needed """
		identifier = unicode(identifier) # make sure it's python unicode string
		# let's quote it (and double the double-quotes)
		return u'"%s"' % identifier.replace('"', '""')

	def _quote_str(self, txt):
		""" make the string safe - replace ' with '' """
		txt = unicode(txt) # make sure it's python unicode string
		return txt.replace("'", "''")
