# Copyright (c) 2025, UChicago Argonne, LLC
# BSD OPEN SOURCE LICENSE. Full license can be found in LICENSE
# Copyright (c) 2025, UChicago Argonne, LLC
# BSD OPEN SOURCE LICENSE. Full license can be found in LICENSE
from polaris.analyze.mapping.flow_lines import delaunay_assignment, desire_lines
from polaris.analyze.trip_metrics import TripMetrics
from polaris.utils.database.data_table_access import DataTableAccess
from polaris.utils.database.db_utils import read_and_close, has_table
from qgis.PyQt.QtCore import QTime
from qgis.PyQt.QtGui import QFont
from qgis.PyQt.QtWidgets import QApplication
from qgis.PyQt.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QComboBox, QLabel, QDialog, QTimeEdit, QPushButton
from qgis.core import QgsLineSymbol, QgsSimpleLineSymbolLayer
from qgis.utils import iface

from QPolaris.modules.common_tools.common_matrix_extraction import CommonMatrixExtractionDialog
from QPolaris.modules.common_tools.geo_layer_from_geodataframe import layer_from_geodataframe


class DesireDelaunayDialog(CommonMatrixExtractionDialog):
    def __init__(self, _PQgis):
        super().__init__(_PQgis)
        self.complete_ui()
        self.__data = {}

    def complete_ui(self):
        # Main layout
        self.setFixedSize(560, 250)
        # Mapping type

        lbl1 = QLabel("Map type")
        lbl1.setFixedWidth(75)
        self.map_types = QComboBox()
        self.map_types.setFixedWidth(160)
        self.map_types.addItems(["Desire Lines", "Delaunay Lines"])
        self.map_types.setCurrentIndex(0)

        self.first_box.insertStretch(0)
        self.first_box.insertWidget(0, self.map_types)
        self.first_box.insertWidget(0, lbl1)

        # Mapping selection
        thickness_selector = QHBoxLayout()

        lbl5 = QLabel("Trip mapping mode")
        lbl5.setFixedWidth(160)
        self.mode_selector = QComboBox()
        self.mode_selector.setFixedWidth(160)
        self.mode_selector.setToolTip("Select the trip mode to be used for the map")
        self.load_mode_names()
        thickness_selector.addWidget(lbl5)
        thickness_selector.addWidget(self.mode_selector)
        thickness_selector.addStretch()

        self.main_layout.insertLayout(2, thickness_selector)

        # Execution button name and action
        self.go.setText("Build map")
        self.go.clicked.connect(self.load_data)

        self.setWindowTitle("Desire and Delaunay lines")

    def load_data(self):
        self.go.setEnabled(False)
        trip_metrics = TripMetrics(self.project.supply_path, self.project.demand_path)

        # Let's get all the trips for the last iteration
        fromtime = int(self.from_time.time().hour() * 3600 + self.from_time.time().minute() * 60)
        totime = int(self.to_time.time().hour() * 3600 + self.to_time.time().minute() * 60)
        agg = self.aggregation.currentText().lower()
        matrix = trip_metrics.trip_matrix(from_start_time=fromtime, to_start_time=totime, aggregation=agg)

        func = desire_lines if self.map_types.currentText() == "Desire Lines" else delaunay_assignment
        gdf = func(self.project.supply_path, aggregation=agg, matrix=matrix)
        gdf.columns = gdf.columns.str.lower()
        gdf["total_ab"] = gdf.filter(regex=r"_ab$").sum(axis=1)
        gdf["total_ba"] = gdf.filter(regex=r"_ba$").sum(axis=1)
        gdf["total_tot"] = gdf.filter(regex=r"_tot$").sum(axis=1)
        self.make_map(gdf)

        self.close()

    def load_mode_names(self):
        with read_and_close(self.project.demand_path) as conn:
            if has_table(conn, "Mode"):
                modes = DataTableAccess(self.project.demand_path).get("Mode", conn).query("mode_id<999")
                data_modes = modes.mode_description.tolist()
                excepted = ["_AND_", "NEST"]
                data_modes = [x for x in data_modes if not any(y in x for y in excepted)]
            else:
                data_modes = [
                    "SOV",
                    "TRUCK",
                    "BUS",
                    "RAIL",
                    "BICYCLE",
                    "WALK",
                    "TAXI",
                    "SCHOOLBUS",
                    "MD_TRUCK",
                    "HD_TRUCK",
                    "BPLATE",
                    "LD_TRUCK",
                ]
            self.mode_selector.addItems(["Total"] + data_modes)
            self.mode_selector.setCurrentIndex(0)

    def make_map(self, gdf):
        layer = layer_from_geodataframe(gdf=gdf, geo_type="linestring", layer_name=self.map_types.currentText())
        ab, ba = (1, -1) if self.drive_side == "right" else (-1, 1)

        symbol = QgsLineSymbol.createSimple({"name": "square", "color": "black"})
        layer.renderer().setSymbol(symbol)

        field = self.mode_selector.currentText().lower()
        field = field if any(f"{x}_tot" in field for x in list(gdf.columns)) else "total"

        fields = {ab: f"{field}_ab", ba: f"{field}_ba"}

        mvalue = gdf[list(fields.values())].max().max()
        for side in [ab, ba]:
            symbol_layer = QgsSimpleLineSymbolLayer.create({})
            props = symbol_layer.properties()
            field = fields[side]
            props["width_dd_expression"] = f'(coalesce(scale_linear(abs("{field}"), 0.05, {mvalue}, 0, 10), 0))'
            props["line_style_expression"] = f"""if (coalesce("{field}",0) = 0, 'no', 'solid')"""
            props["offset_dd_expression"] = '{side}* (coalesce(scale_linear(abs("{field}"),0, {mvalue},0.05,10),0)/2)'
            props["line_color"] = "0,0,139,255"
            layer.renderer().symbol().appendSymbolLayer(QgsSimpleLineSymbolLayer.create(props))

        layer.renderer().symbol().deleteSymbolLayer(0)
        layer.triggerRepaint()
        self.iface.mapCanvas().setExtent(layer.extent())
