# Copyright (c) 2026, UChicago Argonne, LLC
# BSD OPEN SOURCE LICENSE. Full license can be found in LICENSE
# Copyright (c) 2026, UChicago Argonne, LLC
# BSD OPEN SOURCE LICENSE. Full license can be found in LICENSE
import sys
from os.path import dirname, join

import qgis
from qgis.PyQt.QtCore import Qt
from polaris.network.traffic.intersec import Intersection
from polaris.network.traffic.intersection.turn_type import turn_type
from qgis.PyQt import uic
from qgis.PyQt.QtWidgets import QDialog, QAbstractItemView, QTableWidgetItem, QComboBox, QTableWidget, QTabWidget
from qgis.core import QgsProject

from polaris.utils.database.spatialite_utils import connect_spatialite

sys.modules["qgsfieldcombobox"] = qgis.gui
sys.modules["qgsmaplayercombobox"] = qgis.gui
FORM_CLASS, _ = uic.loadUiType(join(dirname(__file__), "forms/intersection.ui"))


class IntersectionDialog(QDialog, FORM_CLASS):
    def __init__(self, _PQgis, conn, node=None):
        QDialog.__init__(self)
        self.iface = _PQgis.iface
        self.setupUi(self)

        self._PQgis = _PQgis
        self._p = _PQgis.network
        self.__data_tables = self._p.tables
        self.intersection = Intersection(data_tables=self._p.tables, path_to_file=_PQgis.supply_path)
        self.conn = conn
        self.period_data = {}
        self.connection_greens = {}
        self.connection_rows = {}
        self.timing_data = {}
        self.changes = {}
        self.required_layers = ["Node", "Connection"]

        if sum(self.conn.execute("Select count(*) from Signal").fetchone()) == 0:
            self.chb_traffic_only.setChecked(False)

        # layers already loaded
        layer_ids = [layer.id() for layer in QgsProject.instance().mapLayers().values()]

        for lyr in self.required_layers:
            if lyr.lower() not in self._PQgis.layers:
                print("Layer was not found, which is weird")
                self._PQgis.create_layer_by_name(lyr)
            layer_id = self._PQgis.layers[lyr.lower()][1]
            if layer_id not in layer_ids:
                self._PQgis.load_layer_by_name(lyr)

        fldr = join(dirname(__file__), "styles")
        self.connection_layer = self._PQgis.layers["connection"][0]
        self.connection_layer.loadNamedStyle(join(fldr, "Connection.qml"), True)

        layer_names = [layer.name() for layer in QgsProject.instance().mapLayers().values()]
        self.turn_types = ["protected", "permitted", "stop_permit"]
        for lyr in self.turn_types:
            if lyr in layer_names:
                continue
            layer = self._PQgis.create_loose_layer("Connection")
            layer.setName(lyr)
            layer.loadNamedStyle(join(fldr, f"{lyr}.qml"), True)
            QgsProject.instance().addMapLayer(layer)
            self._PQgis.layers[lyr.lower()] = [layer, layer.id()]

        self.cob_nodes.currentIndexChanged.connect(self.load_intersection_first_time)
        self.chb_traffic_only.toggled.connect(self.add_nodes_to_cbox)
        self.but_goto.clicked.connect(self.quick_selection)
        self.but_execute_change.clicked.connect(self.execute_change)
        self.cob_period.currentIndexChanged.connect(self.load_data_for_period)

        self.tab_conn.itemSelectionChanged.connect(self.filter_connections_directly)
        self.tab_phasing.itemSelectionChanged.connect(self.data_for_single_phase)

        self.but_add_connection.clicked.connect(self.add_connection_to_intersection)
        self.but_remove_connection.clicked.connect(self.remove_connection_from_intersection)

        self.but_add_movement_to_phase.clicked.connect(self.add_connection_to_signal_phase)
        self.but_remove_movement_from_phase.clicked.connect(self.remove_connection_from_signal_phase)

        self.add_nodes_to_cbox()

        self.tab_conn.setColumnCount(8)
        self.tab_conn.setHorizontalHeaderLabels(
            ["conn", "link", "to_link", "lanes", "to_lanes", "protected", "permitted", "stop_permit"]
        )

        self.tab_missing_conn.setColumnCount(3)
        self.tab_missing_conn.setHorizontalHeaderLabels(["link", "to_link", "type"])

        df = self.__data_tables.get("Timing")
        fields = [df.index._name] + list(df.columns)
        self.tab_timing.setColumnCount(1)
        self.tab_timing.setHorizontalHeaderLabels(["Values"])
        self.tab_timing.setRowCount(len(fields))
        self.tab_timing.setVerticalHeaderLabels(fields)

        df = self.__data_tables.get("Timing_Nested_Records")
        fields = list(df.columns)
        self.tab_timing_nested.setColumnCount(len(fields))
        self.tab_timing_nested.setHorizontalHeaderLabels(fields)

        df = self.__data_tables.get("Phasing")
        fields = [df.index._name] + list(df.columns)
        self.tab_phasing.setColumnCount(len(fields))
        self.tab_phasing.setHorizontalHeaderLabels(fields)

        df = self.__data_tables.get("Phasing_Nested_Records")
        fields = list(df.columns)
        self.protection_types = sorted(df.value_protect.unique())[::-1]

        for tab in [self.tab_phasing_nested, self.tab_not_in_phase]:
            tab.setColumnCount(len(fields))
            tab.setHorizontalHeaderLabels(fields)

        for tab in [
            self.tab_conn,
            self.tab_missing_conn,
            self.tab_phasing,
            self.tab_not_in_phase,
            self.tab_phasing_nested,
        ]:
            tab.setSelectionBehavior(QAbstractItemView.SelectRows)

        if node is not None:
            index = self.cob_nodes.findText(str(node), Qt.MatchFixedString)
            if index >= 0:
                self.cob_nodes.setCurrentIndex(index)
        self.signal = None
        self.load_intersection_first_time()
        self.zoom_to_intersection()

    def sets_supported_intersection_control(self):
        self.cob_reset.clear()
        items = ["Just rebuild what's there"]
        if self.intersection.supports_signal(self.conn):
            items.extend(
                [
                    "Add signal after Geometric check",
                    "Add signal after OSM check",
                    "Add signal without checking",
                    "Add stop sign",
                ]
            )
            if self.intersection.has_signal(self.conn) or self.intersection.has_stop_sign(self.conn):
                items.append("Remove intersection control")
        for item in items:
            self.cob_reset.addItem(item)

    def execute_change(self):
        txt = self.cob_reset.currentText()

        inter = self.intersection  # type: Intersection
        if txt == "Just rebuild what's there":
            sig_type = "forced" if inter.has_signal(self.conn) else "none"
            sig_type = "stop_sign" if inter.has_stop_sign(self.conn) else sig_type
            inter.rebuild_intersection(self.conn, signal_type=sig_type)
        elif txt == "Add signal after Geometric check":
            inter.rebuild_intersection(self.conn, signal_type="geometric")
        elif txt == "Add signal after OSM check":
            inter.rebuild_intersection(self.conn, signal_type="osm")
        elif txt == "Add signal without checking":
            sig = inter.create_signal(self.conn, compute_and_save=False)
            if sig is not None:
                sig.re_compute(self.conn)
                sig.save(self.conn)
        elif txt == "Add stop sign":
            inter.add_stop_sign(self.conn)
        elif txt == "Remove intersection control":
            if self.intersection.has_signal(self.conn):
                inter.delete_signal(self.conn)
            if self.intersection.has_stop_sign(self.conn):
                inter.delete_stop_sign(self.conn)
        else:
            raise ValueError("Unknown option")

        self.conn.commit()
        self.load_intersection()
        self.sets_supported_intersection_control()

    def closeEvent(self, event):
        """Generate 'question' dialog on clicking 'X' button in title bar.

        Reimplement the closeEvent() event handler to include a 'Question'
        dialog with options on how to proceed - Save, Close, Cancel buttons
        """
        # reply = QMessageBox.question(
        #     self, "Message",
        #     "Are you sure you want to quit? Any unsaved work will be lost.",
        #     QMessageBox.Save | QMessageBox.Close | QMessageBox.Cancel,
        #     QMessageBox.Save)
        #
        # if reply == QMessageBox.Close:
        self.close()
        self.exit_procedure()

    def load_intersection_first_time(self):
        self.changes.clear()
        self.load_intersection()
        self.sets_supported_intersection_control()
        self.zoom_to_intersection()

    def zoom_to_intersection(self):
        lyr = self._PQgis.layers["connection"][0]
        if lyr.featureCount() > 0:
            self.iface.mapCanvas().setExtent(lyr.extent())
        else:
            self.iface.mapCanvas().setExtent(self._PQgis.layers["node"][0].extent())
        self.iface.mapCanvas().refresh()

    def load_intersection(self):
        self.connection_greens.clear()
        self.connection_rows.clear()
        self.timing_data.clear()

        node = int(self.cob_nodes.currentText())

        for lyr in self.required_layers:
            self._PQgis.layers[lyr.lower()][0].setSubsetString(f'"node"={node}')

        self.cob_period.blockSignals(True)
        self.cob_period.clear()
        self.period_data.clear()

        for tab in [
            self.tab_timing_nested,
            self.tab_phasing_nested,
            self.tab_not_in_phase,
            self.tab_conn,
            self.tab_phasing,
            self.tab_missing_conn,
        ]:
            tab.clearSelection()
            tab.setRowCount(0)

        self.intersection = self._p.get_intersection(node)
        if self.intersection.intersection_type in ["Dead-end", "Dead-start", "disconnected"]:
            return

        if self.intersection.has_signal(self.conn):
            self.semaphore = self.intersection.create_signal(self.conn, compute_and_save=False)
            self.signal = int(self.semaphore.signal)
            self.main_tab.setEnabled(True)
            table = self.__data_tables.get("Signal_nested_Records")
            table = table.loc[table.object_id == self.signal, :]
            for _, row in table.iterrows():
                val = f"{row.value_start} - {row.value_end}"
                self.cob_period.addItem(val)
                self.period_data[val] = row
        else:
            self.signal = None

        self.has_signal(self.intersection.has_signal(self.conn))
        table = self.intersection.connections(self.conn)
        table = table[["conn", "link", "to_link", "lanes", "to_lanes", "key"]]
        self.tab_conn.setRowCount(table.shape[0])
        for i, row in table.iterrows():
            for j, val in enumerate(row.values[:-1]):
                self.tab_conn.setItem(i, j, QTableWidgetItem(str(val)))
                self.update()
            self.connection_greens[row.key] = dict.fromkeys(self.turn_types, 0)
            self.connection_rows[row.key] = i
        self.tab_conn.update()

        all_keys = table.key.to_list()
        missing = []
        for inc in self.intersection.incoming:
            outs = [out for out in self.intersection.outgoing if f"{inc.link}-{out.link}" not in all_keys]
            missing.extend([[inc.link, out.link, turn_type(inc, out)] for out in outs])

        self.tab_missing_conn.setRowCount(len(missing))
        if missing:
            for i, data in enumerate(missing):
                for j, val in enumerate(data):
                    self.tab_missing_conn.setItem(i, j, QTableWidgetItem(str(val)))

        self.tab_missing_conn.update()

        self.cob_period.blockSignals(False)

        self.timing_df = self.__data_tables.get("Timing")
        self.timing_df = self.timing_df.loc[self.timing_df.signal == self.signal]
        self.phasing_df = self.__data_tables.get("Phasing")
        self.phasing_df = self.phasing_df.loc[self.phasing_df.signal == self.signal]

        self.load_data_for_period()
        self.hide_turns()
        self.main_tab.setCurrentIndex(0)
        self.zoom_to_intersection()

    def quick_selection(self):
        txt = self.txt_goto.text()

        self.txt_goto.setText("")
        if not txt.isdigit():
            return

        index = self.cob_nodes.findText(txt, Qt.MatchFixedString)
        if index >= 0:
            self.cob_nodes.setCurrentIndex(index)
        self.zoom_to_intersection()

    def add_connection_to_intersection(self):
        sel = self.tab_missing_conn.selectedItems()
        if not sel:
            return
        rows = [s.row() for s in sel if s.column() == 0]
        self.tab_missing_conn.clearSelection()

        for i in rows:
            from_link = int(self.tab_missing_conn.item(i, 0).text())
            to_link = int(self.tab_missing_conn.item(i, 1).text())
            self.intersection.add_movement(from_link, to_link, self.conn)

        self.load_intersection_first_time()

    def add_connection_to_signal_phase(self):
        sel = self.tab_not_in_phase.selectedItems()
        if not sel:
            return

        phase_selected = [s.row() for s in self.tab_phasing.selectedItems() if s.column() == 0][0]
        row = [s.row() for s in sel if s.column() == 0][0]
        self.tab_not_in_phase.clearSelection()

        obj_id = int(self.tab_phasing_nested.item(0, 0).text())
        idx = self.tab_phasing_nested.rowCount()
        mov = self.tab_not_in_phase.item(row, 2).text()
        from_link = int(self.tab_not_in_phase.item(row, 3).text())
        link_dir = int(self.tab_not_in_phase.item(row, 4).text())
        to_link = int(self.tab_not_in_phase.item(row, 5).text())
        penalty_type = self.tab_not_in_phase.cellWidget(row, 6).currentText()

        data = [obj_id, idx, mov, from_link, link_dir, to_link, penalty_type]
        self.conn.execute("Insert into Phasing_Nested_Records Values(?,?,?,?,?,?,?)", data)
        self.conn.commit()
        self.__data_tables.refresh_cache("Phasing_Nested_Records")
        self.__data_tables.refresh_cache("Phasing")
        self.load_intersection_first_time()
        self.main_tab.setCurrentIndex(2)
        self.tab_phasing.selectRow(phase_selected)

    def remove_connection_from_intersection(self):
        # self.remove_items(self.tab_conn, 'turns')
        sel = self.tab_conn.selectedItems()
        if not sel:
            return

        rows = [s.row() for s in sel if s.column() == 0]
        self.tab_conn.clearSelection()

        to_add = self.intersection.has_signal(self.conn)
        if to_add:
            self.intersection.delete_signal(self.conn)

        for i in rows:
            from_link = int(self.tab_conn.item(i, 1).text())
            to_link = int(self.tab_conn.item(i, 2).text())

            self.intersection.block_movement(from_link, to_link, self.conn)

        if to_add:
            self.intersection.create_signal(self.conn)
        self.load_intersection_first_time()

    def remove_connection_from_signal_phase(self):
        sel = self.tab_phasing_nested.selectedItems()
        if not sel:
            return

        phase_selected = [s.row() for s in sel if s.column() == 0][0]
        row = [s.row() for s in sel if s.column() == 0][0]
        self.tab_phasing_nested.clearSelection()

        obj_id = int(self.tab_phasing_nested.item(row, 0).text())
        from_link = int(self.tab_phasing_nested.item(row, 3).text())
        to_link = int(self.tab_phasing_nested.item(row, 5).text())

        data = [obj_id, from_link, to_link]
        sql = "DELETE FROM Phasing_Nested_Records WHERE object_id=? and value_link=? and value_to_link=?"
        self.conn.execute(sql, data)
        self.conn.commit()
        self.__data_tables.refresh_cache("Phasing_Nested_Records")
        self.__data_tables.refresh_cache("Phasing")
        self.load_intersection_first_time()
        self.main_tab.setCurrentIndex(2)
        self.tab_phasing.selectRow(phase_selected)

    def has_signal(self, has):
        tab = self.main_tab  # type: QTabWidget
        tab.currentChanged.connect(self.selects_connections_tab)
        for i in [1, 2]:
            tab.setTabEnabled(i, has)
            tab.setStyleSheet("QTabBar::tab::disabled {width: 0; height: 0; margin: 0; padding: 0; border: none;} ")

        for elem in [self.lbl_period, self.cob_period]:
            elem.setVisible(has)

    def selects_connections_tab(self):
        if self.main_tab.currentIndex() > 0:
            return

        self.hide_turns()
        self.filter_connections_directly()

    def add_nodes_to_cbox(self):
        self.cob_nodes.blockSignals(True)
        qry = (
            "Select nodes from Signal order by nodes"
            if self.chb_traffic_only.isChecked()
            else "Select node from Node order by node"
        )

        self.cob_nodes.clear()
        for node in self.conn.execute(qry).fetchall():
            self.cob_nodes.addItem(str(node[0]))
        self.cob_nodes.blockSignals(False)
        self.load_intersection()

    def exit_procedure(self):
        from ..menu_actions import clear_intersection_display

        clear_intersection_display(self._PQgis)
        self.close()

    def filter_connections_directly(self):
        for tab in [
            self.tab_timing,
            self.tab_phasing,
            self.tab_timing_nested,
            self.tab_phasing_nested,
            self.tab_not_in_phase,
        ]:  # type: QTableWidget
            tab.clearSelection()

        if sel := self.tab_conn.selectedItems():
            conns = [int(s.text()) for s in sel if s.column() == 0]
            fltr = f'"conn" = {conns[0]}' if len(conns) == 1 else f'"conn" IN {tuple(conns)}'

            self.connection_layer.setSubsetString(fltr)
            self.hide_turns()
        else:
            self.load_intersection()
        self.zoom_to_intersection()

    def load_data_for_period(self):
        self.timing_data.clear()

        period = self.cob_period.currentText()
        if len(period) == 0:
            return

        sig_rec = self.period_data[period]

        for tab in [
            self.tab_timing,
            self.tab_phasing,
            self.tab_timing_nested,
            self.tab_phasing_nested,
            self.tab_not_in_phase,
        ]:  # type: QTableWidget
            for i in range(tab.rowCount()):
                for j in range(tab.columnCount()):
                    tab.setItem(i, j, QTableWidgetItem(""))

        time_df = self.timing_df.loc[self.timing_df.timing == sig_rec.value_timing, :]
        time_df.reset_index(inplace=True)
        row = [row for i, row in time_df.iterrows()][0]
        for i, cell in enumerate(row.values):
            self.tab_timing.setItem(i, 0, QTableWidgetItem(str(cell)))

        table = self.__data_tables.get("Timing_Nested_Records")
        table = table.loc[table.object_id == time_df.timing_id.values[0], :]
        self.tab_timing_nested.setRowCount(0)
        self.tab_timing_nested.setRowCount(table.shape[0])
        for i, (_, row) in enumerate(table.iterrows()):
            for j, cell in enumerate(row.values):
                self.tab_timing_nested.setItem(i, j, QTableWidgetItem(str(cell)))
            self.timing_data[row.value_phase] = row.value_minimum
        phases = self.phasing_df.loc[self.phasing_df.phasing == sig_rec.value_phasing, :]
        phases.reset_index(drop=True, inplace=True)
        self.tab_phasing.setRowCount(0)
        self.tab_phasing.setRowCount(phases.shape[0])
        obj_ids = []
        phn = {}
        for i, row in phases.iterrows():
            for j, cell in enumerate(row.values):
                self.tab_phasing.setItem(i, j, QTableWidgetItem(str(cell)))
            obj_ids.append(phases.phasing_id)
            phn[row.phasing_id] = row.phase

        table = self.__data_tables.get("Phasing_Nested_Records")
        self.phasing_data = table.loc[table.object_id.isin(phases.phasing_id), :]

        # tabulate green times for each turn
        for key, val in self.connection_greens.items():
            self.connection_greens[key] = dict.fromkeys(val.keys(), 0)

        for _, row in self.phasing_data.iterrows():
            phase_id = row.object_id
            phase = phn[phase_id]
            green = self.timing_data[int(phase)]
            connection = f"{row.value_link}-{row.value_to_link}"
            turn_type = str(row.value_protect).lower()
            self.connection_greens[connection] = self.connection_greens.get(connection, {})
            self.connection_greens[connection][turn_type] = self.connection_greens[connection].get(turn_type, 0)
            self.connection_greens[connection][turn_type] += green

        for connection, row in self.connection_rows.items():
            for j, key in enumerate(self.turn_types):
                conn_data = self.connection_greens.get(connection, {})
                green_time = conn_data.get(key, "missing")
                self.tab_conn.setItem(row, 5 + j, QTableWidgetItem(str(green_time)))
        self.iface.mapCanvas().refresh()

    def data_for_single_phase(self):
        self.hide_connections()
        row = list(self.tab_phasing.selectedItems())
        if not row:
            return
        df = self.phasing_data.loc[self.phasing_data.object_id == int(row[0].text()), :]
        df_not_part = self.phasing_data.loc[self.phasing_data.object_id != int(row[0].text()), :]
        df_not_part.drop(columns=["value_protect"], inplace=True)

        # We show only relevant stuff
        df_not_part.drop_duplicates(subset=["value_link", "value_to_link"], inplace=True)
        current_pairs = df.value_link * 10000000000 + df.value_to_link
        available_pairs = df_not_part.value_link * 10000000000 + df_not_part.value_to_link
        df_not_part = df_not_part.loc[~available_pairs.isin(current_pairs), :]

        for tab, pd_df in [[self.tab_phasing_nested, df], [self.tab_not_in_phase, df_not_part]]:
            tab.setRowCount(0)
            tab.setRowCount(pd_df.shape[0])

        pairs = {key: [] for key in self.turn_types}
        for i, (_, row) in enumerate(df.iterrows()):
            for j, cell in enumerate(row.values):
                self.tab_phasing_nested.setItem(i, j, QTableWidgetItem(str(cell)))
            pairs[row.value_protect.lower()].append(f"(link = {row.value_link} and to_link={row.value_to_link})")

        last_col = df_not_part.shape[1]
        for i, (_, row) in enumerate(df_not_part.iterrows()):
            for j, cell in enumerate(row.values):
                self.tab_not_in_phase.setItem(i, j, QTableWidgetItem(str(cell)))
            combo = QComboBox()
            combo.addItems(self.protection_types)
            self.tab_not_in_phase.setCellWidget(i, last_col, combo)
        self.hide_turns()
        for key, val in pairs.items():
            if val:
                fltr = " or ".join(val)
                if key.lower() in self._PQgis.layers:
                    self._PQgis.layers[key.lower()][0].setSubsetString(fltr)
        self.iface.mapCanvas().refresh()

    def hide_connections(self):
        self.connection_layer.setSubsetString("conn=-9999999")

    def hide_turns(self):
        for layer in self.turn_types:
            if layer.lower() in self._PQgis.layers:
                self._PQgis.layers[layer.lower()][0].setSubsetString("conn=-9999999")
