# Copyright (c) 2025, UChicago Argonne, LLC
# BSD OPEN SOURCE LICENSE. Full license can be found in LICENSE
# Copyright (c) 2025, UChicago Argonne, LLC
# BSD OPEN SOURCE LICENSE. Full license can be found in LICENSE
from pathlib import Path
from typing import List, Optional

import numpy as np
import pandas as pd
import tables
from qgis.core import QgsVectorLayerJoinInfo

from ..common_tools import layer_from_dataframe


sum_agg = ["link_out_volume", "link_in_volume"]

avg_out = [
    "link_density",
    "num_vehicles_in_link",
    "link_travel_time",
    "link_out_flow_rate",
    "link_density_ratio",
    "link_out_flow_ratio",
]

avg_in = [
    "link_queue_length",
    "entry_queue_length",
    "link_in_flow_rate",
    "link_in_flow_ratio",
    "link_travel_delay",
]

max_agg = ["volume_cum_HDT", "volume_cum_MDT"]

cannot_aggregate = [
    "link_speed_ratio",
    "link_travel_time_ratio",
]

specific = ["link_speed"]

cannot_map = [
    "link_lengths",
    "link_uids",
    "link_travel_delay_standard_deviation",
    "link_travel_time_standard_deviation",
]

base_interval_results = 2


class TrafficResults:
    def __init__(self, base_path: Path, base_name: str, alt_path: Path = None, alt_name: str = None, htodo=6):
        self.base_path = base_path
        self.alt_path = alt_path
        self.conn = None
        self.max_value = -1
        self.periods = []
        self.ab_lf_lyr = None
        self.ba_lf_lyr = None
        self.iterations = [base_name, alt_name]
        self.metric_name = ""
        self.joined_ab = ""
        self.joined_ba = ""
        self.base_interval = base_interval_results
        self.speed_unit = "mph"

        self.comparison = False if alt_path is None else True
        self.htodo = htodo
        self.interval: List[int] = []
        self.periods: List[int] = []
        self.step: int = 0

        with tables.open_file(self.base_path, "r") as h5file:
            self.tables = [k.name for k in h5file.get_node("/link_moe") if k.name not in cannot_map]
            self.__link_ids = np.array(h5file.get_node("/link_moe/link_uids")).flatten()

        self.load_parameters()

        self.metric = ""
        _ = np.seterr(divide="ignore")

    def tables_for_aggregation(self, aggregation: int):
        if aggregation == self.base_interval:
            return self.tables
        return [t for t in self.tables if t not in cannot_aggregate]

    def load_parameters(self):
        with tables.open_file(self.base_path, "r") as h5file:
            d = h5file.get_node("/link_moe/link_in_volume")
            self.step = int(86400 / d.shape[0])
            self.periods = [int(i * self.step) for i in range(d.shape[0])]
            del d

    def load_metric(self, metric_name: str, from_hour, aggregation: int):
        # results are for every two minutes, so any period aggregation must take that into consideration
        per_agg = int(aggregation / self.base_interval)
        self.load_parameters()
        from_interval = int(from_hour * 3600 / self.step)
        self.interval = [from_interval, from_interval + int(self.htodo * 3600 / self.step)]
        self.periods = self.periods[from_interval : self.interval[1]]

        arr = self.get_from_hdf5(self.base_path, metric_name, self.interval[0], self.interval[1], aggregation)

        if self.comparison:
            arr -= self.get_from_hdf5(self.alt_path, metric_name, self.interval[0], self.interval[1], aggregation)

        arr = arr.transpose()
        self.max_value = np.nanmax(arr)
        self.metric_name = metric_name

        self.step *= per_agg
        if per_agg > 1:
            self.periods = [int(i * self.step) for i in range(arr.shape[1])]

        columns = [f"t_{p:06}_ab" for p in self.periods]
        df = pd.DataFrame(arr, columns=columns)

        return self.__bidirectional_df(df.assign(link_uid=self.__link_ids))

    def get_from_hdf5(self, file_path, metric_name: str, inter1: int, inter2: int, aggregation) -> np.ndarray:
        per_agg = int(aggregation / self.base_interval)

        with tables.open_file(file_path, "r") as h5file:
            d = h5file.get_node(f"/link_moe/{metric_name}")
            if inter1 >= 0 and inter2 >= 0:
                arr = np.array(d[inter1:inter2, :], np.float64)
            else:
                arr = np.array(d, np.float64)
            del d

        if per_agg == 1:
            return self.__unit_conversion(arr, metric_name)

        if metric_name in avg_out or metric_name in avg_in:
            agg_arr = self.__weighted_average_metrics(metric_name, arr, file_path, inter1, inter2, per_agg)

        elif metric_name in sum_agg:
            agg_arr = np.nansum(arr.reshape((int(arr.shape[0] / per_agg), per_agg, arr.shape[1])), axis=1)

        elif metric_name in specific:
            if metric_name == "link_speed":
                agg_arr = self.__build_aggregate_link_speed(file_path, inter1, inter2, per_agg)
            else:
                raise ValueError(f"What am I doing here? We don't have a special case for: {metric_name}")
        elif metric_name in max_agg:
            agg_arr = np.nanmax(arr.reshape((int(arr.shape[0] / per_agg), per_agg, arr.shape[1])), axis=1)

        elif metric_name in cannot_aggregate:
            raise ValueError(f"Cannot aggregate metric {metric_name}")
        else:
            raise ValueError(f"What am I doing here? This metric does not exist: {metric_name}")

        return self.__unit_conversion(agg_arr, metric_name)

    def __build_aggregate_link_speed(self, file_path, inter1, inter2, per_agg):
        ttime = self.get_from_hdf5(file_path, "link_travel_time", inter1, inter2, self.base_interval)
        llength = self.get_from_hdf5(file_path, "link_lengths", -1, inter2, self.base_interval)

        # Lengths are in meters, times are in seconds
        agg_ttime = np.nanmean(ttime.reshape((int(ttime.shape[0] / per_agg), per_agg, ttime.shape[1])), axis=1)
        agg_llength = llength.mean(axis=1)
        return np.nan_to_num(agg_llength / agg_ttime, copy=True, nan=0.0, posinf=0.0, neginf=0.0)

    def __unit_conversion(self, arr, metric_name):
        if metric_name == "link_speed":
            if self.speed_unit == "mph":
                return arr * (3600 / 1609.344)
            else:
                raise ValueError(f"Speed unit {self.speed_unit} not supported")
        return arr

    def __weighted_average_metrics(self, metric, arr, file_path, inter1, inter2, per_agg):
        # We grab inflows or outflows for our weighted average, depending on the metric
        var = "link_out_volume" if metric in avg_out else "link_in_volume"
        flows = self.get_from_hdf5(file_path, var, inter1, inter2, self.base_interval)

        # We multiply the flows by the metric values
        arr = arr * flows

        # Sum them up across the aggregation period
        agg_arr = np.nansum(arr.reshape((int(arr.shape[0] / per_agg), per_agg, arr.shape[1])), axis=1)
        agg_flows = np.nansum(flows.reshape((int(flows.shape[0] / per_agg), per_agg, flows.shape[1])), axis=1)

        # Compute the division
        return np.nan_to_num(agg_arr / agg_flows, copy=True, nan=0.0, posinf=0, neginf=0)

    def build_metric_layer(self, metric_name: str, from_hour, aggregation: int):
        self.data = self.load_metric(metric_name, from_hour, aggregation)

        self.layer = layer_from_dataframe(self.data, metric_name)

    def join_to_links(self, link_layer, interval=0, prefix=""):
        joined_ab = f"t_{interval:06}_ab"
        joined_ba = f"t_{interval:06}_ba"

        lien = QgsVectorLayerJoinInfo()
        lien.setJoinFieldName("link")
        lien.setTargetFieldName("link")
        lien.setJoinLayerId(self.layer.id())
        lien.setUsingMemoryCache(True)
        lien.setJoinLayer(self.layer)
        lien.setPrefix(prefix)
        lien.setJoinFieldNamesSubset([joined_ab, joined_ba])
        link_layer.addJoin(lien)
        link_layer.updateFields()

        self.joined_ab = f"{prefix}{joined_ab}"
        self.joined_ba = f"{prefix}{joined_ba}"

    def __bidirectional_df(self, df):
        ab_flows = df.loc[df["link_uid"].mod(2).eq(0)]
        ba_flows = df.loc[~df.link_uid.isin(ab_flows.link_uid)]
        ab_flows = ab_flows.assign(link=(ab_flows.link_uid / 2).astype(int)).drop(columns=["link_uid"])
        ba_flows.columns = [x.replace("ab", "ba") for x in ba_flows.columns]
        ba_flows = ba_flows.assign(link=((ba_flows.link_uid - 1) / 2).astype(int)).drop(columns=["link_uid"])
        return ab_flows.merge(ba_flows, on="link", how="outer")
