from PyQt4.QtGui import *
from PyQt4.QtCore import *
from qgis.core import *
from qgis.gui import *
import traceback, datetime, numpy, os, struct

from MagicWand.modules.colormath.color_diff import delta_e_cie2000
from MagicWand.modules.colormath.color_objects import LabColor
from MagicWand.config.SettingsInterface import SettingsInterface

class MapGraph(QObject):
    def __init__(self):
        QObject.__init__(self)
        self.setInt = SettingsInterface()
        settings = self.setInt.loadSettings()

        self.tolerance = settings["tolerance"]
        self.approximation = settings["approximation"]
        self.viewSteps = settings["viewSteps"]
        self.timeout = settings["timeout"]

        self.directions = ["up", "up-right", "right", "down-right","down", "down-left", "left", "up-left"]
        self.cornerDirections = [self.directions[1]] +[self.directions[3]] + [self.directions[5]] + [self.directions[7]]


    def pixelSimili(self, lab1, lab2):
        value = delta_e_cie2000(lab1, lab2, Kl=1, Kc=1, Kh=1)

        return self.tolerance - value >= 0

    def getPixelRectangleNxN(self, qpoint, raster=None):

        if raster != None:
            pixelSizeX = raster.rasterUnitsPerPixelX()
            pixelSizeY = raster.rasterUnitsPerPixelY()
        else:
            pixelSizeX = 0.5
            pixelSizeY = 0.5

        xMin = qpoint.x() - ( ( pixelSizeX ) * self.approximation)
        xMax = qpoint.x() + ( ( pixelSizeX ) * self.approximation)
        yMin = qpoint.y() - ( ( pixelSizeY ) * self.approximation)
        yMax = qpoint.y() + ( ( pixelSizeY ) * self.approximation)

        return QgsRectangle(xMin, yMin, xMax, yMax)

    def getPointValue(self, point, raster):
        rect = self.getPixelRectangleNxN(point, raster)
        labColor = self.getRectValue(rect, raster)[0]
        return labColor

    def getRectValue(self, rect, raster):

        #pixelSizeX = raster.rasterUnitsPerPixelX()
        #pixelSizeY = raster.rasterUnitsPerPixelY()

        pixelSizeX = (rect.xMaximum() - rect.xMinimum())/2
        pixelSizeY = (rect.yMaximum() - rect.yMinimum())/2

        xPos = []
        yPos = []

        for n in range (0, 3):
            xPos.append( rect.xMinimum() + ( (pixelSizeX) * n ) )
            yPos.append( rect.yMinimum() + ( (pixelSizeY) * n ) )

        rTot, gTot, bTot = 0,0,0
        count = 0

        for x in xPos:
            for y in yPos:
                r,g,b = self.getRGBColor(QgsPoint(x, y), raster)

                if r != None and g != None and b != None:
                    count += 1
                    rTot += r
                    gTot += g
                    bTot += b

        if count > 0:
            rFin = rTot/count
            gFin = gTot/count
            bFin = bTot/count

            return (LabColor.fromRGB((rFin,gFin,bFin)),(rFin,gFin,bFin))
        else:
            return (LabColor.fromRGB((-1, -1, -1)),(-1, -1, -1))

    def getRGBColor(self, qgspoint, raster):
        r, g, b = None, None, None
        if type(raster.renderer()) == QgsMultiBandColorRenderer:
            indentified = raster.dataProvider().identify(qgspoint, QgsRaster.IdentifyFormatValue)
            res = indentified.results()
            redValue = res[raster.renderer().redBand()]
            greenValue = res[raster.renderer().greenBand()]
            blueValue = res[raster.renderer().blueBand()]

            if redValue != None:
                redContrastEnhancement = raster.renderer().redContrastEnhancement()
                if redContrastEnhancement:
                    r = redContrastEnhancement.enhanceContrast(redValue)
                else:
                    r = redValue
                if r == -1 : r = None

            if greenValue != None:
                greenContrastEnhancement = raster.renderer().greenContrastEnhancement()
                if greenContrastEnhancement:
                    g = greenContrastEnhancement.enhanceContrast(greenValue)
                else:
                    g = greenValue
                if g == -1 : g = None

            if blueValue != None:
                blueContrastEnhancement = raster.renderer().blueContrastEnhancement()
                if blueContrastEnhancement:
                    b = blueContrastEnhancement.enhanceContrast(blueValue)
                else:
                    b = blueValue
                if b == -1 : b = None
        else:
            QMessageBox.information(None, "", "Raster renderer type can be only multi band color (QgsMultiBandColorRenderer) at the moment. Change it from the specific layer properties -> style")
        return r,g,b

    def getAllAdjacentPixelPositionsNxN(self, qpoint, raster):
        points = []
        for ind, values in enumerate(self.directions):
            toAdd = (self.getAdjacentPixelPositionNxN(qpoint, raster, values), values)
            points.append(toAdd)
        return points

    def getAdjacentValidPixels(self, visitati, qpoint, raster, colorFirst, extent):
        adjacentPoints = self.getAllAdjacentPixelPositionsNxN(qpoint, raster)
        validPixels = set([])
        for point, direction in adjacentPoints:
            if extent.contains(point) and point not in visitati:
                labColor = self.getPointValue(point, raster)

                isValid = self.pixelSimili(colorFirst, labColor)

                if isValid:
                    validPixels.add(point)#isValid))

        final_valid_pixels = validPixels
        """
        final_valid_pixels = []
        for i, element in enumerate(validPixels):
            if element[3]:
                direction = element[2]
                if direction in self.cornerDirections:
                    prec_i = (i - 1) % len(validPixels)
                    succ_i = (i + 1) % len(validPixels)
                    if validPixels[prec_i][3] or validPixels[succ_i][3]:
                        final_valid_pixels.append((element[0], element[1], element[2]))
                else:
                    final_valid_pixels.append((element[0], element[1], element[2]))

                final_valid_pixels.append((element[0], element[1], element[2]))
        """
        return final_valid_pixels

    def getAdjacentPixelPositionNxN(self, qpoint, raster, side):
        pixelSizeX = ( raster.rasterUnitsPerPixelX() * self.approximation) * 2
        pixelSizeY = ( raster.rasterUnitsPerPixelY() * self.approximation) * 2

        toR = QgsPoint(qpoint.x(), qpoint.y())
        if side == "up": #up
            toR.setY(toR.y()+pixelSizeY)
        elif side == "up-right": #up-right
            toR.setY(toR.y()+pixelSizeY)
            toR.setX(toR.x()+pixelSizeX)
        elif side == "right": #right
            toR.setX(toR.x()+pixelSizeX)
        elif side == "down-right": #down-right
            toR.setY(toR.y()-pixelSizeY)
            toR.setX(toR.x()+pixelSizeX)
        elif side == "down": #down
            toR.setY(toR.y()-pixelSizeY)
        elif side == "down-left": #down-left
            toR.setY(toR.y()-pixelSizeY)
            toR.setX(toR.x()-pixelSizeX)
        elif side == "left": #left
            toR.setX(toR.x()-pixelSizeX)
        elif side == "up-left": #up-left
            toR.setY(toR.y()+pixelSizeY)
            toR.setX(toR.x()-pixelSizeX)
        else:
            toR = QgsPoint(0,0)
        return toR