# -*- coding: utf-8 -*-

from PyQt4.QtCore import *
from PyQt4.QtGui import *

from qgis.gui import *
from qgis.core import *

import postgis_utils
from postgis_utils import GeoDB as DB

# Retrieves GDAL binaries location
def getGdalPath():
	settings = QSettings()
	return settings.value( "/RT_PostrgresExtractor/gdalPath", QVariant( "" ) ).toString()

# Stores GDAL binaries location
def setGdalPath(path):
	settings = QSettings()
	settings.setValue( "/RT_PostrgresExtractor/gdalPath", QVariant( path ) )

# Retrieves last used dir from persistent settings
def getLastUsedDir():
    settings = QSettings()
    return settings.value( "/RT_PostrgresExtractor/lastUsedDir", QVariant(".") ).toString()

# Stores last used dir in persistent settings
def setLastUsedDir(filePath):
    settings = QSettings()
    fileInfo = QFileInfo(filePath)
    if fileInfo.isDir():
      dirPath = fileInfo.filePath()
    else:
      dirPath = fileInfo.path()
    settings.setValue( "/RT_PostrgresExtractor/lastUsedDir", QVariant(dirPath) )


class QueryUtils:
	@classmethod
	def sanitizeQuery(self, query):
		query = QString( query ).trimmed()
		if not query.contains( QRegExp( "^SELECT\\s*", Qt.CaseInsensitive) ):
			query = QString( u"SELECT * FROM %s" % query )

		# commands after ; will be discarded
		rx = QRegExp( """^(?:[^"';]*(?:"[^"]*")*(?:'[^']*')*)*;""" )
		if rx.indexIn( query ) < 0:
			return query
		query = rx.cap(0)
		return query.left(query.length()-1)


	@classmethod
	def getAlias(self, query):
		query = self.sanitizeQuery( query )

		# get an alias
		finded = False
		aliasIndex = 0
		while not finded:
			alias = "t__%s" % aliasIndex
			escaped = '("?)' + QRegExp.escape(alias) + '\\1'
			regex = QRegExp( escaped )
			regex.setCaseSensitivity(Qt.CaseInsensitive)
			finded = not query.contains(regex)
			aliasIndex += 1

		return alias


	@classmethod
	def getQueryWithAlias(self, query):
		if QString(query).contains( QRegExp( "^SELECT\\s*", Qt.CaseInsensitive) ):
			query = self.sanitizeQuery(query)
			alias = self.getAlias(query)
			query = u"(%s) AS %s" % ( unicode(query), DB._quote(alias) )
		return query


	@classmethod
	def getFields(self, db, query, exclude = None):
		query = self.getQueryWithAlias(query)

		# check if there are fields with duplicated names in the result 
		queryLimited = u"SELECT * FROM %s LIMIT 0" % ( unicode(query) )
		c = db.con.cursor()
		db._exec_sql(c, queryLimited)

		fields = dict()
		fieldsList = list()
		types = list()
		for fld in c.description:
			# make sure there are no duplicated fields
			if fields.has_key( fld[0] ) > 0:
				raise DuplicatedFieldsError( QCoreApplication.translate( "RT_PostgresExtractor", "Columns with duplicated names are not allowed in the result. \nUse an alias to make sure there is only one '%1' column." ).arg( fld[0] ) )

			if exclude != None and fld[0] in exclude:
				continue

			fields[ fld[0] ] = fld
			fieldsList.append( fld )

			ftype = str( fld[1] )
			if types.count( ftype ) <= 0:
				types.append( ftype )

		c.close()

		if len(types) == 0:
			return []

		# retrieve the type name using the type oid
		types = ",".join( types )
		typesQuery = u"SELECT oid, typname FROM pg_type WHERE oid IN (%s)" % types

		c = db.con.cursor()
		db._exec_sql(c, typesQuery)

		types = dict()
		for oid, typename in c.fetchall():
			variantType = QVariant.String
			typename = QString(typename)

			if typename == "int8":
				variantType = QVariant.LongLong
			elif typename.startsWith("int") or typename == "serial" or typename == "oid":
				variantType = QVariant.Int
			elif typename == "real" or typename == "double precision" or typename.startsWith( "float" ) or typename == "numeric":
				variantType = QVariant.Double

			types[oid] = (variantType, typename)

		c.close()

		fields = list()
		for fld in fieldsList:
			ftype = types[ fld[1] ]

			f = QgsField( fld[0], ftype[0], ftype[1] )
			fields.append( f )

		return fields

	@classmethod
	def getGeomAndUniqueFields(self, db, query):
		retcols = QueryUtils.getFields( db, query )

		uniqueCols = list()
		geomCols = list()
		for f in retcols:
			if f.typeName() == "oid" or f.typeName() == "serial" or f.typeName() == "int4":
				uniqueCols.append( f )
			if f.typeName() == "geometry":
				geomCols.append( f )

		return ( geomCols, uniqueCols )


	@classmethod
	def getSridAndGeomType(self, db, query, geom):
		query = QueryUtils.getQueryWithAlias(query)
		newQuery = u"SELECT srid(%(geom)s), geometrytype(%(geom)s), ST_CoordDim(%(geom)s) FROM %(table)s LIMIT 1" % { 'geom' : db._quote(geom), 'table' : query }
		try:
			c = db.con.cursor()
			db._exec_sql(c, newQuery)
		except postgis_utils.DbError:
			return None, None, 0

		row = c.fetchone()
		if row == None:
			return None, None, 0

		srid, geomtype, dim = row

		geomtype = QString(geomtype).toUpper()
		if geomtype.contains( 'GEOMETRY' ):
			newQuery = "select distinct case" + \
							" when geometrytype(%(geom)s) IN ('POINT','MULTIPOINT') THEN 'MULTIPOINT'" + \
							" when geometrytype(%(geom)s) IN ('LINESTRING','MULTILINESTRING') THEN 'MULTILINESTRING'" + \
							" when geometrytype(%(geom)s) IN ('POLYGON','MULTIPOLYGON') THEN 'MULTIPOLYGON'" + \
						" end from %(table)s" % ( db._quote(geom), query )

			try:
				c = db.con.cursor()
				db._exec_sql(c, newQuery)
			except postgis_utils.DbError:
				return srid, geomtype, dim

			row = c.fetchone()
			if row != None:
				geomtype = QString(row[0]).toUpper()

		return srid, geomtype, dim


	class DuplicatedFieldsError(Exception):
		def __init__(self, msg):
			self.msg = msg
