# Copyright (c) 2025, UChicago Argonne, LLC
# BSD OPEN SOURCE LICENSE. Full license can be found in LICENSE
# Copyright (c) 2025, UChicago Argonne, LLC
# BSD OPEN SOURCE LICENSE. Full license can be found in LICENSE
import pandas as pd
import re
import sqlalchemy as sa
import sys
from qgis.PyQt.QtCore import Qt
from qgis.PyQt.QtCore import pyqtSignal, QRect, QPoint
from qgis.PyQt.QtGui import QPainter, QColor, QLinearGradient, QPen
from qgis.PyQt.QtWidgets import (
    QApplication,
    QWidget,
    QVBoxLayout,
    QHBoxLayout,
    QTableWidget,
    QTableWidgetItem,
    QCheckBox,
    QLabel,
    QDialog,
    QDateEdit,
    QSpinBox,
    QAbstractItemView,
    QHeaderView,
)
from qgis.PyQt.QtWidgets import QSlider
from datetime import timedelta
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure
from polaris.utils.db_config import DbConfig
from qgis.utils import iface

from QPolaris.modules.common_tools.blocking_signals import block_signals
from QPolaris.modules.common_tools.double_slider import SingleBarDoubleSlider
from QPolaris.modules.cluster.charts.run_time import run_time_chart
from QPolaris.modules.cluster.charts.transit import transit_charts


class PerformanceAnalysisDialog(QDialog):
    def __init__(self):
        super().__init__()
        self.engine = DbConfig.stats_db().create_engine()
        self.initUI()
        self.__data = {}

    def initUI(self):
        # Main layout
        self.main_layout = QVBoxLayout()
        self.current_chart_set = ""

        # First row: 4 tables
        tables_layout = QHBoxLayout()

        # Chart Table
        charts_layout = QVBoxLayout()
        self.tbl_chart = self._create_table(trigger_function=None)
        self.populate_chart_list()
        self.tbl_chart.itemSelectionChanged.connect(self.change_chart_selection)
        self.chk_ci_runs = QCheckBox("CI runs only")
        self.chk_ci_runs.setChecked(False)
        charts_layout.addWidget(self.tbl_chart)
        charts_layout.addWidget(self.chk_ci_runs)

        # Models Table
        models_layout = QVBoxLayout()
        self.tbl_models = self._create_table(trigger_function=self.change_models)
        self.chk_models = QCheckBox("Collapse models")
        self.chk_models.setChecked(True)
        models_layout.addWidget(self.tbl_models)
        models_layout.addWidget(self.chk_models)

        # First dynamic table
        machine_layout = QVBoxLayout()
        self.tbl_dyn1 = self._create_table(trigger_function=self.build_chart)
        self.chk_dyn1 = QCheckBox("Collapse values")
        self.chk_dyn1.setChecked(True)
        machine_layout.addWidget(self.tbl_dyn1)
        machine_layout.addWidget(self.chk_dyn1)

        # Second dynamic table
        iter_layout = QVBoxLayout()
        self.tbl_dyn2 = self._create_table(trigger_function=self.build_chart)
        self.chk_dyn2 = QCheckBox("Collapse values")
        self.chk_dyn2.setChecked(True)
        iter_layout.addWidget(self.tbl_dyn2)
        iter_layout.addWidget(self.chk_dyn2)

        # Third dynamic table
        dyn_layout = QVBoxLayout()
        self.tbl_dyn3 = self._create_table(trigger_function=self.build_chart)
        self.chk_dyn3 = QCheckBox("Collapse values")
        self.chk_dyn3.setChecked(True)

        dyn_layout.addWidget(self.tbl_dyn3)
        dyn_layout.addWidget(self.chk_dyn3)

        for checkbox in [self.chk_dyn1, self.chk_dyn2, self.chk_dyn3, self.chk_ci_runs, self.chk_models]:
            checkbox.toggled.connect(self.build_chart)

        tables_layout.addLayout(charts_layout)
        tables_layout.addLayout(models_layout)
        tables_layout.addLayout(machine_layout)
        tables_layout.addLayout(iter_layout)
        tables_layout.addLayout(dyn_layout)

        self.main_layout.addLayout(tables_layout)

        with self.engine.connect() as connection:
            res_date = connection.execute(sa.text("SELECT min(created_at), max(created_at) FROM iterations"))

        self.min_date, self.max_date = res_date.fetchall()[0]
        self.min_date = self.min_date.date()
        self.max_date = self.max_date.date()

        # Second row: Sliders and Checkbox
        sliders_layout = QHBoxLayout()

        # First slider
        slider1_layout = QVBoxLayout()

        self.from_date = QDateEdit(self.min_date)
        self.to_date = QDateEdit(self.max_date)

        for date_edit in [self.from_date, self.to_date]:
            date_edit.setMinimumDate(self.min_date)
            date_edit.setMaximumDate(self.max_date)
            date_edit.setDisplayFormat("dd/MM/yyyy")

        date_slider_header_layout = QHBoxLayout()
        date_slider_header_layout.addStretch()
        lbl1 = QLabel("From date")
        lbl1.setMaximumWidth(70)
        date_slider_header_layout.addWidget(lbl1)
        date_slider_header_layout.addWidget(self.from_date)

        lbl2 = QLabel("     To date")
        lbl1.setMaximumWidth(70)
        date_slider_header_layout.addWidget(lbl2)
        date_slider_header_layout.addWidget(self.to_date)
        date_slider_header_layout.addStretch()

        slider1_layout.addLayout(date_slider_header_layout)

        self.slider_date = SingleBarDoubleSlider(0, (self.max_date - self.min_date).days)
        slider1_layout.addWidget(self.slider_date)

        sliders_layout.addLayout(slider1_layout)
        # sliders_layout.addLayout(slider2_layout)

        self.main_layout.addLayout(sliders_layout)

        # Third row: Matplotlib Figure
        self.graphicsView = None

        # Set main layout
        self.setLayout(self.main_layout)
        self.setWindowTitle("Performance Analysis")

        self.machines_by_type = {}
        self.populate_machines_by_type()

        self.model_names = pd.DataFrame()
        self.populate_model_names()

        self.iteration_types = pd.DataFrame()
        self.populate_iteration_types()

        self.slider_date.rangeChanged.connect(self.slider_date_changed)
        self.from_date.dateChanged.connect(self.change_min_date)
        self.to_date.dateChanged.connect(self.change_max_date)

    @block_signals
    def slider_date_changed(self, low, high):
        fdate = self.min_date + timedelta(days=low)
        tdate = self.min_date + timedelta(days=high)

        self.from_date.setDate(fdate)
        self.from_date.setMaximumDate(tdate)

        self.to_date.setDate(tdate)
        self.to_date.setMinimumDate(fdate)
        self.build_chart()

    @block_signals
    def change_max_date(self, qdate):
        max_date = qdate.toPyDate()
        self.slider_date.setHighValue((max_date - self.min_date).days)
        self.from_date.setMaximumDate(max_date)
        self.build_chart()

    @block_signals
    def change_min_date(self, qdate):
        min_date = qdate.toPyDate()
        self.slider_date.setLowValue((min_date - self.min_date).days)
        self.to_date.setMinimumDate(min_date)
        self.build_chart()

    def _create_table(self, trigger_function=None, min_height=250):
        """Create a styled QTableWidget"""
        table = QTableWidget()
        table.setStyleSheet(
            """
            QTableWidget {
                background-color: white;
                alternate-background-color: #F0F0F0;
                selection-background-color: #4A90E2;
            }
            QHeaderView::section {
                background-color: #E0E0E0;
                padding: 4px;
                border: 1px solid #D0D0D0;
                font-weight: bold;
            }
        """
        )
        table.setAlternatingRowColors(True)
        if trigger_function is not None:
            table.itemSelectionChanged.connect(trigger_function)
        table.setMinimumHeight(min_height)
        table.setSelectionBehavior(QAbstractItemView.SelectRows)
        header = table.horizontalHeader()
        header.setSectionResizeMode(QHeaderView.Stretch)
        return table

    def populate_model_names(self):
        self.model_names = pd.read_sql("SELECT count(*), model_name  FROM iterations group by model_name", self.engine)
        self.model_names["model_name_pretty"] = self.model_names["model_name"].str.lower()
        self.model_names["model_name_pretty"] = self.model_names["model_name_pretty"].str.capitalize()

        # Machine Types Table
        model_names = sorted(self.model_names.model_name_pretty.unique())
        self.tbl_models.setRowCount(len(model_names))
        self.tbl_models.setColumnCount(1)
        self.tbl_models.setColumnWidth(0, 200)
        self.tbl_models.setHorizontalHeaderLabels(["Model"])

        for row, model_name in enumerate(model_names):
            self.tbl_models.setItem(row, 0, QTableWidgetItem(model_name))

        self.tbl_models.selectRow(0)

    def populate_iteration_types(self):
        self.iteration_types = pd.read_sql(
            "SELECT count(*), iteration_type  FROM iterations group by iteration_type", self.engine
        )
        self.iteration_types["iter_type_pretty"] = self.iteration_types["iteration_type"].str.lower()

        # Machine Types Table
        pretty_names = self.iteration_types.iter_type_pretty.unique()
        self.tbl_dyn2.setRowCount(int(pretty_names.shape[0]))
        self.tbl_dyn2.setColumnCount(1)
        self.tbl_dyn2.setColumnWidth(0, 200)
        self.tbl_dyn2.setHorizontalHeaderLabels(["Iteration Type"])

        for row, model_name in enumerate(sorted(pretty_names)):
            self.tbl_dyn2.setItem(row, 0, QTableWidgetItem(model_name))

    def populate_machines_by_type(self):
        from QPolaris.modules.cluster.charts.machine_groups import machine_grouping

        machines = pd.read_sql("SELECT count(*), machine FROM iterations group by machine", self.engine)

        self.machines_by_type = machine_grouping(machines)

        # Machine Types Table
        self.tbl_dyn1.clearContents()
        self.tbl_dyn1.setRowCount(len(self.machines_by_type.keys()))
        self.tbl_dyn1.setColumnCount(1)
        self.tbl_dyn1.setColumnWidth(0, 200)
        self.tbl_dyn1.setHorizontalHeaderLabels(["Machine group"])

        for row, mach_type in enumerate(sorted(self.machines_by_type.keys())):
            self.tbl_dyn1.setItem(row, 0, QTableWidgetItem(mach_type))

    def populate_chart_list(self):
        charts = ["run time", "Transit VMT", "Transit PMT", "Transit Occupancy"]

        self.tbl_chart.setRowCount(len(charts))
        self.tbl_chart.setColumnCount(1)
        self.tbl_chart.setHorizontalHeaderLabels(["Chart"])
        self.tbl_chart.setColumnWidth(0, 200)

        for row, model_name in enumerate(charts):
            self.tbl_chart.setItem(row, 0, QTableWidgetItem(model_name))
        self.tbl_chart.setSelectionMode(QAbstractItemView.SingleSelection)
        self.tbl_chart.selectRow(0)

    @property
    def _chart(self):
        m = self.get_selected_value(self.tbl_chart)
        if not m:
            return None
        return m[0]

    def build_chart(self):
        chart = self._chart
        if not chart:
            return

        if chart == "run time":
            args = {
                "model_name": self.get_selected_value(self.tbl_models),
                "machine_grouping": self.machines_by_type,
                "iteration_type": self.get_selected_value(self.tbl_dyn2),
                "machine": self.get_selected_value(self.tbl_dyn1),
                # "num_threads": {"min_threads": self.sb_from_thread.value(), "max_threads": self.sb_to_thread.value()},
                "date": {"from": self.from_date.date().toPyDate(), "to": self.to_date.date().toPyDate()},
                "dissolving": {
                    "machine_type": self.chk_dyn1.isChecked(),
                    "model": self.chk_models.isChecked(),
                    "iteration": self.chk_dyn2.isChecked(),
                },
            }

            for arg in ["model_name", "iteration_type", "machine"]:
                if not args[arg]:
                    args.pop(arg)

            fig = run_time_chart(self.engine, **args)
        else:
            # All other charts require a single model to be selected
            if len(self.get_selected_value(self.tbl_models)) != 1:
                msg = iface.messageBar()
                message = "Please select a single model to produce this chart"
                msg.pushMessage("Error", message, level=2, duration=5)
                return
            if chart in ["Transit VMT", "Transit PMT", "Transit Occupancy"]:
                args = {
                    "model_name": self.get_selected_value(self.tbl_models)[0],
                    "peak_periods": self.get_selected_value(self.tbl_dyn1),
                    "agencies": self.get_selected_value(self.tbl_dyn2),
                    "modes": self.get_selected_value(self.tbl_dyn3),
                    "date": {"from": self.from_date.date().toPyDate(), "to": self.to_date.date().toPyDate()},
                    "dissolving": {
                        "peak_periods": self.chk_dyn1.isChecked(),
                        "agencies": self.chk_dyn2.isChecked(),
                        "modes": self.chk_dyn3.isChecked(),
                    },
                    "chart": chart,
                    "ci_only": self.chk_ci_runs.isChecked(),
                    "machines": self.machines_by_type["CI Runners"].machine,
                }
                fig = transit_charts(self.engine, **args)
            else:
                raise ValueError(f"Unknown chart type: {chart}")

        if self.graphicsView is not None:
            self.main_layout.removeWidget(self.graphicsView)
            self.graphicsView.hide()

        self.graphicsView = FigureCanvas(fig)
        self.main_layout.addWidget(self.graphicsView)
        self.graphicsView.setMinimumWidth(600)
        self.graphicsView.setMinimumHeight(500)
        self.main_layout.update()
        self.setLayout(self.main_layout)

    def get_table(self, table):
        if table not in self.__data:
            self.__data[table] = pd.read_sql("""SELECT * from "table" """, self.engine)
        self.__data[table]

    def change_chart_selection(self):
        chart = self._chart
        if not chart:
            return

        self.chk_ci_runs.setVisible(chart != "run time")
        self.tbl_dyn3.setVisible(chart != "run time")
        self.chk_dyn3.setVisible(chart != "run time")
        self.chk_models.setVisible(chart == "run time")

        if chart == "run time":
            self.selected_runtime_chart()
        else:
            # All other charts require a single model to be selected
            if len(self.get_selected_value(self.tbl_models)) != 1:
                msg = iface.messageBar()
                message = "Please select a single model to produce this chart"
                msg.pushMessage("Error", message, level=2, duration=5)
                return
            if chart in ["Transit VMT", "Transit PMT", "Transit Occupancy"]:
                self.selected_transit_charts()
            else:
                raise ValueError(f"Unknown chart type: {chart}")
        self.build_chart()

    def change_models(self):
        self.current_chart_set = ""
        self.selected_transit_charts()
        self.build_chart()

    @block_signals
    def selected_runtime_chart(self):
        if self._chart not in ["run time"]:
            return

        if self.current_chart_set == "run time":
            return
        self.current_chart_set = "run time"

        self.populate_machines_by_type()
        self.populate_iteration_types()

    @block_signals
    def selected_transit_charts(self):
        if self._chart not in ["Transit VMT", "Transit PMT", "Transit Occupancy"]:
            return

        if self.current_chart_set == "transit":
            return
        self.current_chart_set = "transit"

        if "pt_data" not in self.__data:

            sql = """WITH valid_uuids AS (
                        WITH ranked_records AS (
                            SELECT
                                *,
                                ROW_NUMBER() OVER (PARTITION BY iteration_number ORDER BY convergence_uuid DESC) AS rn
                            FROM
                                iterations
                        )
                        SELECT iteration_uuid FROM ranked_records WHERE rn = 1
                    )
                    SELECT
                        pt.*,
                        i.model_name,
                        i.created_at
                    FROM
                        transit_vmt_pmt_occ_by_period pt
                    JOIN
                        iterations i
                    ON
                        pt.iteration_uuid = i.iteration_uuid
                    WHERE
                        i.iteration_type = 'normal'
                        AND i.iteration_uuid IN (SELECT iteration_uuid FROM valid_uuids);"""

            self.__data["pt_data"] = pd.read_sql(sql, self.engine).set_index("created_at")

        model_name = self.get_selected_value(self.tbl_models)[0]
        data = self.__data["pt_data"].query(f"model_name == '{model_name}'")
        headers = ["Peak Period", "Agency", "Mode"]
        box_data = [data.peak_period.unique().tolist(), data.agency.unique().tolist(), data["mode"].unique().tolist()]
        boxes = [self.tbl_dyn1, self.tbl_dyn2, self.tbl_dyn3]

        for header, dt, box in zip(headers, box_data, boxes):
            box.clearContents()
            box.setRowCount(len(dt))
            box.setColumnCount(1)
            box.setHorizontalHeaderLabels([header])
            box.setColumnWidth(0, 200)

            for row, item in enumerate(dt):
                box.setItem(row, 0, QTableWidgetItem(str(item)))

    @staticmethod
    def get_selected_value(table):
        selected_items = table.selectedItems()
        if selected_items:
            return [selected_item.text() for selected_item in selected_items]
        else:
            return []
