#
# This file is part of GHydraulics
#
# GHydraulicsInpWriter.py - Write INP files
#
# Copyright 2007 - 2013 Steffen Macke <sdteffen@sdteffen.de>
#
# GHydraulics is free software; you can redistribute it
# and/or modify it under the terms of the GNU General Public
# License as published by the Free Software Foundation; either
# version 2, or (at your option) any later version.
#
# GHydraulics is distributed in the hope that it will be
# useful, but WITHOUT ANY WARRANTY; without even the implied
# warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
# PURPOSE. See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public
# License along with program; see the file COPYING. If not,
# write to the Free Software Foundation, Inc., 59 Temple Place
# - Suite 330, Boston, MA 02111-1307, USA.
#
#

import math
import re
import numpy
import numpy.linalg
import os
from pickle import *
from PyQt4.QtCore import *
from PyQt4.QtGui import *
from qgis.core import *
from qgis.gui import *
from qgis import *
from EpanetModel import *
from GHydraulicsCommon import *
from GHydraulicsException import *
from GHydraulicsModel import *

# Write EPANET INP file
class GHydraulicsInpWriter(GHydraulicsCommon):
    TITLE = 'Save EPANET INP file'

    # Write to the given filename
    def write(self, filename, backdrop):
        self.filename = filename
        self.backdrop = backdrop
        template = open(self.templateFilename)
        self.inpfile = open(filename, 'w')
        section = None
        for line in template.readlines():
            sectionname = re.match(' *\[([A-Z]+)\].*', line)
            if None != sectionname:
                section = self.writeSection(sectionname.group(1))
            if None == section:
                self.inpfile.write(line)
        self.inpfile.close()
        template.close()

    # Write the given section to the INP file
    def writeSection(self, section):
        if self.sections.has_key(section) and 0 < len(self.sections[section]):
            self.writeSectionLabel(section)
            if section in self.sections:
                self.inpfile.write(self.sections[section])
                return section
        elif 'BACKDROP' == section and self.backdrop:
            self.writeBackdropSection()
            return section
        return None

    #prevent NULL values in INP file
    def getString(self, value):
        v = str(value)
        if 'NULL' == v:
            return ''
        return v

    # Setup coordinate transformation where necessary
    def setLayerCrs(self, crs):
        self.crstransform = False
        canvascrs = self.iface.mapCanvas().mapRenderer().destinationCrs()
        if crs.isValid() and canvascrs.isValid():
            self.crstransform = QgsCoordinateTransform(crs, canvascrs)
        if self.crstransform.isShortCircuited():
            self.crstransform = False

    # Transform coordinates where necessary from layer to canvas crs
    def transformXY(self, x, y):
        if self.crstransform:
            pnt = self.crstransform.transform(QgsPoint(x,y))
            x = pnt.x()
            y = pnt.y()
        return [pnt.x(),pnt.y()]

    # Extract nodes from the model and write to string
    def getNodes(self, section):
        nodes = ''
        if self.sections.has_key(section):
            nodes = self.sections[section]
        if not self.layers.has_key(section):
            return
        for name in self.layers[section]:
            maplayers = QgsMapLayerRegistry.instance().mapLayers()
            for l,layer in maplayers.iteritems():
                if layer.type() == QgsMapLayer.VectorLayer and layer.name() == name:
                    feature = QgsFeature()
                    provider = layer.dataProvider()
                    allAttrs = provider.attributeIndexes()
                    fieldIndices = []
                    self.setLayerCrs(layer.crs())
                    for field in EpanetModel.COLUMNS[section]:
                        fieldidx = provider.fieldNameIndex(field)
                        if -1 == fieldidx:
                            raise GHydraulicsException('ERROR: Failed to locate '+field+' field in layer '+name)
                        fieldIndices.append(fieldidx)
                    iter = layer.getFeatures()
                    # Loop over all features
                    for feature in iter:
                        geometry = feature.geometry()
                        if geometry.type() == QGis.Point:
                            attrs = feature.attributes()
                            # write node
                            id = ''
                            for fieldidx in fieldIndices:
                                attribute = self.getString(attrs[fieldidx])
                                nodes = nodes + attribute + ' '
                                if '' == id:
                                    id = attribute
                            nodes = nodes + '\n'
                            # write coordinate
                            point = self.getFirstMultiPoint(geometry)
                            (x,y) = self.transformXY(point.x(), point.y())
                            p = str(x) + ' ' + str(y)
                            self.sections['COORDINATES'] = self.sections['COORDINATES'] + id + ' ' + p + '\n'
                            self.xcoords.append(float(x))
                            self.ycoords.append(float(y))
        self.sections[section] = nodes + '\n'

    # Extract pipes from model, write to INP format string
    def getPipes(self):
        if not self.layers.has_key(EpanetModel.PIPES):
            return
        if self.sections.has_key(EpanetModel.PIPES):
            pipes = self.sections[EpanetModel.PIPES]
        for name in self.layers[EpanetModel.PIPES]:
            maplayers = QgsMapLayerRegistry.instance().mapLayers()
            for l,layer in maplayers.iteritems():
                if layer.type() == QgsMapLayer.VectorLayer and layer.name() == name:
                    feature = QgsFeature()
                    provider = layer.dataProvider()
                    allAttrs = provider.attributeIndexes()
                    fieldIndices = []
                    node1idx = -1
                    self.setLayerCrs(layer.crs())
                    for field in EpanetModel.COLUMNS[EpanetModel.PIPES]:
                        fieldidx = provider.fieldNameIndex(field)
                        if -1 == fieldidx:
                            raise GHydraulicsException('ERROR: Failed to locate '+field+' field in layer '+name)
                        fieldIndices.append(fieldidx)
                        if EpanetModel.NODE1 == field:
                            node1idx = fieldidx
                    iter = layer.getFeatures()
                    # Loop over all features
                    for feature in iter:
                        geometry = feature.geometry()
                        if geometry.type() == QGis.Line:
                            attrs = feature.attributes()
                            # write node
                            id = ''
                            for fieldidx in fieldIndices:
                                attribute = str(attrs[fieldidx])
                                # Use dynamic nodes where necessary
                                if fieldidx == node1idx and self.virtualnodes.has_key(attribute):
                                    attribute = self.virtualnodes[attribute]
                                pipes = pipes + attribute + ' '
                                if '' == id:
                                    id = attribute
                            pipes = pipes + '\n'
                            # vertices
                            line = geometry.asPolyline()
                            for p in range(1,len(line)-1):
                                (x,y) = self.transformXY(line[p].x(), line[p].y())
                                self.xcoords.append(float(x))
                                self.ycoords.append(float(y))
                                self.sections['VERTICES'] = self.sections['VERTICES'] + id + ' ' + self.getString(x) + ' ' + self.getString(y) + '\n'
        self.sections['PIPES'] = pipes + '\n'

    # Virtual lines are lines in EPANET and nodes in QGIS
    def getVirtualLines(self, section):
        lines = ''
        if not self.layers.has_key(section):
            return
        for name in self.layers[section]:
            maplayers = QgsMapLayerRegistry.instance().mapLayers()
            for l,layer in maplayers.iteritems():
                if layer.type() == QgsMapLayer.VectorLayer and layer.name() == name:
                    # Examine type field for SOV handling
                    typeidx = -1
                    diameteridx = -1
                    minorlossidx = -1
                    settingidx = -1
                    feature = QgsFeature()
                    provider = layer.dataProvider()
                    allAttrs = provider.attributeIndexes()
                    fieldIndices = []
                    self.setLayerCrs(layer.crs())
                    for field in EpanetModel.COLUMNS[section]:
                        fieldidx = provider.fieldNameIndex(field)
                        if -1 == fieldidx:
                            raise GHydraulicsException('ERROR: Failed to locate '+field+' field in layer '+name)
                        if EpanetModel.VALVES == section:
                            if EpanetModel.TYPE == field:
                                typeidx = fieldidx
                            if EpanetModel.DIAMETER == field:
                                diameteridx = fieldidx
                            if EpanetModel.MINORLOSS == field:
                                minorlossidx = fieldidx
                            if EpanetModel.SETTING == field:
                                settingidx = fieldidx
                        fieldIndices.append(fieldidx)
                    if EpanetModel.VALVES == section and (-1 == typeidx or -1 == diameteridx or -1 == minorlossidx or -1 == settingidx):
                        raise GHydraulicsException('ERROR: Failed to locate type, diameter, minorloss or status in layer '+name)
                    iter = layer.getFeatures()
                    # Loop over all features
                    for feature in iter:
                        geometry = feature.geometry()
                        if geometry.type() == QGis.Point:
                            attrs = feature.attributes()
                            # write node
                            id = ''
                            virtual_id = ''
                            sov = False
                            if -1 != typeidx and str(attrs[typeidx]) == 'SOV':
                        # Handle shut off valves
                                sov = True
                            # Leave out the elevation
                            for i in range(1,len(fieldIndices)):
                                attribute = str(attrs[fieldIndices[i]])
                                if not sov:
                                    lines = lines + attribute + ' '
                                if '' == id:
                                    id = attribute
                                    virtual_id = id + GHydraulicsModel.VIRTUAL_POSTFIX
                                    self.virtualnodes[id] = virtual_id
                                    if not sov:
                                        lines = lines + id + ' ' + virtual_id + ' '
                            if not sov:
                                lines = lines + '\n'
                            # write first point
                            point = self.getFirstMultiPoint(geometry)
                            elevation = str(attrs[fieldIndices[0]])
                            self.addJunction(id, elevation, '0', '', point.x(), point.y())
                            # write the second point
                            length = self.getSecondVirtualLineJunction(id, elevation)
                            if sov:
                                pipes = self.sections[EpanetModel.PIPES]

                                    #;ID                Node1                   Node2                   Length          Diameter        Roughness       MinorLoss       Status

                                pipes = pipes + id + ' ' + id + ' ' + virtual_id + ' ' + self.getString(length) + ' ' + self.getString(attrs[diameteridx]) + ' 1 ' + self.getString(attrs[minorlossidx]) + ' '+self.getString(attrs[settingidx]) + '\n'
                                self.sections[EpanetModel.PIPES] = pipes

        self.sections[section] = lines + '\n'

    # Insert virtual node into line
    def getSecondVirtualLineJunction(self, id, elevation):
        node = ''
        virtual_id = self.virtualnodes[id]
        # Find the referenced pipe
        for name in self.layers[EpanetModel.PIPES]:
            maplayers = QgsMapLayerRegistry.instance().mapLayers()
            for l,layer in maplayers.iteritems():
                if layer.type() == QgsMapLayer.VectorLayer and layer.name() == name:
                    feature = QgsFeature()
                    provider = layer.dataProvider()
                    allAttrs = provider.attributeIndexes()
                    node1idx = provider.fieldNameIndex(EpanetModel.NODE1)
                    self.setLayerCrs(layer.crs())
                    if -1 == node1idx:
                        raise GHydraulicsException('ERROR: Failed to locate '+EpanetModel.NODE1+' field in layer '+name)
                    iter = layer.getFeatures()
                    # Loop over all features
                    for feature in iter:
                        geometry = feature.geometry()
                        if geometry.type() == QGis.Line:
                            attrs = feature.attributes()
                            node1 = self.getString(attrs[node1idx])
                            if  node1 == id or node1 == virtual_id:
                                line = geometry.asPolyline()
                                s = line[0]
                                e = line[1]
                                if GHydraulicsModel.VIRTUAL_LINE_LENGTH > s.sqrDist(e):
                                    p = [(s.x()+e.x())/2,(s.y()+e.y())/2]
                                else:
                                    sv = numpy.array([s.x(), s.y()])
                                    ev = numpy.array([e.x(), e.y()])
                                    dv = ev - sv
                                    nv = dv/numpy.linalg.norm(dv)
                                    p = sv + nv
                                self.addJunction(virtual_id, elevation, '0', '', p[0], p[1])
                                return math.sqrt(math.pow(p[0]-s[0], 2) + math.pow(p[1]-s[1], 2))
        raise GHydraulicsException('ERROR: Failed to locate Pipe with NODE1 named '+id)

    # Add a junction to the buffer
    def addJunction(self, id, elevation, demand, pattern, x, y):
        (x,y) = self.transformXY(x, y)
        self.sections[EpanetModel.JUNCTIONS] = self.sections[EpanetModel.JUNCTIONS] + id + ' ' + elevation+ ' ' + demand + ' ' + pattern + '\n'
        self.addXY(EpanetModel.COORDINATES, id, x, y)

    # section is one of COORDINATES or VERTICES
    def addXY(self, section, id, x, y):
        self.sections[section] = self.sections[section] + id + ' ' + self.getString(x) + ' ' + self.getString(y) + '\n'

    # Write a section label to the INP file
    def writeSectionLabel(self, section):
        self.inpfile.write('['+section+'] ; created by GHydraulics\n')

    # Write out the backdrop section
    def writeBackdropSection(self):
        self.writeSectionLabel('BACKDROP')
        backdropfile = self.getBackdropFromInp(str(self.filename))
        canvas = self.iface.mapCanvas()
        canvas.saveAsImage(backdropfile, None, 'BMP')
        # Use current view extent, if there are no network elements
        extent = canvas.extent()
        mins = str(extent.xMinimum()) + ' ' + str(extent.yMinimum())
        maxs = str(extent.xMaximum()) + ' ' + str(extent.yMaximum())
        #if 0 < len(self.xcoords) and 0 < len(self.ycoords):
        #    mins = str(min(self.xcoords)*0.9) + ' '  + str(min(self.ycoords)*0.9) + ' '
        #    maxs = str(max(self.xcoords)*1.1) + ' ' + str(max(self.ycoords)*1.1)
        self.inpfile.write('DIMENSIONS ' + mins + ' ' + maxs + '\n')
        units = 'None'
        mapunits = canvas.mapUnits()
        if not canvas.mapRenderer().destinationCrs().isValid():
            for section in EpanetModel.GIS_SECTIONS:
                if self.layers.has_key(section):
                    for name in self.layers[EpanetModel.PIPES]:
                        maplayers = QgsMapLayerRegistry.instance().mapLayers()
                        for l,layer in maplayers.iteritems():
                            if layer.type() == QgsMapLayer.VectorLayer and layer.name() == name:
                                mapunits = layer.crs().mapUnits()
        if GHydraulicsModel.UNITMAP.has_key(mapunits):
            units = GHydraulicsModel.UNITMAP[mapunits]
        self.inpfile.write('UNITS '+units+'\n')
        self.inpfile.write('FILE "' + os.path.basename(backdropfile) + '"\n')
        self.inpfile.write('OFFSET 0.00 0.00\n\n')

    # Canonical backdrop name from inp file
    def getBackdropFromInp(self, inpfilename):
        return os.path.splitext(inpfilename)[0]+'.bmp'

    def __init__(self, templateFilename, iface):
        self.templateFilename = templateFilename
        self.iface = iface
        # Calculate map extent
        self.xcoords = []
        self.ycoords = []
        self.getLayers()
        self.sections = {EpanetModel.JUNCTIONS: '', EpanetModel.PIPES: ''}
        # Dictionary of those node1 values that change because of virtual lines
        self.virtualnodes = {}
        # Transform coordinates, where necessary
        self.crstransform = False
        for section in EpanetModel.COORDINATE_DATA_SECTIONS:
            self.sections[section] = ''
        for section in EpanetModel.VIRTUAL_LINE_SECTIONS:
            self.getVirtualLines(section)
        for section in EpanetModel.COORDINATE_SECTIONS:
            self.getNodes(section)
        self.getPipes()
