import gradio as gr import pandas as pd import numpy as np import matplotlib.pyplot as plt import lcpfn import torch model = lcpfn.LCPFN() def line_plot_fn(data, cutoff, ci_form): cutoff = int(cutoff) ci = int(ci_form) empty_values = list(data[data.y == ""].index) if len(empty_values) > 0: if (len(empty_values) == 1 and empty_values[0] != 49) or (len(empty_values) > 1 and not all(y-x==1 for x,y in zip(empty_values, empty_values[1:]))): raise gr.Error("Please enter a valid learning curve.") else: data = data[data.y != ""] if len(data) < cutoff: raise gr.Error(f"Cutoff ({cutoff}) cannot be greater than the number of data points ({len(data)}).") try: data["y"] = data["y"].astype(float) except: raise gr.Error("Please enter a valid learning curve.") x = torch.arange(1, 51).unsqueeze(1) y = torch.from_numpy(data.y.values).float().unsqueeze(1) rest_prob = (1 - (ci / 100)) / 2 predictions = model.predict_quantiles(x_train=x[:cutoff], y_train=y[:cutoff], x_test=x[cutoff:], qs=[rest_prob, 0.5, 1-rest_prob]) fig, ax = plt.subplots() ax.plot(x, data.y, "black", label="target") predictions = predictions.numpy() new = np.array([y[cutoff-1], y[cutoff-1], y[cutoff-1]]).reshape(1, 3) predictions = np.concatenate( [ new, predictions ], axis=0 ) # plot extrapolation ax.plot(x[(cutoff-1):], predictions[:, 1], "blue", label="Extrapolation by PFN") ax.fill_between( x[(cutoff-1):].flatten(), predictions[:, 0], predictions[:, 2], color="blue", alpha=0.2, label=f"CI of {ci}%" ) # plot cutoff ax.vlines(cutoff, 0, 1, linewidth=0.5, color="k", label="cutoff", linestyles="dashed") ax.set_ylim(0, 1) ax.set_xlim(0, 50) ax.legend(loc="lower right") ax.set_xlabel("t") ax.set_ylabel("y") return fig prior = lcpfn.sample_from_prior(np.random) curve, _ = prior() examples = [] for _ in range(14): prior = lcpfn.sample_from_prior(np.random) curve, _ = prior() if np.random.rand() < 0.5: curve = _ df = pd.DataFrame.from_records(curve[:50][..., np.newaxis], columns=["y"]) df["t"] = [i for i in range(1, 50 + 1)] examples.append([df[["t", "y"]], 10]) with gr.Column() as components: gr.Number(value=10) gr.Number(value=10) with gr.Blocks() as demo: with gr.Row(): with gr.Column(): dataform = gr.Dataframe( value=examples[0][0], headers=["t", "y"], datatype=["number", "number"], row_count=(50, "fixed"), col_count=(2, "fixed"), type="pandas", ) with gr.Row(): cutoffform = gr.Number(label="cutoff", value=10) ci_form = gr.Dropdown(label="Confidence Interval", choices=[ ("90%", 90), ("95%", 95), ("99%", 99) ], value=90) outputform = gr.Plot() gr.Examples(examples, inputs=[dataform], label="Examples of synthetic learning curves", examples_per_page=14) dataform.change(fn=line_plot_fn, inputs=[dataform, cutoffform, ci_form], outputs=outputform) cutoffform.change(fn=line_plot_fn, inputs=[dataform, cutoffform, ci_form], outputs=outputform) ci_form.change(fn=line_plot_fn, inputs=[dataform, cutoffform, ci_form], outputs=outputform) demo.load(fn=line_plot_fn, inputs=[dataform, cutoffform, ci_form], outputs=outputform) if __name__ == "__main__": demo.launch()