from os.path import dirname, basename, join
from qgis.core import *
from qgis.gui import QgisInterface, QgsMapCanvas
from PyQt5.QtCore import QCoreApplication, QTimer
from PyQt5.QtWidgets import QAction
from osgeo import gdal

from .ui import RasterTimeseriesManagerUi

class RasterTimeseriesManagerController(object):

    def __init__(self, iface):
        assert isinstance(iface, QgisInterface)
        self.iface = iface
        self.ui = RasterTimeseriesManagerUi()

        # connect signals
        self.ui.date().valueChanged.connect(self.onDateChanged)
        self.ui.slider().valueChanged.connect(self.onDateChanged)
        self.ui.play().toggled.connect(self.onPlayToggled)
        self.ui.layer().layerChanged.connect(self.setTimeseries)
        self.ui.first().clicked.connect(lambda *args: self.setDate(1))
        self.ui.last().clicked.connect(lambda *args: self.setDate(self.numberOfObservations))
        self.ui.next().clicked.connect(lambda *args: self.setDate(self.index + 1 + 1))
        self.ui.previous().clicked.connect(lambda *args: self.setDate(self.index + 1 - 1))
        self.ui.groupTime().setEnabled(False)
        self.ui.groupRenderer().setVisible(False)
        self.index = 0
        self.names = None
        self.dates = None
        self._layer = None
        self.dateCache = dict()
        self.animationFrameLength = 250

        # add to plugins toolbar
        self.action1 = QAction('Toggle visibility', self.iface.mainWindow())
        self.action1.triggered.connect(self.toggleUiVisibility)
        self.iface.addPluginToMenu('Raster Timeseries Manager', self.action1)
        self.action2 = QAction('Load test data', self.iface.mainWindow())
        self.action2.triggered.connect(self.loadTestdata)
        self.iface.addPluginToMenu('Raster Timeseries Manager', self.action2)

    def toggleUiVisibility(self):
        self.ui.setVisible(not self.ui.isVisible())

    def loadTestdata(self):
        layer = self.iface.addRasterLayer(join(dirname(__file__), '..', 'testdata', 'timeseries.bsq'))
        self.ui.layer().setLayer(layer)
        #self.setTimeseries(layer=layer)

    def unload(self):
        """Unload the plugin"""
        self.iface.removeDockWidget(self.ui)
        self.iface.removePluginMenu('Raster Timeseries Manager', self.action1)

    def canvas(self):
        canvas = self.iface.mapCanvas()
        assert isinstance(canvas, QgsMapCanvas)
        return canvas

    def show(self):
        self.ui.show()

    def setTimeseries(self, layer):

        if isinstance(layer, QgsRasterLayer):

            def splitDescription(description):
                i = description.find(' - ')
                if i == -1:
                    raise Exception('invalid band description format "{}", expected: {}'.format(
                        description, '<timestamp> - <name>'))
                timestamp = description[:i].strip()
                name = description[i+1:].strip()
                return timestamp, name

            # extract names and dates from band descriptions
            ds = gdal.Open(layer.source())
            dates, names = zip(*[splitDescription(ds.GetRasterBand(i+1).GetDescription()) for i in range(ds.RasterCount)])
            try:
                self.names = names[:names.index(names[0], 1)]
            except ValueError:
                self.names = names
            self.numberOfBands = len(self.names)
            self.numberOfObservations = int(ds.RasterCount / self.numberOfBands)
            self.dates = dates[::self.numberOfBands]

#            renderer = layer.renderer()
#            if isinstance(renderer, QgsMultiBandColorRenderer):
#                self.redBand = (renderer.redBand() - 1) % self.numberOfBands + 1
#                self.greenBand = (renderer.greenBand() - 1) % self.numberOfBands + 1
#                self.blueBand = (renderer.blueBand() - 1) % self.numberOfBands + 1
#            else:
#                raise NotImplementedError('renderer not supported: {}'.format(renderer))

            # set spinbox min max
            self.ui.date().setRange(1, self.numberOfObservations)
            self.ui.slider().setRange(1, self.numberOfObservations)

            self.ui.groupTime().setEnabled(True)
            self.setDate(number=self.dateCache.get(self._layer, 1))
        else:
            layer = None
            self.ui.groupTime().setEnabled(False)

        self._layer = layer


    def timeseries(self):
        assert isinstance(self._layer, QgsRasterLayer)
        return self._layer

    def setDate(self, number):

        if self._layer is None:
            return

        if number > self.numberOfObservations:
            number = 1

        if number < 1:
            number = self.numberOfObservations - number

        self.index = number - 1

        renderer = self._layer.renderer()
        bandNumber = lambda number: ((number - 1) % self.numberOfBands + 1) + (self.index * self.numberOfBands)
        if isinstance(renderer, QgsMultiBandColorRenderer):
            renderer.setRedBand(bandNumber(renderer.redBand()))
            renderer.setGreenBand(bandNumber(renderer.greenBand()))
            renderer.setBlueBand(bandNumber(renderer.blueBand()))
        elif isinstance(renderer, QgsPalettedRasterRenderer):
            #renderer.setBand(bandNumber(renderer.band()))
            assert NotImplementedError() # question posted https://gis.stackexchange.com/questions/315160/how-to-set-a-new-raster-band-to-a-qgspalettedrasterrenderer-object
        elif isinstance(renderer, QgsSingleBandGrayRenderer):
            renderer.setGrayBand(bandNumber(renderer.grayBand()))
        elif isinstance(renderer, QgsSingleBandPseudoColorRenderer):
            renderer.setBand(bandNumber(renderer.band()))
        elif isinstance(renderer, QgsHillshadeRenderer):
            renderer.setBand(bandNumber(renderer.band()))
        else:
            raise NotImplementedError('renderer not supported: {}'.format(renderer))

        #self.timeseries().rendererChanged.emit()
        self.timeseries().styleChanged.emit()

        self.ui.date().setValue(number)
        self.ui.slider().setValue(number)
        self.ui.timestamp().setText(self.dates[number - 1])

        self._layer.triggerRepaint()

        # cache date
        self.dateCache[self._layer] = number


    def play(self):
        self.setDate(number=self.index+1)
        self.canvas().waitWhileRendering()
        self.index += 1
        if self.ui.play().isChecked():
            QTimer.singleShot(self.animationFrameLength, self.play)

    def onDateChanged(self, number):
        self.setDate(number=number)

    def onPlayToggled(self, checked):
        if checked:
            self.ui.play().setText('Pause')
            self.play()
        else:
            self.ui.play().setText('Play')

