import traceback
from osgeo import gdal, ogr, osr
import os, sys, re, io
from os.path import dirname, basename, exists
from os import makedirs
from qgis.PyQt.QtXml import QDomDocument
from qgis.PyQt import uic
from qgis.core import *
from qgis.PyQt.QtCore import QSize
from qgis.PyQt.QtGui import QColor

import numpy as np
#from qgis.gui import *

jp = os.path.join
dn = os.path.dirname

# dictionary to store form classes and avoid multiple calls to read <myui>.ui
FORM_CLASSES = dict()
QGIS_RESOURCE_WARNINGS = set()

def loadUIFormClass(pathUi:str, from_imports=False, resourceSuffix:str='', fixQGISRessourceFileReferences=True, _modifiedui=None):
    """
    Loads Qt UI files (*.ui) while taking care on QgsCustomWidgets.
    Uses PyQt4.uic.loadUiType (see http://pyqt.sourceforge.net/Docs/PyQt4/designer.html#the-uic-module)
    :param pathUi: *.ui file path
    :param from_imports:  is optionally set to use import statements that are relative to '.'. At the moment this only applies to the import of resource modules.
    :param resourceSuffix: is the suffix appended to the basename of any resource file specified in the .ui file to create the name of the Python module generated from the resource file by pyrcc4. The default is '_rc', i.e. if the .ui file specified a resource file called foo.qrc then the corresponding Python module is foo_rc.
    :return: the form class, e.g. to be used in a class definition like MyClassUI(QFrame, loadUi('myclassui.ui'))
    """

    RC_SUFFIX = resourceSuffix
    assert os.path.isfile(pathUi), '*.ui file does not exist: {}'.format(pathUi)


    if pathUi not in FORM_CLASSES.keys():
        #parse *.ui xml and replace *.h by qgis.gui

        with open(pathUi, 'r', encoding='utf-8') as f:
            txt = f.read()

        dirUi = os.path.dirname(pathUi)

        locations = []

        for m in re.findall(r'(<include location="(.*\.qrc)"/>)', txt):
            locations.append(m)

        missing = []
        for t in locations:
            line, path = t
            if not os.path.isabs(path):
                p = os.path.join(dirUi, path)
            else:
                p = path

            if not os.path.isfile(p):
                missing.append(t)
        match = re.search(r'resource="[^:].*/QGIS[^/"]*/images/images.qrc"',txt)
        if match:
            txt = txt.replace(match.group(), 'resource=":/images/images.qrc"')



        if len(missing) > 0:

            missingQrc = []
            missingQgs = []

            for t in missing:
                line, path = t
                if re.search(r'.*(?i:qgis)/images/images\.qrc.*', line):
                    missingQgs.append(m)
                else:
                    missingQrc.append(m)

            if len(missingQrc) > 0:
                print('{}\nrefers to {} none-existing resource (*.qrc) file(s):'.format(pathUi, len(missingQrc)))
                for i, t in enumerate(missingQrc):
                    line, path = t
                    print('{}: "{}"'.format(i+1, path), file=sys.stderr)

        doc = QDomDocument()
        doc.setContent(txt)

        elem = doc.elementsByTagName('customwidget')
        for child in [elem.item(i) for i in range(elem.count())]:
            child = child.toElement()
            className = str(child.firstChildElement('class').firstChild().nodeValue())
            if className.startswith('Qgs'):
                cHeader = child.firstChildElement('header').firstChild()
                cHeader.setNodeValue('qgis.gui')


        # collect resource file locations
        elems = doc.elementsByTagName('include')
        qrcPaths = []
        for i in range(elems.count()):
            node = elems.item(i).toElement()
            lpath = node.attribute('location')
            if len(lpath) > 0 and lpath.endswith('.qrc'):
                p = lpath
                if not os.path.isabs(lpath):
                    p = os.path.join(dirUi, lpath)
                else:
                    p = lpath
                qrcPaths.append(p)


        buffer = io.StringIO()  # buffer to store modified XML

        if isinstance(_modifiedui, str):
            f = open(_modifiedui, 'w', encoding='utf-8')
            f.write(doc.toString())
            f.flush()
            f.close()

        buffer.write(doc.toString())
        buffer.flush()
        buffer.seek(0)



        #if existent, make resource file directories available to the python path (sys.path)
        baseDir = os.path.dirname(pathUi)
        tmpDirs = []
        if True:
            for qrcPath in qrcPaths:
                d = os.path.abspath(os.path.join(baseDir, qrcPath))
                d = os.path.dirname(d)
                if os.path.isdir(d) and d not in sys.path:
                    tmpDirs.append(d)
            sys.path.extend(tmpDirs)

        #create requried mockups

        if True:
            FORM_CLASS_MOCKUP_MODULES = [os.path.splitext(os.path.basename(p))[0] for p in qrcPaths]
            FORM_CLASS_MOCKUP_MODULES = [m for m in FORM_CLASS_MOCKUP_MODULES if m not in sys.modules.keys()]
            for mockupModule in FORM_CLASS_MOCKUP_MODULES:
                pass
                #print('ADD MOCKUP MODULE {}'.format(mockupModule))

                #sys.modules[mockupModule] = resourcemockup


        #load form class
        try:
            FORM_CLASS, _ = uic.loadUiType(buffer, resource_suffix=RC_SUFFIX)
        except Exception as ex1:
            print(doc.toString(), file=sys.stderr)
            info = 'Unable to load {}'.format(pathUi) + '\n{}'.format(str(ex1))
            ex = Exception(info)
            raise ex

        for mockupModule in FORM_CLASS_MOCKUP_MODULES:
            if mockupModule in sys.modules.keys():
                sys.modules.pop(mockupModule)


        buffer.close()

        FORM_CLASSES[pathUi] = FORM_CLASS

        #remove temporary added directories from python path
        for d in tmpDirs:
            sys.path.remove(d)
    if pathUi.endswith('spectrallibrarywidget.ui'):
        s =""


    return FORM_CLASSES[pathUi]

def qgisDataTypeToNumpyDataType(dataType):

    if dataType == Qgis.Byte:
        return np.uint8
    elif dataType == Qgis.Float32:
        return np.float32
    elif dataType == Qgis.Float64:
        return np.float64
    elif dataType == Qgis.Int16:
        return np.int16
    elif dataType == Qgis.Int32:
        return np.int32
    elif dataType == Qgis.UInt16:
        return np.uint16
    elif dataType == Qgis.UInt32:
        return np.uint32
    elif dataType == Qgis.UnknownDataType:
        return None
    else:
        raise Exception('unsupported data type: {}'.format(dataType))

def toFloat(s, default=0):
    try:
        return float(s)
    except:
        return default

def roundPlotRangeValue(v):
    v = toFloat(v)
    if abs(v) < 1:
        v = round(v, 4)
    elif abs(v) < 100:
        v = round(v, 2)
    else:
        v = int(v)
    return v

def roundPlotHistogramRangeValue(v):
    return int(v) #round(v, 10)

def version():
    metadata = os.path.abspath(os.path.join(__file__, '..', '..', 'metadata.txt'))
    with open(metadata) as f:
        for line in f.readlines():
            if line.startswith('version='):
                return line.split('=')[1].strip()

def fidArray(layer, crs, extent, size, onlySelectedFeatures, allTouched=False):

    assert isinstance(layer, QgsVectorLayer)
    assert isinstance(crs, QgsCoordinateReferenceSystem)
    assert isinstance(extent, QgsRectangle)
    assert isinstance(size, QSize)

    # Create reprojected temp layer with selected features only
    filename = '/vsimem/rasterdataplotting/mask.gpkg'
    options = QgsVectorFileWriter.SaveVectorOptions()
    options.fileEncoding = 'System'
    options.ct = QgsCoordinateTransform(crs, layer.crs(), QgsProject.instance())
    options.driverName = 'GPKG'
    options.onlySelectedFeatures = onlySelectedFeatures
    options.skipAttributeCreation = False
    options.filterExtent = extent
    QgsVectorFileWriter.writeAsVectorFormat(layer, filename, options)

    # Create another memory layer to add FID field.
    vds1 = ogr.Open(filename)
    layer1 = vds1.GetLayerByIndex(0)
    assert isinstance(layer1, ogr.Layer)

    driver = ogr.GetDriverByName('Memory')
    vds2 = driver.CreateDataSource('wrk')
#    vds2 = ogr.GetDriverByName('Memory').CreateDataSource('wrk')
    srs = osr.SpatialReference(crs.toWkt())

    # - create layer
    geom_type = layer1.GetGeomType()
    layer2 = vds2.CreateLayer('layer', srs=srs, geom_type=geom_type)
    assert isinstance(layer2, ogr.Layer)

    # - create field
    field = ogr.FieldDefn('_FID', ogr.OFTInteger)
    layer2.CreateField(field)

    # - create features and set values
    for fid, feature in enumerate(layer1):
        assert isinstance(feature, ogr.Feature)

        outFeature = ogr.Feature(layer2.GetLayerDefn())
        outFeature.SetGeometry(feature.GetGeometryRef())
        outFeature.SetField('_FID', fid)
        layer2.CreateFeature(outFeature)

    # Free memory.

    vds1 = None
    gdal.Unlink(filename)


    # Rasterize FID.
    xmin = extent.xMinimum()
    xmax = extent.xMaximum()
    ymin = extent.yMinimum()
    ymax = extent.yMaximum()
    xsize = size.width()
    ysize = size.height()
    xres = (xmax - xmin) / xsize
    yres = (ymax - ymin) / ysize
    geotransform = (xmin, xres, 0.0, ymax, 0.0, -yres)
    gdalType = gdal.GDT_Int32
    initValue = -1
    burnAttribute = '_FID'
    filename = ''
    driver = gdal.GetDriverByName('MEM')
    #filename = 'c:/vsimem/rasterdataplotting/rasterFid.bsq'
    #driver = gdal.GetDriverByName('ENVI')

    ds = driver.Create(filename, xsize, ysize, 1, gdalType)
    assert isinstance(ds, gdal.Dataset)
    ds.GetRasterBand(1).Fill(initValue)
    ds.SetProjection(crs.toWkt())
    ds.SetGeoTransform(geotransform)

    rasterizeLayerOptions = list()
    if allTouched:
        rasterizeLayerOptions.append('ALL_TOUCHED=TRUE')
    if burnAttribute is not None:
        rasterizeLayerOptions.append('ATTRIBUTE=' + burnAttribute)

    gdal.RasterizeLayer(ds, [1], layer2, options=rasterizeLayerOptions)
    array = ds.ReadAsArray()
    return array

def colorArray(layer, crs, extent, size, onlySelectedFeatures, allTouched=False):

    assert isinstance(layer, QgsVectorLayer)
    assert isinstance(crs, QgsCoordinateReferenceSystem)
    assert isinstance(extent, QgsRectangle)
    assert isinstance(size, QSize)

    # Create reprojected temp layer with selected features only
    filename = '/vsimem/rasterdataplotting/mask.gpkg'
    options = QgsVectorFileWriter.SaveVectorOptions()
    options.fileEncoding = 'System'
    options.ct = QgsCoordinateTransform(crs, layer.crs(), QgsProject.instance())
    options.driverName = 'GPKG'
    options.onlySelectedFeatures = onlySelectedFeatures
    options.skipAttributeCreation = False
    options.filterExtent = extent
    QgsVectorFileWriter.writeAsVectorFormat(layer, filename, options)

    # Get renderer colors.
    renderer = layer.renderer()
    defaultColor = QColor('red')

    if isinstance(renderer, QgsInvertedPolygonRenderer):
        renderer = renderer.embeddedRenderer()

    if isinstance(renderer, QgsCategorizedSymbolRenderer):

        def getColor(feature):
            try:
                legendClassificationAttribute = renderer.legendClassificationAttribute()
                legendClassificationValue = feature.GetField(legendClassificationAttribute)
                categoryIndex = renderer.categoryIndexForValue(legendClassificationValue)
                category = renderer.categories()[categoryIndex]
                color = category.symbol().color()
            except:
                traceback.print_exc()
                color = defaultColor
            return color

    elif isinstance(renderer, QgsGraduatedSymbolRenderer):

        def getColor(feature):
            try:
                legendClassificationAttribute = renderer.legendClassificationAttribute()
                legendClassificationValue = feature.GetField(legendClassificationAttribute)
                color = renderer.symbolForValue(legendClassificationValue).color()

            except:
                traceback.print_exc()
                color = defaultColor
            return color

    elif isinstance(renderer, QgsSingleSymbolRenderer):

        getColor = lambda feature: renderer.symbol().color()

    else:

        getColor = lambda feature: defaultColor

    # Create another memory layer to add COLOR field coded as 1RRRGGGBBB integer.
    vds1 = ogr.Open(filename)
    layer1 = vds1.GetLayerByIndex(0)
    assert isinstance(layer1, ogr.Layer)

    driver = ogr.GetDriverByName('Memory')
    vds2 = driver.CreateDataSource('wrk')
    srs = osr.SpatialReference(crs.toWkt())

    # - create layer
    geom_type = layer1.GetGeomType()
    layer2 = vds2.CreateLayer('layer', srs=srs, geom_type=geom_type)
    assert isinstance(layer2, ogr.Layer)

    # - create field
    field = ogr.FieldDefn('_COLOR', ogr.OFTInteger)
    layer2.CreateField(field)

    # - create features and set values
    encodeColor = lambda qcolor: 1000000000 + qcolor.red() * 1000000 + qcolor.green() * 1000 + qcolor.blue()
    for feature in layer1:
        assert isinstance(feature, ogr.Feature)
        color = encodeColor(getColor(feature))
        outFeature = ogr.Feature(layer2.GetLayerDefn())
        outFeature.SetGeometry(feature.GetGeometryRef())
        outFeature.SetField('_COLOR', color)
        layer2.CreateFeature(outFeature)

    # Free memory.
    vds1 = None
    gdal.Unlink(filename)

    # Rasterize COLOR.
    xmin = extent.xMinimum()
    xmax = extent.xMaximum()
    ymin = extent.yMinimum()
    ymax = extent.yMaximum()
    xsize = size.width()
    ysize = size.height()
    xres = (xmax - xmin) / xsize
    yres = (ymax - ymin) / ysize
    geotransform = (xmin, xres, 0.0, ymax, 0.0, -yres)
    gdalType = gdal.GDT_Int32
    initValue = -1
    burnAttribute = '_COLOR'
    filename = ''
    driver = gdal.GetDriverByName('MEM')
    #filename = 'c:/vsimem/rasterdataplotting/rasterFid.bsq'
    #driver = gdal.GetDriverByName('ENVI')

    ds = driver.Create(filename, xsize, ysize, 1, gdalType)
    assert isinstance(ds, gdal.Dataset)
    ds.GetRasterBand(1).Fill(initValue)
    ds.SetProjection(crs.toWkt())
    ds.SetGeoTransform(geotransform)

    rasterizeLayerOptions = list()
    if allTouched:
        rasterizeLayerOptions.append('ALL_TOUCHED=TRUE')
    if burnAttribute is not None:
        rasterizeLayerOptions.append('ATTRIBUTE=' + burnAttribute)

    gdal.RasterizeLayer(ds, [1], layer2, options=rasterizeLayerOptions)
    array = ds.ReadAsArray()
    return array
