# Copyright (c) 2026, UChicago Argonne, LLC
# BSD OPEN SOURCE LICENSE. Full license can be found in LICENSE
# Copyright (c) 2026, UChicago Argonne, LLC
# BSD OPEN SOURCE LICENSE. Full license can be found in LICENSE
from os.path import dirname, join

import pandas as pd
import qgis
from qgis.PyQt import uic
from qgis.PyQt.QtWidgets import QDialog, QAbstractItemView, QTableWidgetItem
from qgis.core import QgsProject
from qgis.core import QgsVectorLayer

from polaris.utils.database.db_utils import commit_and_close

FORM_CLASS, _ = uic.loadUiType(join(dirname(__file__), "forms/diagnostics.ui"))


class TrafficDiagnostics(QDialog, FORM_CLASS):
    def __init__(self, _PQgis):
        QDialog.__init__(self)
        self.setupUi(self)
        self.iface = _PQgis.iface
        self.feed = None
        self._PQgis = _PQgis
        self.links_list = []
        self.all_links = _PQgis.network.tables.get("link")

        fldr = join(dirname(__file__), "styles")
        self.link_layer = self._PQgis.layers["link"][0]
        self.link_layer.loadNamedStyle(join(fldr, "links.qml"), True)

        self.node_layer = self._PQgis.layers["node"][0]

        for layer in [self.link_layer, self.node_layer]:
            QgsProject.instance().addMapLayer(layer)

        # supply_file = _PQgis.polaris_project.loaded_files()['supply']

        self.tab_links.setColumnCount(2)
        self.tab_links.setHorizontalHeaderLabels(["link", "length"])
        self.tab_links.itemSelectionChanged.connect(self.zoom_to_link)

        self.tab_loops.setColumnCount(4)
        self.tab_loops.setHorizontalHeaderLabels(["link", "length", "detour links", "loop length"])
        self.tab_loops.itemSelectionChanged.connect(self.zoom_to_detour)

        for tab in [self.tab_links, self.tab_loops]:
            tab.setSelectionBehavior(QAbstractItemView.SelectRows)
            tab.setSelectionMode(QAbstractItemView.SingleSelection)

        self.but_filter_links.clicked.connect(self.find_links)
        self.but_delete_link.clicked.connect(self.collapse_link)

        self.but_find_loops.clicked.connect(self.find_detours)
        self.but_block_turns.clicked.connect(self.block_turn_movement)
        self.but_delete_link_loops.clicked.connect(self.delete_loop)

    def find_detours(self):
        self.tab_loops.setRowCount(0)
        diag = self._PQgis.network.diagnostics

        detours = diag.short_detours(self.spn_loop_length.value(), "links")

        dt = []
        for lid, tours in detours.items():
            dt.extend([lid, self.all_links["length"][lid], link_list, dist] for dist, link_list in tours.items())

        df = pd.DataFrame(dt, columns=["link", "dist", "alternate_links", "loop_length"])
        df.sort_values(by=["loop_length"], inplace=True)
        df.reset_index(drop=True, inplace=True)
        self.tab_loops.setRowCount(df.shape[0])

        for idx, rec in df.iterrows():
            self.tab_loops.setItem(idx, 0, QTableWidgetItem(str(rec.link)))
            self.tab_loops.setItem(idx, 1, QTableWidgetItem(str(rec.dist)))
            self.tab_loops.setItem(idx, 2, QTableWidgetItem(str(rec.alternate_links)))
            self.tab_loops.setItem(idx, 3, QTableWidgetItem(str(round(rec.loop_length, 1))))

    def zoom_to_detour(self):
        lks = self.tab_loops.selectedItems()
        if not lks:
            return
        links = tuple(eval(lks[2].text()))

        self.but_block_turns.setEnabled(len(links) == 2)

        layer = self.link_layer
        layer.selectByExpression(f'"link" in {links}', QgsVectorLayer.SetSelection)
        self.zoom_to_link_selection()

    def block_turn_movement(self):
        extent = self.iface.mapCanvas().extent()
        lks = self.tab_loops.selectedItems()
        if not lks:
            return

        links = list(eval(lks[2].text()))

        with commit_and_close(self._PQgis.supply_path, spatial=True) as conn:
            data = conn.execute("select dir, to_dir, node from Connection where link=? and to_link=?", links).fetchone()
            data = [links[0], data[0], links[1], data[1], data[2], -1]
            sql_override = "insert into Turn_Overrides(link, dir, to_link, to_dir, node, penalty) VALUES(?,?,?,?,?,?)"
            conn.execute(sql_override, data)
            conn.execute("delete from Connection where link=? and to_link=?", links)
            self.tab_loops.removeRow(self.tab_loops.currentRow())
        self.iface.mapCanvas().setExtent(extent)

    def delete_loop(self):
        extent = self.iface.mapCanvas().extent()
        lks = self.tab_loops.selectedItems()
        if not lks:
            return
        links = list(eval(lks[2].text()))
        links.append(eval(lks[0].text()))
        for lnk in links:
            self.delete_link(lnk)
        self.iface.mapCanvas().setExtent(extent)

    def find_links(self):
        self.tab_links.setRowCount(0)
        diag = self._PQgis.network.diagnostics

        links = diag.short_links(self.spn_link_size.value(), "links")
        count = 0
        for dist in sorted(links.keys()):
            for x in links[dist]:
                count += 1
                self.tab_links.setRowCount(count)
                self.tab_links.setItem(count - 1, 0, QTableWidgetItem(str(x)))
                self.tab_links.setItem(count - 1, 1, QTableWidgetItem(str(round(dist, 3))))

    def zoom_to_link(self):
        link_id = int(self.tab_links.selectedItems()[0].text())

        layer = self.link_layer
        layer.selectByExpression(f'"link"={link_id}', QgsVectorLayer.SetSelection)

        self.zoom_to_link_selection()

    def collapse_link(self):
        extent = self.iface.mapCanvas().extent()
        link_id = self.tab_links.selectedItems()

        if not link_id:
            return
        link_id = int(link_id[0].text())

        self.tab_links.removeRow(self.tab_links.currentRow())
        self.delete_link(link_id)
        self.iface.mapCanvas().setExtent(extent)

    def delete_link(self, link_id):
        # sql_to = 'Update link set geo=AddPoint(geo, (Select geo from node where node=?)) where node_b=?'
        # sql_from = 'Update link set geo=AddPoint(geo, (Select geo from node where node=?), 0) where node_a=?'
        # data_nodes = [self.all_links.node_a[link_id], self.all_links.node_b[link_id]]
        # conn.execute(sql_to, data_nodes)
        # conn.execute(sql_from, data_nodes)
        # conn.commit()
        with commit_and_close(self._PQgis.supply_path, spatial=True) as conn:
            data_nodes = conn.execute("select node_b, node_a from Link where link=?", [link_id]).fetchone()
            sql_merge = "update node set geo=(Select geo from node where node=?) where node=?"
            data_links = [link_id]
            sql_delete = "delete from link where link=?"
            conn.execute(sql_delete, data_links)
            conn.commit()
            conn.execute(sql_merge, data_nodes)

    def zoom_to_link_selection(self):
        layer = self.link_layer
        idx = layer.dataProvider().fieldNameIndex("length")
        feat = list(layer.getSelectedFeatures())[0]
        lgt = 2 * feat.attributes()[idx]

        bb = layer.boundingBoxOfSelected()
        bb.setYMaximum(bb.yMaximum() + lgt)
        bb.setXMaximum(bb.xMaximum() + lgt)
        bb.setYMinimum(bb.yMinimum() - lgt)
        bb.setXMinimum(bb.xMinimum() - lgt)
        self.iface.mapCanvas().setExtent(bb)
