# Copyright (c) 2026, UChicago Argonne, LLC
# BSD OPEN SOURCE LICENSE. Full license can be found in LICENSE
# Copyright (c) 2026, UChicago Argonne, LLC
# BSD OPEN SOURCE LICENSE. Full license can be found in LICENSE
import logging
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from polaris.analyze.kpi_comparator import KpiComparator
from polaris.analyze.result_kpis import ResultKPIs
from polaris.runs.convergence.convergence_iteration import ConvergenceIteration
from qgis.PyQt.QtWidgets import QDialog, QGridLayout, QTableWidget, QTableWidgetItem, QAbstractItemView

from QPolaris.modules.common_tools import list_iterations

logger = logging.getLogger("polaris")


class PlotKPIs(QDialog):
    def __init__(self, qgis_project) -> None:
        QDialog.__init__(self)
        self.iface = qgis_project.iface
        self.qgis_project = qgis_project
        self.buildUI()
        self.load_table_values()

    def buildUI(self):
        # Create the grid layout
        self.grid = QGridLayout()
        self.setLayout(self.grid)

        # Create and add two tables
        def build_table(size, lbl):
            tbl = QTableWidget()
            tbl.setMaximumWidth(size)
            tbl.setMinimumWidth(size)
            tbl.setColumnCount(1)
            tbl.setColumnWidth(0, size - 20)
            tbl.setHorizontalHeaderLabels([lbl])
            tbl.setSelectionBehavior(QAbstractItemView.SelectRows)
            return tbl

        self.table_iter = build_table(220, "Iterations")
        self.table_iter.setSelectionMode(QAbstractItemView.MultiSelection)
        self.grid.addWidget(self.table_iter, 0, 0)  # Row 1, Column 0

        self.table_metrics = build_table(220, "Metrics")
        self.table_metrics.setSelectionMode(QAbstractItemView.SingleSelection)
        self.grid.addWidget(self.table_metrics, 0, 1)  # Row 1, Column 1

        # Create and add a graphics view
        self.graphicsView = None
        self.setWindowTitle("Model KPIs")
        self.setGeometry(300, 300, 750, 300)

    def load_table_values(self):
        list_iters = list_iterations(self.qgis_project)
        plots = KpiComparator.available_plots()

        def build_table(tbl, lst):
            tbl.setRowCount(len(lst))
            for i, val in enumerate(lst):
                tbl.setItem(i, 0, QTableWidgetItem(str(val)))

        build_table(self.table_iter, list_iters)
        build_table(self.table_metrics, plots)

        self.table_iter.itemSelectionChanged.connect(self.build_chart)
        self.table_metrics.itemSelectionChanged.connect(self.build_chart)

    def build_chart(self):
        row_iters = list(self.table_iter.selectedItems())
        row_plot = list(self.table_metrics.selectedItems())
        if not row_iters or not row_plot:
            return

        c = KpiComparator()
        db_name = self.qgis_project.polaris_project.run_config.db_name
        for iter_row in row_iters:
            iter_dir = self.qgis_project.polaris_project.model_path / str(iter_row.text())
            kpi = ResultKPIs.from_dir(iteration_dir=iter_dir, db_name=db_name)
            c.add_run(kpi, str(iter_row.text()))

        fig = c.__getattribute__(f"plot_{row_plot[0].text()}")()

        if self.graphicsView is not None:
            self.grid.removeWidget(self.graphicsView)
            self.graphicsView.hide()

        self.graphicsView = FigureCanvas(fig)
        self.grid.addWidget(self.graphicsView, 0, 3, 1, 1)
        self.graphicsView.setMinimumWidth(400)
        self.graphicsView.setMinimumHeight(300)
        self.grid.update()
        self.setLayout(self.grid)
