######################################################################
#             / ____|  ____|  __ \| |  | |  ____|  ____|             #
#            | |    | |__  | |__) | |__| | |__  | |__                #
#            | |    |  __| |  ___/|  __  |  __| |  __|               #
#            | |____| |____| |    | |  | | |____| |____              #
#             \_____|______|_|    |_|  |_|______|______|             #
######################################################################
#
# CEPHEE
# Copyright (C) 2024 Toulouse INP
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License for more details :
# <http://www.gnu.org/licenses/>.
#
######################################################################

from shapely.geometry import Point
from qgis.PyQt.QtWidgets import QDockWidget, QFileDialog
from qgis.PyQt.QtCore import pyqtSignal, QDir, pyqtSlot, Qt
from qgis.core import (QgsVectorLayer, QgsRasterLayer, QgsProject, QgsMapLayerProxyModel, QgsMessageLog, QgsPointXY,
                       Qgis, QgsTask, QgsApplication, QgsCoordinateReferenceSystem)
from .ui.Ui_DataWidget import Ui_DataWidget
from os import path
from qgis.gui import QgsMapToolEmitPoint


class DataWidget(QDockWidget, Ui_DataWidget):
    closingWidget = pyqtSignal(name="closingWidget")

    def __init__(self, iface, BV, param, parent = None):
        super(DataWidget, self).__init__(parent)
        self.setupUi(self)
        self.iface = iface
        self.BV = BV
        self.param = param
        self.task = None
        self.param.work_path = path.join(QDir.homePath(), 'CEPHEE_project')
        self.lineEdit_workPath.setText(self.param.work_path)
        self.pushButton_work_path.clicked.connect(self.select_work_path)
        self.pushButton_DEM_path.clicked.connect(self.select_DEM_path)
        self.pushButton_loadRiverFile.clicked.connect(self.load_riverLayer)
        self.comboBox_riverLayerSelection.setFilters(QgsMapLayerProxyModel.VectorLayer)
        self.select_riverLayer()
        self.comboBox_riverLayerSelection.layerChanged.connect(self.select_riverLayer)
        self.comboBox_DEMlayerSelection.setFilters(QgsMapLayerProxyModel.RasterLayer)
        self.pushButton_process.clicked.connect(self.process)
        self.option_DEM()
        self.radioButton_pathToFolder.toggled.connect(self.option_DEM)
        self.comboBox_DEMlayerSelection.layerChanged.connect(self.option_DEM)
        self.checkBox_computeRiverFromDEM.stateChanged.connect(self.option_computeRiverFromDEM)
        self.ASC_grid_projection.setCrs(QgsProject.instance().crs())

        # capture coordinates tool
        self.captureCoordinatesTool = CaptureCoordinatesTool(iface.mapCanvas(), self.lineEdit_outletX,self.lineEdit_outletY)
        self.pushButton_captureCoordinates.clicked.connect(self.start_capture)
        self.captureCoordinatesTool.capturePoint.connect(self.update_coordinates)

    def closeEvent(self, event):
        self.closingWidget.emit()
        event.accept()

    def option_DEM(self):
        if self.radioButton_pathToFolder.isChecked():
            self.comboBox_DEMlayerSelection.hide()
            self.pushButton_DEM_path.show()
            self.lineEdit_DEM_path.show()
            self.comboBox_extension.show()
            self.param.qgis_DEM_layer = None
        elif self.radioButton_QGISlayer.isChecked():
            self.comboBox_DEMlayerSelection.show()
            self.pushButton_DEM_path.hide()
            self.lineEdit_DEM_path.hide()
            self.comboBox_extension.hide()
            self.param.qgis_DEM_layer = self.comboBox_DEMlayerSelection.currentLayer()

    def option_computeRiverFromDEM(self):
        if self.checkBox_computeRiverFromDEM.isChecked():
            self.lineEdit_minAccumulativeArea.setEnabled(True)
            self.pushButton_loadRiverFile.setDisabled(True)
            self.comboBox_riverLayerSelection.setDisabled(True)
            self.lineEdit_classMax.setDisabled(True)
            self.lineEdit_minDistJunction.setDisabled(True)
            self.param['N']['riverPath'] = None

        else:
            self.lineEdit_minAccumulativeArea.setDisabled(True)
            self.pushButton_loadRiverFile.setEnabled(True)
            self.comboBox_riverLayerSelection.setEnabled(True)
            self.lineEdit_classMax.setEnabled(True)
            self.lineEdit_minDistJunction.setEnabled(True)

    def select_work_path(self):
        path = QFileDialog.getExistingDirectory(parent = None, caption="Select working directory")
        if path:
            self.param.work_path = path
            self.lineEdit_workPath.setText(path)

    def select_DEM_path(self):
        path = QFileDialog.getExistingDirectory(parent = None, caption="Select DEM directory")
        if path:
            self.param['C']['DEMpath'] = path
            self.lineEdit_DEM_path.setText(path)

    def load_riverLayer(self):
        filename, filetype = QFileDialog.getOpenFileName(parent = None, caption = "Open File",
                                                         directory = ".", filter = "Shape file (*.shp)")

        layer_name = path.splitext(path.basename(filename))[0]
        v_layer = QgsVectorLayer(filename, layer_name, "ogr")
        if v_layer.isValid():
            QgsProject.instance().addMapLayer(v_layer)
            self.comboBox_riverLayerSelection.setLayer(v_layer)
            self.param['N']['riverPath'] = filename

    def select_riverLayer(self):
        self.param.qgis_river_layer = self.comboBox_riverLayerSelection.currentLayer()

    def ui_to_param(self):
        self.param.user_outlet_point = Point(float(self.lineEdit_outletX.text()), float(self.lineEdit_outletY.text()))
        self.param['C']['DEM_file_extension'] = self.comboBox_extension.currentText()
        self.param['C']['DEM_CRS_ID'] = self.ASC_grid_projection.crs().authid()
        self.param['C']['resolution'] = float(self.lineEdit_resolution.text())
        self.param['C']['findCatchment'] = True
        self.param['C']['computeGlobal'] = True
        self.param['N']['classeMax'] = int(self.lineEdit_classMax.text())
        self.param['N']['minAccumulativeArea']  = float(self.lineEdit_minAccumulativeArea.text())
        self.param['N']['minDistJunction'] =  float(self.lineEdit_minDistJunction.text())
        if self.checkBox_computeRiverFromDEM.isChecked():
            self.param['N']['riverPath'] = None
        else:
            self.select_riverLayer()
            self.param['N']['riverPath'] = 'from QGIS'

        if self.radioButton_interpolation.isChecked():
            self.param['N']['typeProj'] = 'interpolation'
            self.param['N']['interpolationMethod'] = self.comboBox_interpolationMethod.currentText()
        else:
            self.param['N']['typeProj'] = 'raster'
            self.param['N']['interpolationMethod'] = None

    @pyqtSlot(QgsPointXY)
    def update_coordinates(self, pt):
        self.lineEdit_outletX.setText(str(int(pt.x())))
        self.lineEdit_outletY.setText(str(int(pt.y())))
        self.end_capture()

    def start_capture(self):
        self.captureCoordinatesTool.savedMapTool = self.iface.mapCanvas().mapTool()
        self.iface.mapCanvas().setMapTool(self.captureCoordinatesTool)

    def end_capture(self):
        self.iface.mapCanvas().setMapTool(self.captureCoordinatesTool.savedMapTool)
        self.captureCoordinatesTool.savedMapTool = None
        self.captureCoordinatesTool.deactivate()

    def process(self):
        self.ui_to_param()
        self.task = BackgroundProcess('process_DATA', self.BV, self.param)
        QgsApplication.taskManager().addTask(self.task)


class BackgroundProcess(QgsTask):

    def __init__(self, description, BV, param):
        super().__init__(description, QgsTask.CanCancel)
        self.BV = BV
        self.param = param

    def run(self):
        QgsMessageLog.logMessage(message=str(self.param['C']['DEM_CRS_ID']) , tag="CEPHEE", level=Qgis.MessageLevel.Info)
        QgsMessageLog.logMessage(message="Starting process", tag="CEPHEE", level=Qgis.MessageLevel.Info)
        self.setProgress(0)

        self.BV.read_data(self.param)
        out_crs = QgsCoordinateReferenceSystem(self.BV.crs.to_string())
        QgsMessageLog.logMessage(message="Read DEM and Network", tag="CEPHEE", level=Qgis.MessageLevel.Info)
        self.setProgress(25)


        self.BV.HydroNetwork.to_file(path.join(self.param.work_path, 'hydro_network.shp'))
        hydro_layer = QgsVectorLayer(path.join(self.param.work_path, 'hydro_network.shp'), 'hydro network')
        hydro_layer.setCrs(out_crs)
        self.BV.projectionOnDEM(self.param['N']['typeProj'], self.param['N']['interpolationMethod'])
        # write and load projeted filtered network
        print('projection done')
        QgsMessageLog.logMessage(message="Network projection completed", tag="CEPHEE", level=Qgis.MessageLevel.Info)
        self.setProgress(50)
        self.BV.order_network(self.param['N']['minDistJunction'])
        ordered_hydro_layer = QgsVectorLayer(path.join(self.param.work_path, 'hydro_network_ordered.shp'),
                                             'hydro network ordered')
        ordered_hydro_layer.setCrs(out_crs)

        #self.BV.build_globalDEM(self.param)
        self.BV.find_junctionAndOutlet(self.param['N']['minDistJunction'])
        self.BV.setOutlet(self.param.user_outlet_point)
        QgsMessageLog.logMessage(message="Network post-treatment completed", tag="CEPHEE", level=Qgis.MessageLevel.Info)
        self.setProgress(75)

        layer_name = path.splitext(path.basename(self.BV.global_DEM_path))[0]
        dem_layer = QgsRasterLayer(self.BV.global_DEM_path, layer_name)
        flow_acc_layer = QgsRasterLayer(path.join(self.param.work_path, 'flow_acc.tif'),'flow accumulation')
        mask_layer = QgsRasterLayer(path.join(self.param.work_path, 'watershed_mask.tif','watershed mask'))
        QgsMessageLog.logMessage(message="Global DEM built : ", tag="CEPHEE", level=Qgis.MessageLevel.Info)

        self.BV.renameReachFromJunction(self.param)
        self.BV.ordered_network.to_file(path.join(self.param.work_path, 'hydro_network_ordered.shp'))

        # Add layers to QGIS
        QgsProject.instance().addMapLayer(mask_layer)
        QgsProject.instance().addMapLayer(flow_acc_layer)
        QgsProject.instance().addMapLayer(dem_layer)
        QgsProject.instance().addMapLayer(hydro_layer)
        QgsProject.instance().addMapLayer(ordered_hydro_layer)
        if not self.BV.junction.empty:
            self.BV.junction.to_file(path.join(self.param.work_path, 'junction.shp'))
            junction_layer = QgsVectorLayer(path.join(self.param.work_path, 'junction.shp'), 'junction')
            QgsProject.instance().addMapLayer(junction_layer)
        self.setProgress(100)
        return True

    def finished(self, result):
        if result:
            QgsMessageLog.logMessage(message="Process OK", tag="CEPHEE", level=Qgis.MessageLevel.Info)
        else:
            QgsMessageLog.logMessage(message="Process Error", tag="CEPHEE", level=Qgis.MessageLevel.Critical)


class CaptureCoordinatesTool(QgsMapToolEmitPoint):
    """Class to interact with the map canvas to capture the coordinate
    when the mouse button is pressed."""
    capturePoint = pyqtSignal(QgsPointXY)
    captureFinished = pyqtSignal()

    def __init__(self, canvas, X, Y):
        QgsMapToolEmitPoint.__init__(self, canvas)
        self.canvas = canvas
        self.savedMapTool = None

    def activate(self):
        self.canvas.setCursor(Qt.CrossCursor)

    def canvasReleaseEvent(self, event):
        pt = self.toMapCoordinates(event.originalPixelPoint())
        self.capturePoint.emit(pt)
