import traceback
import numpy as np
from qgis.core import *
from qgis.gui import *
from PyQt5.QtCore import *
from PyQt5.QtWidgets import *
from PyQt5.QtGui import *
from osgeo import gdal

import tempfile
from os.path import join, dirname
from os import listdir
from qgis.gui import *
from PyQt5.QtWidgets import *
from PyQt5.QtCore import *
from PyQt5.QtGui import *
from .utils import loadUIFormClass, qgisDataTypeToNumpyDataType

import pyqtgraph as pg

class RasterDataPlotting(QObject):

    pluginName = 'Raster Data Plotting'
    pluginId = pluginName.replace(' ', '')

    class PlotType(object):
        Scatter = 0
        SpectralProfile = 1
        TemporalProfile = 2

    def __init__(self, iface, parent=None):
        QObject.__init__(self, parent)
        assert isinstance(iface, QgisInterface)
        self.iface = iface

        # init ui
        self.ui = RasterDataPlottingUi()
        self.ui.setWindowIcon(self.icon())
        self.ui.imageView().view.invertY(False)

        # hide
        #self.ui.plotType().hide()
        self.ui.g1.hide()
        self.ui.bandsX().hide()

        # init plots
        self.profiles = dict()

        # state
        self.cacheBandX = dict()

        # connect signals
        #self.canvas().extentsChanged.connect(self.onSomethingChanged)
        self.canvas().renderComplete.connect(self.onSomethingChanged)
        self.ui.layerX().layerChanged.connect(self.onLayerXChanged)
        self.ui.layerY().layerChanged.connect(self.onSomethingChanged)
        self.ui.bandXRaster().bandChanged.connect(self.onSomethingChanged)
        self.ui.bandYRaster().bandChanged.connect(self.onSomethingChanged)
        self.ui.bandXRenderer().currentIndexChanged.connect(self.onSomethingChanged)
        self.ui.bandYRenderer().currentIndexChanged.connect(self.onSomethingChanged)
        self.ui.binsX().valueChanged.connect(self.onSomethingChanged)
        self.ui.binsY().valueChanged.connect(self.onSomethingChanged)
        self.ui.bandsX().valueChanged.connect(self.onBandsXChanged)
        self.ui.plotType().currentIndexChanged.connect(self.onPlotTypeChanged)

        # add to plugins toolbar
        if isinstance(self.iface, QgisInterface):
            self.action1 = QAction(self.icon(), self.pluginName.replace(' ', ''), self.iface.mainWindow())
            self.action1.triggered.connect(self.toggleUiVisibility)
            self.iface.addToolBarIcon(self.action1)
            #self.action2 = QAction('Load test data', self.iface.mainWindow())
            #self.action2.triggered.connect(self.loadTestdata)
            #self.iface.addPluginToMenu('Raster Timeseries Manager', self.action2)

    def show(self):
        self.ui.show()

    def toggleUiVisibility(self):
        self.ui.setVisible(not self.ui.isVisible())

    def pluginFolder(self):
        return join(dirname(__file__), '..')

    def icon(self):
        return QIcon(join(self.pluginFolder(), 'icon.png'))

    def showWaitCursor(self):
        QApplication.setOverrideCursor(Qt.WaitCursor)

    def hideWaitCursor(self):
        QApplication.restoreOverrideCursor()

    def handleException(self, error):
        assert isinstance(error, Exception)
        traceback.print_exc()

    def unload(self):
        """Unload the plugin"""
        self.iface.removeDockWidget(self.ui)
        self.iface.removePluginMenu(self.pluginName, self.action1)

    def canvas(self):
        canvas = self.iface.mapCanvas()
        assert isinstance(canvas, QgsMapCanvas)
        return canvas

    def onBandsXChanged(self, nbands):
        self.cacheBandX[self.ui.layerX().currentLayer()] = nbands
        self.clearProfiles()
        self.plot()

    def onLayerXChanged(self):
        if (self.plotType() == self.PlotType.SpectralProfile) or (self.plotType() == self.PlotType.TemporalProfile):
            self.clearProfiles()
            layer = self.ui.layerX().currentLayer()
            if layer is not None:
                assert isinstance(layer, QgsRasterLayer)
                self.ui.bandsX().setMaximum(layer.bandCount())
                if self.plotType() == self.PlotType.SpectralProfile:
                    self.ui.bandsX().setValue(self.cacheBandX.get(layer, layer.bandCount()))
                if self.plotType() == self.PlotType.TemporalProfile:
                    self.ui.bandsX().setValue(self.cacheBandX.get(layer, 1))
        self.plot()

    def onPlotTypeChanged(self, index):
        self.initPlot(plotType=index)
        self.onLayerXChanged()

    def onSomethingChanged(self, *args):
        self.plot()

    def initPlot(self, plotType):
        self.ui.imageView().setHistogramVisibility(visible=plotType == self.PlotType.Scatter)

    def plotReset(self):
        pass #self.ui.imageView().clear()

    def plotType(self):
        return self.ui.plotType().currentIndex()

    def plot(self):
        if self.plotType() == self.PlotType.Scatter:
            self.plotScatter()
        elif self.plotType() == self.PlotType.SpectralProfile:
            self.plotProfile()
        elif self.plotType() == self.PlotType.TemporalProfile:
            self.plotProfile()

    def plotScatter(self):

        # return if layers are not valid
        if self.ui.layerX().currentLayer() is None or self.ui.layerY().currentLayer() is None:
            self.plotReset()
            return

        # derive band indices
        if self.ui.bandMode().currentIndex() == 0: # raster band selection
            bandX = self.ui.bandXRaster().currentBand()
            bandY = self.ui.bandYRaster().currentBand()
        else: # renderer band selection

            def getRendererBand(uiLayer, uiIndex):
                assert isinstance(uiLayer, QgsMapLayerComboBox)
                assert isinstance(uiIndex, QComboBox)
                renderer = uiLayer.currentLayer().renderer()
                if not isinstance(renderer, QgsMultiBandColorRenderer):
                    if uiIndex.currentIndex() != 0:
                        uiIndex.setCurrentIndex(0)

                return renderer.usesBands()[uiIndex.currentIndex()]

            bandX = getRendererBand(uiLayer=self.ui.layerX(), uiIndex=self.ui.bandXRenderer())
            bandY = getRendererBand(uiLayer=self.ui.layerY(), uiIndex=self.ui.bandYRenderer())

        # read data and mask
        x, mask = self.readArray(layer=self.ui.layerX().currentLayer(), band=bandX)
        y, mask2 = self.readArray(layer=self.ui.layerY().currentLayer(), band=bandY)

        if x is None or x.size == 0:
            self.plotReset()
            return

        if y is None or y.size == 0:
            self.plotReset()
            return

        # apply mask
        np.logical_and(mask, mask2, out=mask)
        try:
            x = x[mask]
            y = y[mask]
        except AttributeError: # handle full masked case
            self.plotReset()
            return

        # check for empty data
        if x.size == 0:
            self.plotReset()
            return

        if y.size == 0:
            self.plotReset()
            return

        binsX = self.ui.binsX().value()
        binsY = self.ui.binsY().value()
        minX = x.min()
        minY = y.min()
        maxX = x.max()
        maxY = y.max()
        rangeX = maxX - minX
        rangeY = maxY - minY

        #_, xedges, yedges = np.histogram2d(x=[0], y=[0], bins=bins, range=[rangeX, rangeY])
        h = np.histogram2d(x=x, y=y, bins=[binsX, binsY], density=True)[0]#[xedges, yedges])[0]

        scaleX = rangeX / float(binsX)
        scaleY = rangeY / float(binsY)
        self.ui.imageView().setImage(h, pos=[minX, minY], scale=[scaleX, scaleY])
        self.ui.imageView().setLevels(*np.percentile(h, (1, 99)))  # stretch ramp between 2% - 98%

    def plotProfile(self):

        values = self.readProfile(layer=self.ui.layerX().currentLayer())
        if values is not None:
            x = range(len(values[0]))
            for i, y in enumerate(values):
                if not self.profileInitialized(i):
                    self.initProfile(i)
                self.profiles[i].setData(x, y)

    def initProfile(self, index):
        assert isinstance(index, int)
        self.profiles[index] = self.ui.imageView().plotItem().plot([0, 0],[0, 0], pen=pg.mkPen(color=(255, 0, 0), width=2, style=QtCore.Qt.SolidLine))
        self.ui.imageView().clear()

    def clearProfiles(self):
        for profile in self.profiles.values():
            profile.clear()

    def profileInitialized(self, index):
        assert isinstance(index, int)
        return self.profiles.get(index) is not None

    def readProfile(self, layer):

        if layer is None:

            values = None

        else:

            # get center
            assert isinstance(layer, QgsRasterLayer), repr(layer)
            provider = layer.dataProvider()
            assert isinstance(provider, QgsRasterDataProvider)
            center = self.canvas().center()
            assert isinstance(center, QgsPointXY)
            extent = self.canvas().extent()
            assert isinstance(extent, QgsRectangle)

            # reproject center if needed
            canvasCrs = self.canvas().mapSettings().destinationCrs()
            layerCrs = layer.crs()
            if canvasCrs != layerCrs:
                tr = QgsCoordinateTransform(canvasCrs, layerCrs, QgsProject.instance())
                center = tr.transform(center)
                extent = tr.transform(extent)

            # read data
            size = self.canvas().size()
            identifyResult = provider.identify(point=center, format=QgsRaster.IdentifyFormatValue,
                                               boundingBox=extent, width=size.width(), height=size.height())
            results = identifyResult.results()
            values = [results[i+1] for i in range(len(results))]
            values = [v if v is not None else np.nan for v in values]

            # split data
            nbands = self.ui.bandsX().value()
            if (layer.bandCount() % nbands) != 0:
                return None
            nobservations = int(layer.bandCount() / nbands)
            if self.plotType() == self.PlotType.TemporalProfile:
                values = np.array(values).reshape((nobservations, nbands)).T
            elif self.plotType() == self.PlotType.SpectralProfile:
                values = np.array(values).reshape((nobservations, nbands))

        return values

    def readArray(self, layer, band):
        if layer is None:
            array = mask = None
        else:
            # get extent
            assert isinstance(layer, QgsRasterLayer), repr(layer)
            provider = layer.dataProvider()
            assert isinstance(provider, QgsRasterDataProvider)
            extent = self.canvas().extent()
            assert isinstance(extent, QgsRectangle)

            # reproject extent if needed
            canvasCrs = self.canvas().mapSettings().destinationCrs()
            layerCrs = layer.crs()
            if canvasCrs != layerCrs:
                tr = QgsCoordinateTransform(canvasCrs, layerCrs, QgsProject.instance())
                extent = tr.transform(extent)

            # read data
            size = self.canvas().size()
            block = provider.block(band, extent, size.width(), size.height())
            assert isinstance(block, QgsRasterBlock)
            array = np.frombuffer(np.array(np.array(block.data())),
                                  dtype=qgisDataTypeToNumpyDataType(block.dataType()))

            # calculate mask
            mask = np.full_like(array, fill_value=True, dtype=np.bool)
            noDataValues = [obj.min() for obj in provider.userNoDataValues(band)]
            if provider.sourceHasNoDataValue(band) and provider.useSourceNoDataValue(band):
                noDataValues.append(provider.sourceNoDataValue(band))

            for noDataValue in noDataValues:
                mask[array == noDataValue] = False

        return array, mask

class RasterDataPlottingUi(QgsDockWidget, loadUIFormClass(pathUi=join(join(dirname(__file__)), 'dockwidget.ui'))):

    def __init__(self, parent=None):
        QgsDockWidget.__init__(self, parent)
        self.setupUi(self)
        self._imageView = ImageView()
        self.verticalLayout.addWidget(self._imageView)

    def layerX(self):
        assert isinstance(self._layerX, QgsMapLayerComboBox)
        return self._layerX

    def layerY(self):
        assert isinstance(self._layerY, QgsMapLayerComboBox)
        return self._layerY

    def bandXRaster(self):
        assert isinstance(self._bandXRaster, QgsRasterBandComboBox)
        return self._bandXRaster

    def bandYRaster(self):
        assert isinstance(self._bandYRaster, QgsRasterBandComboBox)
        return self._bandYRaster

    def bandXRenderer(self):
        assert isinstance(self._bandXRenderer, QComboBox)
        return self._bandXRenderer

    def bandYRenderer(self):
        assert isinstance(self._bandYRenderer, QComboBox)
        return self._bandYRenderer

    def bandMode(self):
        assert isinstance(self._bandMode, QComboBox)
        return self._bandMode

    def binsX(self):
        assert isinstance(self._binsX, QSpinBox)
        return self._binsX

    def binsY(self):
        assert isinstance(self._binsY, QSpinBox)
        return self._binsY

    def bandsX(self):
        assert isinstance(self._bandsX, QSpinBox)
        return self._bandsX

    def imageView(self):
        assert isinstance(self._imageView, ImageView)
        return self._imageView

    def plotType(self):
        assert isinstance(self._plotType, QComboBox)
        return self._plotType


class ImageView(pg.ImageView):

    def __init__(self, *args, **kwargs):
        self.plotItem_ = pg.PlotItem()
        pg.ImageView.__init__(self, *args, view=self.plotItem_, **kwargs)
        self.plotItem().setAspectLocked(lock=False)

    def setHistogramVisibility(self, visible):

        self.ui.roiBtn.setVisible(False)
        self.ui.menuBtn.setVisible(False)
        self.ui.histogram.setVisible(visible)

    def plotItem(self):
        assert isinstance(self.plotItem_, pg.PlotItem)
        return self.plotItem_
