######################################################################
#             / ____|  ____|  __ \| |  | |  ____|  ____|             #
#            | |    | |__  | |__) | |__| | |__  | |__                #
#            | |    |  __| |  ___/|  __  |  __| |  __|               #
#            | |____| |____| |    | |  | | |____| |____              #
#             \_____|______|_|    |_|  |_|______|______|             #
######################################################################
#
# CEPHEE
# Copyright (C) 2024 Toulouse INP
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License for more details :
# <http://www.gnu.org/licenses/>.
#
######################################################################

import pandas as pd
from qgis.PyQt.QtWidgets import QWidget
from qgis.PyQt.QtCore import pyqtSignal
from qgis.core import QgsMapLayerProxyModel
from .ui.Ui_PlotWidget import Ui_PlotWidget
from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT
from ..core.Tools import compute_distance
from shapely.geometry import LineString

class PlotWidget(QWidget, Ui_PlotWidget):
    closingWidget = pyqtSignal(name="closingWidget")

    def __init__(self, iface, BV, param, parent = None):
        self.iface = iface
        self.BV = BV
        self.param = param
        super(PlotWidget, self).__init__(parent)
        self.setupUi(self)
        self.figureFrame.resize(632,564)
        self.figureWidget.__init__(BV, parent=self.figureFrame)
        self.nav_toolbar = NavigationToolbar2QT(self.figureWidget, self)
        self.nav_toolbar.removeAction(self.nav_toolbar.actions()[0])
        self.verticalLayout.insertWidget(10, self.nav_toolbar)
        self.comboBox_layerSelection.setFilters(QgsMapLayerProxyModel.VectorLayer)
        self.comboBox_layerSelection.currentIndexChanged.connect(self.layerSelected)
        self.layerSelected()
        self.pushButton_plot_qgis.clicked.connect(self.plot_qgis)
        self.update_reach_list()
        self.comboBox_reachSelection.currentIndexChanged.connect(self.update_cephee_panel)
        self.checkBox_plotXS.stateChanged.connect(self.update_cephee_panel)
        self.pushButton_plot_cephee.clicked.connect(self.plot_cephee)

    def closeEvent(self, event):
        self.closingWidget.emit()
        event.accept()

    def resizeEvent(self, event):
        self.figureWidget.resize(self.figureFrame.size())

    # QGIS tab
    def layerSelected(self):
        if self.comboBox_layerSelection.currentLayer():
            self.label_nID.setText("0-"+str(len(self.comboBox_layerSelection.currentLayer())-1))

    def plot_qgis(self):
        map_layer = self.comboBox_layerSelection.currentLayer()
        id_section = int(self.lineEdit_XS_qgis.text())
        self.figureWidget.plot_qgis(map_layer, id_section)

    # CEPHEE tab
    def update_reach_list(self):
            self.comboBox_reachSelection.clear()
            if len(self.BV.reach)>0:
                 for i in range(len(self.BV.reach)):
                    self.comboBox_reachSelection.addItem(self.BV.reach[i].name)
                    print(self.BV.reach[i].name)
            else:
                self.comboBox_reachSelection.addItem('no reach available')

    def update_cephee_panel(self):
        self.lineEdit_XS_cephee.clear()
        if self.checkBox_plotXS.isChecked():
            self.lineEdit_XS_cephee.setEnabled(True)
            self.comboBox_fieldSelection.setEnabled(True)
        else:
            self.lineEdit_XS_cephee.setDisabled(True)
            self.comboBox_fieldSelection.setDisabled(True)
        n_sections = "-"
        if len(self.BV.reach):
            current_reach = self.BV.reach[self.comboBox_reachSelection.currentIndex()]
            if len(current_reach.section)>0:
                n_sections = (str(current_reach.id_first_section) + "-"
                              + str(current_reach.id_first_section + len(current_reach.section)-1))
                self.lineEdit_XS_cephee.setText("0")
        self.label_nSections.setText(n_sections)

    def plot_cephee(self):
        id_reach = self.comboBox_reachSelection.currentIndex()
        if self.comboBox_reachSelection.currentText() == "no reach available":
            return
        
        if self.comboBox_resSelection.currentText() == 'Normal':
            results = self.BV.reach[id_reach].resNormal
        elif self.comboBox_resSelection.currentText() == '1D':
            results = self.BV.reach[id_reach].res1D
        elif self.comboBox_resSelection.currentText() == 'Himposed':
            results = self.BV.reach[id_reach].resHimposed
        elif self.comboBox_resSelection.currentText() == 'Obs':
            results = self.BV.reach[id_reach].resObs
        else:
            results = pd.DataFrame()

        self.figureWidget.axes.cla()

        if self.checkBox_plotXS.isChecked():
            id_section = int(self.lineEdit_XS_cephee.text()) - self.BV.reach[id_reach].id_first_section
            id_field = self.comboBox_fieldSelection.currentIndex()
            if id_section < 0 or id_section >= len(self.BV.reach[id_reach].section):
                return
            distance = self.BV.reach[id_reach].section[id_section].distance
            if id_field ==0:
                self.figureWidget.plot_z(distance, self.BV.reach[id_reach].section[id_section].line)

            if not results.empty:
                df_filter=results[results['idSection']==id_section]
                if id_field ==0:
                    distance_to_plot = [df_filter['bank'].iloc[0][0][0], df_filter['bank'].iloc[0][0][1]]
                    WSE_to_plot = [df_filter['WSE'].iloc[0], df_filter['WSE'].iloc[0]]
                    self.figureWidget.plot_WS(distance_to_plot, WSE_to_plot)
                elif id_field == 1:
                    self.figureWidget.plot_V(df_filter['distance'].iloc[0],df_filter['V'].iloc[0])
                elif id_field == 2:
                    self.figureWidget.plot_V(df_filter['distance'].iloc[0],df_filter['h'].iloc[0])


        else:
            #create line with z min of section
            X_bed = [self.BV.reach[id_reach].section[i].centre.x for i in range(len(self.BV.reach[id_reach].section))]
            Y_bed = [self.BV.reach[id_reach].section[i].centre.y for i in range(len(self.BV.reach[id_reach].section))]
            Z_bed = [self.BV.reach[id_reach].section[idpoint].Zbed for idpoint in range(len(self.BV.reach[id_reach].section))]
            line_bed = LineString([(x, y, z) for x, y, z in zip(X_bed,Y_bed, Z_bed)])

            self.figureWidget.plot_z(self.BV.reach[id_reach].Xinterp, line_bed)
            distance = compute_distance(self.BV.reach[id_reach].geodata['geometry'].coords)
            self.figureWidget.plot_z_origin(distance, self.BV.reach[id_reach].geodata['geometry'])

            if not results.empty:
                WS =[]
                for j in range(0, len(self.BV.reach[id_reach].section)):
                    df_filter=results[results['idSection'] == j]
                    WS.append(df_filter['WSE'].iloc[0])
                self.figureWidget.plot_WS(self.BV.reach[id_reach].Xinterp, WS)
