# Copyright (c) 2025, UChicago Argonne, LLC
# BSD OPEN SOURCE LICENSE. Full license can be found in LICENSE
# Copyright (c) 2025, UChicago Argonne, LLC
# BSD OPEN SOURCE LICENSE. Full license can be found in LICENSE
import logging
import pandas as pd
import qgis
from datetime import timedelta
from math import floor
from os.path import dirname, join
from polaris.utils.database.db_utils import read_and_close
from polaris.runs.scenario_compression import ScenarioCompression
from qgis.PyQt.QtGui import QColor
from qgis.PyQt import QtCore, QtWidgets, uic
from qgis.PyQt.QtWidgets import QAbstractItemView
from qgis.core import (
    QgsProject,
    QgsVectorLayer,
    QgsSymbol,
    QgsRuleBasedRenderer,
    QgsSingleSymbolRenderer,
    QgsMarkerSymbol,
    QgsSpatialIndex,
)
from qgis.utils import iface
from typing import Optional

from ..common_tools import PandasModel
from ..common_tools.path_drawer import PathDrawer
from ..common_tools.point_tool import PointTool

FORM_CLASS, _ = uic.loadUiType(join(dirname(__file__), "../traffic/forms/ui_compute_path.ui"))

logger = logging.getLogger("polaris")


class MultimodalPathDialog(QtWidgets.QDialog, FORM_CLASS):
    clickTool = PointTool(iface.mapCanvas())

    def __init__(self, qgis_project) -> None:
        QtWidgets.QDialog.__init__(self)
        self.iface = qgis_project.iface
        self._PQgis = qgis_project
        self.project = qgis_project.polaris_project
        self.has_path = False
        self.setupUi(self)

        hidden_items = [
            self.grp_target,
            self.chb_time_navigator,
            self.chb_speed_render,
            self.grp_output_type,
            self.lbl_dir,
            self.from_dir,
            self.to_dir,
            self.frm_render_colors,
            self.lbl_last_path,
        ]
        for item in hidden_items:
            item.setVisible(False)
        self.setFixedHeight(205)

        valid_modes = [4, 5, 7, 8, 11, 13, 15, 29]
        with read_and_close(ScenarioCompression.maybe_extract(qgis_project.demand_path)) as conn:
            self.mode_list.setSelectionBehavior(QAbstractItemView.SelectRows)
            self.modes_avail = pd.read_sql("SELECT * from Mode where mode_id< 100", conn)
            self.modes_avail = self.modes_avail[self.modes_avail.mode_id.isin(valid_modes)].reset_index(drop=True)
            self.modes_model = PandasModel(self.modes_avail[["mode_id", "mode_description"]])
            self.mode_list.setModel(self.modes_model)
            self.mode_list.selectAll()

        self.__time_nav_layers = {}
        self.__time_nav_layer: Optional[QgsVectorLayer] = None

        layers = {
            "Links": ["link", self._PQgis.layers["link"][0]],
            "Walk": ["walk_link", self._PQgis.layers["transit_walk"][0]],
            "Transit": ["transit_link", self._PQgis.layers["transit_links"][0]],
            "Bike": ["bike_link", self._PQgis.layers["transit_bike"][0]],
        }
        self.path_drawer = PathDrawer(layers)
        self.location_layer = self._PQgis.layers["location"][0]
        self.loc_index = QgsSpatialIndex(self.location_layer.getFeatures())

        self.path = None
        self.instant = 28800
        self.time_of_day.valueChanged.connect(self.change_time_of_day)

        self.from_but.clicked.connect(self.search_for_point_from)
        self.to_but.clicked.connect(self.search_for_point_to)
        self.do_path.clicked.connect(self.produces_path)
        self.but_clear.clicked.connect(self.clear_all_paths)
        self.do_path.setEnabled(True)

        self.__load_router()

    def __load_router(self):
        self.router = self.project.router
        self.setWindowTitle("Shortest path")

    def search_for_point_from(self):
        self.clickTool.clicked.connect(self.fill_path_from)
        self.iface.mapCanvas().setMapTool(self.clickTool)

    def search_for_point_to(self):
        self.iface.mapCanvas().setMapTool(self.clickTool)
        self.clickTool.clicked.connect(self.fill_path_to)

    def fill_path_to(self):
        item = self.find_point()
        if item < 0:
            return
        self.path_to.setText(str(item))

    @QtCore.pyqtSlot()
    def fill_path_from(self):
        item = self.find_point()
        if item < 0:
            return
        self.path_from.setText(str(item))
        self.iface.mapCanvas().setMapTool(self.clickTool)
        self.search_for_point_to()

    def clear_all_paths(self):
        for layer_list in self.__time_nav_layers.values():
            for layer in layer_list:
                QgsProject.instance().removeMapLayer(layer)
        self.__time_nav_layers.clear()

        symbol = QgsMarkerSymbol.createSimple({"name": "circle", "color": "red", "size": 1.4})
        self.location_layer.setRenderer(QgsSingleSymbolRenderer(symbol))
        self.location_layer.triggerRepaint()
        qgis.utils.iface.mapCanvas().refresh()

    def find_point(self):
        try:
            point = self.clickTool.point
            self.clickTool = PointTool(self.iface.mapCanvas())
            nearest = self.loc_index.nearestNeighbor(point, 1)[0]
            location = self.location_layer.getFeature(nearest)
            return location["location"]
        except Exception as e:
            logger.error(e.args)
            return -1

    def change_time_of_day(self):
        self.instant = self.time_of_day.value()
        m = floor((self.instant % 3600) / 60)
        h = floor(self.instant / 3600)
        self.grp_time.setTitle(f"Time of day: {h}:{str(m).zfill(2)}")
        self.__navigate_time()

    def __navigate_time(self):
        self.produces_path(navigating=True)

    def produces_path(self, navigating=False):
        self.path = None
        frm = int(self.path_from.text())
        to = int(self.path_to.text())
        inst = str(timedelta(seconds=self.instant))
        path_key = f"{frm}_{to}"
        if navigating:
            for layer in self.__time_nav_layers.get(path_key, []):
                QgsProject.instance().removeMapLayer(layer)

        modes = self.selected_modes()
        self.__time_nav_layers[path_key] = []
        for _, mode_rec in modes.iterrows():
            mode = mode_rec.mode_id
            mode_name = mode_rec.mode_description
            self.path = self.router.multimodal(frm, to, departure_time=self.instant, mode=mode)
            if not self.path.links.shape[0]:
                continue
            layer_name = f"{frm} to {to} at {inst}  by {mode_name} (Travel time: {int(self.path.travel_time):,})"

            prev = 0
            ttimes = {}
            for link, ttime in zip(self.path.links, self.path.cumulative_time):
                ttimes[link] = ttime - prev
                prev = ttime
            layer = self.path_drawer.create_multimodal_path(
                path=self.path.links,
                layer_name=layer_name,
                ttimes=ttimes,
                mode_number=int(mode),
            )
            self.iface.layerTreeView().refreshLayerSymbology(layer.id())

            self.__time_nav_layers[path_key].append(layer)
        self.__format_orig_dest(frm, to)

    def selected_modes(self):
        idx = [x.row() for x in list(self.mode_list.selectionModel().selectedRows())]

        return self.modes_avail.loc[idx, :]

    def __format_orig_dest(self, frm, to):
        ref_symbol = QgsSymbol.defaultSymbol(self.location_layer.geometryType())
        renderer = QgsRuleBasedRenderer(ref_symbol)

        color_rules = [["Origin", f""""location"={frm}""", "red"], ["Destination", f""""location"={to}""", "black"]]

        def rule_based_symbology(renderer, ruleset):
            root_rule = renderer.rootRule()
            rule = root_rule.children()[0].clone()
            rule.setLabel(ruleset[0])
            rule.setFilterExpression(ruleset[1])
            rule.symbol().setColor(QColor(ruleset[2]))
            rule.symbol().setSize(6)

            root_rule.appendChild(rule)

        for ruleset in color_rules:
            rule_based_symbology(renderer, ruleset)

        self.location_layer.setRenderer(renderer)
        self.location_layer.triggerRepaint()
        renderer.rootRule().removeChildAt(0)

    def exit_procedure(self):
        self.close()
