import numpy as np
from os.path import basename, splitext
from qgis.core import QgsRasterLayer, QgsMultiBandColorRenderer, QgsPalettedRasterRenderer, QgsSingleBandGrayRenderer, QgsSingleBandPseudoColorRenderer, QgsHillshadeRenderer
from qgis.PyQt.QtCore import QDate
from osgeo import gdal

class RtmRasterTimeseries(object):

    def __init__(self, layer):

        if isinstance(layer, QgsRasterLayer):
            self._layer = layer
            self._dates, self._bands, self._numberOfObservations, self._numberOfBands = self._deriveInformation(layer=layer)
            self._currentIndex = -1
            self._valid = True
        else:
            self._layer = None
            self._dates = self._bands = self._numberOfObservations = self._numberOfBands = None
            self._currentIndex = None
            self._valid = False

    def layer(self):
        return self._layer

    def dates(self):
        return self._dates

    def bands(self):
        return self._bands

    def numberOfBands(self):
        return self._numberOfBands

    def numberOfObservations(self):
        return self._numberOfObservations


    def isValid(self):
        return self._valid

    @staticmethod
    def _deriveInformation(layer):
        assert isinstance(layer, QgsRasterLayer)

        def deriveFromDescriptions(descriptions):
            dates = list()
            names = list()
            for description in descriptions:
                sep = ' - '
                i = description.find(sep)
                if i == -1:
                    return None
                name = description[i + len(sep):].strip()
                date = description[:i].strip()
                y, m, d = map(int, date.split('-'))
                date = QDate(y, m, d)
                dates.append(date)
                names.append(name)
            return dates, names

        def deriveFromFallback():
            date0 = QDate(2000, 1, 1)
            dates = [date0.addDays(i) for i in range(layer.bandCount())]
            names = [splitext(basename(layer.source()))[0]]
            return dates, names

        def deriveFromMetadata(metadata):

            def toArray(s, dtype=str):
                if s is None:
                    return None
                else:
                    return [dtype(v.strip()) for v in s.replace('{', '').replace('}', '').split(',')]

            if metadata.get('wavelength') is None:
                return None
            else:
                units = metadata.get('wavelength_units', 'decimal years')
                if units.lower() == 'decimal years':
                    dyears = toArray(metadata['wavelength'], float)
                    dates = [QDate(int(dy), 1, 1).addDays(round((dy - int(dy)) * 366) - 1) for dy in dyears]
                else:
                    return None

            if metadata.get('band_names') is None:
                return None
            else:
                names = toArray(metadata['band_names'])
            return dates, names

        ds = gdal.Open(layer.source())
        info = deriveFromMetadata(ds.GetMetadata('ENVI'))
        if info is None:
            info = deriveFromDescriptions([ds.GetRasterBand(i + 1).GetDescription() for i in range(ds.RasterCount)])
            if info is None:
                info = deriveFromFallback()
        dates_, names_ = info


        names = [names_[0]]
        for name in names_[1:]:
            if name == names_[0]:
                break
            else:
                names.append(name)

        numberOfBands = len(names)
        numberOfObservations = int(ds.RasterCount / numberOfBands)
        dates = dates_[::numberOfBands]

        for date in dates:
            assert isinstance(date, QDate)
        return dates, names, numberOfObservations, numberOfBands

    def findDateIndex(self, date, snap):
        if self.isValid():
            if snap == 'nearest':
                dist = [abs(date.daysTo(d)) for d in self._dates]
            elif snap == 'next':
                dist = np.array([date.daysTo(d) for d in self._dates], dtype=np.float32)
                dist[dist < 0] = np.inf
            elif snap == 'previous':
                dist = np.array([-date.daysTo(d) for d in self._dates], dtype=np.float32)
                dist[dist < 0] = np.inf
            else:
                raise Exception('unknown mode: {}'.format(snap))
            index = np.argmin(dist)
        else:
            index = -1

        return index

    def setDateIndex(self, index):
        renderer = self._layer.renderer()
        bandNumber = lambda number: ((number - 1) % self._numberOfBands + 1) + (index * self._numberOfBands)
        if isinstance(renderer, QgsMultiBandColorRenderer):
            renderer.setRedBand(bandNumber(renderer.redBand()))
            renderer.setGreenBand(bandNumber(renderer.greenBand()))
            renderer.setBlueBand(bandNumber(renderer.blueBand()))
        elif isinstance(renderer, QgsPalettedRasterRenderer):
            # renderer.setBand(bandNumber(renderer.band()))
            assert NotImplementedError()  # question posted https://gis.stackexchange.com/questions/315160/how-to-set-a-new-raster-band-to-a-qgspalettedrasterrenderer-object
        elif isinstance(renderer, QgsSingleBandGrayRenderer):
            renderer.setGrayBand(bandNumber(renderer.grayBand()))
        elif isinstance(renderer, QgsSingleBandPseudoColorRenderer):
            renderer.setBand(bandNumber(renderer.band()))
        elif isinstance(renderer, QgsHillshadeRenderer):
            renderer.setBand(bandNumber(renderer.band()))
        elif renderer is None:
            pass
        else:
            raise NotImplementedError('renderer not supported: {}'.format(renderer))
