# Copyright (c) 2026, UChicago Argonne, LLC
# BSD OPEN SOURCE LICENSE. Full license can be found in LICENSE
# Copyright (c) 2026, UChicago Argonne, LLC
# BSD OPEN SOURCE LICENSE. Full license can be found in LICENSE
# COPIED FROM AEQUILIBRAE
import logging
import os
from functools import partial
from math import ceil
from os.path import dirname, join

import numpy as np
import pandas as pd
from polaris.skims.highway.highway_skim import HighwaySkim
from polaris.skims.transit.transit_skim import TransitSkim
from polaris.skims.utils.basic_skim import SkimBase
from qgis.core import (
    QgsApplication,
    QgsGraduatedSymbolRenderer,
    QgsProject,
    QgsRendererRange,
    QgsSymbol,
    QgsVectorLayer,
    QgsVectorLayerJoinInfo,
)
from qgis.PyQt import QtWidgets, uic
from qgis.PyQt.QtGui import QColor
from qgis.PyQt.QtWidgets import (
    QAbstractItemView,
    QCheckBox,
    QComboBox,
    QHBoxLayout,
    QLabel,
    QPushButton,
    QRadioButton,
    QSpacerItem,
    QSpinBox,
    QTableView,
    QVBoxLayout,
)

from ..common_tools import GetOutputFileName, NumpyModel, layer_from_dataframe
from ..common_tools.mapping import color_ramp_shades

FORM_CLASS, _ = uic.loadUiType(os.path.join(os.path.dirname(__file__), "forms/ui_data_viewer.ui"))


class DisplaySkimMatricesDialog(QtWidgets.QDialog, FORM_CLASS):
    def __init__(self, qgis_project):
        QtWidgets.QDialog.__init__(self)
        self.iface = qgis_project.iface
        self.setupUi(self)
        self.data_to_show = np.array((1, 1))
        self.error = None
        self.logger = logging.getLogger("polaris")
        self.qgis_project = qgis_project
        self.skims = SkimBase()
        self.data_path = ""
        self.data_type = ""
        self.skim_type = ""
        self.__setup = False
        self.intervals = []
        self.indices = np.array(1)
        self.mapping_layer = None
        self.selected_col = None
        self.selected_row = None

        self.but_load_traffic.clicked.connect(partial(self.load_skims, "Auto"))
        self.but_load_transit.clicked.connect(partial(self.load_skims, "PT"))

        if self.qgis_project.supply_path:
            self.zones_layer = self.qgis_project.layers["zone"][0]
            style_fldr = join(dirname(dirname(__file__)), "style_loader", "styles")
            self.zones_layer.loadNamedStyle(join(style_fldr, "zone_background.qml"), True)
            QgsProject.instance().addMapLayer(self.zones_layer)
        self.remove_mapping_layer()

    def load_skims(self, mode: str):
        if self.qgis_project.open_mode == "project":
            skims = self.qgis_project.polaris_project.skims
            mat = skims.highway_path if mode == "Auto" else skims.transit_path
            if not os.path.isfile(mat):
                return
            self.skims = skims.highway if mode == "Auto" else skims.transit
            self.intervals = self.skims.intervals
            self.indices = self.skims.index.zones.astype(np.int32)
            self.continue_with_data(mode)

    def continue_with_data(self, mode: str):
        self.setWindowTitle(f"File path:  {self.data_path}")

        # Elements that will be used during the displaying
        self._layout = QVBoxLayout()
        self.table = QTableView()

        self._layout.addWidget(self.table)

        # Settings for displaying
        self.show_layout = QHBoxLayout()

        # Thousand separator
        self.thousand_separator = QCheckBox()
        self.thousand_separator.setChecked(True)
        self.thousand_separator.setText("Thousands separator")
        self.thousand_separator.toggled.connect(self.format_showing)
        self.show_layout.addWidget(self.thousand_separator)

        self.spacer = QSpacerItem(5, 0, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum)
        self.show_layout.addItem(self.spacer)

        # Decimals
        txt = QLabel()
        txt.setText("Decimal places")
        self.show_layout.addWidget(txt)
        self.decimals = QSpinBox()
        self.decimals.valueChanged.connect(self.format_showing)
        self.decimals.setMinimum(0)
        self.decimals.setValue(4)
        self.decimals.setMaximum(10)

        self.show_layout.addWidget(self.decimals)
        self._layout.addItem(self.show_layout)

        # differentiates between matrix and dataset
        # Matrices need cores and indices to be set as well
        if self.qgis_project.supply_path:
            self.mapping_layout = QHBoxLayout()

            self.no_mapping = QRadioButton()
            self.no_mapping.setText("No mapping")
            self.no_mapping.toggled.connect(self.set_mapping)

            self.by_row = QRadioButton()
            self.by_row.setText("By origin")
            self.by_row.toggled.connect(self.set_mapping)

            self.by_col = QRadioButton()
            self.by_col.setText("By destination")
            self.by_col.toggled.connect(self.set_mapping)

            self.no_mapping.setChecked(True)

            self.mapping_layout.addWidget(self.no_mapping)
            self.mapping_layout.addWidget(self.by_row)
            self.mapping_layout.addWidget(self.by_col)
            self._layout.addItem(self.mapping_layout)

        # differentiates between matrix and dataset
        # Matrices need cores and indices to be set as well
        self.mat_layout = QHBoxLayout()
        self.headers_layout = QHBoxLayout()

        self.skim_list = QComboBox()
        self.skim_list.addItems(self.skims.metrics)
        mlabel = QLabel()
        mlabel.setText(f"{mode} Skim")

        self.skim_list.currentIndexChanged.connect(self.format_showing)
        self.mat_layout.addWidget(self.skim_list)
        self.headers_layout.addWidget(mlabel)

        self.transit_mode = QComboBox()
        if mode.upper() == "PT":
            self.transit_mode.addItems(self.skims.modes)
            self.transit_mode.currentIndexChanged.connect(self.format_showing)
            self.mat_layout.addWidget(self.transit_mode)
            tlabel = QLabel()
            tlabel.setText("Transit mode")
            self.headers_layout.addWidget(tlabel)
        else:
            self.transit_mode.addItems(["AUTO"])
            self.transit_mode.setCurrentIndex(0)

        # Adds intervals
        self.intervals_list = QComboBox()
        self.intervals_list.addItems([str(i) for i in self.intervals])
        self.intervals_list.currentIndexChanged.connect(self.format_showing)
        self.mat_layout.addWidget(self.intervals_list)
        ilabel = QLabel()
        ilabel.setText("Time interval")
        self.headers_layout.addWidget(ilabel)

        self._layout.addItem(self.headers_layout)
        self._layout.addItem(self.mat_layout)

        self.but_export = QPushButton()
        self.but_export.setText("Export")
        self.but_export.clicked.connect(self.export)

        self.but_close = QPushButton()
        self.but_close.clicked.connect(self.exit_procedure)
        self.but_close.setText("Close")

        self.but_layout = QHBoxLayout()
        self.but_layout.addWidget(self.but_export)
        self.but_layout.addWidget(self.but_close)

        self._layout.addItem(self.but_layout)

        self.resize(700, 500)
        self.setLayout(self._layout)
        self.__setup = True
        self.format_showing()

    def select_column(self):
        self.selected_col = None
        col_id = [col_idx.column() for col_idx in self.table.selectionModel().selectedColumns()]
        if not col_id:
            return
        self.selected_col = col_id[0]
        self.zones_layer.selectByExpression(f'"zone"={self.indices[col_id[0]]}', QgsVectorLayer.SetSelection)
        self.iface.mapCanvas().refresh()

        dt = np.array(self.data_to_show[:, col_id]).reshape(self.indices.shape[0])

        self.map_dt(dt)

    def select_row(self):
        self.selected_row = None
        row_id = [rowidx.row() for rowidx in self.table.selectionModel().selectedRows()]
        if not row_id:
            return
        self.selected_row = row_id[0]
        self.zones_layer.selectByExpression(f'"zone"={self.indices[row_id[0]]}', QgsVectorLayer.SetSelection)
        dt = np.array(self.data_to_show[row_id[0], :]).reshape(self.indices.shape[0])
        self.map_dt(dt)

    def map_dt(self, dt):
        self.remove_mapping_layer(False)
        df = pd.DataFrame({"zone": self.indices, "data": dt}).dropna()
        df = df[df["data"] < self.skims._infinite]
        self.mapping_layer = layer_from_dataframe(df, "matrix_row")
        self.make_join(self.zones_layer, "zone", self.mapping_layer)
        self.draw_zone_styles()

    def get_data_to_show(self):
        if not self.__setup:
            return None
        metric = self.skim_list.currentText()
        mode = self.transit_mode.currentText()
        interval = int(self.intervals_list.currentText())
        return self.skims.get_skims(mode=mode, metric=metric, interval=interval)

    def format_showing(self):
        self.data_to_show = self.get_data_to_show()
        if self.data_to_show is None:
            return
        decimals = self.decimals.value()
        separator = self.thousand_separator.isChecked()
        m = NumpyModel(matrix=self.data_to_show, indices=self.indices, separator=separator, decimals=decimals)
        self.table.clearSpans()
        self.table.setModel(m)
        self.set_mapping()

    def make_join(self, base_layer, join_field, metric_layer):
        lien = QgsVectorLayerJoinInfo()
        lien.setJoinFieldName(join_field)
        lien.setTargetFieldName(join_field)
        lien.setJoinLayerId(metric_layer.id())
        lien.setUsingMemoryCache(True)
        lien.setJoinLayer(metric_layer)
        lien.setPrefix("metrics_")
        base_layer.addJoin(lien)

    def draw_zone_styles(self):
        color_ramp_name = "Blues"  # if method != 'Color' else self.cob_zones_color.currentText()

        self.map_ranges("metrics_data", self.zones_layer, color_ramp_name)

    def map_ranges(self, fld, layer, color_ramp_name):
        idx = self.zones_layer.fields().indexFromName("metrics_data")
        max_metric = self.zones_layer.maximumValue(idx)

        num_steps = 9
        max_metric = num_steps if max_metric is None else max_metric
        values = [ceil(i * (max_metric / num_steps)) for i in range(1, num_steps + 1)]
        values = [0, 0.000001] + values
        color_ramp = color_ramp_shades(color_ramp_name, num_steps)
        color_ramp[0] = QColor("#feffdf")
        ranges = []
        for i in range(num_steps + 1):
            myColour = color_ramp[i]
            symbol = QgsSymbol.defaultSymbol(layer.geometryType())
            symbol.setColor(myColour)
            symbol.setOpacity(1)

            if i == 0:
                label = f"0/Null ({fld.replace('metrics_', '')})"
            elif i == 1:
                label = f"Up to {values[i + 1]:,.0f}"
            else:
                label = f"{values[i]:,.0f} to {values[i + 1]:,.0f}"

            ranges.append(QgsRendererRange(values[i], values[i + 1], symbol, label))

        sizes = [0, max_metric]
        renderer = QgsGraduatedSymbolRenderer("", ranges)
        renderer.setSymbolSizes(*sizes)
        renderer.setClassAttribute(f"""coalesce("{fld}", 0)""")

        classific_method = QgsApplication.classificationMethodRegistry().method("EqualInterval")
        renderer.setClassificationMethod(classific_method)

        layer.setRenderer(renderer)
        layer.triggerRepaint()
        self.iface.mapCanvas().setExtent(layer.extent())
        self.iface.mapCanvas().refresh()

    def set_mapping(self):
        self.table.clearSelection()
        self.table.setSelectionMode(QAbstractItemView.SingleSelection)
        if not self.qgis_project.supply_path:
            return

        self.remove_mapping_layer()
        if self.no_mapping.isChecked():
            return

        if self.by_row.isChecked():
            self.table.setSelectionBehavior(QAbstractItemView.SelectRows)
            self.selected_col = None
            if self.selected_row:
                self.table.blockSignals(True)
                self.table.selectRow(self.selected_row)
                self.table.blockSignals(False)
                self.select_row()
            self.table.selectionModel().selectionChanged.connect(self.select_row)
        else:
            self.table.setSelectionBehavior(QAbstractItemView.SelectColumns)
            self.selected_row = None
            if self.selected_col:
                self.table.blockSignals(True)
                self.table.selectColumn(self.selected_col)
                self.table.blockSignals(False)
                self.select_column()
            self.table.selectionModel().selectionChanged.connect(self.select_column)

    def export(self):
        pass
        # new_name, file_type = GetOutputFileName(
        #     self, self.data_type, ["Comma-separated file(*.csv)"], ".csv", self.data_path
        # )
        # if new_name is not None:
        #     self.data_to_show.export(new_name)

    def remove_mapping_layer(self, clear_selection=True):
        if self.mapping_layer is not None:
            QgsProject.instance().removeMapLayers([self.mapping_layer.id()])
        for lien in self.zones_layer.vectorJoins():
            self.zones_layer.removeJoin(lien.joinLayerId())
        self.mapping_layer = None
        if clear_selection:
            self.zones_layer.selectByExpression('"zone"-<1000', QgsVectorLayer.SetSelection)
        self.zones_layer.triggerRepaint()

    def exit_procedure(self):
        # self.skims
        # del self.skims
        # del self.data_to_show
        # self.remove_mapping_layer()
        # self.show()
        self.close()
