# Copyright (c) 2025, UChicago Argonne, LLC
# BSD OPEN SOURCE LICENSE. Full license can be found in LICENSE
# Copyright (c) 2025, UChicago Argonne, LLC
# BSD OPEN SOURCE LICENSE. Full license can be found in LICENSE
import os
from polaris.network.network import Network
from polaris.utils.database.db_utils import has_table, count_table
from polaris.utils.database.db_utils import read_and_close
from qgis.PyQt.QtCore import Qt
from qgis.PyQt.QtWidgets import (
    QDialog,
    QTableWidgetItem,
    QWidget,
    QVBoxLayout,
    QTableWidget,
)
from tempfile import gettempdir

from ..common_tools import GetOutputFileName


def load_supply_file(qgis_project):
    if qgis_project.open_mode:
        qgis_project.show_message_close_first()
        return

    qgis_project.layers.clear()
    formats = ["Polaris Network(*.sqlite)"]
    file_path, dtype = GetOutputFileName(
        QDialog(),
        "Polaris network",
        formats,
        ".sqlite",
        qgis_project.path,
    )

    if file_path:
        supply_setter(qgis_project, file_path)


def supply_setter(qgis_project, file_path):
    qgis_project.open_mode = "supply"
    qgis_project._network = Network().from_file(file_path, False)
    layers = set_supply_file(qgis_project)
    create_layers(qgis_project, layers)


def set_supply_file(qgis_project):
    fp = str(qgis_project.supply_path)
    if fp != qgis_project.path:
        qgis_project.path = fp
        pth = os.path.join(gettempdir(), "polaris_last_folder.txt")
        with open(pth, "w") as file:
            file.write(fp)

    with qgis_project.supply_conn as conn:
        if has_table(conn, "Geo_Consistency_Controller"):
            if count_table(conn, "Geo_Consistency_Controller") > 0:
                qgis_project.show_message_geoconsistency_issues()
        else:
            qgis_project.show_error_message(
                "Your project needs to be upgraded. It does not have the Geo_Consistency_Controller table."
            )

    with read_and_close(qgis_project.supply_path, spatial=True) as conn:
        sql = "select f_table_name from geometry_columns;"
        layers = [x[0] for x in conn.execute(sql).fetchall()]

    try:
        qgis_project.contents = []
        descrlayout = QVBoxLayout()
        qgis_project.geo_layers_table = QTableWidget()
        qgis_project.geo_layers_table.doubleClicked.connect(qgis_project.load_geo_layer)
        qgis_project.geo_layers_table.setRowCount(len(layers))
        qgis_project.geo_layers_table.setColumnCount(1)
        qgis_project.geo_layers_table.horizontalHeader().hide()
        for i, f in enumerate(layers):
            item1 = QTableWidgetItem(f)
            item1.setFlags(Qt.ItemIsEnabled | Qt.ItemIsSelectable)
            qgis_project.geo_layers_table.setItem(i, 0, item1)
        descrlayout.addWidget(qgis_project.geo_layers_table)
        descr = QWidget()
        descr.setLayout(descrlayout)
        qgis_project.tabContents = [(descr, "Geo layers")]

        for i in range(qgis_project.projectManager.count()):
            qgis_project.projectManager.removeTab(i)
    except RuntimeError:
        pass
    qgis_project.projectManager.addTab(descr, "Geo layers")
    return layers


def create_layers(qgis_project, layers):
    # This part needs to stay separate because it cannot be tested automatically (it raises a segfault on Docker)
    # Creates all layers and puts them in memory
    qgis_project.layers.clear()
    for lyr in layers:
        qgis_project.create_layer_by_name(lyr)
