# Copyright (c) 2026, UChicago Argonne, LLC
# BSD OPEN SOURCE LICENSE. Full license can be found in LICENSE
# Copyright (c) 2026, UChicago Argonne, LLC
# BSD OPEN SOURCE LICENSE. Full license can be found in LICENSE
# Based on AequilibraE's code for bandwidth map creation
import qgis
import sys
from qgis.PyQt.QtCore import Qt
from copy import deepcopy
from math import ceil
from os.path import dirname, join
from pathlib import Path
from qgis.PyQt import uic
from qgis.PyQt.QtGui import QColor
from qgis.PyQt.QtWidgets import QDialog
from qgis.core import QgsProject, QgsStyle
from random import randint
from typing import Optional

from QPolaris.modules.common_tools import list_iterations
from QPolaris.modules.menu_actions.open_project import open_comparison_project
from .map_slider import MapSlider
from .traffic_results import TrafficResults, base_interval_results as base_interval
from ..common_tools import running_on_ci

sys.modules["qgsfieldcombobox"] = qgis.gui
sys.modules["qgscolorbutton"] = qgis.gui
sys.modules["qgsmaplayercombobox"] = qgis.gui
FORM_CLASS, _ = uic.loadUiType(join(dirname(__file__), "forms/ui_bandwidths.ui"))

start_thickness_field = "ab_link_out_volume"
start_color_field = "ab_link_travel_delay"


class LinkMapCreationDialog(QDialog, FORM_CLASS):
    def __init__(self, qgis_project):
        QDialog.__init__(self)
        self.slider: MapSlider = None

        self._PQgis = qgis_project
        self._p = qgis_project.network

        self.iface = qgis_project.iface
        self.setupUi(self)

        self.tot_bands = 0
        self.htodo = 6
        self.scale = {"width": 10, "spacing": 0, "max_flow": 400}

        self.band_size = 10.0
        self.space_size = 0.01
        # self.drive_side = "right"
        self.link_layer = self._PQgis.layers.get("link", [None, None])[0]
        QgsProject.instance().addMapLayer(self.link_layer)

        self.base_path: Optional[Path] = None
        self.alt_path: Optional[Path] = None
        self.color_results: Optional[TrafficResults] = None
        self.results: Optional[TrafficResults] = None

        self.cob_analysis_mode.addItems(["Iteration analysis", "Compare iterations", "Compare across model runs"])
        self.cob_analysis_mode.currentIndexChanged.connect(self.set_analysis_mode)

        self.set_analysis_mode()
        self.load_data()

        self.but_open_other_model.clicked.connect(self.load_comparison_project)
        config = deepcopy(qgis_project.polaris_project.run_config)
        config.data_dir = qgis_project.polaris_project.model_path
        iterations = list_iterations(qgis_project)
        self.cob_base.addItems(iterations)
        self.cob_alternative.addItems(iterations)
        self.cob_base_intermodel.addItems(iterations)

        self.interval_bar.valueChanged.connect(self.change_range)

        self.chb_ramp.toggled.connect(self.deal_with_colors)
        self.chb_thickness.toggled.connect(self.set_thickness)
        self.grouping_slider.valueChanged.connect(self.change_grouping)

        self.but_run.clicked.connect(self.add_bands_to_map)

        self.random_rgb()
        myStyle = QgsStyle().defaultStyle()
        color_ramps = myStyle.colorRampNames()
        for i in color_ramps:
            self.cob_colorRamp.addItem(i)
        self.deal_with_colors()
        self.setFixedHeight(418)
        self.change_grouping()
        self.add_fields_to_cboxes()

    def load_comparison_project(self):
        open_comparison_project(self._PQgis)
        self.populate_comparison()

    def populate_comparison(self):
        self.cob_alternative_intermodel.clear()
        outdirs = list_iterations(self._PQgis)
        self.cob_alternative_intermodel.addItems(outdirs)

    def change_grouping(self):
        self.grouping_slider.setValue(int(base_interval * ceil(self.grouping_slider.value() / base_interval)))
        self.to_group.setText(f"{self.grouping_slider.value()} min")
        self.htodo = min(24, 3 * self.grouping_slider.value())
        self.interval_bar.setEnabled(self.htodo < 24)
        self.interval_bar.setPageStep(self.htodo * 2 if self.htodo < 24 else 1000)
        self.change_range()
        self.add_fields_to_cboxes()

    def deal_with_colors(self):
        self.box_color_ramp.setEnabled(self.chb_ramp.isChecked())
        self.box_color_ramp.setEnabled(self.chb_ramp.isChecked())

        show_color = not self.chb_ramp.isChecked() and self.cob_analysis_mode.currentText() == "Iteration analysis"
        self.lab_color.setVisible(show_color)
        self.lab_color.setEnabled(show_color)

        self.mColorButton.setVisible(show_color)
        self.mColorButton.setEnabled(show_color)

    def set_analysis_mode(self):
        mode_val = self.cob_analysis_mode.currentText()

        self.comparebox.setVisible(mode_val == "Compare iterations")
        self.intermodel_box.setVisible(mode_val == "Compare across model runs")

        self.box_color_ramp.setVisible(mode_val == "Iteration analysis")
        self.chb_ramp.setVisible(mode_val == "Iteration analysis")
        self.chb_thickness.setVisible(mode_val == "Iteration analysis")
        self.chb_ramp.setChecked(mode_val == "Iteration analysis")
        if mode_val == "Iteration analysis":
            self.deal_with_colors()

    def set_thickness(self):
        self.box_thickness.setEnabled(self.chb_thickness.isChecked())
        if not self.chb_thickness.isChecked():
            self.chb_ramp.setChecked(True)

    def add_fields_to_cboxes(self):
        tables_available = self.results.tables_for_aggregation(self.grouping_slider.value())
        color_metric = None if self.FieldColor.currentText() == "" else self.FieldColor.currentText()
        self.FieldColor.clear()
        self.FieldColor.addItems(tables_available)

        if color_metric is not None and color_metric in tables_available:
            self.FieldColor.setCurrentText(color_metric)

        thickness_metric = None if self.FieldThickness.currentText() == "" else self.FieldThickness.currentText()
        self.FieldThickness.clear()
        self.FieldThickness.addItems(tables_available)

        if thickness_metric is not None and thickness_metric in tables_available:
            self.FieldThickness.setCurrentText(thickness_metric)

    def load_data(self):
        base_root = Path(self._p.path_to_file).parent
        h5_file = str(self._PQgis.result_h5_path.name)

        def deals_with_root(iter_name, proj_root, res_file):
            if iter_name == "root":
                return proj_root / res_file
            else:
                return proj_root / iter_name / res_file

        self.base_path = self._PQgis.result_h5_path
        if self.cob_analysis_mode.currentText() == "Iteration analysis":
            pth = self._PQgis.result_path
            base_name = pth.stem
            self.alt_path = None
            alt_name = None
        elif self.cob_analysis_mode.currentText() == "Compare iterations":
            self.base_path = deals_with_root(self.cob_base.currentText(), base_root, h5_file)
            base_name = self.cob_base.currentText()
            self.alt_path = deals_with_root(self.cob_alternative.currentText(), base_root, h5_file)
            alt_name = self.cob_alternative.currentText()
        elif self.cob_analysis_mode.currentText() == "Compare across model runs":
            self.base_path = deals_with_root(self.cob_base_intermodel.currentText(), base_root, h5_file)
            base_name = self.txt_name_base_comparison.text()

            alt_root = Path(self._PQgis.alternative_project.model_path)
            self.alt_path = deals_with_root(self.cob_alternative.currentText(), alt_root, h5_file)
            alt_name = self.txt_name_alt_comparison.text()
        else:
            raise ValueError("Analysis mode does not exist")

        self.results = TrafficResults(self.base_path, base_name, self.alt_path, alt_name, self.htodo)
        self.color_results = TrafficResults(self.base_path, base_name, self.alt_path, alt_name, self.htodo)

    def add_bands_to_map(self):
        do_thickness = self.chb_thickness.isChecked()
        do_color = self.chb_ramp.isChecked()

        if do_thickness + do_color == 0:
            self.iface.messageBar().pushMessage("Error", "Doing nothing? Try again", level=3, duration=5)
            return

        if self.cob_analysis_mode.currentText() == "Compare iterations":
            base = self.cob_base.currentText()
            alt = self.cob_alternative.currentText()

            if base == alt:
                msg = "Comparing something to itself is silly. Try again"
                self.iface.messageBar().pushMessage("Error", msg, level=3, duration=5)
                return
        elif self.cob_analysis_mode.currentText() == "Compare across model runs":
            if not self._PQgis.alternative_project.is_open:
                msg = "You need another model to compare. Open it first"
                self.iface.messageBar().pushMessage("Error", msg, level=3, duration=5)
                return
        self.but_run.setEnabled(False)

        self.load_data()
        if do_thickness:
            self.results.build_metric_layer(
                self.FieldThickness.currentText(), self.interval_bar.value(), self.grouping_slider.value()
            )

        if do_color:
            self.color_results.build_metric_layer(
                self.FieldColor.currentText(), self.interval_bar.value(), self.grouping_slider.value()
            )
        self.exit_procedure()

    def change_range(self):
        end_interval = min(24, self.interval_bar.value() + self.htodo)
        start_interval = end_interval - self.htodo

        if start_interval != self.interval_bar.value():
            self.interval_bar.setValue(start_interval)

        self.from_time.setText(f"{start_interval}:00")
        self.to_time.setText(f"{end_interval}:00")

    def random_rgb(self):
        rgb = [randint(0, 255) for _ in range(3)]
        a = QColor()
        a.setRgb(rgb[0], rgb[1], rgb[2])
        self.mColorButton.setColor(a)

    def exit_procedure(self):
        from_interval = int(self.interval_bar.value() * 3600 / self.results.step)
        comparison = self.cob_analysis_mode.currentText() != "Iteration analysis"
        self.slider = MapSlider(
            self._PQgis,
            self.results,
            self.color_results,
            from_interval,
            self.cob_colorRamp.currentText(),
            self.chb_invert.isChecked(),
            self.chb_thickness.isChecked(),
            self.chb_ramp.isChecked(),
            self.mColorButton.color(),
            comparison,
        )

        self.slider.setWindowFlags(Qt.WindowStaysOnTopHint)
        self.close()
        if running_on_ci():
            return
        self.slider.show()
        self.slider.exec_()
