# Copyright (c) 2025, UChicago Argonne, LLC
# BSD OPEN SOURCE LICENSE. Full license can be found in LICENSE
# Copyright (c) 2025, UChicago Argonne, LLC
# BSD OPEN SOURCE LICENSE. Full license can be found in LICENSE
import logging
import pandas as pd
import qgis
from datetime import timedelta
from math import floor
from os.path import dirname, join
from polaris.utils.database.db_utils import read_and_close
from qgis.PyQt import QtCore, QtWidgets, uic
from qgis.core import QgsSpatialIndex, QgsProject, QgsVectorLayer
from qgis.utils import iface
from typing import Optional

from ..common_tools.point_tool import PointTool
from ..common_tools.path_drawer import PathDrawer

FORM_CLASS, _ = uic.loadUiType(join(dirname(__file__), "forms/ui_compute_path.ui"))

logger = logging.getLogger("polaris")


class ShortestPathDialog(QtWidgets.QDialog, FORM_CLASS):
    clickTool = PointTool(iface.mapCanvas())

    def __init__(self, qgis_project) -> None:
        QtWidgets.QDialog.__init__(self)
        self.iface = qgis_project.iface
        self._PQgis = qgis_project
        self.project = qgis_project.polaris_project
        self.setupUi(self)
        self.field_types = {}
        self.link_layer = self._PQgis.layers["link"][0]
        self.__time_nav_layer: Optional[QgsVectorLayer] = None
        with read_and_close(qgis_project.polaris_project.supply_file) as conn:
            df = pd.read_sql("Select link, lanes_ab, lanes_ba from Link", conn)
            self.link_directions = df.set_index(["link"])

        self.path_drawer = PathDrawer({"Links": ["link", self.link_layer]})
        self.link_index = QgsSpatialIndex(self.link_layer.getFeatures())

        self.location_layer = self._PQgis.layers["location"][0]
        self.loc_index = QgsSpatialIndex(self.location_layer.getFeatures())
        self.rad_locations.toggled.connect(self.change_type_of_path)
        self.chb_time_navigator.toggled.connect(self.set_force_new_layer)
        self.chb_speed_render.toggled.connect(self.set_force_new_layer)
        self.rad_links.toggled.connect(self.change_type_of_path)

        self.from_directions = []
        self.path_from.textChanged.connect(self.change_from_directions)
        self.path = None
        self.to_directions = []
        self.path_to.textChanged.connect(self.change_to_directions)
        self.change_from_directions()

        self.instant = 28800
        self.time_of_day.valueChanged.connect(self.change_time_of_day)
        self.change_time_of_day()

        self.from_but.clicked.connect(self.search_for_point_from)
        self.to_but.clicked.connect(self.search_for_point_to)
        self.do_path.clicked.connect(self.produces_path)

        self.mode_list.setVisible(False)
        self.but_clear.setVisible(False)
        self.__load_router()
        self.change_type_of_path()
        self.set_force_new_layer()

    def set_force_new_layer(self):
        can_do_selection = not self.chb_time_navigator.isChecked() and not self.chb_speed_render.isChecked()
        if not can_do_selection:
            self.rdo_new_layer.setChecked(True)
        self.grp_output_type.setVisible(can_do_selection)
        self.setFixedHeight(390 - int(can_do_selection) * 20 - int(not self.chb_speed_render.isChecked()) * 18)
        self.frm_render_colors.setVisible(self.chb_speed_render.isChecked())

    def change_type_of_path(self):
        text = "Link ID" if self.rad_links.isChecked() else "Location ID"
        self.lbl_id.setText(text)
        self.from_dir.setVisible(self.rad_links.isChecked())
        self.lbl_dir.setVisible(self.rad_links.isChecked())
        self.to_dir.setVisible(self.rad_links.isChecked())
        self.path_from.setText("")
        self.path_to.setText("")
        self.from_directions.clear()
        self.change_from_directions()
        self.to_directions.clear()
        self.change_to_directions()
        self.do_path.setEnabled(self.rad_locations.isChecked())

    def get_directions(self, link):
        if len(link) == 0:
            return []

        lnk = self.link_directions[self.link_directions.index.isin([int(link)])]
        if lnk.shape[0] < 1:
            return []

        directions = [0] if lnk.lanes_ab.values[0] > 0 else []
        if lnk.lanes_ba.values[0] > 0:
            directions.append(1)
        return directions

    def change_from_directions(self):
        self.from_directions = self.get_directions(self.path_from.text())
        self.change_directions(self.from_directions, self.from_dir)

    def change_to_directions(self):
        self.to_directions = self.get_directions(self.path_to.text())
        self.change_directions(self.to_directions, self.to_dir)

    def change_directions(self, direc, slider):
        if self.rad_locations.isChecked():
            return
        self.do_path.setEnabled(len(self.from_directions) > 0 and len(self.to_directions) > 0)

        slider.setEnabled(len(direc) > 0)
        if len(direc) == 0:
            return
        slider.setMaximum(max(direc))
        slider.setMinimum(min(direc))

    def __load_router(self):
        self.router = self.project.router
        self.setWindowTitle("Shortest path")

    def change_time_of_day(self):
        self.instant = self.time_of_day.value()
        m = floor((self.instant % 3600) / 60)
        h = floor(self.instant / 3600)
        self.grp_time.setTitle(f"Time of day: {h}:{str(m).zfill(2)}")
        if self.chb_time_navigator.isChecked():
            self.__navigate_time()

    def __navigate_time(self):
        if self.__time_nav_layer is None:
            return
        color = self.__time_nav_layer.renderer().symbol().color()
        QgsProject.instance().removeMapLayer(self.__time_nav_layer)
        self.produces_path(color=color)

    def search_for_point_from(self):
        self.clickTool.clicked.connect(self.fill_path_from)
        self.iface.mapCanvas().setMapTool(self.clickTool)

    def search_for_point_to(self):
        self.iface.mapCanvas().setMapTool(self.clickTool)
        self.clickTool.clicked.connect(self.fill_path_to)

    def fill_path_to(self):
        item = self.find_point()
        if item < 0:
            return
        self.path_to.setText(str(item))

    @QtCore.pyqtSlot()
    def fill_path_from(self):
        item = self.find_point()
        if item < 0:
            return
        self.path_from.setText(str(item))
        self.iface.mapCanvas().setMapTool(self.clickTool)
        self.search_for_point_to()

    def find_point(self):
        try:
            point = self.clickTool.point
            self.clickTool = PointTool(self.iface.mapCanvas())
            if self.rad_links.isChecked():
                nearest = self.link_index.nearestNeighbor(point, 10)
                for i in nearest:
                    # We loop through links just in case we have the closest ones being non-driving links
                    self.iface.mapCanvas().setMapTool(None)
                    link = self.link_layer.getFeature(i)
                    link_id = link["link"]
                    directions = [0] if link["lanes_ab"] > 0 else []
                    if link["lanes_ba"] > 0:
                        directions.append(1)
                    if directions:
                        return link_id
                return -1, []
            else:
                nearest = self.loc_index.nearestNeighbor(point, 1)[0]
                location = self.location_layer.getFeature(nearest)
                return location["location"]
        except Exception as e:
            logger.error(e.args)
            return -1

    def produces_path(self, color=False):
        self.path = None
        frm = int(self.path_from.text())
        to = int(self.path_to.text())
        inst = str(timedelta(seconds=self.instant))
        if self.rad_links.isChecked():
            self.path = self.router.route_links(
                link_origin=frm,
                link_destination=to,
                origin_dir=self.from_dir.value(),
                destination_dir=self.to_dir.value(),
                departure_time=self.instant,
            )
        else:
            self.path = self.router.route(frm, to, departure_time=self.instant)

        if self.path is None:
            msg = self.tr(f"No path between {frm} and {to} at {inst}")
            qgis.utils.iface.messageBar().pushMessage(msg, "", level=3)
            return

        layer_name = f"{frm} to {to} at {inst} (Travel time: {int(self.path.travel_time):,})"
        if self.rdo_selection.isChecked():
            self.path_drawer.create_path_with_selection(self.path.links)
        else:
            prev = 0
            ttimes = {}
            for link, ttime in zip(self.path.links, self.path.cumulative_time):
                ttimes[link] = ttime - prev
                prev = ttime
            self.__time_nav_layer = self.path_drawer.create_traffic_path_with_scratch_layer(
                path=self.path.links,
                layer_name=layer_name,
                color=color,
                ttimes=ttimes,
                render_speeds=self.chb_speed_render.isChecked(),
            )
        self.lbl_last_path.setVisible(True)
        self.lbl_last_path.setText(layer_name)

    def exit_procedure(self):
        self.close()
