Spaces:
Running
Running
import gradio as gr | |
import pandas as pd | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import lcpfn | |
import torch | |
model = lcpfn.LCPFN() | |
def line_plot_fn(data, cutoff, ci_form): | |
cutoff = int(cutoff) | |
ci = int(ci_form) | |
empty_values = list(data[data.y == ""].index) | |
if len(empty_values) > 0: | |
if (len(empty_values) == 1 and empty_values[0] != 49) or (len(empty_values) > 1 and not all(y-x==1 for x,y in zip(empty_values, empty_values[1:]))): | |
raise gr.Error("Please enter a valid learning curve.") | |
else: | |
data = data[data.y != ""] | |
if len(data) < cutoff: | |
raise gr.Error(f"Cutoff ({cutoff}) cannot be greater than the number of data points ({len(data)}).") | |
try: | |
data["y"] = data["y"].astype(float) | |
except: | |
raise gr.Error("Please enter a valid learning curve.") | |
x = torch.arange(1, 51).unsqueeze(1) | |
y = torch.from_numpy(data.y.values).float().unsqueeze(1) | |
rest_prob = (1 - (ci / 100)) / 2 | |
predictions = model.predict_quantiles(x_train=x[:cutoff], y_train=y[:cutoff], x_test=x[cutoff:], qs=[rest_prob, 0.5, 1-rest_prob]) | |
fig, ax = plt.subplots() | |
ax.plot(x, data.y, "black", label="target") | |
predictions = predictions.numpy() | |
new = np.array([y[cutoff-1], y[cutoff-1], y[cutoff-1]]).reshape(1, 3) | |
predictions = np.concatenate( | |
[ | |
new, | |
predictions | |
], | |
axis=0 | |
) | |
# plot extrapolation | |
ax.plot(x[(cutoff-1):], predictions[:, 1], "blue", label="Extrapolation by PFN") | |
ax.fill_between( | |
x[(cutoff-1):].flatten(), predictions[:, 0], predictions[:, 2], color="blue", alpha=0.2, label=f"CI of {ci}%" | |
) | |
# plot cutoff | |
ax.vlines(cutoff, 0, 1, linewidth=0.5, color="k", label="cutoff", linestyles="dashed") | |
ax.set_ylim(0, 1) | |
ax.set_xlim(0, 50) | |
ax.legend(loc="lower right") | |
ax.set_xlabel("t") | |
ax.set_ylabel("y") | |
return fig | |
prior = lcpfn.sample_from_prior(np.random) | |
curve, _ = prior() | |
examples = [] | |
for _ in range(14): | |
prior = lcpfn.sample_from_prior(np.random) | |
curve, _ = prior() | |
if np.random.rand() < 0.5: | |
curve = _ | |
df = pd.DataFrame.from_records(curve[:50][..., np.newaxis], columns=["y"]) | |
df["t"] = [i for i in range(1, 50 + 1)] | |
examples.append([df[["t", "y"]], 10]) | |
with gr.Column() as components: | |
gr.Number(value=10) | |
gr.Number(value=10) | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
dataform = gr.Dataframe( | |
value=examples[0][0], | |
headers=["t", "y"], | |
datatype=["number", "number"], | |
row_count=(50, "fixed"), | |
col_count=(2, "fixed"), | |
type="pandas", | |
) | |
with gr.Row(): | |
cutoffform = gr.Number(label="cutoff", value=10) | |
ci_form = gr.Dropdown(label="Confidence Interval", choices=[ | |
("90%", 90), | |
("95%", 95), | |
("99%", 99) | |
], value=90) | |
outputform = gr.Plot() | |
gr.Examples(examples, inputs=[dataform], label="Examples of synthetic learning curves", examples_per_page=14) | |
dataform.change(fn=line_plot_fn, inputs=[dataform, cutoffform, ci_form], outputs=outputform) | |
cutoffform.change(fn=line_plot_fn, inputs=[dataform, cutoffform, ci_form], outputs=outputform) | |
ci_form.change(fn=line_plot_fn, inputs=[dataform, cutoffform, ci_form], outputs=outputform) | |
demo.load(fn=line_plot_fn, inputs=[dataform, cutoffform, ci_form], outputs=outputform) | |
if __name__ == "__main__": | |
demo.launch() | |