# Copyright (c) 2025, UChicago Argonne, LLC
# BSD OPEN SOURCE LICENSE. Full license can be found in LICENSE
# Copyright (c) 2025, UChicago Argonne, LLC
# BSD OPEN SOURCE LICENSE. Full license can be found in LICENSE
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


def transit_charts(engine, **kwargs):

    sql = f"""SELECT pt.*, i.created_at, i.machine from transit_vmt_pmt_occ_by_period pt
                JOIN iterations i ON pt.iteration_uuid = i.iteration_uuid
                WHERE i.model_name='{kwargs["model_name"]}'"""

    start_date = kwargs["date"]["from"]
    end_date = kwargs["date"]["to"]
    sql += f" AND created_at BETWEEN '{start_date}' AND '{end_date}'"

    if len(kwargs["peak_periods"]):
        f = min(kwargs["peak_periods"])
        t = max(kwargs["peak_periods"])
        sql += f" AND peak_period>={f} AND peak_period<={t}"

    if len(kwargs["agencies"]):
        sql += " AND ("
        sql += " OR ".join([f"agency LIKE LOWER('{iter_type}')" for iter_type in kwargs["agencies"]]) + ")"

    if len(kwargs["modes"]):
        sql += " AND ("
        sql += " OR ".join([f""" "mode" LIKE LOWER('{iter_type}')""" for iter_type in kwargs["modes"]]) + ")"

    data = pd.read_sql(sql, engine)
    data.created_at = pd.to_datetime(data.created_at)
    dates = data[["created_at", "iteration_uuid"]].drop_duplicates(subset="iteration_uuid")

    if kwargs["ci_only"]:
        data = data[data["machine"].isin(kwargs["machines"])]

    data = data.set_index("created_at")

    data = data.assign(peak="peak_0")
    data.loc[data.peak_period == 1, "peak"] = "peak_1"

    fig = plt.figure(figsize=(10, 6))

    if data.empty:
        return fig

    if kwargs["dissolving"]["peak_periods"]:
        data.loc[:, "peak"] = "peaks selected"

    if kwargs["dissolving"]["agencies"]:
        data.loc[:, "agency"] = "Agencies selected"

    if kwargs["dissolving"]["modes"]:
        data.loc[:, "mode"] = "Modes selected"

    metric = kwargs["chart"].replace("Transit ", "")

    # Aggregates all data
    data = data[["agency", "mode", "peak", "iteration_uuid", "VMT", "PMT", "Occupancy"]]
    data = data.groupby(["agency", "mode", "peak", "iteration_uuid"], as_index=False).sum()
    data = data.merge(dates, on="iteration_uuid").set_index("created_at")

    ylims = [0]
    for label, df in data.groupby(["agency", "mode", "peak"]):
        daily_data = df[[metric]].resample("D").agg(["mean", "std"])
        interpolated_data = daily_data.interpolate(method="time")
        interpolated_data = interpolated_data.dropna()
        if interpolated_data.empty:
            continue
        plt.plot(interpolated_data.index, interpolated_data[metric]["mean"], label=",".join(label))

        limit = np.nanmax(interpolated_data[metric]["mean"] + interpolated_data[metric]["std"])
        ylims.append(1.05 * limit)
        plt.fill_between(
            interpolated_data.index,
            interpolated_data[metric]["mean"] - interpolated_data[metric]["std"],
            interpolated_data[metric]["mean"] + interpolated_data[metric]["std"],
            alpha=0.2,
            label="Standard Deviation",
        )

    # Let's not plot the labels for the standard deviations
    handles, labels = plt.gca().get_legend_handles_labels()
    filtered_handles_labels = [(h, lbl) for h, lbl in zip(handles, labels) if lbl != "Standard Deviation"]
    filtered_handles, filtered_labels = zip(*filtered_handles_labels) if filtered_handles_labels else ([], [])
    plt.legend(filtered_handles, filtered_labels)

    plt.xlabel("Date")
    plt.ylabel(metric)
    plt.title(f"{kwargs['model_name']} {metric}")
    plt.xlim(data.index.min(), data.index.max())
    plt.ylim(0, np.max(ylims))
    return fig
