# -*- coding: utf-8 -*-
"""
RT Sql Layer
Copyright 2010 Giuseppe Sucameli

based on PostGIS Manager
Copyright 2008 Martin Dobias

Licensed under the terms of GNU GPL v2 (or any layer)
http://www.gnu.org/copyleft/gpl.html


Good resource for metadata extraction:
http://www.alberton.info/postgresql_meta_info.html
System information functions:
http://www.postgresql.org/docs/8.0/static/functions-info.html
"""

import psycopg2
import psycopg2.extensions # for isolation levels
import re

# use unicode!
psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)


class Table:

	FILTER_NONE = 0x0
	FILTER_NOT_GEOM_TABLES = 0xF
	FILTER_GEOM_TABLES = 0xF0
	FILTER_ALL = 0xFF

	TYPE_TABLE = 0x1
	TYPE_VIEW = 0x2

	GEOM_POINT = 0x10
	GEOM_LINESTRING = 0x20
	GEOM_POLYGON = 0x40
	GEOM_UNKNOWN = 0x80

	def __init__(self, row):
		self.name, self.schema, reltype, relowner, row_count, page_count, self.geom_col, geom_type, geom_dim, geom_srid = row

		is_view = (reltype == 'v')
		self.ttype = self.TYPE_VIEW if is_view else self.TYPE_TABLE
		if geom_type == None:
			pass
		elif geom_type.find('POINT') != -1:
			self.ttype |= self.GEOM_POINT
		elif geom_type.find('LINESTRING') != -1:
			self.ttype |= self.GEOM_LINESTRING
		elif geom_type.find('POLYGON') != -1:
			self.ttype |= self.GEOM_POLYGON
		else:
			self.ttype |= self.GEOM_UNKNOWN


class TableAttribute:
	def __init__(self, row):
		self.num, self.name, self.data_type, self.char_max_len, self.modifier, self.notnull, self.hasdefault, self.default = row


class TableConstraint:
	""" class that represents a constraint of a table (relation) """
	
	TypeCheck, TypeForeignKey, TypePrimaryKey, TypeUnique = range(4)
	types = { "c" : TypeCheck, "f" : TypeForeignKey, "p" : TypePrimaryKey, "u" : TypeUnique }
	
	on_action = { "a" : "NO ACTION", "r" : "RESTRICT", "c" : "CASCADE", "n" : "SET NULL", "d" : "SET DEFAULT" }
	match_types = { "u" : "UNSPECIFIED", "f" : "FULL", "p" : "PARTIAL" }
	
	def __init__(self, row):
		self.name, con_type, self.is_defferable, self.is_deffered, keys = row[:5]
		self.keys = map(int, keys.split(' '))
		self.con_type = TableConstraint.types[con_type]   # convert to enum
		if self.con_type == TableConstraint.TypeCheck:
			self.check_src = row[5]
		elif self.con_type == TableConstraint.TypeForeignKey:
			self.foreign_table = row[6]
			self.foreign_on_update = TableConstraint.on_action[row[7]]
			self.foreign_on_delete = TableConstraint.on_action[row[8]]
			self.foreign_match_type = TableConstraint.match_types[row[9]]
			self.foreign_keys = row[10]


class DbError(Exception):
	def __init__(self, error):
		# save error. funny that the variables are in utf8, not 
		self.msg = unicode( error.args[0], 'utf-8')
		self.a = error.args[0]
		if hasattr(error, "cursor") and hasattr(error.cursor, "query"):
			self.query = unicode(error.cursor.query, 'utf-8')
		else:
			self.query = None

	def __str__(self):
		if self.query is None:
			return self.msg.encode('utf-8')
		return self.msg.encode('utf-8') + "\nQuery:\n" + self.query.encode('utf-8')
	

class TableField:
	def __init__(self, name, data_type, is_null=None, default=None, modifier=None):
		self.name, self.data_type, self.is_null, self.default, self.modifier = name, data_type, is_null, default, modifier
		
	def is_null_txt(self):
		if self.is_null:
			return "NULL"
		else:
			return "NOT NULL"
		
	def field_def(self, db):
		""" return field definition as used for CREATE TABLE or ALTER TABLE command """
		data_type = self.data_type if (not self.modifier or self.modifier < 0) else "%s(%d)" % (self.data_type, self.modifier)
		txt = "%s %s %s" % (db._quote(self.name), data_type, self.is_null_txt())
		if self.default and len(self.default) > 0:
			txt += " DEFAULT %s" % self.default
		return txt
		

class GeoDB:
	
	def __init__(self, host=None, port=None, dbname=None, user=None, passwd=None):
		
		self.host = host
		self.port = port
		self.dbname = dbname
		self.user = user
		self.passwd = passwd
		
		if self.dbname == '' or self.dbname is None:
			self.dbname = self.user
		
		try:
			self.con = psycopg2.connect(self.con_info())
		except psycopg2.OperationalError, e:
			raise DbError(e)
		
		self.has_postgis = self.check_postgis()

		self.check_geometry_columns_table()

		# a counter to ensure that the cursor will be unique
		self.last_cursor_id = 0
		
	def con_info(self):
		con_str = ''
		if self.host:   con_str += "host='%s' "     % self.host
		if self.port:   con_str += "port=%d "       % self.port
		if self.dbname: con_str += "dbname='%s' "   % self.dbname
		if self.user:   con_str += "user='%s' "     % self.user
		if self.passwd: con_str += "password='%s' " % self.passwd
		return con_str
		
	def check_postgis(self):
		""" check whether postgis_version is present in catalog """
		c = self.con.cursor()
		self._exec_sql(c, "SELECT COUNT(*) FROM pg_proc WHERE proname = 'postgis_version'")
		return (c.fetchone()[0] > 0)
		
	def check_geometry_columns_table(self):

		c = self.con.cursor()
		self._exec_sql(c, "SELECT relname FROM pg_class WHERE relname = 'geometry_columns' AND pg_class.relkind IN ('v', 'r')")
		self.has_geometry_columns = (len(c.fetchall()) != 0)
		
		if not self.has_geometry_columns:
			self.has_geometry_columns_access = False
			return
			
		# find out whether has privileges to access geometry_columns table
		self.has_geometry_columns_access = self.get_table_privileges('geometry_columns')[0]


	def list_schemas(self):
		"""
			get list of schemas in tuples: (oid, name, owner, perms)
		"""
		c = self.con.cursor()
		sql = "SELECT oid, nspname, pg_get_userbyid(nspowner), nspacl FROM pg_namespace WHERE nspname !~ '^pg_' AND nspname != 'information_schema'"
		self._exec_sql(c, sql)

		schema_cmp = lambda x,y: -1 if x[1] < y[1] else 1
		
		return sorted(c.fetchall(), cmp=schema_cmp)
			
	def list_geotables(self, schema=None):
		"""
			get list of tables with schemas, whether user has privileges, whether table has geometry column(s) etc.
			
			geometry_columns:
			- f_table_schema
			- f_table_name
			- f_geometry_column
			- coord_dimension
			- srid
			- type
		"""
		c = self.con.cursor()
		
		if schema:
			schema_where = " AND nspname = '%s' " % self._quote_str(schema)
		else:
			schema_where = " AND (nspname != 'information_schema' AND nspname !~ 'pg_') "
			
		# LEFT OUTER JOIN: like LEFT JOIN but if there are more matches, for join, all are used (not only one)
		
		# first find out whether postgis is enabled
		if not self.has_postgis:
			# get all tables and views
			sql = """SELECT pg_class.relname, pg_namespace.nspname, pg_class.relkind, pg_get_userbyid(relowner), reltuples, relpages, NULL, NULL, NULL, NULL
							FROM pg_class
							JOIN pg_namespace ON pg_namespace.oid = pg_class.relnamespace
							WHERE pg_class.relkind IN ('v', 'r')""" + schema_where + "ORDER BY nspname, relname"
		else:
			# discovery of all tables and whether they contain a geometry column
			sql = """SELECT pg_class.relname, pg_namespace.nspname, pg_class.relkind, pg_get_userbyid(relowner), reltuples, relpages, pg_attribute.attname, pg_attribute.atttypid::regtype, NULL, NULL
							FROM pg_class
							JOIN pg_namespace ON pg_namespace.oid = pg_class.relnamespace
							LEFT OUTER JOIN pg_attribute ON pg_attribute.attrelid = pg_class.oid AND
									( pg_attribute.atttypid = 'geometry'::regtype
										OR pg_attribute.atttypid IN (SELECT oid FROM pg_type WHERE typbasetype='geometry'::regtype ) )
							WHERE pg_class.relkind IN ('v', 'r')""" + schema_where + "ORDER BY nspname, relname, attname"
						  
		self._exec_sql(c, sql)
		items = c.fetchall()
		
		# get geometry info from geometry_columns if exists
		if self.has_postgis and self.has_geometry_columns and self.has_geometry_columns_access:
			sql = """SELECT relname, nspname, relkind, pg_get_userbyid(relowner), reltuples, relpages,
							geometry_columns.f_geometry_column, geometry_columns.type, geometry_columns.coord_dimension, geometry_columns.srid
							FROM pg_class
						  JOIN pg_namespace ON relnamespace=pg_namespace.oid
						  LEFT OUTER JOIN geometry_columns ON relname=f_table_name AND nspname=f_table_schema
						  WHERE (relkind = 'r' or relkind='v') """ + schema_where + "ORDER BY nspname, relname, f_geometry_column"
			self._exec_sql(c, sql)
			
			# merge geometry info to "items"
			for i, geo_item in enumerate(c.fetchall()):
				if geo_item[7]:
					items[i] = geo_item
			
		return items
	
	
	def get_table_rows(self, table, schema=None):
		c = self.con.cursor()
		self._exec_sql(c, "SELECT COUNT(*) FROM %s" % self._table_name(schema, table))
		return c.fetchone()[0]
		
		
	def get_table_fields(self, table, schema=None):
		""" return list of columns in table """
		c = self.con.cursor()
		schema_where = " AND nspname='%s' " % self._quote_str(schema) if schema is not None else ""
		sql = """SELECT a.attnum AS ordinal_position,
				a.attname AS column_name,
				t.typname AS data_type,
				a.attlen AS char_max_len,
				a.atttypmod AS modifier,
				a.attnotnull AS notnull,
				a.atthasdef AS hasdefault,
				adef.adsrc AS default_value
			FROM pg_class c
			JOIN pg_attribute a ON a.attrelid = c.oid
			JOIN pg_type t ON a.atttypid = t.oid
			JOIN pg_namespace nsp ON c.relnamespace = nsp.oid
			LEFT JOIN pg_attrdef adef ON adef.adrelid = a.attrelid AND adef.adnum = a.attnum
			WHERE
			  c.relname = '%s' %s AND
				a.attnum > 0
			ORDER BY a.attnum""" % (self._quote_str(table), schema_where)

		self._exec_sql(c, sql)
		attrs = []
		for row in c.fetchall():
			attrs.append(TableAttribute(row))
		return attrs


	def get_table_constraints(self, table, schema=None):
		c = self.con.cursor()
		
		schema_where = " AND nspname='%s' " % self._quote_str(schema) if schema is not None else ""
		sql = """SELECT c.conname, c.contype, c.condeferrable, c.condeferred, array_to_string(c.conkey, ' '), c.consrc,
		         t2.relname, c.confupdtype, c.confdeltype, c.confmatchtype, array_to_string(c.confkey, ' ') FROM pg_constraint c
		  LEFT JOIN pg_class t ON c.conrelid = t.oid
			LEFT JOIN pg_class t2 ON c.confrelid = t2.oid
			JOIN pg_namespace nsp ON t.relnamespace = nsp.oid
			WHERE t.relname = '%s' %s """ % (self._quote_str(table), schema_where)
		
		self._exec_sql(c, sql)
		
		constrs = []
		for row in c.fetchall():
			constrs.append(TableConstraint(row))
		return constrs


	def get_table_unique_indexes(self, table, schema=None):
		""" get all the unique indexes """
		schema_where = " AND nspname='%s' " % self._quote_str(schema) if schema is not None else ""
		sql = """SELECT relname, indkey 
						FROM pg_index JOIN pg_class ON pg_index.indrelid=pg_class.oid 
						JOIN pg_namespace nsp ON pg_class.relnamespace = nsp.oid 
							WHERE pg_class.relname='%s' %s 
							AND indisprimary != 't' AND indisunique = 't'""" % (self._quote_str(table), schema_where)
		c = self.con.cursor()
		self._exec_sql(c, sql)
		uniqueIndexes = []
		for row in c.fetchall():
			uniqueIndexes.append(TableIndex(row))
		return uniqueIndexes


	def get_database_privileges(self):
		""" db privileges: (can create schemas, can create temp. tables) """
		sql = "SELECT has_database_privilege('%(d)s', 'CREATE'), has_database_privilege('%(d)s', 'TEMP')" % { 'd' : self._quote_str(self.dbname) }
		c = self.con.cursor()
		self._exec_sql(c, sql)
		return c.fetchone()
		
	def get_schema_privileges(self, schema):
		""" schema privileges: (can create new objects, can access objects in schema) """
		sql = "SELECT has_schema_privilege('%(s)s', 'CREATE'), has_schema_privilege('%(s)s', 'USAGE')" % { 's' : self._quote_str(schema) }
		c = self.con.cursor()
		self._exec_sql(c, sql)
		return c.fetchone()
	
	def get_table_privileges(self, table, schema=None):
		""" table privileges: (select, insert, update, delete) """
		t = self._table_name(schema, table)
		sql = """SELECT has_table_privilege('%(t)s', 'SELECT'), has_table_privilege('%(t)s', 'INSERT'),
		                has_table_privilege('%(t)s', 'UPDATE'), has_table_privilege('%(t)s', 'DELETE')""" % { 't': self._quote_str(t) }
		c = self.con.cursor()
		self._exec_sql(c, sql)
		return c.fetchone()

	def get_named_cursor(self, table=None):
		""" return an unique named cursor, optionally including a table name """
		self.last_cursor_id += 1
		if table is not None:
			table2 = re.sub(r'\W', '_', table.encode('ascii','replace')) # all non-alphanum characters to underscore
			cur_name = "cursor_%d_table_%s" % (self.last_cursor_id, table2)
		else:
			cur_name = "cursor_%d" % self.last_cursor_id
		#cur_name = ("\"db_table_"+self.table+"\"").replace(' ', '_')
		#cur_name = cur_name.encode('ascii','replace').replace('?', '_')
		return self.con.cursor(cur_name)
		
	def _exec_sql(self, cursor, sql):
		try:
			cursor.execute(sql)
		except psycopg2.Error, e:
			# do the rollback to avoid a "current transaction aborted, commands ignored" errors
			self.con.rollback()
			raise DbError(e)
		
	def _exec_sql_and_commit(self, sql):
		""" tries to execute and commit some action, on error it rolls back the change """
		#try:
		c = self.con.cursor()
		self._exec_sql(c, sql)
		self.con.commit()
		#except DbError, e:
		#	self.con.rollback()
		#	raise

	@classmethod
	def _quote(self, identifier):
		identifier = unicode(identifier) # make sure it's python unicode string
		return u'"%s"' % identifier.replace('"', '""')

	@classmethod	
	def _quote_str(self, txt):
		""" make the string safe - replace ' with '' """
		txt = unicode(txt) # make sure it's python unicode string
		return txt.replace("'", "''")
		
	@classmethod
	def _table_name(self, schema, table):
		if not schema:
			return self._quote(table)
		else:
			return u"%s.%s" % (self._quote(schema), self._quote(table))
