# Copyright (c) 2026, UChicago Argonne, LLC
# BSD OPEN SOURCE LICENSE. Full license can be found in LICENSE
# Copyright (c) 2026, UChicago Argonne, LLC
# BSD OPEN SOURCE LICENSE. Full license can be found in LICENSE
from os.path import dirname, join

import logging
from polaris.utils.db_config import DbConfig
from polaris.hpc.eq_utils import query_workers
from polaris.hpc.eqsql.worker import Worker
from polaris.hpc.eqsql.task import Task
from sqlalchemy import create_engine

from qgis.PyQt import uic
from qgis.PyQt.QtWidgets import QDialog, QTableWidgetItem, QAbstractItemView
from qgis.utils import iface

from datetime import datetime

FORM_CLASS, _ = uic.loadUiType(join(dirname(__file__), "forms/cluster_info.ui"))


class ClusterDialog(QDialog, FORM_CLASS):
    def __init__(self):
        QDialog.__init__(self)
        self.setupUi(self)

        self._engine = None
        if self.engine is None:
            self.close()
        self.refreshButton.clicked.connect(self.do_refresh)
        self.nodes_table.cellClicked.connect(self.select_task)
        self.task_num.valueChanged.connect(self.update_task)
        self.restartNodeButton.clicked.connect(self.restart_workers)
        self.terminateNodeButton.clicked.connect(self.terminate_workers)

        self.logger = logging.getLogger("polaris")
        self.task_status = None

        for tab in [self.nodes_table, self.tasks_table]:
            tab.setSelectionBehavior(QAbstractItemView.SelectRows)
            tab.setSelectionMode(QAbstractItemView.SingleSelection)

        self.nodes_control_table.setSelectionBehavior(QAbstractItemView.SelectRows)
        self.nodes_control_table.setSelectionMode(QAbstractItemView.MultiSelection)

        self.do_refresh()

    def do_refresh(self):
        try:
            self.nodes_table.clearContents()
            self.nodes_control_table.clearContents()

            worker_status = query_workers(self.engine, style_df=False)

            for tab in [self.nodes_table, self.nodes_control_table]:
                tab.setRowCount(worker_status.shape[0])
                tab.setColumnCount(worker_status.shape[1])
                tab.setHorizontalHeaderLabels(worker_status.columns)

            for i, data in worker_status.iterrows():
                for c, d in enumerate(data.values):
                    item_val_1 = QTableWidgetItem(str(d))
                    self.nodes_table.setItem(i, c, item_val_1)
                    item_val_2 = QTableWidgetItem(str(d))
                    self.nodes_control_table.setItem(i, c, item_val_2)

            for tab in [self.nodes_table, self.nodes_control_table]:
                tab.show()
                tab.resizeColumnsToContents()
                tab.resizeRowsToContents()
            self.last_refreshed_time.setText("Last Refreshed: " + str(datetime.now().strftime("%X")))
        except Exception as e:
            self.last_refreshed_time.setText("Failed. Try Again after waiting a few seconds!")
            self.logger.warning(f"Could not refresh worker status: {e.args}")

        # The update button should also update the task status
        self.update_task()

    def select_task(self):
        selected_rows = self.nodes_table.selectionModel().selectedRows()
        for index in selected_rows:
            row = index.row()
            item = self.nodes_table.item(row, 4)  # Column index 4 for the 5th column
            if item is not None and len(item.text()) > 0:
                self.tabWidget.setCurrentIndex(1)
                self.task_num.setValue(int(item.text()))
            return

    def restart_workers(self):
        selected_rows = self.nodes_control_table.selectionModel().selectedRows()
        msg = iface.messageBar()
        for index in selected_rows:
            row = index.row()
            item = self.nodes_control_table.item(row, 0)  # Getting worker_id from column 0
            if item is not None and len(item.text()) > 0:
                worker_id = item.text()
                w = Worker.from_id(self.engine, worker_id)
                if w.status != "idle":
                    msg.pushMessage(
                        "Error",
                        f"Worker with id: {worker_id}, cannot be restarted => status = {w.status}",
                        level=2,
                        duration=2,
                    )
                    continue
                msg.pushMessage("Info", f"Worker with id: {worker_id} has been restarted", level=0, duration=2)
                w.restart(self.engine)
        self.do_refresh()

    def terminate_workers(self):
        selected_rows = self.nodes_control_table.selectionModel().selectedRows()
        msg = iface.messageBar()
        for index in selected_rows:
            row = index.row()
            item = self.nodes_control_table.item(row, 0)  # Column index 0 for the 1st column
            if item is not None and len(item.text()) > 0:
                worker_id = item.text()
                w = Worker.from_id(self.engine, worker_id)
                if w.status != "idle":
                    msg.pushMessage(
                        "Error",
                        f"Worker with id: {worker_id}, cannot be terminated => status = {w.status}",
                        level=2,
                        duration=2,
                    )
                    continue
                msg.pushMessage("Info", f"Worker with id: {worker_id} has been terminated", level=0, duration=2)
                w.terminate(self.engine)
        self.do_refresh()

    def update_task(self):
        self.tasks_table.clearContents()
        task_id = self.task_num.value()
        task = Task.from_id(self.engine, task_id)
        if task is None:
            self.tasks_table.setRowCount(0)
            self.tasks_table.setColumnCount(0)
            if self.tabWidget.currentIndex() != 1:
                return
            msg = iface.messageBar()
            msg.pushMessage("Error", f"Task with id: {task_id}, does not exist", level=1, duration=2)
            return

        task_df = task.get_logs(self.engine)
        self.tasks_table.setRowCount(task_df.shape[0])
        self.tasks_table.setColumnCount(task_df.shape[1])
        self.tasks_table.setHorizontalHeaderLabels(task_df.columns)

        for i, data in task_df.iterrows():
            for c, d in enumerate(data.values):
                item_val = QTableWidgetItem(str(d))
                self.tasks_table.setItem(i, c, item_val)
        self.tasks_table.show()
        self.tasks_table.resizeColumnsToContents()
        self.tasks_table.resizeRowsToContents()

    @property
    def engine(self):
        if self._engine is None:
            try:
                eng = DbConfig.eqsql_db().create_engine()
                self._engine = eng
            except Exception as e:
                self.logger.warning(f"Cannot setup connection: {e.args}")
                self.logger.warning("If not on ANL server, task status will not be available.")

        return self._engine

    def closeEvent(self, event):
        if self._engine is not None:
            self._engine.dispose()
        event.accept()
