# Copyright (c) 2026, UChicago Argonne, LLC
# BSD OPEN SOURCE LICENSE. Full license can be found in LICENSE
# Copyright (c) 2026, UChicago Argonne, LLC
# BSD OPEN SOURCE LICENSE. Full license can be found in LICENSE
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


def run_time_chart(engine, **kwargs):
    sql = "SELECT model_name, created_at, machine, iteration_type, num_threads, run_time/60 run_time FROM iterations"

    start_date = kwargs["date"]["from"]
    end_date = kwargs["date"]["to"]
    sql += f" WHERE created_at BETWEEN '{start_date}' AND '{end_date}'"

    if "model_name" in kwargs:
        sql += " AND ("
        sql += " OR ".join([f"model_name LIKE LOWER('{model}')" for model in kwargs["model_name"]]) + ")"

    if "iteration_type" in kwargs:
        sql += " AND ("
        sql += (
            " OR ".join([f"iteration_type LIKE LOWER('{iter_type}')" for iter_type in kwargs["iteration_type"]]) + ")"
        )

    if "num_threads" in kwargs:
        nthreads = kwargs["num_threads"]
        sql += f" AND num_threads BETWEEN {nthreads['min_threads']} AND {nthreads['max_threads']}"

    data = pd.read_sql(sql, engine)
    data.created_at = pd.to_datetime(data.created_at)
    data = data.set_index("created_at")

    if "machine" in kwargs:
        machine_grouping = kwargs["machine_grouping"]
        keep = []
        for machine in kwargs["machine"]:
            df = data[data["machine"].isin(machine_grouping[machine].machine)]
            keep.append(df.assign(machine_type=machine))
        data = pd.concat(keep)
    else:
        data = data.assign(machine_type="All")

    fig = plt.figure(figsize=(10, 6))

    if data.empty:
        return fig

    if kwargs["dissolving"]["machine_type"]:
        data.loc[:, "machine_type"] = "Selected machines"

    if kwargs["dissolving"]["model"]:
        data.loc[:, "model_name"] = "Selected models"

    if kwargs["dissolving"]["iteration"]:
        data.loc[:, "iteration_type"] = "Iteration Types"

    ylims = [0]
    for label, df in data.groupby(["machine_type", "model_name", "iteration_type"]):
        daily_data = df[["run_time"]].resample("D").agg(["mean", "std"])
        interpolated_data = daily_data.interpolate(method="time")
        interpolated_data = interpolated_data.dropna()
        if interpolated_data.empty:
            continue
        plt.plot(interpolated_data.index, interpolated_data["run_time"]["mean"], label=",".join(label))

        limit = np.nanmax(interpolated_data["run_time"]["mean"] + interpolated_data["run_time"]["std"])
        ylims.append(1.05 * limit)
        plt.fill_between(
            interpolated_data.index,
            interpolated_data["run_time"]["mean"] - interpolated_data["run_time"]["std"],
            interpolated_data["run_time"]["mean"] + interpolated_data["run_time"]["std"],
            alpha=0.2,
            label="Standard Deviation",
        )

    # Let's not plot the labels for the standard deviations
    handles, labels = plt.gca().get_legend_handles_labels()
    filtered_handles_labels = [(h, lbl) for h, lbl in zip(handles, labels) if lbl != "Standard Deviation"]
    filtered_handles, filtered_labels = zip(*filtered_handles_labels) if filtered_handles_labels else ([], [])
    plt.legend(filtered_handles, filtered_labels)

    plt.xlabel("Date")
    plt.ylabel("Run Time (minutes)")
    plt.title("Run times")
    plt.xlim(data.index.min(), data.index.max())
    plt.ylim(0, np.max(ylims))
    return fig
