# Copyright (c) 2026, UChicago Argonne, LLC
# BSD OPEN SOURCE LICENSE. Full license can be found in LICENSE
# Copyright (c) 2026, UChicago Argonne, LLC
# BSD OPEN SOURCE LICENSE. Full license can be found in LICENSE
from functools import partial
from pathlib import Path
from polaris.project.polaris import Polaris
from polaris.runs.run_utils import get_output_dirs
from qgis.PyQt.QtWidgets import QDialog

from QPolaris.modules.common_tools import GetOutputFolderName, GetOutputFileName, list_iterations
from .load_supply import set_supply_file, create_layers


def open_project(qgis_project):
    if qgis_project.open_mode:
        qgis_project.show_message_close_first()
        return

    tot_path = get_project_path(qgis_project)
    if tot_path:
        load_project_from_path(qgis_project, tot_path)


def open_comparison_project(qgis_project):
    if not qgis_project.open_mode or qgis_project.open_mode != "project":
        qgis_project.message_no_project()
        return

    tot_path = get_project_path(qgis_project)
    qgis_project.alternative_project = Polaris.from_dir(tot_path.parent, tot_path.name)


def get_project_path(qgis_project):

    path = GetOutputFolderName(str(Path(qgis_project.path).parent), "Polaris model folder")

    if not path:
        return

    tot_path = Path(path) / "polaris.yaml"
    if not tot_path.exists():
        formats = ["Convergence control(*.yaml)", "Convergence control(*.yml)"]
        file_path, dtype = GetOutputFileName(
            QDialog(),
            "Please choose the convergence control file",
            formats,
            ".yaml",
            qgis_project.path,
        )

        if not dtype:
            return

        tot_path = Path(file_path)
    return tot_path


def load_project_from_path(qgis_project, tot_path):

    # Cleans the project descriptor
    qgis_project.polaris_project = Polaris().from_dir(tot_path.parent, tot_path.name)

    outdirs = list_iterations(qgis_project)

    qgis_project.cob_iter.addItems(outdirs)

    qgis_project.open_mode = "project"
    layers = set_supply_file(qgis_project)
    create_layers(qgis_project, layers)

    qgis_project.cob_iter.currentIndexChanged.connect(partial(change_iteration, qgis_project))


def change_iteration(proj):
    if proj.cob_iter.currentText() == "root":
        proj.polaris_project.run_config.data_dir = proj.polaris_project.model_path
    else:
        proj.polaris_project.run_config.data_dir = proj.polaris_project.model_path / proj.cob_iter.currentText()
