import traceback
import webbrowser
import tempfile
import threading, subprocess
import numpy as np
from os.path import dirname, basename, join, splitext, exists
from os import makedirs, chdir
from datetime import date as Date, datetime
from qgis.core import *
from qgis.gui import *
from PyQt5.QtCore import *
from PyQt5.QtWidgets import *
from PyQt5.QtGui import *
from osgeo import gdal

from .ui import RasterTimeseriesManagerUi

class RasterTimeseriesManagerController(object):

    pluginName = 'RasterTimeseriesManager'

    class Setting(object):
        FFMPEG_BIN = 'ffmpeg_bin'
        IMAGE_MAGICK_BIN = 'imageMagick_bin'

    def __init__(self, iface):

###
#        ms = iface.mapCanvas().mapSettings()

#        context = QgsRenderContext.fromMapSettings(ms)
#        annotation.render(context)

###
        self.iface = iface
        self.ui = RasterTimeseriesManagerUi()

        self.widgetsToBeEnabled = [self.ui.tab(), self.ui.layerSplit()]

        # init widgets
        self.ui.status().hide()
        self.ui.fileFfmpeg().setFilePath(self.setting(key=self.Setting.FFMPEG_BIN, default=''))
        self.ui.fileImageMagick().setFilePath(self.setting(key=self.Setting.IMAGE_MAGICK_BIN, default=''))
        self.ui.fileImageMagick().setFilePath(self.setting(key=self.Setting.IMAGE_MAGICK_BIN, default=''))
        self.ui.fileTimeseries().setFilePath(join(tempfile.gettempdir(), self.pluginName, 'timeseries.vrt'))

        # todo: ts creation has issues, hide it for now
        self.ui._tab.removeTab(2)


        # connect signals
        self.ui.date().dateChanged.connect(self.setDate)
        self.ui.slider().valueChanged.connect(self.setNumber)
        self.ui.play().toggled.connect(self.onPlayToggled)
        self.ui.layer().layerChanged.connect(self.setTimeseries)
        self.ui.next().clicked.connect(lambda *args: self.setDate(self.nextDate()))
        self.ui.previous().clicked.connect(lambda *args: self.setDate(self.nextDate(reversed=True)))
        self.ui.layerSplit().clicked.connect(self.onLayerSplitClicked)
        self.ui.layerMerge().clicked.connect(self.onLayerMergeClicked)
        self.ui.saveFrames().clicked.connect(self.onSaveFrames)
        self.ui.saveMp4().clicked.connect(self.onSaveMp4)
        self.ui.saveGif().clicked.connect(self.onSaveGif)
        self.ui.resetRange().clicked.connect(self.resetRange)
        self.ui.openFolder().clicked.connect(lambda *args: webbrowser.open(self.ui.saveFolder().filePath()))


        for w in self.widgetsToBeEnabled:
            w.setEnabled(False)

        self.date = QDate()
        self.index = None
        self.names = None
        self.dates = None
        self._layer = None
        self.dateCache = dict()
        self.animationFrameLength = 250

        # add to plugins toolbar
        if isinstance(self.iface, QgisInterface):
            self.action1 = QAction('Toggle visibility', self.iface.mainWindow())
            self.action1.triggered.connect(self.toggleUiVisibility)
            self.iface.addPluginToMenu('Raster Timeseries Manager', self.action1)
            self.action2 = QAction('Load test data', self.iface.mainWindow())
            self.action2.triggered.connect(self.loadTestdata)
            self.iface.addPluginToMenu('Raster Timeseries Manager', self.action2)

    def handleException(self, error):
        assert isinstance(error, Exception)
        traceback.print_exc()
        self.setStatus()
        self.ui.status().show()
        QTimer.singleShot(5000, lambda: self.ui.status().hide())

    def setStatus(self):
        self.ui.status().setText('Unexprected error: see Python Console [Ctrl+Alt+P] log for details. ')

    def setSetting(self, key, value):
        s = QgsSettings()
        s.setValue('{}/{}'.format(self.pluginName, key), value)

    def setting(self, key, default=None):
        s = QgsSettings()
        return s.value('{}/{}'.format(self.pluginName, key), default)

    def toggleUiVisibility(self):
        self.ui.setVisible(not self.ui.isVisible())

    def loadTestdata(self):
        layer = self.iface.addRasterLayer(join(dirname(__file__), '..', 'testdata', 'timeseries.bsq'))
        self.ui.layer().setLayer(layer)
        self.setTimeseries(layer=layer)

    def unload(self):
        """Unload the plugin"""
        self.iface.removeDockWidget(self.ui)
        self.iface.removePluginMenu('Raster Timeseries Manager', self.action1)

    def canvas(self):
        canvas = self.iface.mapCanvas()
        assert isinstance(canvas, QgsMapCanvas)
        return canvas

    def show(self):
        self.ui.show()

    def setTimeseries(self, layer):

        isRasterLayer = isinstance(layer, QgsRasterLayer)
        if isRasterLayer:
            self._layer = layer

            self.dates, self.names, self.numberOfObservations, self.numberOfBands = self.deriveInformation(layer=layer)
            self.ui.slider().setRange(1, self.numberOfObservations)



            self.ui.dateRangeStart().setDate(self.dates[0])
            self.ui.dateRangeEnd().setDate(self.dates[-1])
            self.setDate(date=self.dateCache.get(self._layer, self.dates[0]))
        else:
            self._layer = None

        for w in self.widgetsToBeEnabled:
            w.setEnabled(isRasterLayer)

    def timeseries(self):
        assert isinstance(self._layer, QgsRasterLayer)
        return self._layer

    def setNumber(self, number):
        self.setIndex(index=number-1)

    def setIndex(self, index):
        self.setDate(date=self.dates[index])

    def findIndex(self, date):
        return np.argmin([abs(date.daysTo(d)) for d in self.dates])

    def setDate(self, date):

        assert isinstance(date, QDate)

        if self._layer is None:
            return

        if date > self.ui.dateRangeEnd().date():
            date = self.ui.dateRangeStart().date()
        elif date < self.ui.dateRangeStart().date():
            date = self.ui.dateRangeEnd().date()

        self.date = date

        index = self.findIndex(date)

        # update the gui without emitting signals
        self.ui.date().blockSignals(True)
        self.ui.date().setDate(self.date)

        off = self.date.daysTo(self.dates[index])
        sign = '+' if off >= 0 else '-'
        self.ui.date().setDisplayFormat('yyyy-MM-dd ({}) {}{}'.format(index+1, sign, abs(off)))
        self.ui.date().blockSignals(False)

        self.ui.slider().blockSignals(True)
        self.ui.slider().setValue(index+1)
        self.ui.slider().blockSignals(False)

        if index != self.index: # only update if index changes

            self.index = index
            renderer = self._layer.renderer()
            bandNumber = lambda number: ((number - 1) % self.numberOfBands + 1) + (index * self.numberOfBands)
            if isinstance(renderer, QgsMultiBandColorRenderer):
                renderer.setRedBand(bandNumber(renderer.redBand()))
                renderer.setGreenBand(bandNumber(renderer.greenBand()))
                renderer.setBlueBand(bandNumber(renderer.blueBand()))
            elif isinstance(renderer, QgsPalettedRasterRenderer):
                #renderer.setBand(bandNumber(renderer.band()))
                assert NotImplementedError() # question posted https://gis.stackexchange.com/questions/315160/how-to-set-a-new-raster-band-to-a-qgspalettedrasterrenderer-object
            elif isinstance(renderer, QgsSingleBandGrayRenderer):
                renderer.setGrayBand(bandNumber(renderer.grayBand()))
            elif isinstance(renderer, QgsSingleBandPseudoColorRenderer):
                renderer.setBand(bandNumber(renderer.band()))
            elif isinstance(renderer, QgsHillshadeRenderer):
                renderer.setBand(bandNumber(renderer.band()))
            elif renderer is None:
                pass
            else:
                raise NotImplementedError('renderer not supported: {}'.format(renderer))

            self._layer.triggerRepaint()

    def nextDate(self, reversed=False):

        unit = self.ui.stepUnit().currentText()
        step = self.ui.stepSize().value()
        if reversed:
            step = -step

        if unit == 'days':
            date = self.date.addDays(step)
        elif unit == 'weeks':
            date = self.date.addDays(7*step)
        elif unit == 'months':
            date = self.date.addMonths(step)
        elif unit == 'years':
            date = self.date.addYears(step)
        elif unit == 'indices':
            index = self.findIndex(self.date) + step
            if index < 0:
                date = self.ui.dateRangeStart().date().addDays(-1)
            elif index >= len(self.dates):
                date = self.ui.dateRangeEnd().date().addDays(1)
            else:
                date = self.dates[index]
        else:
            raise Exception('unknown unit')
        return date

    def play(self):
        self.setDate(date=self.nextDate())
        self.canvas().waitWhileRendering()
        if self.ui.play().isChecked():
            QTimer.singleShot(self.animationFrameLength, self.play)

    def onSaveFrames(self):

        QApplication.setOverrideCursor(Qt.WaitCursor)
        folder = self.saveFrames()
        QApplication.restoreOverrideCursor()
        import webbrowser
        webbrowser.open(folder)

    def saveFrames(self):
        folder = self.ui.saveFolder().filePath()

        if not exists(folder):
            makedirs(folder)

        size = self.canvas().size()
        image = QImage(size, QImage.Format_RGB32)

        step = self.ui.stepSize().value()
        if step > 0:
            dateStart = self.ui.dateRangeStart().date()
            dateEnd = self.ui.dateRangeEnd().date()
        elif step < 0:
            dateEnd = self.ui.dateRangeStart().date()
            dateStart = self.ui.dateRangeEnd().date()
        else:
            return

        self.ui.setEnabled(False)
        self.ui.tab().setCurrentIndex(0)
        date = dateStart
        i = 0
        while True:

            if (step > 0 and date > dateEnd) or (step < 0 and date < dateStart):
                break

            self.setDate(date)
            self.canvas().waitWhileRendering()

            image.fill(QColor('white'))
            painter = QPainter(image)
            settings = self.canvas().mapSettings()
            job = QgsMapRendererCustomPainterJob(settings, painter)
            job.renderSynchronously()
            painter.end()

            filename = join(folder, 'frame{}.png'.format(str(i).zfill(10)))
            image.save(filename)

            date = self.nextDate()
            i += 1

        self.ui.setEnabled(True)
        return folder

    def onSaveMp4(self):

        QApplication.setOverrideCursor(Qt.WaitCursor)
        try:
            ffmpeg_bin = self.ui.fileFfmpeg().filePath()

            if not exists(str(ffmpeg_bin)) or not basename(ffmpeg_bin).startswith('ffmpeg'):
                msg = QMessageBox(parent=self.ui,
                                  text='Select the FFmpeg binary under settings.')
                msg.setWindowTitle('Wrong or missing FFmpeg binary.')
                msg.exec()
            else:
                self.setSetting(key=self.Setting.FFMPEG_BIN, value=ffmpeg_bin)


                filename = self.saveMp4(ffmpeg_bin=ffmpeg_bin)
                if filename is not None:
                    webbrowser.open(filename)
        except Exception as error:
            self.handleException(error)

        QApplication.restoreOverrideCursor()

    def saveMp4(self, fps=10, ffmpeg_bin="ffmpeg"):

        folder = self.ui.saveFolder().filePath()
        video = 'video.mp4'
        cmd = [ffmpeg_bin, '-r', str(fps), '-i', 'frame%10d.png', '-vcodec', 'libx264', '-y', '-an',
               video, '-vf', '"pad=ceil(iw/2)*2:ceil(ih/2)*2"']

        print(' '.join(cmd))

        chdir(folder)
        res = subprocess.call(cmd, stdin=subprocess.PIPE, shell=True)#, stdout=f, stderr=f)

        if res != 0:
            return None
        else:
            return join(folder, video)

    def onSaveGif(self):

        imageMagick_bin = self.ui.fileImageMagick().filePath()

        if not exists(str(imageMagick_bin)) or not (basename(imageMagick_bin).startswith('magick') or basename(imageMagick_bin).startswith('convert')):
            msg = QMessageBox(parent=self.ui,
                              text='Select the ImageMagick binary under settings.')
            msg.setWindowTitle('Wrong or missing ImageMagick binary.')
            msg.exec()
        else:
            self.setSetting(key=self.Setting.IMAGE_MAGICK_BIN, value=imageMagick_bin)

            QApplication.setOverrideCursor(Qt.WaitCursor)
            filename = self.saveGif(imageMagick_bin=imageMagick_bin)
            QApplication.restoreOverrideCursor()

            if filename is not None:
                webbrowser.open(filename)

    def saveGif(self, imageMagick_bin='magick'):

        folder = self.ui.saveFolder().filePath()
        video = join(folder, 'video.gif')
        pngs = join(folder, '*.png')
        cmd = [imageMagick_bin, '-delay', '20', pngs, '-loop', '0', video]
        print(' '.join(cmd))

        chdir(dirname(imageMagick_bin))
        res = subprocess.call(cmd, stdin=subprocess.PIPE, shell=True)#, stdout=f, stderr=f)

        if res != 0:
            return None
        else:
            return join(folder, video)

    def onDateChanged(self, date):
        self.setDate(date=date)

    def onDateNumberChanged(self, number):
        self.setDate(date=self.dates[number-1])

    def onPlayToggled(self, checked):
        if checked:
            self.ui.play().setText('Pause')
            self.play()
        else:
            self.ui.play().setText('Play')

    def onLayerSplitClicked(self):
        self.splitTimeseries(dirname=join('/vsimem/{}/splitted'.format(self.pluginName),
                                          str(np.random.randint(100000, 999999))))

    def splitTimeseries(self, dirname):
        if not exists(dirname):
            makedirs(dirname)

        tsFilename = self.timeseries().source()
        ds = gdal.Open(tsFilename)
        filenames = list()
        for i, name in enumerate(self.names):
            filename = join(dirname, '{}.vrt'.format(name))
            bandList = list(range(i + 1, self.numberOfObservations * self.numberOfBands + 1, self.numberOfBands))
            vrt = gdal.BuildVRT(destName=filename, srcDSOrSrcDSTab=ds, options=gdal.BuildVRTOptions(bandList=bandList))
            for i, bandNumber in enumerate(bandList):
                description = ds.GetRasterBand(bandNumber).GetDescription()
                vrt.GetRasterBand(i+1).SetDescription(description)
            filenames.append(filename)
            vrt = None

        layers = [self.iface.addRasterLayer(filename) for filename in reversed(filenames)]
        return list(reversed(layers))

    def onLayerMergeClicked(self):
        layers = QgsProject.instance().layerTreeRoot().checkedLayers()
        layers = [layer for layer in layers if isinstance(layer, QgsRasterLayer)]
        print(layers)
        if len(layers) > 1:
            layer = self.mergeTimeseries(filename=self.ui.fileTimeseries().filePath(),
                                         layers=layers)
            layer = self.iface.addRasterLayer(layer.source())
            self.ui.layer().setLayer(layer=layer)

    def mergeTimeseries(self, filename, layers):
        if not exists(dirname(filename)):
            makedirs(dirname(filename))

        dss = dict()
        infoss = dict()

        ds0 = None
        for layer in layers:

            ds = gdal.Open(layer.source())
            if ds0 is None: ds0 = ds

            if (ds0.RasterXSize != ds.RasterXSize or
                ds0.RasterYSize != ds.RasterYSize or
                ds0.RasterCount != ds.RasterCount):
                raise Exception('Can not merge timeseries, raster sizes do not match.')

            infoss[layer] = RasterTimeseriesManagerController.deriveInformation(layer=layer)
            dss[layer] = ds

        vrt = gdal.GetDriverByName('VRT').Create(filename, ds0.RasterXSize, ds0.RasterYSize,
                                                 ds0.RasterCount * len(layers), ds0.GetRasterBand(1).DataType)

        vrt.SetProjection(ds0.GetProjection())
        vrt.SetGeoTransform(ds0.GetGeoTransform())

        xml = '''<SimpleSource>
              <SourceFilename relativeToVRT="0">{}</SourceFilename>
              <SourceBand>{}</SourceBand>
            </SimpleSource>'''

        vrtBandNumber = 1
        for i in range(ds0.RasterCount):
            for layer in layers:
                vrt.GetRasterBand(vrtBandNumber).SetMetadataItem("source_0",
                                                                 xml.format(layer.source(), i+1),
                                                                 "new_vrt_sources")
                #description = ds.GetRasterBand(i+1).GetDescription()
                dates, names, numberOfObservations, numberOfBands = infoss[layer]
                date = dates[i]
                assert isinstance(date, QDate)
                sdate = '{}-{}-{}'.format(str(date.year()).zfill(4), str(date.month()).zfill(2), str(date.day()).zfill(2))
                description = '{} - {}'.format(sdate, names[i % numberOfBands])
                print(description)
                vrt.GetRasterBand(vrtBandNumber).SetDescription(description)
                vrtBandNumber += 1
        vrt = None

        layer = QgsRasterLayer(filename)
        assert layer.isValid()
        return layer

    def resetRange(self):
        self.ui.dateRangeStart().setDate(self.dates[0])
        self.ui.dateRangeEnd().setDate(self.dates[-1])

    def onLayerToMemoryClicked(self):
        pass

    @staticmethod
    def deriveInformation(layer):
        assert isinstance(layer, QgsRasterLayer)

        def deriveFromDescriptions(descriptions):
            dates = list()
            names = list()
            for description in descriptions:
                sep = ' - '
                i = description.find(sep)
                if i == -1:
                    return None
                name = description[i + len(sep):].strip()
                date = description[:i].strip()
                y, m, d = map(int, date.split('-'))
                date = QDate(y, m, d)
                dates.append(date)
                names.append(name)
            return dates, names

        def deriveFromFallback():
            date0 = QDate(2000, 1, 1)
            dates = [date0.addDays(i) for i in range(layer.bandCount())]
            names = [splitext(basename(layer.source()))[0]]
            return dates, names

        def deriveFromMetadata(metadata):

            def toArray(s, dtype=str):
                if s is None:
                    return None
                else:
                    return [dtype(v.strip()) for v in s.replace('{', '').replace('}', '').split(',')]

            if metadata.get('wavelength') is None:
                return None
            else:
                units = metadata.get('wavelength_units', 'decimal years')
                if units.lower() == 'decimal years':
                    dyears = toArray(metadata['wavelength'], float)
                    dates = [QDate(int(dy), 1, 1).addDays(round((dy - int(dy)) * 366) - 1) for dy in dyears]
                else:
                    raise Exception('unknown time unit: "wavelength units = {}"'.format(units))

            if metadata.get('band_names') is None:
                raise Exception('use "band names" item in "ENVI" domain to specify the band names')
            else:
                names = toArray(metadata['band_names'])
            return dates, names

        ds = gdal.Open(layer.source())
        info = deriveFromMetadata(ds.GetMetadata('ENVI'))
        if info is None:
            info = deriveFromDescriptions([ds.GetRasterBand(i + 1).GetDescription() for i in range(ds.RasterCount)])
            if info is None:
                info = deriveFromFallback()
        dates_, names_ = info


        names = [names_[0]]
        for name in names_[1:]:
            if name == names_[0]:
                break
            else:
                names.append(name)

        numberOfBands = len(names)
        numberOfObservations = int(ds.RasterCount / numberOfBands)
        dates = dates_[::numberOfBands]

        for date in dates:
            assert isinstance(date, QDate)
        return dates, names, numberOfObservations, numberOfBands

gdal.TranslateOptions()