# -*- coding: utf-8 -*-
"""
/***************************************************************************
 LDMP - A QGIS plugin
 This plugin supports monitoring and reporting of land degradation to the UNCCD 
 and in support of the SDG Land Degradation Neutrality (LDN) target.
                              -------------------
        begin                : 2017-05-23
        git sha              : $Format:%H$
        copyright            : (C) 2017 by Conservation International
        email                : trends.earth@conservation.org
 ***************************************************************************/
"""

import os
import json
import re
import copy
import base64

import datetime
from PyQt4 import QtGui
from PyQt4.QtCore import QSettings, QAbstractTableModel, Qt, pyqtSignal

from osgeo import gdal

from qgis.utils import iface
mb = iface.messageBar()

from qgis.gui import QgsMessageBar

from LDMP import __version__
from LDMP.gui.DlgJobs import Ui_DlgJobs
from LDMP.gui.DlgJobsDetails import Ui_DlgJobsDetails
from LDMP.plot import DlgPlotTimeries

from LDMP import log
from LDMP.api import get_user_email, get_execution
from LDMP.download import Download, check_hash_against_etag, DownloadError
from LDMP.layers import add_layer
from LDMP.schemas.schemas import LocalRaster, LocalRasterSchema


def json_serial(obj):
    """JSON serializer for objects not serializable by default json code"""
    if isinstance(obj, (datetime.datetime, datetime.date)):
        return obj.isoformat()
    raise TypeError("Type {} not serializable".format(type(obj)))


def create_gee_json_metadata(json_file, job, data_file):
    # Create a copy of the job so bands can be moved to a different place in 
    # the output
    metadata = copy.deepcopy(job)
    bands = metadata['results'].pop('bands')
    metadata.pop('raw')

    out = LocalRaster(os.path.basename(os.path.normpath(data_file)), bands, metadata)
    local_raster_schema = LocalRasterSchema()
    with open(json_file, 'w') as f:
        json.dump(local_raster_schema.dump(out), f, default=json_serial, 
                  sort_keys=True, indent=4, separators=(',', ': '))


class DlgJobsDetails(QtGui.QDialog, Ui_DlgJobsDetails):
    def __init__(self, parent=None):
        """Constructor."""
        super(DlgJobsDetails, self).__init__(parent)

        self.setupUi(self)


class DlgJobs(QtGui.QDialog, Ui_DlgJobs):
    # When a connection to the api starts, emit true. When it ends, emit False
    connectionEvent = pyqtSignal(bool)

    def __init__(self, parent=None):
        """Constructor."""
        super(DlgJobs, self).__init__(parent)

        self.settings = QSettings()

        self.setupUi(self)

        self.connection_in_progress = False

        # Set a variable used to record the necessary window width to view all
        # columns
        self._full_width = None

        self.bar = QgsMessageBar()
        self.bar.setSizePolicy(QtGui.QSizePolicy.Minimum, QtGui.QSizePolicy.Fixed)
        self.layout().addWidget(self.bar, 0, 0, Qt.AlignTop)

        self.refresh.clicked.connect(self.btn_refresh)
        self.download.clicked.connect(self.btn_download)

        self.connectionEvent.connect(self.connection_event_changed)

        # Only enable download button if a job is selected
        self.download.setEnabled(False)

    def showEvent(self, event):
        super(DlgJobs, self).showEvent(event)
        jobs_cache = self.settings.value("LDMP/jobs_cache", None)
        if jobs_cache:
            self.jobs = jobs_cache
            self.update_jobs_table()

        #######################################################################
        #######################################################################
        # Hack to download multiple countries at once for workshop preparation
        #######################################################################
        #######################################################################
        # from PyQt4.QtCore import QTimer, Qt
        # from PyQt4.QtGui import QMessageBox, QApplication
        # from PyQt4.QtTest import QTest
        # from time import sleep
        #
        # self.btn_refresh()
        #
        # # Ensure any message boxes that open are closed within 1 second
        # def close_msg_boxes():
        #     for w in QApplication.topLevelWidgets():
        #         if isinstance(w, QMessageBox):
        #             print('Closing message box')
        #             QTest.keyClick(w, Qt.Key_Enter)
        # timer = QTimer()
        # timer.timeout.connect(close_msg_boxes)
        # timer.start(1000)
        # for row in range(len(self.jobs)):
        #     if self.jobs[row]['status'] == 'FINISHED' and self.jobs[row]['results']['type'] == 'CloudResults':
        #         name = self.jobs[row]['task_name']
        #         country = name.replace('_All_Indicators_LPD', '')
        #         out_file = os.path.join(u'C:/Users/azvol/Desktop/All_Indicators_for_USB', country, u'{}.json'.format(name))
        #         if not os.path.exists(out_file):
        #             if not os.path.exists(os.path.dirname(out_file)):
        #                 os.makedirs(os.path.dirname(out_file))
        #             log(u'Downloading {} to {}'.format(name, out_file))
        #             download_cloud_results(self.jobs[row],
        #                                    os.path.splitext(out_file)[0],
        #                                    self.tr,
        #                                    add_to_map=False)
        #             sleep(2)
        #######################################################################
        #######################################################################
        # End hack
        #######################################################################
        #######################################################################

    def connection_event_changed(self, flag):
        if flag:
            self.connection_in_progress = True
            self.download.setEnabled(False)
            self.refresh.setEnabled(False)
        else:
            self.connection_in_progress = False
            # Enable the download button if there is a selection
            self.selection_changed()
            self.refresh.setEnabled(True)

    def resizeWindowToColumns(self):
        if not self._full_width:
            margins = self.layout().contentsMargins()
            self._full_width = margins.left() + margins.right() + \
                self.jobs_view.frameWidth() * 2 + \
                self.jobs_view.verticalHeader().width() + \
                self.jobs_view.horizontalHeader().length() + \
                self.jobs_view.style().pixelMetric(QtGui.QStyle.PM_ScrollBarExtent)
        self.resize(self._full_width, self.height())

    def selection_changed(self):
        if self.connection_in_progress:
            return
        elif not self.jobs_view.selectedIndexes():
            self.download.setEnabled(False)
        else:
            rows = list(set(index.row() for index in self.jobs_view.selectedIndexes()))
            if rows:
                for row in rows:
                    # Don't set button to enabled if any of the tasks aren't yet
                    # finished
                    if self.jobs[row]['status'] != 'FINISHED':
                        self.download.setEnabled(False)
                        return
                self.download.setEnabled(True)

    def btn_refresh(self):
        self.connectionEvent.emit(True)
        email = get_user_email()
        if email:
            start_date = datetime.datetime.now() + datetime.timedelta(-14)
            self.jobs = get_execution(date=start_date.strftime('%Y-%m-%d'))
            if self.jobs:
                # Add script names and descriptions to jobs list
                for job in self.jobs:
                    # self.jobs will have prettified data for usage in table,
                    # so save a backup of the original data under key 'raw'
                    job['raw'] = job.copy()
                    script = job.get('script_id', None)
                    if script:
                        job['script_name'] = job['script']['name']
                        # Clean up the script name so the version tag doesn't 
                        # look so odd
                        job['script_name'] = re.sub('([0-9]+)_([0-9]+)$', '(v\g<1>.\g<2>)', job['script_name'])
                        job['script_description'] = job['script']['description']
                    else:
                        # Handle case of scripts that have been removed or that are
                        # no longer supported
                        job['script_name'] = self.tr('Script not found')
                        job['script_description'] = self.tr('Script not found')

                # Pretty print dates and pull the metadata sent as input params
                for job in self.jobs:
                    job['start_date'] = datetime.datetime.strftime(job['start_date'], '%Y/%m/%d (%H:%M)')
                    job['end_date'] = datetime.datetime.strftime(job['end_date'], '%Y/%m/%d (%H:%M)')
                    job['task_name'] = job['params'].get('task_name', '')
                    job['task_notes'] = job['params'].get('task_notes', '')
                    job['params'] = job['params']

                # Cache jobs for later reuse
                self.settings.setValue("LDMP/jobs_cache", self.jobs)

                self.update_jobs_table()

                self.connectionEvent.emit(False)
                return True
        self.connectionEvent.emit(False)
        return False

    def update_jobs_table(self):
        if self.jobs:
            table_model = JobsTableModel(self.jobs, self)
            proxy_model = QtGui.QSortFilterProxyModel()
            proxy_model.setSourceModel(table_model)
            self.jobs_view.setModel(proxy_model)

            # Add "Notes" buttons in cell
            for row in range(0, len(self.jobs)):
                btn = QtGui.QPushButton(self.tr("Details"))
                btn.clicked.connect(self.btn_details)
                self.jobs_view.setIndexWidget(proxy_model.index(row, 5), btn)

            self.jobs_view.horizontalHeader().setResizeMode(QtGui.QHeaderView.ResizeToContents)
            #self.jobs_view.horizontalHeader().setResizeMode(QtGui.QHeaderView.ResizeToContents)
            self.jobs_view.setSelectionBehavior(QtGui.QAbstractItemView.SelectRows)
            self.jobs_view.selectionModel().selectionChanged.connect(self.selection_changed)

            #self.resizeWindowToColumns()

    def btn_details(self):
        button = self.sender()
        index = self.jobs_view.indexAt(button.pos())

        details_dlg = DlgJobsDetails(self)

        job = self.jobs[index.row()]

        details_dlg.task_name.setText(job.get('task_name', ''))
        details_dlg.task_status.setText(job.get('status', ''))
        details_dlg.comments.setText(job.get('task_notes', ''))
        details_dlg.input.setText(json.dumps(job.get('params', ''), indent=4, sort_keys=True))
        details_dlg.output.setText(json.dumps(job.get('results', ''), indent=4, sort_keys=True))

        details_dlg.show()
        details_dlg.exec_()

    def btn_download(self):
        rows = list(set(index.row() for index in self.jobs_view.selectedIndexes()))

        filenames = []
        for row in rows:
            job = self.jobs[row]
            # Check if we need a download filename - some tasks don't need to save
            # data, but if any of the chosen \tasks do, then we need to choose a
            # folder. Right now only TimeSeriesTable doesn't need a filename.
            if job['results'].get('type') != 'TimeSeriesTable':
                f = None
                while not f:
                    # Setup a string to use in filename window
                    if job['task_name']:
                        job_info = u'{} ({})'.format(job['script_name'], job['task_name'])
                    else:
                        job_info = job['script_name']
                    f = QtGui.QFileDialog.getSaveFileName(self,
                                                          self.tr(u'Choose a filename downloading results of: {}'.format(job_info)),
                                                          self.settings.value("LDMP/output_dir", None),
                                                          self.tr('Base filename (*.json)'))

                    # Strip the extension so that it is a basename
                    f = os.path.splitext(f)[0]

                    if f:
                        if os.access(os.path.dirname(f), os.W_OK):
                            self.settings.setValue("LDMP/output_dir", os.path.dirname(f))
                            log(u"Downloading results to {} with basename {}".format(os.path.dirname(f), os.path.basename(f)))
                        else:
                            QtGui.QMessageBox.critical(None, self.tr("Error"),
                                                       self.tr(u"Cannot write to {}. Choose a different base filename.".format(f)))
                    else:
                            return False

                filenames.append(f)
            else:
                filenames.append(None)

        self.close()

        for row, f in zip(rows, filenames):
            job = self.jobs[row]
            log(u"Processing job {}".format(job))
            result_type = job['results'].get('type')
            if result_type == 'CloudResults':
                download_cloud_results(job, f, self.tr)
            elif result_type == 'TimeSeriesTable':
                download_timeseries(job, self.tr)
            else:
                raise ValueError("Unrecognized result type in download results: {}".format(result_type))


class JobsTableModel(QAbstractTableModel):
    def __init__(self, datain, parent=None, *args):
        QAbstractTableModel.__init__(self, parent, *args)
        self.jobs = datain

        # Column names as tuples with json name in [0], pretty name in [1]
        # Note that the columns with json names set to to INVALID aren't loaded
        # into the shell, but shown from a widget.
        colname_tuples = [('task_name', QtGui.QApplication.translate('LDMPPlugin', 'Task name')),
                          ('script_name', QtGui.QApplication.translate('LDMPPlugin', 'Job')),
                          ('start_date', QtGui.QApplication.translate('LDMPPlugin', 'Start time')),
                          ('end_date', QtGui.QApplication.translate('LDMPPlugin', 'End time')),
                          ('status', QtGui.QApplication.translate('LDMPPlugin', 'Status')),
                          ('INVALID', QtGui.QApplication.translate('LDMPPlugin', 'Details'))]
        self.colnames_pretty = [x[1] for x in colname_tuples]
        self.colnames_json = [x[0] for x in colname_tuples]

    def rowCount(self, parent):
        return len(self.jobs)

    def columnCount(self, parent):
        return len(self.colnames_json)

    def data(self, index, role):
        if not index.isValid():
            return None
        elif role != Qt.DisplayRole:
            return None
        return self.jobs[index.row()].get(self.colnames_json[index.column()], '')

    def headerData(self, section, orientation, role=Qt.DisplayRole):
        if role == Qt.DisplayRole and orientation == Qt.Horizontal:
            return self.colnames_pretty[section]
        return QAbstractTableModel.headerData(self, section, orientation, role)


def download_result(url, out_file, job, expected_etag):
    worker = Download(url, out_file)
    worker.start()
    if worker.get_resp():
        return check_hash_against_etag(url, out_file, expected_etag)
    else:
        return None


def download_cloud_results(job, f, tr, add_to_map=True):
    results = job['results']
    json_file = f + '.json'
    if len(results['urls']) > 1:
        # Save a VRT if there are multiple files for this download
        urls = results['urls'] 
        tiles = []
        for n in xrange(len(urls)):
            tiles.append(f + '_{}.tif'.format(n))
            # If file already exists, check its hash and skip redownloading if 
            # it matches
            if os.access(tiles[n], os.R_OK):
                if check_hash_against_etag(urls[n]['url'], tiles[n], base64.b64decode(urls[n]['md5Hash']).encode('hex')):
                    continue
            resp = download_result(urls[n]['url'], tiles[n], job, 
                                   base64.b64decode(urls[n]['md5Hash']).encode('hex'))
            if not resp:
                return
        # Make a VRT mosaicing the tiles so they can be treated as one file 
        # during further processing
        out_file = f + '.vrt'
        gdal.BuildVRT(out_file, tiles)
    else:
        url = results['urls'][0]
        out_file = f + '.tif'
        resp = download_result(url['url'], out_file, job, 
                               base64.b64decode(url['md5Hash']).encode('hex'))
        if not resp:
            return

    create_gee_json_metadata(json_file, job, out_file)

    if add_to_map:
        for band_number in xrange(1, len(results['bands']) + 1):
            # The minus 1 is because band numbers start at 1, not zero
            band_info = results['bands'][band_number - 1]
            if band_info['add_to_map']:
                add_layer(out_file, band_number, band_info)

    mb.pushMessage(tr("Downloaded"),
                   tr(u"Downloaded results to {}".format(out_file)),
                   level=0, duration=5)


def download_timeseries(job, tr):
    log("processing timeseries results...")
    table = job['results'].get('table', None)
    if not table:
        return None
    data = [x for x in table if x['name'] == 'mean'][0]
    dlg_plot = DlgPlotTimeries()
    labels = {'title': job['task_name'],
              'bottom': tr('Time'),
              'left': [tr('Integrated NDVI'), tr('NDVI x 10000')]}
    dlg_plot.plot_data(data['time'], data['y'], labels)
    dlg_plot.show()
    dlg_plot.exec_()
