# Copyright (c) 2026, UChicago Argonne, LLC
# BSD OPEN SOURCE LICENSE. Full license can be found in LICENSE
# Copyright (c) 2026, UChicago Argonne, LLC
# BSD OPEN SOURCE LICENSE. Full license can be found in LICENSE
from typing import Optional

import qgis
from qgis.core import (
    QgsFeature,
    QgsField,
    QgsMarkerLineSymbolLayer,
    QgsMarkerSymbol,
    QgsProject,
    QgsRuleBasedRenderer,
    QgsSimpleLineSymbolLayer,
    QgsSymbol,
    QgsVectorLayer,
)
from qgis.PyQt.QtCore import QVariant


class PathDrawer:
    def __init__(self, line_layers) -> None:
        self.layers = {name.lower(): layer for name, layer in line_layers.items()}

        self.line_features = {}
        self.modes = {}
        self.__colors = {
            "walk": "green",
            "bike": "blue",
            "transit": "yellow",
            "links": "red",
        }

        self.__shapes = {
            4: "triangle",
            5: "square",
            7: "circle",
            8: "cross_fill",
            11: "asterisk_fill",
            13: "arrowhead",
            15: "star_diamond",
            29: "filled_arrowhead",
        }

    def load_features(self):
        if len(self.line_features):
            return
        for mode, [index_col, lyr] in self.layers.items():
            for feature in lyr.getFeatures():
                self.line_features[feature[index_col]] = feature
                self.modes[feature[index_col]] = mode

    def create_path_with_selection(self, path: list, layer_name="links"):
        f = "link"
        t = " or ".join([f"{f}={int(k)}" for k in path])
        self.layers[layer_name][1].selectByExpression(t)

    def create_traffic_path_with_scratch_layer(
        self, layer_name, path, color=None, ttimes: Optional[dict] = None, render_speeds=False
    ):
        self.load_features()

        crs = self.layers["links"][1].dataProvider().crs().authid()
        vl = QgsVectorLayer("LineString?crs={}".format(crs), layer_name, "memory")
        pr = vl.dataProvider()

        # add fields
        pr.addAttributes(self.layers["links"][1].dataProvider().fields())
        vl.updateFields()  # tell the vector layer to fetch changes from the provider

        # add a feature
        all_links = [self.line_features[k] for k in path]

        # add all links to the temp layer
        pr.addFeatures(all_links)

        # add layer to the map
        QgsProject.instance().addMapLayer(vl)

        symbol = vl.renderer().symbol()
        symbol.setWidth(2)

        # We add the travel time speeds even if we don't want to render it
        if ttimes:
            self.__add_travel_times(pr, ttimes, vl)

        if color:
            symbol.setColor(color)
        if render_speeds:
            props = symbol.symbolLayers()[0].properties()
            r = """coalesce(ramp_color('Spectral',scale_linear( "pth_speed" / coalesce(coalesce( "fspd_ab" , "fspd_ba"), "pth_speed" ), 0, 1.2, 0, 1)), '#000000')"""
            props["color_dd_expression"] = r
            symbol.appendSymbolLayer(QgsSimpleLineSymbolLayer.create(props))
            symbol.deleteSymbolLayer(0)

        qgis.utils.iface.mapCanvas().refresh()
        return vl

    def __add_travel_times(self, pr, ttimes, vl):
        pr.addAttributes([QgsField("pth_travel_time", QVariant.Double), QgsField("pth_speed", QVariant.Double)])
        vl.updateFields()  # tell the vector layer to fetch changes from the provider
        tt_fid = pr.fieldNameIndex("pth_travel_time")
        spd_fid = pr.fieldNameIndex("pth_speed")
        lnk_fid = pr.fieldNameIndex("link")
        lngth_fid = pr.fieldNameIndex("length")
        attr_changes = {}
        for feat in vl.getFeatures():
            link_id = feat.attributes()[lnk_fid]
            speed = feat.attributes()[lngth_fid] / ttimes[link_id]
            attr_changes[feat.id()] = {tt_fid: float(ttimes[link_id]), spd_fid: float(speed)}
        # print(attr_changes)
        pr.changeAttributeValues(attr_changes)
        vl.commitChanges()

    def create_multimodal_path(self, layer_name, path, ttimes, mode_number):
        self.load_features()

        crs = self.layers["links"][1].dataProvider().crs().authid()
        vl = QgsVectorLayer("LineString?crs={}".format(crs), layer_name, "memory")
        pr = vl.dataProvider()

        pr.addAttributes([QgsField("mode", QVariant.String), QgsField("pth_travel_time", QVariant.Double)])
        vl.updateFields()  # tell the vector layer to fetch changes from the provider

        layer_fields = vl.fields()

        features = []
        for item in path:
            feat = QgsFeature(layer_fields)
            feat.setGeometry(self.line_features[item].geometry())
            feat.setAttributes([self.modes[item], float(ttimes[item])])
            features.append(feat)
        # add all links to the temp layer
        vl.dataProvider().addFeatures(features)
        vl.updateExtents()

        # add layer to the map
        QgsProject.instance().addMapLayer(vl)

        self.__format_multimodal_path(vl, mode_number)
        vl.triggerRepaint()
        qgis.utils.iface.mapCanvas().refresh()
        return vl

    def __format_multimodal_path(self, layer, mode_number):
        symbol = QgsSymbol.defaultSymbol(layer.geometryType())
        symbol.setWidth(2)
        renderer = QgsRuleBasedRenderer(symbol)

        color_rules = [[val, f""""mode" LIKE '{val}'""", color] for val, color in self.__colors.items()]

        shape = self.__shapes.get(mode_number, "square")

        def rule_based_symbology(layer, renderer, ruleset):
            root_rule = renderer.rootRule()
            rule = root_rule.children()[0].clone()
            rule.setLabel(ruleset[0])
            rule.setFilterExpression(ruleset[1])

            marker = QgsMarkerSymbol.createSimple(
                {"name": shape, "size": "2", "outline_style": "no", "interval": 6, "color": ruleset[2]}
            )
            ms = QgsMarkerLineSymbolLayer(layer.geometryType())
            ms.setSubSymbol(marker)

            rule.symbols()[0].appendSymbolLayer(ms)
            rule.symbols()[0].deleteSymbolLayer(0)
            root_rule.appendChild(rule)

        for ruleset in color_rules:
            rule_based_symbology(layer, renderer, ruleset)

        layer.setRenderer(renderer)
        layer.triggerRepaint()
        renderer.rootRule().removeChildAt(0)
