# -*- coding: utf-8 -*-
"""
/***************************************************************************
 Xyt
                                 A QGIS plugin
 Display the temporal dimension of geo data
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2020-01-21
        git sha              : $Format:%H$
        copyright            : (C) 2020 by Edouard Klein
        email                : edouardklein@gmail.com
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the Affero GNU General Public License as published by  *
 *   the Free Software Foundation; either version 3 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""
from qgis.PyQt.QtCore import (QSettings, QTranslator, QCoreApplication, Qt,
                              pyqtSignal, QObject)
from qgis.PyQt.QtGui import QIcon
from qgis.PyQt.QtWidgets import QAction
# Initialize Qt resources from file resources.py
from .resources import *

# Import the code for the DockWidget
from .xyt_dockwidget import XytDockWidget
import os.path
import matplotlib
from matplotlib import ticker
import matplotlib.pyplot as plt
import seaborn as sns
import dateutil
import datetime
import numpy as np
from qgis.core import (QgsMessageLog, Qgis, QgsMapLayer,
                       QgsExpressionContextUtils, QgsProject,
                       QgsVectorLayer)


def info(msg):
    QgsMessageLog.logMessage(msg, "xyt", Qgis.Info)


class QGISPlugin:
    """QGIS Plugin Implementation."""

    def __init__(self, iface):
        """Constructor.

        :param iface: An interface instance that will be passed to this class
            which provides the hook by which you can manipulate the QGIS
            application at run time.
        :type iface: QgsInterface
        """
        # Save reference to the QGIS interface
        self.iface = iface

        # initialize plugin directory
        self.plugin_dir = os.path.dirname(__file__)

        # initialize locale
        locale = QSettings().value('locale/userLocale')[0:2]
        locale_path = os.path.join(
            self.plugin_dir,
            'i18n',
            'Xyt_{}.qm'.format(locale))

        if os.path.exists(locale_path):
            self.translator = QTranslator()
            self.translator.load(locale_path)
            QCoreApplication.installTranslator(self.translator)

        # Declare instance attributes
        self.actions = []
        self.menu = self.tr(u'&xyt')
        # TODO: We are going to let the user set this up in a future iteration
        self.toolbar = self.iface.addToolBar(u'Xyt')
        self.toolbar.setObjectName(u'Xyt')

        self.pluginIsActive = False
        self.dockwidget = None

    # noinspection PyMethodMayBeStatic
    def tr(self, message):
        """Get the translation for a string using Qt translation API.

        We implement this ourselves since we do not inherit QObject.

        :param message: String for translation.
        :type message: str, QString

        :returns: Translated version of message.
        :rtype: QString
        """
        # noinspection PyTypeChecker,PyArgumentList,PyCallByClass
        return QCoreApplication.translate('Xyt', message)

    def add_action(
            self,
            icon_path,
            text,
            callback,
            enabled_flag=True,
            add_to_menu=True,
            add_to_toolbar=True,
            status_tip=None,
            whats_this=None,
            parent=None):
        """Add a toolbar icon to the toolbar.

        :param icon_path: Path to the icon for this action. Can be a resource
            path (e.g. ':/plugins/foo/bar.png') or a normal file system path.
        :type icon_path: str

        :param text: Text that should be shown in menu items for this action.
        :type text: str

        :param callback: Function to be called when the action is triggered.
        :type callback: function

        :param enabled_flag: A flag indicating if the action should be enabled
            by default. Defaults to True.
        :type enabled_flag: bool

        :param add_to_menu: Flag indicating whether the action should also
            be added to the menu. Defaults to True.
        :type add_to_menu: bool

        :param add_to_toolbar: Flag indicating whether the action should also
            be added to the toolbar. Defaults to True.
        :type add_to_toolbar: bool

        :param status_tip: Optional text to show in a popup when mouse pointer
            hovers over the action.
        :type status_tip: str

        :param parent: Parent widget for the new action. Defaults None.
        :type parent: QWidget

        :param whats_this: Optional text to show in the status bar when the
            mouse pointer hovers over the action.

        :returns: The action that was created. Note that the action is also
            added to self.actions list.
        :rtype: QAction
        """

        icon = QIcon(icon_path)
        action = QAction(icon, text, parent)
        action.triggered.connect(callback)
        action.setEnabled(enabled_flag)

        if status_tip is not None:
            action.setStatusTip(status_tip)

        if whats_this is not None:
            action.setWhatsThis(whats_this)

        if add_to_toolbar:
            self.toolbar.addAction(action)

        if add_to_menu:
            self.iface.addPluginToMenu(
                self.menu,
                action)

        self.actions.append(action)

        return action

    def initGui(self):
        """Create the menu entries and toolbar icons inside the QGIS GUI."""

        icon_path = ':/plugins/xyt/icon.png'
        self.add_action(
            icon_path,
            text=self.tr(u'xyt'),
            callback=self.run,
            parent=self.iface.mainWindow())

    def onClosePlugin(self):
        """Cleanup necessary items here when plugin dockwidget is closed"""

        # disconnects
        self.dockwidget.closingPlugin.disconnect(self.onClosePlugin)

        # remove this statement if dockwidget is to remain
        # for reuse if plugin is reopened
        self.dockwidget = None

        self.pluginIsActive = False

    def unload(self):
        """Removes the plugin menu item and icon from QGIS GUI."""

        for action in self.actions:
            self.iface.removePluginMenu(
                self.tr(u'&xyt'),
                action)
            self.iface.removeToolBarIcon(action)
        # remove the toolbar
        del self.toolbar

    def run(self):
        """Run method that loads and starts the plugin"""

        if not self.pluginIsActive:
            self.pluginIsActive = True

            # dockwidget may not exist if:
            #    first run of plugin
            #    removed on close (see self.onClosePlugin method)
            if self.dockwidget is None:
                # Create the dockwidget (after translation) and keep reference
                self.dockwidget = XytDockWidget()

            # connect to provide cleanup on closing of dockwidget
            self.dockwidget.closingPlugin.connect(self.onClosePlugin)

            # show the dockwidget
            # TODO: fix to allow choice of dock location
            self.iface.addDockWidget(Qt.BottomDockWidgetArea, self.dockwidget)
            self.dockwidget.show()


def str2epoch(s):
    """Return the number of second since the epoch of the datetime in
    the given string"""
    return dateutil.parser.parse(s).timestamp()


def epoch2str(n):
    """Return a Iso formatted string of the date"""
    return datetime.datetime.fromtimestamp(n).isoformat()


class Xyt(QGISPlugin, QObject):

    redraw = pyqtSignal()

    def __init__(self, iface):
        QGISPlugin.__init__(self, iface)
        QObject.__init__(self)
        self.layers = []  # The layers to display
        self.attributes = {}  # The column to display for each layer
        # self.select_tool = None  # Manages click and drag events on the graph
        self.fig = plt.figure()
        self.ax = None
        self.timefilter = TimeProximityFilter(self)
        # self.canvas = None
        self.lock = False  # To avoid unwanted redraws
        self.dockwidget = None  # Our UI object

        iface.layerTreeView().currentLayerChanged.connect(
            lambda layer: self.change_layers(layer),
            type=Qt.QueuedConnection)  # If not QueuedConnection,
        # The slot is called BEFORE the selection actually changes
        self.redraw.connect(self.graph_painter,
                            type=Qt.QueuedConnection)

        sns.set_style('dark')

    def change_layers(self, layer):
        """Update self's state to the newly selected layers"""
        self.layers = list(
            [l for l in self.iface.layerTreeView().selectedLayers()
             if l.type() == QgsMapLayer.VectorLayer])
        self.redraw.emit()

    def run(self):
        QGISPlugin.run(self)
        self.dockwidget.horizontalLayout.addWidget(self.fig.canvas)
        self.graph_painter()
        self.fig.canvas.draw()

    def graph_painter(self):
        """Manage the display of the matplotlib graph"""
        info("graph_painter is called")
        info(f"self.layers is {[l.name() for l in self.layers]}")
        self.iface.messageBar().pushMessage(
            f'Xyt processing layers  {[l.name() for l in self.layers]}',
            duration=2)
        if self.lock or self.dockwidget is None:
            return
        # Extract and convert the data
        data = {l.name(): {
            "dates": self.layer_features_as_epoch(l),
            "color": l.renderer().symbol().color().name()}
                 for l in self.layers}
        # Draw the data
        self.fig.clf()
        self.ax = self.fig.add_axes([0, 0.15, 1, 0.85])
        self.ax.xaxis.set_major_formatter(ticker.FuncFormatter(
            lambda x, pos: epoch2str(x)))
        for k, v in data.items():
            ax = self.ax.twinx()
            sns.distplot(ax=ax,
                         a=list(v['dates'].values()),
                         rug=True,
                         kde=False,
                         color=v['color'],
                         label=k)
        if len(data) == 2:
            self.timefilter.draw(data)
        self.fig.canvas.draw()

    def layer_features_as_epoch(self, l):
        """Return a id<->epoch dict of the given layer's features"""
        try:
            i = int(
                QgsExpressionContextUtils.layerScope(l).variable('date_col'))
            # Users can set which col is the date col in the layer's properties
        except TypeError:  # No such var, default to GeoFADET's DateISO
            i = 26
        return {f.id(): str2epoch(f[i])
                for f in l.getFeatures()}


class TimeProximityFilter(QObject):

    redraw = pyqtSignal()

    def __init__(self, xyt):
        QObject.__init__(self)
        self.xyt = xyt
        self.xyt.fig.canvas.mpl_connect('scroll_event', self.onwheel)
        self.xyt.fig.canvas.mpl_connect('button_release_event', self.onclick)
        self.temporal_resolution = '1d'
        self.threshold_line = None
        self.threshold_ax = None
        self.threshold_text = None
        self.redraw.connect(self.draw,
                            type=Qt.QueuedConnection)

    def draw(self, *args):
        if len(args) == 1:
            # Called by xyt
            self.layers = args[0]
            assert len(self.layers) == 2, "We need exactly 2 layers"
            keys = list(self.layers.keys())
            first_layer_ordered_id_epochs = sorted(
                [(i, epoch)
                 for i, epoch in self.layers[keys[0]]['dates'].items()],
                key=lambda ie: ie[1])
            first_layer_ordered_ids = \
                [i for i, _ in first_layer_ordered_id_epochs]
            first_layer_ordered_epochs = np.array(
                [epoch for _, epoch in first_layer_ordered_id_epochs])
            second_layer_ordered_id_epochs = sorted(
                [(i, epoch)
                 for i, epoch in self.layers[keys[1]]['dates'].items()],
                key=lambda ie: ie[1])
            second_layer_ordered_ids = \
                [i for i, _ in second_layer_ordered_id_epochs]
            second_layer_ordered_epochs = np.array(
                [epoch for _, epoch in second_layer_ordered_id_epochs])
            temporal_differences = abs(
                first_layer_ordered_epochs.reshape(-1, 1) -
                second_layer_ordered_epochs.reshape(1, -1))
            self.layers[keys[0]]['temporal differences'] = \
                {i: td for i, td in zip(first_layer_ordered_ids,
                                        temporal_differences.min(axis=1))}
            self.layers[keys[1]]['temporal differences'] = \
                {i: td for i, td in zip(second_layer_ordered_ids,
                                        temporal_differences.min(axis=0))}

            self.threshold_ax = self.xyt.ax.twinx()
            self.xlim = self.threshold_ax.get_xlim()
            for k in keys:
                dates_time_diffs = [(self.layers[k]['dates'][i],
                                     self.layers[k]['temporal differences'][i])
                                    for i in self.layers[k]['dates'].keys()]
                sns.lineplot(ax=self.threshold_ax,
                             color=self.layers[k]['color'],
                             x=[d for d, _ in dates_time_diffs],
                             y=[td for _, td in dates_time_diffs])

        if self.threshold_line is not None:
            # Remove old line
            self.threshold_line.pop(0).remove()
        if self.threshold_text is not None:
            # Remove old text
            self.threshold_text.remove()
        self.threshold_line = self.threshold_ax.plot(
            self.xlim,
            [self.resolution_as_seconds()]*2,
            "red")
        self.threshold_text = self.threshold_ax.text(
            self.xlim[0],
            self.resolution_as_seconds(),
            self.temporal_resolution,
            color="red")
        self.xyt.fig.canvas.draw()

    def onwheel(self, event):
        """Change the temporal resolution when one wheels on the graph"""
        if event.step < 0:
            self.increase_resolution()
        else:  # step > 0
            self.decrease_resolution()
        self.redraw.emit()
        info(f"ON WHEEL, resolution is now: {self.temporal_resolution}"
             f" ({self.resolution_as_seconds()}s)")

    def onclick(self, event):
        """Extract temporally filtered data in new layers"""
        info(f"ON CLICK")
        if event.button != matplotlib.backend_bases.MouseButton.MIDDLE:
            self.xyt.iface.messageBar().pushMessage(
                f'Click but no middle click',
                duration=2, level=1)
            return
        info(f"ON MIDDLE CLICK")
        self.xyt.iface.messageBar().pushMessage(
            f'Extracting overlapping layers',
            duration=2)
        self.create_layers()

    def create_layers(self):
        """Actually create the layers"""
        self.xyt.iface.statusBarIface().showMessage(
            f"Extracting layers: start")
        new_layers = {}
        for i, lname in enumerate(self.layers):
            self.xyt.iface.statusBarIface().showMessage(
                f"Extracting layers: Work on layer {i}/2: {lname}")
            indices = []
            accumulating = False
            cnt = 0
            total = len(self.layers[lname]['temporal differences'])
            for j, (k, d) in enumerate(
                    self.layers[lname]['temporal differences'].items()):
                if j % 100 == 0:
                    self.xyt.iface.statusBarIface().showMessage(
                        f"Extracting layers: Work on layer {i+1}/2: {lname}"
                        f"{j/total*100:2.1f}%")
                if d <= self.resolution_as_seconds():
                    accumulating = True
                    indices.append(k)
                elif accumulating:
                    # Accumulating but temporal difference has become higher
                    # than threshold
                    # Stopping our layer there
                    new_layers[f'{lname}_{cnt}'] = {"name": lname,
                                                    "indices": indices}
                    self.xyt.iface.messageBar().pushMessage(
                        f'Will create layer {lname}_{cnt} with'
                        f' {len(indices)} points')
                    indices = []
                    cnt += 1
                    accumulating = False
        for i, new_lname in enumerate(new_layers):
            self.xyt.iface.statusBarIface().showMessage(
                f"creating layer {i+1}/{len(new_layers)}: {new_lname}")
            old_layer, new_layer = self.duplicate_and_filter_layer(
                new_layers[new_lname]['name'],
                new_lname,
                new_layers[new_lname]['indices'])

    def duplicate_and_filter_layer(self, old_lname, new_lname, feat_ids):
        """Create a new layer with the same structure as the old one,
        duplicating only the specified features.
        """
        # FIXME: This assume different names for all layers
        layer = QgsProject.instance().mapLayersByName(old_lname)[0]

        layerGeometryType = ['Point', 'Line', 'Polygon'][layer.geometryType()]
        layerCRS = layer.crs().authid()
        mem_layer = QgsVectorLayer(layerGeometryType + '?crs='+layerCRS,
                                   new_lname,
                                   'memory')
        old_color = layer.renderer().symbol().color()
        mem_layer.renderer().symbol().setColor(old_color)

        feats = [feat for feat in layer.getFeatures() if feat.id() in feat_ids]
        mem_layer_data = mem_layer.dataProvider()
        attr = layer.dataProvider().fields().toList()
        mem_layer_data.addAttributes(attr)
        mem_layer.updateFields()
        mem_layer_data.addFeatures(feats)
        QgsProject.instance().addMapLayer(mem_layer)
        return layer, mem_layer

    def decrease_resolution(self):
        """Decrease the temporal resolution one step"""
        n, t = int(self.temporal_resolution[:-1]), self.temporal_resolution[-1]
        if n == 1:  # We must change the unit
            if t == 'd':
                t = 'h'
                n = 23
            elif t == 'h':
                t = 'm'
                n = 59
        else:
            n -= 1
        self.temporal_resolution = f'{n}{t}'

    def increase_resolution(self):
        """Increase the temporal resolution one step"""
        n, t = int(self.temporal_resolution[:-1]), self.temporal_resolution[-1]
        if n == 23 and t == 'h':
            n = 1
            t = 'd'
        elif n == 59 and t == 'm':
            n = 1
            t = 'h'
        else:
            n += 1
        self.temporal_resolution = f'{n}{t}'

    def resolution_as_seconds(self):
        """Return the number of seconds corresponding to the temporal
        resolution"""
        n, t = int(self.temporal_resolution[:-1]), self.temporal_resolution[-1]
        return n*(60 if t == 'm' else 60*60 if t == 'h' else 60*60*24)
