# -*- coding: utf-8 -*-
"""
SpatiaLite Manager
Copyright 2010 by Giuseppe Sucameli (Faunalia) and Alessandro Furieri

based on PostGIS Manager
Copyright 2008 Martin Dobias

Licensed under the terms of GNU GPL v2 (or any layer)
http://www.gnu.org/copyleft/gpl.html
"""

from pyspatialite import dbapi2 as sqlite

class TableAttribute:
	def __init__(self, row):
		self.num, self.name, self.data_type, self.notnull, self.default, self.primary_key = row

'''
class TableConstraint:
	""" class that represents a constraint of a table (relation) """
	
	TypeCheck, TypeForeignKey, TypePrimaryKey, TypeUnique = range(4)
	types = { "c" : TypeCheck, "f" : TypeForeignKey, "p" : TypePrimaryKey, "u" : TypeUnique }
	
	on_action = { "a" : "NO ACTION", "r" : "RESTRICT", "c" : "CASCADE", "n" : "SET NULL", "d" : "SET DEFAULT" }
	match_types = { "u" : "UNSPECIFIED", "f" : "FULL", "p" : "PARTIAL" }
	
	def __init__(self, row):
		self.name, con_type, self.is_defferable, self.is_deffered, keys = row[:5]
		self.keys = map(int, keys.split(' '))
		self.con_type = TableConstraint.types[con_type]   # convert to enum
		if self.con_type == TableConstraint.TypeCheck:
			self.check_src = row[5]
		elif self.con_type == TableConstraint.TypeForeignKey:
			self.foreign_table = row[6]
			self.foreign_on_update = TableConstraint.on_action[row[7]]
			self.foreign_on_delete = TableConstraint.on_action[row[8]]
			self.foreign_match_type = TableConstraint.match_types[row[9]]
			self.foreign_keys = row[10]
'''

class TableIndex:

	def __init__(self, row):
		self.num, self.name, self.unique, self.columns = row


class TableTrigger:

	def __init__(self, row):
		self.name = row


class DbError(Exception):
	def __init__(self, error):
		# save error. funny that the variables are in utf8, not 
		self.msg = unicode( error.args[0], 'utf-8')
		self.a = error.args[0]
		if hasattr(error, "cursor") and hasattr(error.cursor, "query"):
			self.query = unicode(error.cursor.query, 'utf-8')
		else:
			self.query = None

	def __str__(self):
		if self.query is None:
			return self.msg.encode('utf-8')
		return self.msg.encode('utf-8') + "\nQuery:\n" + self.query.encode('utf-8')
		

class TableField:
	def __init__(self, name, data_type, is_null=None, default=None):
		self.name, self.data_type, self.is_null, self.default = name, data_type, is_null, default
		
	def is_null_txt(self):
		if self.is_null:
			return "NULL"
		else:
			return "NOT NULL"
		
	def field_def(self, db):
		""" return field definition as used for CREATE TABLE or ALTER TABLE command """
		data_type = self.data_type
		txt = "%s %s %s" % (db._quote(self.name), data_type, self.is_null_txt())
		if self.default and len(self.default) > 0:
			txt += " DEFAULT %s" % self.default
		return txt
		

class GeoDB:
	
	def __init__(self, dbname=None):
		self.dbname = dbname
		
		try:
			self.con = sqlite.connect(self.dbname)
			#self.con.enable_load_extension(True);
			#self.con.execute('SELECT load_extension("libspatialite.so")');
		except sqlite.OperationalError, e:
			raise DbError(e)
		
		self.has_spatialite = self.check_spatialite()

		# a counter to ensure that the cursor will be unique
		self.last_cursor_id = 0
		
	def get_info(self):
		c = self.con.cursor()
		self._exec_sql(c, "SELECT sqlite_version()")
		return c.fetchone()[0]
	
	def check_spatialite(self):
		""" check if is a valid spatialite db """
		try:
			c = self.con.cursor()
			self._exec_sql(c, "SELECT CheckSpatialMetaData()")
			self.has_geometry_columns = c.fetchone()[0] == 1
		except Exception, e:
			self.has_geometry_columns = False

		return self.has_geometry_columns
	
	def get_spatialite_info(self):
		""" returns tuple about spatialite support:
			- lib version
			- geos version
			- proj version
		"""
		c = self.con.cursor()
		self._exec_sql(c, "SELECT spatialite_version(), geos_version(), proj4_version()")
		return c.fetchone()
					
	def list_geotables(self):
		"""
			get list of tables, whether table has geometry column(s) etc.
			
			geometry_columns:
			- f_table_name
			- f_geometry_column
			- coord_dimension
			- srid
			- type
		"""
		c = self.con.cursor()

		items = []		
		# get geometry info from geometry_columns if exists
		if self.has_geometry_columns:
			sql = """SELECT m.name, m.type, g.f_geometry_column, g.type, g.coord_dimension, g.srid
							FROM sqlite_master AS m LEFT JOIN geometry_columns AS g ON m.name = g.f_table_name
							WHERE m.type in ('table', 'view') 
							ORDER BY m.name, g.f_geometry_column"""
		else:
			sql = "SELECT name, type, NULL, NULL, NULL, NULL FROM sqlite_master WHERE type IN ('table', 'view')"

		self._exec_sql(c, sql)

		for geo_item in c.fetchall():
			items.append( geo_item )
			
		return items
	
	
	def get_table_rows(self, table):
		c = self.con.cursor()
		self._exec_sql(c, "SELECT COUNT(*) FROM %s" % self._quote(table) )
		return c.fetchone()[0]
		
		
	def get_table_fields(self, table):
		""" return list of columns in table """
		c = self.con.cursor()
		sql = "PRAGMA table_info(%s)" % (self._quote(table))
		self._exec_sql(c, sql)

		attrs = []
		for row in c.fetchall():
			attrs.append( TableAttribute(row) )

		return attrs
		
		
	def get_table_indexes(self, table):
		""" get info about table's indexes """
		c = self.con.cursor()
		sql = "PRAGMA index_list(%s)" % (self._quote(table))
		self._exec_sql(c, sql)

		indexes = []
		for num, name, unique in c.fetchall():
			c2 = self.con.cursor()
			sql = "PRAGMA index_info(%s)" % (self._quote(name))
			self._exec_sql(c2, sql)

			row = [num, name, unique]
			cols = []
			for seq, cid, cname in c2.fetchall():
				cols.append(cid)

			row.append(cols)
			indexes.append( TableIndex(row) )

		return indexes
	
	
	def get_table_triggers(self, table):
		c = self.con.cursor()
		sql = "SELECT name FROM sqlite_master WHERE tbl_name = %s AND type = 'trigger'" % (self._quote_str(table))
		self._exec_sql(c, sql)
		
		triggers = []
		for row in c.fetchall():
			triggers.append( TableTrigger(row) )

		return triggers

	# TODO get_table_constraints		
	
	def get_table_estimated_extent(self, geom, table):
		""" find out estimated extent (from the statistics) """
		c = self.con.cursor()
		sql = """ SELECT Min(MbrMinX(%(geom)s)), Min(MbrMinY(%(geom)s)), Max(MbrMaxX(%(geom)s)), Max(MbrMaxY(%(geom)s)) 
						FROM %(table)s """ % { 'geom' : self._quote(geom), 'table' : self._quote(table) }
		self._exec_sql(c, sql)
		
		row = c.fetchone()
		return row
	
	def get_view_definition(self, view):
		""" returns definition of the view """
		sql = "SELECT sql FROM sqlite_master WHERE type = 'view' AND name = %s" % (self._quote_str(view))
		c = self.con.cursor()
		self._exec_sql(c, sql)
		return c.fetchone()[0]
		
	def add_geometry_column(self, table, geom_type, geom_column='the_geom', srid=-1, dim='XY'):
		sql = "SELECT AddGeometryColumn('%s', '%s', %d, '%s', %s)" % (self._quote_str(table), self._quote_str(geom_column), srid, self._quote_str(geom_type), dim)
		self._exec_sql_and_commit(sql)
		
	def delete_geometry_column(self, table, geom_column):
		""" discard a geometry column """
		sql = "SELECT DiscardGeometryColumn('%s', '%s')" % (self._quote_str(table), self._quote_str(geom_column))
		self._exec_sql_and_commit(sql)
		
	def delete_geometry_table(self, table):
		""" delete table with one or more geometries """
		return self.delete_table(table)
		
	def create_table(self, table, fields, pkey=None):
		""" create ordinary table
				'fields' is array containing instances of TableField
				'pkey' contains name of column to be used as primary key
		"""
				
		if len(fields) == 0:
			return False
		
		table_name = self._quote(table)
		
		sql = "CREATE TABLE %s (%s" % (table_name, fields[0].field_def(self))
		for field in fields[1:]:
			sql += ", %s" % field.field_def(self)
		if pkey:
			sql += ", PRIMARY KEY (%s)" % self._quote(pkey)
		sql += ")"
		self._exec_sql_and_commit(sql)
		return True
	
	def delete_table(self, table):
		""" delete table from the database """
		sql = "DROP TABLE %s" % self._quote(table)
		self._exec_sql_and_commit(sql)
		
	def empty_table(self, table):
		""" delete all rows from table """
		sql = "DELETE FROM %s" % self._quote(table)
		self._exec_sql_and_commit(sql)
		
	def rename_table(self, table, new_table):
		""" rename a table """
		sql = "ALTER TABLE %s RENAME TO %s" % (self._quote(table), self._quote(new_table))
		self._exec_sql_and_commit(sql)
		
		# update geometry_columns
		if self.has_geometry_columns:
			sql = "UPDATE geometry_columns SET f_table_name = %s WHERE f_table_name = %s" % (self._quote_str(new_table), self._quote_str(table))
			self._exec_sql_and_commit(sql)
		
	def create_view(self, name, query):
		sql = "CREATE VIEW %s AS %s" % (self._quote(name), query)
		self._exec_sql_and_commit(sql)
	
	def delete_view(self, name):
		sql = "DROP VIEW %s" % ( self._quote(name) )
		self._exec_sql_and_commit(sql)
	
	def rename_view(self, name, new_name):
		""" rename view """
		self.rename_table(name, new_name)
		
	def table_add_column(self, table, field):
		""" add a column to table (passed as TableField instance) """
		sql = "ALTER TABLE %s ADD %s" % (self._quote(table), field.field_def(self))
		self._exec_sql_and_commit(sql)
	
	def table_delete_trigger(self, trigger):
		""" delete trigger """
		sql = "DROP TRIGGER %s" % (self._quote(trigger))
		self._exec_sql_and_commit(sql)

	def create_index(self, table, name, column, unique=True):
		""" create index on one column """
		unique_str = "UNIQUE" if unique else ""
		sql = "CREATE " + unique_str + " INDEX %s ON %s (%s)" % (self._quote(index), self._quote(table), self._quote(column))
		self._exec_sql_and_commit(sql)
	
	def create_spatial_index(self, table, geom_column='the_geom'):
		table_name = self._quote(table)
		idx_name = self._quote("sidx_"+table)
		sql = "SELECT CreateSpatialIndex(%s, %s)" % (self._quote(table), self._quote(geom_column))
		self._exec_sql_and_commit(sql)

	def delete_index(self, name):
		sql = "DROP INDEX %s" % (self._quote(name))
		self._exec_sql_and_commit(sql)
		
	def delete_spatial_index(self, name, geom_column='the_geom'):
		sql = "SELECT DiscardSpatialIndex(%s, %s)" % (self._quote(name), self._quote(geom_column))
		self._exec_sql_and_commit(sql)
	
	def vacuum(self):
		""" run vacuum on the db """
		self._exec_sql_and_commit("VACUUM")
		
	def sr_info_for_srid(self, srid):
		c = self.con.cursor()
		self._exec_sql(c, "SELECT ref_sys_name FROM spatial_ref_sys WHERE srid = %s" % self._quote_str(srid))
		return c.fetchone()[0]

	def insert_table_row(self, table, values, cursor=None):
		""" insert a row with specified values to a table.
		 if a cursor is specified, it doesn't commit (expecting that there will be more inserts)
		 otherwise it commits immediately """
		sql = ""
		for value in values:
			# TODO: quote values?
			if sql: sql += ", "
			sql += value
		sql = "INSERT INTO %s VALUES (%s)" % (self._quote(table), sql)
		if cursor:
			self._exec_sql(cursor, sql)
		else:
			self._exec_sql_and_commit(sql)
		
	def _exec_sql(self, cursor, sql):
		try:
			cursor.execute(sql)
		except sqlite.Error, e:
			# do the rollback to avoid a "current transaction aborted, commands ignored" errors
			self.con.rollback()
			raise DbError(e)
		
	def _exec_sql_and_commit(self, sql):
		""" tries to execute and commit some action, on error it rolls back the change """
		c = self.con.cursor()
		self._exec_sql(c, sql)
		self.con.commit()

	def _quote(self, identifier):
		""" quote identifier if needed """
		identifier = unicode(identifier) # make sure it's python unicode string
		return u'"%s"' % identifier.replace('"', '""')
	
	def _quote_str(self, txt):
		""" make the string safe - replace ' with '' """
		txt = unicode(txt) # make sure it's python unicode string
		return u"'%s'" % txt.replace("'", "''")
		

# for debugging / testing
if __name__ == '__main__':

	db = GeoDB(dbname='/home/brushtyler/Projects/Work/Faunalia/spatialite/test-2.3.sqlite')
	
	print db.list_schemas()
	print '=========='
	
	for row in db.list_geotables():
		print row

	print '=========='
	
	for row in db.get_table_indexes('trencin'):
		print row

	print '=========='
	
	for row in db.get_table_constraints('trencin'):
		print row
	
	print '=========='
	
	print db.get_table_rows('trencin')
	
	#for fld in db.get_table_metadata('trencin'):
	#	print fld
	
	#try:
	#	db.create_table('trrrr', [('id','serial'), ('test','text')])
	#except DbError, e:
	#	print e.message, e.query
	
