# Copyright (c) 2026, UChicago Argonne, LLC
# BSD OPEN SOURCE LICENSE. Full license can be found in LICENSE
# Copyright (c) 2026, UChicago Argonne, LLC
# BSD OPEN SOURCE LICENSE. Full license can be found in LICENSE
from math import ceil
from os.path import dirname, join
from polaris.analyze.tnc_metrics import TNCMetrics

from qgis.PyQt import uic, QtGui
from qgis.PyQt.QtCore import QSizeF
from qgis.PyQt.QtWidgets import QDialog
from qgis.core import QgsDiagramSettings, QgsDiagramLayerSettings, QgsRendererRange, QgsStackedBarDiagram
from qgis.core import QgsProject, QgsStyle, QgsVectorLayerJoinInfo, QgsGraduatedSymbolRenderer, QgsApplication
from qgis.core import QgsSymbol, QgsLinearlyInterpolatedDiagramRenderer, QgsPalLayerSettings, QgsTextFormat
from qgis.core import QgsTextBufferSettings, QgsVectorLayerSimpleLabeling, QgsPieDiagram

from ..common_tools import layer_from_dataframe
from ..common_tools.mapping import color_ramp_shades

FORM_CLASS, _ = uic.loadUiType(join(dirname(__file__), "forms/tnc_navigator.ui"))


class TNCViewerDialog(QDialog, FORM_CLASS):
    def __init__(self, _PQgis):
        QDialog.__init__(self)
        self.setupUi(self)
        self.iface = _PQgis.iface
        self.feed = None
        self._PQgis = _PQgis
        self.polaris_project = _PQgis.polaris_project
        self._p = _PQgis.network

        proj = _PQgis.polaris_project
        self.tnc = TNCMetrics(proj.supply_file, proj.demand_file, proj.result_file, _PQgis.result_h5_path)
        style_fldr = join(dirname(dirname(__file__)), "style_loader", "styles")

        self.mapped_zones = False
        self.zone_target_metric = ""
        self.zone_metric_layer = None

        self.zones_layer = self._PQgis.layers["zone"][0]
        self.zones_layer.loadNamedStyle(join(style_fldr, "zone_background.qml"), True)
        self.chb_label_zones.toggled.connect(self.show_label_zones)
        self.rdo_no_map_zones.toggled.connect(self.enable_zone_mapping)
        self.rdo_map_zones.toggled.connect(self.enable_zone_mapping)
        self.cob_zones_map_type.currentIndexChanged.connect(self.allows_colors_zones)
        self.cob_zones_map_info.currentIndexChanged.connect(self.check_driver_filtering)

        self.enable_zone_mapping()
        self.but_map_zones.clicked.connect(self.map_zones)
        self.sld_zone_scale.valueChanged.connect(self.draw_zone_styles)
        self.time_from.timeChanged.connect(self.update_time)
        self.time_to.timeChanged.connect(self.update_time)

        # LINK MAPPING
        self.mapped_links = False
        self.link_target_metric = ""
        self.link_metric_layer = None
        self.link_layer = self._PQgis.layers["link"][0]
        self.link_layer.loadNamedStyle(join(style_fldr, "link_basic_display.qml"), True)

        self.chb_label_link.toggled.connect(self.show_label_links)
        self.rdo_map_link.toggled.connect(self.enable_link_mapping)
        self.rdo_no_map_link.toggled.connect(self.enable_link_mapping)
        self.cob_link_map_info.currentIndexChanged.connect(self.check_link_filter)
        self.cob_link_map_type.currentIndexChanged.connect(self.allows_colors_links)

        self.enable_link_mapping()
        self.but_map_link.clicked.connect(self.map_links)

        self.sld_link_scale.valueChanged.connect(self.draw_link_styles)

        # LOCATION MAPPING
        self.mapped_locations = False
        self.location_target_metric = ""
        self.location_metric_layer = None
        self.location_layer = self._PQgis.layers["location"][0]
        self.location_layer.loadNamedStyle(join(style_fldr, "location_basic_display.qml"), True)

        self.chb_label_location.toggled.connect(self.show_label_locations)
        self.rdo_map_location.toggled.connect(self.enable_location_mapping)
        self.rdo_no_map_location.toggled.connect(self.enable_location_mapping)
        self.cob_location_map_info.currentIndexChanged.connect(self.check_driver_filtering)
        self.cob_location_map_type.currentIndexChanged.connect(self.allows_colors_location)

        self.enable_location_mapping()
        self.but_map_location.clicked.connect(self.map_location)

        self.sld_location_scale.valueChanged.connect(self.draw_location_styles)

        for layer in [self.zones_layer, self.link_layer, self.location_layer]:
            QgsProject.instance().addMapLayer(layer)

        self.chb_time.toggled.connect(self.allow_filter_by_time)
        self.but_reset.clicked.connect(self.reset)

        self.check_link_filter()
        self.check_driver_filtering()

    def show_label_locations(self):
        self.build_label(
            self.location_layer,
            f"metrics_{self.location_target_metric}",
            self.chb_label_location.isChecked(),
            self.mapped_locations,
            QgsPalLayerSettings.AroundPoint,
        )

    def enable_location_mapping(self):
        for item in [
            self.cob_location_map_type,
            self.cob_location_map_info,
            self.cob_location_color,
            self.sld_location_scale,
            self.chb_label_location,
            self.but_map_location,
        ]:
            item.setEnabled(self.rdo_map_location.isChecked())

        for cob in [self.cob_location_map_type, self.cob_location_map_info, self.cob_location_color]:
            cob.clear()

        if self.rdo_map_location.isChecked():
            default_style = QgsStyle().defaultStyle()
            self.cob_location_map_type.addItems(["Color", "Pie chart", "Stacked bars"])
            self.cob_location_map_info.addItems(self.tnc.list_location_metrics())
            self.cob_location_color.addItems(list(default_style.colorRampNames()))
        else:
            self.mapped_locations = False
            fldr = join(dirname(dirname(__file__)), "style_loader", "styles")
            self.location_layer.loadNamedStyle(join(fldr, "location_basic_display.qml"), True)
            self.location_layer.triggerRepaint()
            if self.location_metric_layer is not None:
                QgsProject.instance().removeMapLayers([self.location_metric_layer.id()])
                self.location_metric_layer = None

        self.allows_colors_location()
        self.iface.mapCanvas().refresh()

    def allows_colors_location(self):
        self.cob_location_color.setVisible(self.cob_location_map_type.currentText() == "Color")
        self.sld_location_scale.setEnabled(self.cob_location_map_type.currentText() != "Color")

    def map_location(self):
        self.mapped_locations = True
        self.but_map_location.setEnabled(False)
        if self.location_metric_layer is not None:
            QgsProject.instance().removeMapLayers([self.location_metric_layer.id()])
            rem = [lien.joinLayerId() for lien in self.location_layer.vectorJoins()]
            for lien_id in rem:
                self.location_layer.removeJoin(lien_id)

        from_time, to_time, sample = self.__get_global_settings()

        metric = self.cob_location_map_info.currentText()

        self.location_target_metric, metric_function = self.__function_metrics(metric)

        df = metric_function(from_minute=from_time, to_minute=to_time, aggregation="location")
        df.loc[:, :] /= sample

        self.location_metric_layer = layer_from_dataframe(df.reset_index(), "zone_metrics")

        self.make_join(self.location_layer, "location", self.location_metric_layer)
        self.show_label_locations()
        self.draw_location_styles()
        self.but_map_location.setEnabled(True)

    def draw_location_styles(self):
        par = [
            self.mapped_locations,
            self.location_target_metric,
            self.location_layer,
            self.cob_location_map_type,
            self.cob_location_color,
            self.sld_location_scale,
            self.but_map_location,
        ]
        self.draw_styles(*par)

    def show_label_links(self):
        self.build_label(
            self.link_metric_layer,
            f"metrics_{self.link_target_metric}",
            self.chb_label_link.isChecked(),
            self.mapped_links,
            QgsPalLayerSettings.Line,
        )

    def enable_link_mapping(self):
        self.check_unit_filtering()
        for item in [
            self.cob_link_map_type,
            self.cob_link_map_info,
            self.cob_link_color,
            self.sld_link_scale,
            self.chb_label_link,
            self.but_map_link,
        ]:
            item.setEnabled(self.rdo_map_link.isChecked())

        for cob in [self.cob_link_map_type, self.cob_link_map_info, self.cob_link_color]:
            cob.clear()

        if self.rdo_map_link.isChecked():
            default_style = QgsStyle().defaultStyle()
            self.cob_link_map_type.addItems(["Color", "Thickness"])
            self.cob_link_map_info.addItems(self.tnc.list_link_metrics())
            self.cob_link_color.addItems(list(default_style.colorRampNames()))
        else:
            self.mapped_zones = False
            fldr = join(dirname(dirname(__file__)), "style_loader", "styles")
            self.link_layer.loadNamedStyle(join(fldr, "link_basic_display.qml"), True)
            self.link_layer.triggerRepaint()
            if self.link_metric_layer is not None:
                QgsProject.instance().removeMapLayers([self.link_metric_layer.id()])
                self.link_metric_layer = None

        self.allows_colors_links()
        self.iface.mapCanvas().refresh()

    def allows_colors_links(self):
        self.cob_link_color.setVisible(self.cob_link_map_type.currentText() == "Color")
        self.sld_link_scale.setEnabled(self.cob_link_map_type.currentText() != "Color")

    def map_links(self):
        self.mapped_links = True
        self.but_map_link.setEnabled(False)
        if self.link_metric_layer is not None:
            QgsProject.instance().removeMapLayers([self.link_metric_layer.id()])
            rem = [lien.joinLayerId() for lien in self.link_layer.vectorJoins()]
            for lien_id in rem:
                self.link_layer.removeJoin(lien_id)

        from_time, to_time, sample = self.__get_global_settings()
        metric = self.cob_link_map_info.currentText()

        self.link_target_metric, metric_function = self.__link_metrics(metric)
        unit = "mile" if self.link_target_metric == "vmt" else "km"
        df = metric_function(from_minute=from_time, to_minute=to_time, unit=unit)
        df.loc[:, :] /= sample

        self.link_metric_layer = layer_from_dataframe(df.reset_index(), "link_metrics")

        self.make_join(self.link_layer, "link", self.link_metric_layer)
        self.show_label_links()
        self.draw_link_styles()
        self.but_map_link.setEnabled(True)

    def check_link_filter(self):
        self.glob_filter_unit.setVisible(self.cob_link_map_info.currentText() == "link_use")

    def draw_link_styles(self):
        par = [
            self.mapped_links,
            self.link_target_metric,
            self.link_layer,
            self.cob_link_map_type,
            self.cob_link_color,
            self.sld_link_scale,
            self.but_map_link,
        ]
        self.draw_styles(*par)

    def __link_metrics(self, metric_name: str):
        if metric_name != "link_use":
            raise ValueError("Link metric not implemented in the GUI")
        metric = "vmt" if self.rdo_miles.isChecked() else "vkt"
        return metric, self.tnc.link_use

    def enable_zone_mapping(self):
        for item in [
            self.cob_zones_map_type,
            self.cob_zones_map_info,
            self.cob_zones_color,
            self.sld_zone_scale,
            self.but_map_zones,
            self.lbl_scl_zone,
        ]:
            item.setEnabled(self.rdo_map_zones.isChecked())

        for cob in [self.cob_zones_map_type, self.cob_zones_map_info, self.cob_zones_color]:
            cob.clear()

        if self.rdo_map_zones.isChecked():
            default_style = QgsStyle().defaultStyle()
            self.cob_zones_map_type.addItems(["Color", "Pie chart", "Stacked bars"])
            self.cob_zones_map_info.addItems(self.tnc.list_zone_metrics())
            self.cob_zones_color.addItems(list(default_style.colorRampNames()))
        else:
            self.mapped_zones = False
            fldr = join(dirname(dirname(__file__)), "style_loader", "styles")
            self.zones_layer.loadNamedStyle(join(fldr, "zone_background.qml"), True)
            self.zones_layer.triggerRepaint()
            if self.zone_metric_layer is not None:
                QgsProject.instance().removeMapLayers([self.zone_metric_layer.id()])
                self.zone_metric_layer = None

        self.allows_colors_zones()
        self.but_map_zones.setEnabled(True)
        self.iface.mapCanvas().refresh()

    def allows_colors_zones(self):
        self.cob_zones_color.setVisible(self.cob_zones_map_type.currentText() == "Color")
        self.sld_zone_scale.setEnabled(self.cob_zones_map_type.currentText() != "Color")

    def map_zones(self):
        self.mapped_zones = True
        self.but_map_zones.setEnabled(False)
        if self.zone_metric_layer is not None:
            QgsProject.instance().removeMapLayers([self.zone_metric_layer.id()])
            rem = [lien.joinLayerId() for lien in self.zones_layer.vectorJoins()]
            for lien_id in rem:
                self.zones_layer.removeJoin(lien_id)

        from_time, to_time, sample = self.__get_global_settings()

        metric = self.cob_zones_map_info.currentText()
        self.zone_target_metric, metric_function = self.__function_metrics(metric)
        df = metric_function(from_minute=from_time, to_minute=to_time, aggregation="zone")
        df.loc[:, :] /= sample

        self.zone_metric_layer = layer_from_dataframe(df.reset_index(), "zone_metrics")

        self.make_join(self.zones_layer, "zone", self.zone_metric_layer)
        self.show_label_zones()
        self.draw_zone_styles()
        self.but_map_zones.setEnabled(True)

    def show_label_zones(self):
        self.build_label(
            self.zones_layer,
            f"metrics_{self.zone_target_metric}",
            self.chb_label_zones.isChecked(),
            self.mapped_zones,
            QgsPalLayerSettings.AroundPoint,
        )

    def draw_zone_styles(self):
        if not self.mapped_zones:
            return
        fld = f"metrics_{self.zone_target_metric}"
        idx = self.zones_layer.fields().indexFromName(fld)
        max_metric = self.zones_layer.maximumValue(idx)
        method = self.cob_zones_map_type.currentText()

        val = self.sld_zone_scale.value()
        color_ramp_name = "Blues" if method != "Color" else self.cob_zones_color.currentText()

        if method == "Color":
            self.map_ranges(fld, max_metric, method, self.zones_layer, val / 2, color_ramp_name)
        else:
            self.map_diagram(fld, max_metric, method, self.zones_layer, val)

        self.but_map_zones.setEnabled(True)

    def check_driver_filtering(self):
        requires = [self.cob_zones_map_info.currentText(), self.cob_location_map_info.currentText()]
        self.glob_filter_driver_type.setVisible("initial_locations" in requires)

    def check_unit_filtering(self):
        self.glob_filter_driver_type.setVisible(self.rdo_map_link.isChecked())

    def __function_metrics(self, metric_name: str):
        metric = metric_name
        funcs = {
            "failed_requests": self.tnc.failed_requests,
            "initial_locations": self.tnc.initial_loc_metrics,
            "mean_wait": self.tnc.wait_metrics,
            "revenue": self.tnc.revenue_metrics,
        }
        if metric not in funcs:
            raise ValueError("Zone metric not implemented in the GUI")

        metric = self.driver_type() if metric == "initial_locations" else metric
        return metric, funcs[metric]

    def driver_type(self):
        return "all_vehicles" if self.rdo_all.isChecked() else "human" if self.rdo_h.isChecked() else "automated"

    def update_time(self):
        if self.mapped_zones:
            self.map_zones()

        if self.mapped_links:
            self.map_links()

    def map_diagram(self, fld, max_metric, method, layer, val):
        if method == "Pie chart":
            diagram = QgsPieDiagram()
        elif method == "Stacked bars":
            diagram = QgsStackedBarDiagram()

        if max_metric is None:
            return

        ds = QgsDiagramSettings()
        dColors = {fld: QtGui.QColor("#5a09a6")}

        ds.categoryColors = dColors.values()
        ds.categoryAttributes = dColors.keys()
        ds.categoryLabels = ds.categoryAttributes
        ds.sizeType = 0  # 0 = Millimeters, 1 = Map Units, 2 = Pixels, 4 = Points, 5 = Inches

        # Set renderer:
        dr = QgsLinearlyInterpolatedDiagramRenderer()
        dr.setLowerValue(0)
        dr.setUpperValue(max_metric)
        dr.setUpperSize(QSizeF(val, val))
        dr.setClassificationField(fld)
        dr.setDiagram(diagram)
        dr.setDiagramSettings(ds)

        layer.setDiagramRenderer(dr)
        dls = QgsDiagramLayerSettings()
        dls.setPlacement(4)
        layer.setDiagramLayerSettings(dls)
        layer.triggerRepaint()

    def map_ranges(self, fld, max_metric, method, layer, val, color_ramp_name):
        intervals = 5
        max_metric = intervals if max_metric is None else max_metric
        values = [ceil(i * (max_metric / intervals)) for i in range(1, intervals + 1)]
        values = [0, 0.000001] + values
        color_ramp = color_ramp_shades(color_ramp_name, intervals)
        ranges = []
        for i in range(intervals):
            myColour = QtGui.QColor("#1e00ff") if method != "Color" else color_ramp[i]
            symbol = QgsSymbol.defaultSymbol(layer.geometryType())
            symbol.setColor(myColour)
            symbol.setOpacity(1)

            if i == 0:
                label = f"0/Null ({fld.replace('metrics_', '')})"
            elif i == 1:
                label = f"Up to {values[i + 1]:,.0f}"
            else:
                label = f"{values[i]:,.0f} to {values[i + 1]:,.0f}"

            ranges.append(QgsRendererRange(values[i], values[i + 1], symbol, label))

        sizes = [0, val] if method != "Color" else [val, val]
        renderer = QgsGraduatedSymbolRenderer("", ranges)
        renderer.setSymbolSizes(*sizes)
        # renderer.setClassAttribute(f"""coalesce("{fld}", 0)""")
        renderer.setClassAttribute(f'''"{fld}"''')

        if method != "Color":
            renderer.setGraduatedMethod(QgsGraduatedSymbolRenderer.GraduatedSize)

        classific_method = QgsApplication.classificationMethodRegistry().method("EqualInterval")
        renderer.setClassificationMethod(classific_method)

        layer.setRenderer(renderer)
        layer.triggerRepaint()
        self.iface.mapCanvas().setExtent(layer.extent())
        self.iface.mapCanvas().refresh()

    def build_label(self, layer, field, active_label, active_map, placement):
        if not active_map:
            return
        if not active_label:
            layer.setLabelsEnabled(False)
            layer.triggerRepaint()
            return

        label = QgsPalLayerSettings()
        txt_format = QgsTextFormat()
        txt_format.setFont(QtGui.QFont("Arial", 10))
        txt_format.setColor(QtGui.QColor("Black"))
        buff = QgsTextBufferSettings()
        buff.setSize(1)
        buff.setEnabled(True)
        txt_format.setBuffer(buff)
        label.setFormat(txt_format)
        label.fieldName = f"""to_int(round("{field}",0))"""
        label.isExpression = True
        label.placement = placement
        layer.setLabelsEnabled(True)
        layer.setLabeling(QgsVectorLayerSimpleLabeling(label))
        layer.triggerRepaint()

    def reset(self):
        if self.mapped_zones:
            self.map_zones()

        # if self.mapped_lines:
        #     self.map_lines()

    def allow_filter_by_time(self):
        for item in [self.lbl_time1, self.lbl_time2, self.time_from, self.time_to]:
            item.setEnabled(self.chb_time.isChecked())

    def draw_styles(self, is_mapped, target_metric, layer, cob_method, cob_color, scale, button):
        if not is_mapped:
            return
        fld = f"metrics_{target_metric}"
        idx = layer.fields().indexFromName(fld)
        max_metric = layer.maximumValue(idx)
        method = cob_method.currentText()

        has_diagram = "Pie chart" in [cob_method.itemText(i) for i in range(cob_method.count())]

        val = scale.value()
        color_ramp_name = "Blues" if method != "Color" else cob_color.currentText()

        if has_diagram and method != "Color":
            self.map_diagram(fld, max_metric, method, layer, val)
        else:
            self.map_ranges(fld, max_metric, method, layer, val / 2, color_ramp_name)

        button.setEnabled(True)

    def make_join(self, base_layer, join_field, metric_layer):
        lien = QgsVectorLayerJoinInfo()
        lien.setJoinFieldName(join_field)
        lien.setTargetFieldName(join_field)
        lien.setJoinLayerId(metric_layer.id())
        lien.setUsingMemoryCache(True)
        lien.setJoinLayer(metric_layer)
        lien.setPrefix("metrics_")
        base_layer.addJoin(lien)

    def __get_global_settings(self):
        minutes_from = self.time_from.time().hour() * 60 + self.time_from.time().minute()
        minutes_to = self.time_to.time().hour() * 60 + self.time_to.time().minute()
        from_time = minutes_from if self.chb_time.isChecked() else None
        to_time = minutes_to if self.chb_time.isChecked() else None
        sample = self.sb_sample.value() / 100
        return from_time, to_time, sample
