# Copyright (c) 2025, UChicago Argonne, LLC
# BSD OPEN SOURCE LICENSE. Full license can be found in LICENSE
# Copyright (c) 2025, UChicago Argonne, LLC
# BSD OPEN SOURCE LICENSE. Full license can be found in LICENSE
# Based on AequilibraE's code for bandwidth map creation
import logging
import pandas as pd
import qgis
import sys
import time
from modules.results_analysis.traffic_results import TrafficResults
from os.path import dirname, join
from qgis.PyQt import uic
from qgis.PyQt.QtGui import QColor
from qgis.PyQt.QtWidgets import QDialog
from qgis.core import (
    QgsLineSymbol,
    QgsProject,
    QgsRuleBasedRenderer,
    QgsSimpleLineSymbolLayer,
    QgsStyle,
    QgsSymbol,
    QgsVectorLayer,
)

from QPolaris.modules.common_tools.blocking_signals import block_signals

sys.modules["qgsfieldcombobox"] = qgis.gui
sys.modules["qgscolorbutton"] = qgis.gui
sys.modules["qgsmaplayercombobox"] = qgis.gui
FORM_CLASS, _ = uic.loadUiType(join(dirname(__file__), "forms/ui_map_slider.ui"))


class MapSlider(QDialog, FORM_CLASS):
    def __init__(
        self,
        _PQgis,
        results: TrafficResults,
        color_results: TrafficResults,
        position,
        color_ramp,
        invert_colors,
        do_width: bool,
        do_color: bool,
        default_color,
        comparison,
    ):
        QDialog.__init__(self)

        self.setupUi(self)
        self._PQgis = _PQgis
        self.position = position
        self.color_results = color_results
        self.results = results
        self.do_width = do_width
        self.do_color = do_color
        self.color_ramp = color_ramp
        self.drive_side = "right"
        self.invert_colors = invert_colors
        self.default_color = default_color
        self.comparison = comparison
        self.__legend_layers = []
        self.__map_width = 10
        self.__no_road = self.type_filters()

        self.maxwidth = self.results.max_value if self.do_width else 0
        self.maxcolor = self.color_results.max_value if self.do_color else 0

        if self.comparison:
            title = self.append_unit(self.results, "Period slider (Scenario Comparison)")
        else:
            title = "Period slider"
            if do_width:
                title += self.append_unit(self.results, " - Widths")

            if do_color:
                title += self.append_unit(self.color_results, " - Color ramp")
        self.setWindowTitle(title)

        self.link_layer = self._PQgis.layers.get("link", [None, None])[0]

        self.interv = self.results.periods if do_width else self.color_results.periods
        self.step = self.results.step
        self.slider.setMinimum(self.interv[0])
        self.slider.setMaximum(self.interv[-1])
        self.slider.setSliderPosition(self.position)
        self.slider.setSingleStep(self.step)
        self.slider.valueChanged.connect(self.change_interval)
        self.change_interval()
        self.set_controls()
        self.but_reset.clicked.connect(self.reset_scale)
        self.cob_periods.clear()

        self.cob_periods.addItems([time.strftime("%H:%M", time.gmtime(x)) for x in self.interv])

        self.cob_periods.setCurrentIndex(0)
        self.cob_periods.currentIndexChanged.connect(self.refilter_link_flows_cob)
        self.reset_scale()
        self.redo_map()
        self.dsp_color_thresh.setEnabled(self.do_color or self.comparison)
        self.redo_legend_layers()

    def append_unit(self, val, text):
        text = f"{text} --> {val.metric_name}"
        if val.metric_name == "link_speed":
            return f"{text} ({val.speed_unit})"
        return text

    def type_filters(self) -> str:
        sql = """SELECT link_type from link_type where use_codes not like '%AUTO%' and use_codes not like '%TRUCK%' and
                                use_codes not like '%BUS%' and use_codes not like '%HOV%' and use_codes not like '%TAXI%'"""
        with self._PQgis.supply_conn as conn:
            no_road = pd.read_sql(sql, conn).link_type.tolist()
            return "('" + "','".join(no_road) + "')"

    def set_controls(self):
        self.dsp_thick.valueChanged.connect(self.change_steps)
        self.dsp_color_thresh.valueChanged.connect(self.change_steps)

        if self.comparison:
            self.lbl_color.setText("Threshold")
            self.lbl_color.setToolTip("Absolute differences below this threshold will be mapped in grey")
            self.dsp_color_thresh.setToolTip("Absolute differences below this threshold will be mapped in grey")
            self.populate_combox_threshold(self.dsp_color_thresh)
        else:
            self.comparison_box.setVisible(False)
            self.setFixedHeight(68)

    def populate_combox_threshold(self, dsp):
        if not self.comparison:
            return
        maxv = self.results.max_value
        dsp.setMaximum(maxv)
        dsp.setMinimum(0)
        dsp.setSingleStep(maxv / 20)
        dsp.setValue(0)
        dsp.valueChanged.connect(self.redo_map)

    def reset_scale(self):
        if not self.do_width:
            self.dsp_thick.setMinimum(0.1)
            self.dsp_thick.setMaximum(2.0)
        self.dsp_thick.setValue(1.0 if self.do_width else 0.3)
        self.dsp_color_thresh.setValue(0 if self.comparison else 1.0)
        self.redo_map()

    def refilter_link_flows_cob(self):
        self.position = self.interv[self.cob_periods.currentIndex()]
        self.update_interface()

    def do_joins(self):
        # removes any existing joins
        for lien_id in [lien.joinLayerId() for lien in self.link_layer.vectorJoins()]:
            self.link_layer.removeJoin(lien_id)

        if self.do_width:
            self.results.join_to_links(self.link_layer, interval=self.position, prefix="width_")

        if self.do_color:
            self.color_results.join_to_links(self.link_layer, interval=self.position, prefix="color_")

    def change_steps(self):
        def format_dsp(dsp):
            if dsp.value() < 1.5:
                step = 0.1
            elif dsp.value() < 4:
                step = 2
            elif dsp.value() < 10:
                step = 2
            else:
                step = 5
            dsp.setSingleStep(step)

        if self.do_width:
            format_dsp(self.dsp_thick)
        else:
            self.dsp_thick.setSingleStep(0.1)

        if self.do_color:
            format_dsp(self.dsp_color_thresh)
        self.redo_map()
        self.redo_legend_layers()

    def redo_map(self):
        c_ramp = self.color_ramp

        mvalue = self.maxwidth
        mcolor = self.maxcolor

        ab, ba = (1, -1) if self.drive_side == "right" else (-1, 1)

        symbol = QgsLineSymbol.createSimple({"name": "square", "color": "red"})
        self.link_layer.renderer().setSymbol(symbol)
        wdt = self.dsp_thick.value()
        clr = self.dsp_color_thresh.value()
        fields = {
            ab: {"field": self.results.joined_ab, "fieldc": self.color_results.joined_ab, "lanes": "lanes_ab"},
            ba: {"field": self.results.joined_ba, "fieldc": self.color_results.joined_ba, "lanes": "lanes_ba"},
        }

        for side in [ab, ba]:
            symbol_layer = QgsSimpleLineSymbolLayer.create({})
            props = symbol_layer.properties()
            field = fields[side]["field"]
            lanes = fields[side]["lanes"]
            fieldc = fields[side]["fieldc"]
            if self.do_width:
                props["width_dd_expression"] = (
                    f'(coalesce(scale_linear(abs("{field}"), 0.05, {mvalue}, 0, {self.__map_width}), 0)) * {wdt}'
                )
                props["line_style_expression"] = (
                    f"""if (coalesce("{field}",0) = 0, if("{lanes}" = 0 or "type" in {self.__no_road}, 'no', 'dash'), 'solid')"""
                )

                expr = f"""{side}* (coalesce(scale_linear(abs("{field}"),0, {mvalue},0.05,{self.__map_width}),0) * {wdt}/2)"""
                props["offset_dd_expression"] = expr
            else:
                props["width_dd_expression"] = f"{wdt}"
                props["offset_dd_expression"] = f"""{side}* {wdt}/2)"""
                props["line_style_expression"] = f"""if ("{lanes}" = 0 or "type" in {self.__no_road}, 'no', 'solid')"""

            if self.do_color:
                a, b = (1, 0) if self.invert_colors else (0, 1)
                r = f"""coalesce(ramp_color('{c_ramp}',scale_linear("{fieldc}" * {clr}, 0, {mcolor}, {a}, {b})), '#dddddd')"""
                props["color_dd_expression"] = r
            else:
                if self.comparison:
                    color = f"""if(coalesce("{field}", 0) > {mcolor},'#91bfdb',if(coalesce("{field}",0) < -{mcolor}, '#fc8d59', '#eaeaea'))"""
                    props["color_dd_expression"] = color

                else:
                    default_color = ",".join([str(x) for x in self.default_color.getRgb()][:4])
                    props["line_color"] = default_color
            self.link_layer.renderer().symbol().appendSymbolLayer(QgsSimpleLineSymbolLayer.create(props))

        self.link_layer.renderer().symbol().deleteSymbolLayer(0)
        self.link_layer.triggerRepaint()

        self.recompute_stats()

    def recompute_stats(self):
        if not self.comparison:
            return

        red, green, maxv, minv = 0, 0, 0, 0
        for field in [self.results.joined_ab, self.results.joined_ba]:
            red += self.results.data[self.results.data[field[6:]] < -self.dsp_color_thresh.value()].shape[0]
            green += self.results.data[self.results.data[field[6:]] > self.dsp_color_thresh.value()].shape[0]
            maxv = max(maxv, self.results.data[field[6:]].max())
            minv = min(minv, self.results.data[field[6:]].min())

        grey = self.results.data.shape[0] * 2 - green - red
        base, alt = self.results.iterations

        self.lbl_red.setText(f"{alt} is bigger:  {red:,}")
        self.lbl_grey.setText(f"Within threshold:  {grey:,}")
        self.lbl_green.setText(f"{base} is bigger:  {green:,}")
        self.lbl_min.setText(f"Min: {round(minv, 2):,}")
        self.lbl_max.setText(f"Max: {round(maxv, 2):,}")

    def change_interval(self):
        self.position = int(round(self.slider.value() / self.step, 0)) * self.step
        self.update_interface()

    @block_signals
    def update_interface(self):
        self.lbl_position.setText(f"{self.position:,}")
        self.cob_periods.setCurrentText(time.strftime("%H:%M", time.gmtime(self.position)))
        self.slider.setSliderPosition(self.position)

        self.do_joins()
        self.redo_map()

    def redo_legend_layers(self):
        for lyr in self.__legend_layers:
            QgsProject.instance().removeMapLayer(lyr)
        self.__legend_layers = []
        self.__do_color_ramp_legend()
        self.__do_width_legend()

    def __do_color_ramp_legend(self):
        if not self.do_color or self.comparison:
            return

        legend_steps = 5
        color_layer = QgsVectorLayer("LineString?crs=4326", f"{self.color_results.metric_name} (Color)", "memory")
        symbol = QgsSymbol.defaultSymbol(color_layer.geometryType())
        symbol.setWidth(2)
        renderer = QgsRuleBasedRenderer(symbol)

        ref_style = QgsStyle().defaultStyle()
        color_ramp = ref_style.colorRamp(self.color_ramp)
        max_legend = self.maxcolor / self.dsp_color_thresh.value()
        for interval in range(legend_steps + 1):
            val = interval / legend_steps
            rule = renderer.rootRule().children()[0].clone()
            label = f"{(val * max_legend):,.2f}" if max_legend < 100 else f"{int(val * max_legend):,}"
            rule.setLabel(label)
            rule.symbol().setColor(color_ramp.color(val))
            renderer.rootRule().appendChild(rule)

        # remove first child
        renderer.rootRule().removeChildAt(0)
        color_layer.setRenderer(renderer)
        QgsProject.instance().addMapLayer(color_layer)
        color_layer.triggerRepaint()
        self.__refresh_layer(color_layer.id())
        self.__legend_layers.append(color_layer)

    def __do_width_legend(self):
        if not self.do_width:
            return

        legend_steps = 5
        width_layer = QgsVectorLayer("LineString?crs=4326", f"{self.results.metric_name} (Width)", "memory")
        symbol = QgsSymbol.defaultSymbol(width_layer.geometryType())
        symbol.setWidth(0.001)

        if self.do_color:
            # If we are doing ramps, let's get a color from that ramp
            ref_style = QgsStyle().defaultStyle()
            color_ramp = ref_style.colorRamp(self.color_ramp)
            symbol.setColor(color_ramp.color(0.66))
        else:
            symbol.setColor(QColor("#000000"))

        renderer = QgsRuleBasedRenderer(symbol)
        max_legend = self.maxwidth / self.dsp_thick.value()
        for interval in range(legend_steps + 1):
            val = interval / legend_steps
            rule = renderer.rootRule().children()[0].clone()
            label = f"{(val * max_legend):,.2f}" if max_legend < 100 else f"{int(val * max_legend):,}"
            rule.setLabel(label)
            rule.symbol().setWidth(val * self.__map_width)
            renderer.rootRule().appendChild(rule)

        # remove first child
        renderer.rootRule().removeChildAt(0)
        width_layer.setRenderer(renderer)
        QgsProject.instance().addMapLayer(width_layer)
        width_layer.triggerRepaint()
        self.__refresh_layer(width_layer.id())
        self.__legend_layers.append(width_layer)

    def __refresh_layer(self, layer_id):
        try:
            self._PQgis.iface.layerTreeView().refreshLayerSymbology(layer_id)
        except Exception as e:
            logging.info(f"Could not refreh layer. {e.args}")

    def exit_procedure(self):
        self.close()
