# -*- coding: utf-8 -*-

"""
***************************************************************************
    BandTableWidget.py
    ---------------------
    Date                 : 2021-09-14
    Copyright            : (C) 2021 by J. Pierson, UMR 6554 LETG, CNRS
    Email                : julie.pierson@univ-brest.fr
    Based upon           : ReliefColorsWidget.py (C) 2016 by Alexander Bruy
***************************************************************************
*                                                                         *
*   This program is free software; you can redistribute it and/or modify  *
*   it under the terms of the GNU General Public License as published by  *
*   the Free Software Foundation; either version 2 of the License, or     *
*   (at your option) any later version.                                   *
*                                                                         *
***************************************************************************
"""

__author__ = 'J. Pierson, UMR 6554 LETG, CNRS'
__date__ = '2021-09-14'
__copyright__ = '(C) 2021 by J. Pierson, UMR 6554 LETG, CNRS'

import os

from qgis.PyQt import uic
from qgis.PyQt.QtCore import pyqtSlot
from qgis.PyQt.QtWidgets import (QTreeWidgetItem,
                                 QMessageBox,
                                 QInputDialog,
                                 )
from qgis.core import (QgsApplication, 
                       QgsProject,
                       QgsMapLayer,
                       QgsMapLayerProxyModel)
from processing.gui.wrappers import WidgetWrapper
from osgeo import gdal

pluginPath = os.path.dirname(__file__)
WIDGET, BASE = uic.loadUiType(os.path.join(pluginPath, 'bandtablewidgetbase.ui'))


class BandTableWidget(BASE, WIDGET):

    def __init__(self):
        super(BandTableWidget, self).__init__(None)
        self.setupUi(self)
        
        self.cmbLayers.setFilters(QgsMapLayerProxyModel.RasterLayer)
        self.cmbLayers.layerChanged.connect(self.layerChanged)
        self.btnRemove.setIcon(QgsApplication.getThemeIcon('/symbologyRemove.svg'))
        self.btnUp.setIcon(QgsApplication.getThemeIcon('/mActionArrowUp.svg'))
        self.btnDown.setIcon(QgsApplication.getThemeIcon('/mActionArrowDown.svg'))
        self.layer = None
        
        # when alg is launched by user, display bands in table of 1st raster layer
        self.layerChanged()
        
        # this is how to display a message for debugging
        #QMessageBox.information(None, self.tr('window name'), self.tr('message'))
      
    # when user changes selected layer in combobox
    def layerChanged(self):
        # remove table content
        self._removeBandData()
        # get raster layer
        layer = self.cmbLayers.currentLayer()
        # if there is one 
        if layer != None:
            input_raster = gdal.Open(layer.source())
            # get number of bands in raster
            nbands = input_raster.RasterCount
            # for each band
            for band_number in range(nbands):
                band = input_raster.GetRasterBand(band_number + 1)
                # add band number to table
                bandnumber = str(band_number+1)
                bandname = band.GetDescription()
                self._addBandData(bandnumber, bandname)
    
    # add selected raster band number and names to table
    def _addBandData(self, bandnumber, bandname):
        item = QTreeWidgetItem()
        item.setText(0, bandnumber)
        item.setText(1, bandname)
        #item.setText(2, '')
        self.bandClassTree.addTopLevelItem(item)
    
    # remove all rows from table    
    def _removeBandData(self,):
        self.bandClassTree.clear()

    # remove selected line in table when remove button is clicked
    @pyqtSlot()
    def on_btnRemove_clicked(self):
        selectedItems = self.bandClassTree.selectedItems()
        for item in selectedItems:
            self.bandClassTree.invisibleRootItem().removeChild(item)
            item = None
    
    # move down selected line in table when down button is clicked
    @pyqtSlot()
    def on_btnDown_clicked(self):
        selectedItems = self.bandClassTree.selectedItems()
        for item in selectedItems:
            currentIndex = self.bandClassTree.indexOfTopLevelItem(item)
            if currentIndex < self.bandClassTree.topLevelItemCount() - 1:
                self.bandClassTree.takeTopLevelItem(currentIndex)
                self.bandClassTree.insertTopLevelItem(currentIndex + 1, item)
                self.bandClassTree.setCurrentItem(item)
    
    # move up selected line in table when up button is clicked
    @pyqtSlot()
    def on_btnUp_clicked(self):
        selectedItems = self.bandClassTree.selectedItems()
        for item in selectedItems:
            currentIndex = self.bandClassTree.indexOfTopLevelItem(item)
            if currentIndex > 0:
                self.bandClassTree.takeTopLevelItem(currentIndex)
                self.bandClassTree.insertTopLevelItem(currentIndex - 1, item)
                self.bandClassTree.setCurrentItem(item)

    # when a cell in 3rd column is clicked, open dialog box for entering value
    @pyqtSlot(QTreeWidgetItem, int)
    def on_bandClassTree_itemDoubleClicked(self, item, column):
        if not item:
            return

        if column == 2:
            d, ok = QInputDialog.getText(None,
                                           self.tr('Band name'),
                                           self.tr('Enter new band name')
                                           )
            if ok:
                item.setText(2, str(d))

    # return table values in a list, one element for each row
    # [['band 1', 'old name 1', 'new name 1'], ['band 2', 'old name 2', 'new name 2']]
    def bandNames(self):
        band_names = []
        for i in range(self.bandClassTree.topLevelItemCount()):
            item = self.bandClassTree.topLevelItem(i)
            if item:
                row = [item.text(0), item.text(1), item.text(2)]
                band_names.append(row)
        return band_names

    def setLayer(self, layer):
        self.layer = layer
        self.updateTable(layer)
        
    def updateTable(self, layer):
        item = QTreeWidgetItem()
        item.setText(0, '0.00')
        item.setText(1, '0.00')
        self.bandClassTree.addTopLevelItem(item)

    def setValue(self, value):
        self.bandClassTree.clear()
        param = value.split(';')
        # setting input layer
        layer_name = param[0]
        # if layer is loaded in project
        if QgsProject.instance().mapLayersByName(layer_name):
            layer = QgsProject.instance().mapLayersByName(layer_name)[0]
            self.cmbLayers.setLayer(layer)
            # setting table values
            self.bandClassTree.clear()
            rows = param[2:]
            for r in rows:
                v = r.split(',')
                item = QTreeWidgetItem()
                item.setText(0, v[0])
                item.setText(1, v[1])
                item.setText(2, v[2])
                self.bandClassTree.addTopLevelItem(item)

    # return a list where 1st element is input raster path
    # and then each element for a table row
    def value(self):
        # adding input layer names to values (useful to select again this layer when using toolbox history)
        allValues = self.cmbLayers.currentLayer().name()
        allValues += ';'
        # adding input layer path to values
        allValues += self.cmbLayers.currentLayer().source()
        allValues += ';'
        # adding table content to values
        for b in self.bandNames():
            allValues += '{0},{1},{2};'.format(b[0], b[1], b[2])
        return allValues[:-1]


class BandTableWidgetWrapper(WidgetWrapper):

    def createWidget(self):
        return BandTableWidget()

    def postInitialize(self, wrappers):
        for wrapper in wrappers:
            if wrapper.param.name == self.param.parent:
                self.setLayer(wrapper.value())
                wrapper.widgetValueHasChanged.connect(self.parentValueChanged)
                break

    def parentValueChanged(self, wrapper):
        self.setLayer(wrapper.parameterValue())

    def setLayer(self, layer):
        if isinstance(layer, QgsMapLayer):
            layer = layer.source()
        self.widget.setLayer(layer)
        self.widget.updateTable(layer)

    def setValue(self, value):
        self.widget.setValue(value)

    def value(self):
        return self.widget.value()
