# -*- coding: utf-8 -*-

from PyQt4.QtCore import *
from PyQt4.QtGui import *

from qgis.gui import *
from qgis.core import *

from DlgDbError import DlgDbError

import psycopg2
import postgis_utils
from postgis_utils import GeoDB as DB
import zipfile

import Utils
from Utils import QueryUtils
from ui.WizPage5_ui import Ui_WizardPage

class WizardPage(QWizardPage, Ui_WizardPage):
	
	def __init__(self, iface, wizState):
		QWizardPage.__init__(self)
		self.iface = iface
		self.wizState = wizState

		self.setupUi(self)
		self.stopped = False

		self.process = QProcess(self)
		self.process.setProcessChannelMode(QProcess.MergedChannels)	
		self.connect(self.process, SIGNAL("error(QProcess::ProcessError)"), self.onProcessError)
		self.connect(self.process, SIGNAL("finished(int, QProcess::ExitStatus)"), self.onProcessFinished)

	def initializePage(self):
		self.progressBar.setValue(0)
		self.setInformations()

		# but first let the wizard chance to show the page
		QTimer.singleShot(50, self.prepare)

	def setInformations(self):
		# get the intersection archive info
		archiveType = self.wizState.cuttingArchiveModeName
		if self.wizState.cuttingArchiveMode == self.wizState.CUTTING_DRAWN:
			archiveName = archiveType
		elif self.wizState.cuttingArchiveMode == self.wizState.CUTTING_LAYER:
			archiveName = self.wizState.cuttingArchiveLayer.name()
		else:
			schema, table, geom, sql = self.wizState.cuttingArchiveLayer
			archiveName = sql if sql != None else DB._table_name(schema, table)
		archiveTitle = self.wizState.cuttingArchiveTitle

		# get the archives names
		archives = []
		for i in range(len(self.wizState.archiveToCutLayers)):
			schema, table, geom, sql, geomtype = self.wizState.archiveToCutLayers[i]
			archives.append(sql if sql != None else DB._table_name(schema, table))

		# get the export options
		exportUnchanged = self.tr( 'Yes' ) if self.wizState.exportOptions == self.wizState.EXPORT_WITHOUT_CUT else self.tr( 'No' )
		bufferLength = str(self.wizState.cuttingBuffer) if self.wizState.cuttingBuffer != None else self.tr( 'No' )

		# get the output options
		directory = self.wizState.outputDir
		prefix = self.wizState.outputFilePrefix if self.wizState.outputFilePrefix != None else ''
		if self.wizState.outputFormat == self.wizState.OUT_FORMAT_SPATIALITE:
			format = 'SQLite'
		elif self.wizState.outputFormat == self.wizState.OUT_FORMAT_KML:
			format = 'KML'
		else:
			format = 'ESRI Shapefile'

		infos = self.informations.toHtml()
		infos = infos.arg( archiveType ).arg( archiveName ).arg( archiveTitle )
		infos = infos.arg( '<br>'.join(archives) )
		infos = infos.arg( exportUnchanged ).arg( bufferLength )
		infos = infos.arg( directory ).arg( format ).arg( prefix )
		self.informations.setHtml( infos )

	def onClosing(self):
		self.stop()

	# stop the command execution
	def stop(self):
		self.stopped = True

		self.process.kill()
		self.removeParamsFile()

		QApplication.restoreOverrideCursor()


	def prepare(self):
		QApplication.setOverrideCursor(QCursor(Qt.WaitCursor))

		if self.wizState.cuttingArchiveMode == self.wizState.CUTTING_POSTGRES:
			connName, db = self.wizState.cuttingArchiveDb
			schema, table, geom, sql = self.wizState.cuttingArchiveLayer
			title = self.wizState.cuttingArchiveTitle

			if not sql:
				tableName = db._table_name(schema, table)
			else:
				tableName = QueryUtils.getQueryWithAlias(sql)

			geomCol = db._quote( geom )
			cutTitle = db._quote( title )
			
			query = u"SELECT encode(ST_asEWKB(%s), 'hex'), %s FROM %s" % (geomCol, cutTitle, tableName)

			try:
				self.cursor1 = db.con.cursor()
				db._exec_sql(self.cursor1, query)

				self.counter1 = self.cursor1.rowcount
			except postgis_utils.DbError, e:
				self.labelMsg.setText( self.tr( "An error occurs, please click Cancel button" ) )
				QApplication.restoreOverrideCursor()
				DlgDbError.showError(e, self)
				return

		elif self.wizState.cuttingArchiveMode == self.wizState.CUTTING_LAYER:
			vl = self.wizState.cuttingArchiveLayer

			self.vlProvider = vl.dataProvider()
			self.titleFieldIndex = self.vlProvider.fieldNameIndex( self.wizState.cuttingArchiveTitle )
			self.vlProvider.select( [self.titleFieldIndex] )

			self.counter1 = self.vlProvider.featureCount()

		else: # polygon by hand
			self.counter1 = 1

		self.counter2 = len( self.wizState.archiveToCutLayers )

		self.title2index = {}

		self.index1 = 0
		self.index2 = 0
		self.runLevel1()


	# called by the prepare function
	# retrieve the polygon at self.index1 from the intersecting archive 
	# and call the runLevel2 function
	def runLevel1(self):
		if self.stopped:
			return

		if self.index1 >= self.counter1:
			return self.onFinished()

		self.outputFiles = []
		srid = self.wizState.srid
		connName, db = self.wizState.archiveToCutDb

		if self.wizState.cuttingArchiveMode == self.wizState.CUTTING_POSTGRES:
			try:
				row = self.cursor1.fetchone()
				cutter = row[0]
				cutTitle = row[1]

			except postgis_utils.DbError, e:
				self.labelMsg.setText( self.tr( "An error occurs, please click Cancel button" ) )
				QApplication.restoreOverrideCursor()
				DlgDbError.showError(e, self)
				return

		else:
			if self.wizState.cuttingArchiveMode == self.wizState.CUTTING_DRAWN:
				polygonGeometry = self.wizState.cuttingArchiveLayer
				cutTitle = self.wizState.cuttingArchiveTitle
				cutter = polygonGeometry.asWkb()

			elif self.wizState.cuttingArchiveMode == self.wizState.CUTTING_LAYER:
				f = QgsFeature()
				if not self.vlProvider.nextFeature(f):
					return

				cutTitle = f.attributeMap()[self.titleFieldIndex].toString()
				cutter = f.geometry().asWkb()

			cutter = str( QByteArray( str(cutter) ).toHex() )

			query = u"SELECT encode(ST_asEWKB(ST_GeomFromWKB(decode('%s', 'hex'), %d)), 'hex')" % (cutter, srid)
			try:
				cursor = db.con.cursor()
				db._exec_sql(cursor, query)

				row = cursor.fetchone()
				cutter = row[0]

			except postgis_utils.DbError, e:
				self.labelMsg.setText( self.tr( "An error occurs, please click Cancel button" ) )
				QApplication.restoreOverrideCursor()
				DlgDbError.showError(e, self)
				return

		self.cutter = cutter
		self.cutTitle = cutTitle
		self.runLevel2()

	# called by either the runLevel1 or the onProcessFinished function
	# create the query using the polygon in self.cutter
	# and call the cutGeometry function to exec the intersection
	def runLevel2(self):
		if self.stopped:
			return

		if self.index2 >= self.counter2:
			self.createZipArchive(self.cutTitle, self.outputFiles)
			self.outputFiles = []

			self.index1 += 1
			self.index2 = 0
			self.runLevel1()
			return

		# get the parameters used in the query 
		cutBuffer = self.wizState.cuttingBuffer
		cutBufferVal = '%.2f' % cutBuffer if cutBuffer else 'null'
		trulyCut = self.wizState.exportOptions == self.wizState.EXPORT_AFTER_CUT

		connName, db = self.wizState.archiveToCutDb
		schema, table, geom, sql, geomtype = self.wizState.archiveToCutLayers[self.index2]

		if not sql:
			tableName = db._table_name(schema, table)
		else:
			tableName = QueryUtils.getQueryWithAlias( sql )
			table = "query%d" % self.index2

		geomCol = db._quote( geom )

		# get fields
		fieldsListOld, fieldsListNew = self.createFieldsList( db, tableName, [geom] )

		geomTypeStr = 'Polygon'
		if geomtype.contains('POINT'):
			geomTypeInt = 1
			geomTypeStr = 'Point'
		elif geomtype.contains('LINESTRING'):
			geomTypeInt = 2
			geomTypeStr = 'LineString'
		else:
			geomTypeInt = 3
			geomTypeStr = 'Polygon'

		# get MULTI* geometries
		if geomtype.contains( 'POINT' ) or geomtype.contains( 'LINESTRING' ) or geomtype.contains( 'POLYGON' ):
			if not geomtype.startsWith( 'MULTI' ):
				geomtype.prepend( 'MULTI' )

		realCutterStr = cutterStr = u"ST_GeomFromEWKB(decode('%s', 'hex'))" % self.cutter
		if cutBuffer != None:
			realCutterStr = u"ST_Buffer(%s, %s)" % (cutterStr, cutBufferVal)

		intersectionFuncStr = u"ST_Intersection(%(geom)s, %(cutter)s) AS %(geom)s" % {'geom':geomCol, 'cutter':realCutterStr}
		intersectsFuncStr = u"ST_Intersects(%s, %s)" % (geomCol, realCutterStr)
		touchesFuncStr = u"ST_Touches(%s, %s)" % (geomCol, realCutterStr)

		# construct the query
		if trulyCut:
			# list of fields to retrieve
			flds = u'%s, %s' % (fieldsListOld, intersectionFuncStr) if len(fieldsListOld) > 0 else intersectionFuncStr

			queryLevel3 = u"SELECT %s FROM %s WHERE %s=true" % (flds, tableName, intersectsFuncStr)
		else:
			queryLevel3 = u"SELECT * FROM %s WHERE %s=true" % (tableName, intersectsFuncStr)
		queryLevel3 = QueryUtils.getQueryWithAlias( queryLevel3 )

		geometryComputationStr = """
CASE 
	WHEN ST_GeometryType(%(geomCol)s)='ST_GeometryCollection' THEN 
		ST_Multi(ST_CollectionExtract(%(geomCol)s, %(type)d)) 
	ELSE 
		ST_Multi(%(geomCol)s) 
END AS %(geomCol)s
""" % { 'geomCol': geomCol, 'type': geomTypeInt }


		# list of fields to retrieve
		flds = u'%s, %s' % (fieldsListNew, geometryComputationStr) if len(fieldsListNew) > 0 else geometryComputationStr

		if trulyCut:
			# get only the geometries of the correct type
			whereStr = "ST_GeometryType(%(geomCol)s) IN ('ST_%(type)s', 'ST_Multi%(type)s', 'ST_GeometryCollection')" % { 'geomCol': geomCol, 'type': geomTypeStr }

			queryLevel2 = u"SELECT %s FROM %s WHERE %s" % (flds, queryLevel3, whereStr)
		else:
			queryLevel2 = u"SELECT %s FROM %s WHERE %s=false" % (flds, queryLevel3, touchesFuncStr)
		queryLevel2 = QueryUtils.getQueryWithAlias( queryLevel2 )

		# remove the empy geometries from the output
		query = u"SELECT *, CURRENT_DATE AS data_query FROM %s WHERE ST_isEmpty(%s)=false" % ( queryLevel2, geomCol )

		# now create the output
		self.saveVectorFileByOgr( db, self.cutTitle, table, query, geomtype )


	def saveVectorFileByOgr(self, db, cutTitle, cuttedArchiveName, query, geomtype):
		prefix = self.wizState.outputFilePrefix
		directory = self.wizState.outputDir

		#srid = self.wizState.srid
		#srs = QgsCoordinateReferenceSystem()
		#srs.createFromSrid(srid)

		# set the filename
		filename = ""
		if prefix:
			filename += "%s_" % prefix
		filename += "%s" % cutTitle

		if self.wizState.outputFormat != self.wizState.OUT_FORMAT_SPATIALITE:
			filename += "_%s" % cuttedArchiveName

		# remove spaces and quotes from filename
		filename = QString(filename).remove("'").remove('"').replace(' ', '_')
		filename = unicode(filename)

		# append to filename the number of the archives within the container (zip or sqlite)
		num = 1
		if self.title2index.has_key( filename ):
			t2i = self.title2index[ filename ]
			if t2i.has_key( cuttedArchiveName ):
				num = t2i[ cuttedArchiveName ] + 1
		else:
			self.title2index[ filename ] = {}
		self.title2index[ filename ][ cuttedArchiveName ] = num

		processedArchiveName = cuttedArchiveName
		if num > 1:
			if self.wizState.outputFormat != self.wizState.OUT_FORMAT_SPATIALITE:
				filename += "_%d" % num
			else:
				processedArchiveName += "_%d" % num

		arguments = QStringList() 
		if self.wizState.outputFormat == self.wizState.OUT_FORMAT_SPATIALITE:
			arguments << '-f' << 'SQLite'
			ext = ".sqlite"
		elif self.wizState.outputFormat == self.wizState.OUT_FORMAT_SHAPEFILE:
			arguments << '-f' << 'ESRI Shapefile'
			ext = ".shp"
		elif self.wizState.outputFormat == self.wizState.OUT_FORMAT_KML:
			arguments << '-f' << 'KML'
			ext = ".kml"

		outputPath = directory + QDir.separator() + filename + ext
		outputPath = unicode(outputPath)

		fileExists = outputPath in self.outputFiles or QFileInfo(outputPath).exists()

		if self.wizState.outputFormat == self.wizState.OUT_FORMAT_SPATIALITE:
			if fileExists:
				arguments << '-update' << '-append'
			else:
				arguments << '-dsco' << 'SPATIALITE=YES'
				arguments << '-lco' << 'OVERWRITE=yes'
				arguments << '-lco' << 'LAUNDER=yes'

		arguments << '%s' % outputPath	# output
		arguments << 'PG:%s' % db.con_info()	# input

		arguments << '-nln' << processedArchiveName
		arguments << '-nlt' << geomtype

		# use a file to pass the sql string
		paramsFilePath = outputPath + '~'
		paramsFilePath = unicode(paramsFilePath)

		paramsFile = open(paramsFilePath, 'w')
		bytes = '-sql "%s"' % query.replace('"', '\\"').replace('\n', ' ')
		paramsFile.write( unicode(bytes).encode('utf8') )
		paramsFile.close()

		arguments << '--optfile' << paramsFilePath

		self.outputPath = outputPath
		self.paramsFilePath = paramsFilePath

		gdalPath = Utils.getGdalPath()
		if not gdalPath.isEmpty():
			env = self.process.environment()
			if not env.contains( QRegExp( "^PATH=(.*)", Qt.CaseInsensitive ) ):
				env << u"PATH=%s" % unicode(gdalPath)
			else:
				import platform
				newPath = "PATH=\\1"
				newPath += ";" if platform.system() == "Windows" else ":"
				newPath += gdalPath
				env.replaceInStrings( QRegExp( "^PATH=(.*)", Qt.CaseInsensitive ), newPath )
			self.process.setEnvironment( env )

		#print 'ogr2ogr', unicode(arguments.join(' ')).encode('utf-8')
		self.process.start('ogr2ogr', arguments, QIODevice.ReadOnly)

		if self.stopped:
			self.stop()
			return

		msg = self.tr( "Using geometry \"%1\" extract from \"%2\"" ).arg( cutTitle ).arg( cuttedArchiveName )
		if self.index1 > 0 or self.index2 > 0:
			self.outputBrowser.insertHtml( '<br>' )
		self.outputBrowser.insertHtml( '<span style="margin: 0px; font-weight:600;">%s</span><br>' % msg  )
		if self.lockToLastLineCheck.isChecked(): self.outputBrowser.ensureCursorVisible()

	def createZipArchive(self, cutTitle, outFiles):
		if self.wizState.outputFormat == self.wizState.OUT_FORMAT_SPATIALITE:
			return

		prefix = self.wizState.outputFilePrefix
		directory = self.wizState.outputDir

		# create the list of files to add to the zip
		filesToZip = []
		if self.wizState.outputFormat == self.wizState.OUT_FORMAT_SHAPEFILE:
			exts = ["shp", "prj", "shx", "dbf"]
			for f in self.outputFiles:
				basename = QFileInfo(f).baseName()
				for e in exts:
					inFn = directory + QDir.separator() + basename + ".%s" % e
					filesToZip.append( inFn )
		else:
			filesToZip = self.outputFiles

		# construct the zip filename
		filename = ""
		if prefix:
			filename += "%s_" % prefix
		filename += "%s" % (cutTitle)

		# remove spaces and quotes from filename
		zipBaseFn = QString(filename).remove("'").remove('"').replace(' ', '_')
		filename = zipBaseFn + ".zip"

		outputPath = directory + QDir.separator() + filename

		# if exists open the zip file, otherwise create it
		mode = 'w' if not QFileInfo(outputPath).exists() else 'a'
		zip = zipfile.ZipFile(unicode(outputPath), mode, zipfile.ZIP_DEFLATED)

		# append the files into the list to the zip
		for inFn in filesToZip:
			inFnInfo = QFileInfo(inFn)
			if inFnInfo.exists():
				internalName = inFnInfo.fileName()
				if True or self.wizState.useShpShortName:
					internalName = internalName.right( internalName.length() - (zipBaseFn.length()+1) )
				zip.write(unicode(inFn).encode('utf-8'), unicode(internalName).encode('latin-1'))
				QFile(inFn).remove()
		zip.close()


	def removeParamsFile(self):
		if not hasattr(self, 'paramsFilePath'):
			return

		try:
			f = QFile(self.paramsFilePath)
			if f != None: f.remove()
		except: pass
		return


	def onProcessFinished(self, exitCode, status):
		self.removeParamsFile()
		if self.stopped:
			return

		if status == QProcess.CrashExit:
			self.stop()
			return

		msg = QString(unicode(self.process.readAll(), 'utf-8'))
		if not msg.isEmpty():
			self.outputBrowser.insertHtml( '<span style="margin: 0px;">%s</span>' % msg.replace('\n', '<br>') )
			if self.lockToLastLineCheck.isChecked(): self.outputBrowser.ensureCursorVisible()
		if exitCode != 0 or msg.contains( "FAILURE" ):
			self.onError()
			return 

		self.outputFiles.append(self.outputPath)

		self.index2 += 1
		self.updateProgress(self.index2 + self.counter2 * self.index1)
		self.runLevel2()


	def onProcessError(self, error):
		self.removeParamsFile()
		if self.stopped:
			return

		if error == QProcess.FailedToStart:
			msg = self.tr( "The process failed to start. Either the invoked program is missing, or you may have insufficient permissions to invoke the program." )

			#QMessageBox.warning(self, self.tr( "Unable to e" ), msg )
			if False:
				# unable to find ogr2ogr, let's ask for it.
				QApplication.restoreOverrideCursor()
				path = QFileDialog.getExistingDirectory(self, u"Select the path to ogr2ogr executable", Utils.getGdalPath())
				if not path.isEmpty():
					Utils.setGdalPath( path )
					# try to run ogr2ogr again
					QTimer.singleShot(50, self.prepare)
					return

		elif error == QProcess.Crashed:
			msg = self.tr( "The process crashed some time after starting successfully." )
		else:
			msg = self.tr( "An unknown error occurred." )

		self.onError(msg)


	def onError(self, errorMsg = QString()):
		if not errorMsg.isEmpty():
			self.outputBrowser.insertHtml( '<span style="color:red; margin:0px; font-weight:600;">ERROR: %s</span><br>' % errorMsg.replace('\n', '<br>') )
			if self.lockToLastLineCheck.isChecked(): self.outputBrowser.ensureCursorVisible()
		else:
			errorMsg = self.tr( 'For more information, please see the log.' )

		self.labelMsg.setText( self.tr( "An error occurs, please click Cancel button" ) )
		QMessageBox.critical(self, self.tr( "An error occurs" ), errorMsg )
		self.stop()


	def onFinished(self):
		self.updateProgress(self.counter1 * self.counter2)

		self.outputBrowser.insertHtml( '<br><span style="color:green; margin:0px; font-weight:600;">FINISHED without issues</span>' )
		if self.lockToLastLineCheck.isChecked(): self.outputBrowser.ensureCursorVisible()

		self.labelMsg.setText( self.tr( "Terminated, please click Finish button" ) )
		QMessageBox.information(self, self.tr( "Finished" ), self.tr( "The process finished without issues." ) )
		self.stop()

		# disable the cancel button
		wiz = self.wizard()
		cancelBtn = wiz.button( QWizard.CancelButton )
		cancelBtn.setEnabled(False)

		self.emit( SIGNAL( "completeChanged()" ) )


	def updateProgress(self, val):
		tot = self.counter1 * self.counter2
		self.progressBar.setValue( val * 100.0 / tot )
		QCoreApplication.processEvents()


	def createFieldsList(self, db, query, exclude = None):
		fields = QueryUtils.getFields(db, query, exclude)
		if len(fields) <= 0:
			return '', ''

		fieldsList = list()
		if self.wizState.outputFormat == self.wizState.OUT_FORMAT_SHAPEFILE:
			# fix names longer than 10 chars
			fieldsMap = dict()
			for findex, f in enumerate(fields):
				longName = shortName = unicode(f.name())
				if len(longName) >= 10:
					shortName = longName[:9] + "*"

				if not fieldsMap.has_key( shortName ):
					fieldsMap[ shortName ] = list()

				value = (findex, longName)
				fieldsMap[ shortName ].append( value )

			for k, namesList in fieldsMap.iteritems():
				if len(namesList) == 1:
					findex, newname = namesList[0]

					value = ( findex, db._quote(newname), db._quote(newname) )
					fieldsList.append( value )
				else:
					appendIndex = k[-1:] == '*'
					i = 0
					for findex, oldname in namesList:
						if not appendIndex:
							newname = k
						else:
							newname = k[:9] + str(i)

						value = (findex, db._quote(newname), db._quote(oldname))
						fieldsList.append( value )
						i += 1
		else:
			for findex, f in enumerate(fields):
				name = unicode(f.name())
				value = (findex, db._quote(name), db._quote(name))
				fieldsList.append( value )

		# if the fields list is empty use *, otherwise get the fields order by index
		if len(fieldsList) <= 0:
			value = (0, '*', '*')
			fieldsList.append( value )
		else:
			fieldsList.sort( key=lambda x: x[0] )
			fieldsList = map( lambda x: (x[1], x[2]), fieldsList )

		fieldsListOld = ''
		fieldsListNew = ''
		for index, f in enumerate(fieldsList):
			if index > 0:
				fieldsListOld += ', '
				fieldsListNew += ', '
			fieldsListOld += f[1]
			fieldsListNew += f[1]
			if f[0] != f[1]:
				fieldsListNew += " AS %s" % f[0]

		return fieldsListOld, fieldsListNew


	def validatePage(self):
		return True

	def isComplete(self):
		return self.progressBar.value() >= 100


class WriterError(Exception):
	def __init__(self, msg):
		self.msg = msg

