lcpfn / app.py
herilalaina's picture
live demo
3f180fa
raw
history blame
3.68 kB
import gradio as gr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import lcpfn
import torch
model = lcpfn.LCPFN()
def line_plot_fn(data, cutoff, ci_form):
cutoff = int(cutoff)
ci = int(ci_form)
empty_values = list(data[data.y == ""].index)
if len(empty_values) > 0:
if (len(empty_values) == 1 and empty_values[0] != 49) or (len(empty_values) > 1 and not all(y-x==1 for x,y in zip(empty_values, empty_values[1:]))):
raise gr.Error("Please enter a valid learning curve.")
else:
data = data[data.y != ""]
if len(data) < cutoff:
raise gr.Error(f"Cutoff ({cutoff}) cannot be greater than the number of data points ({len(data)}).")
try:
data["y"] = data["y"].astype(float)
except:
raise gr.Error("Please enter a valid learning curve.")
x = torch.arange(1, 51).unsqueeze(1)
y = torch.from_numpy(data.y.values).float().unsqueeze(1)
rest_prob = (1 - (ci / 100)) / 2
predictions = model.predict_quantiles(x_train=x[:cutoff], y_train=y[:cutoff], x_test=x[cutoff:], qs=[rest_prob, 0.5, 1-rest_prob])
fig, ax = plt.subplots()
ax.plot(x, data.y, "black", label="target")
predictions = predictions.numpy()
new = np.array([y[cutoff-1], y[cutoff-1], y[cutoff-1]]).reshape(1, 3)
predictions = np.concatenate(
[
new,
predictions
],
axis=0
)
# plot extrapolation
ax.plot(x[(cutoff-1):], predictions[:, 1], "blue", label="Extrapolation by PFN")
ax.fill_between(
x[(cutoff-1):].flatten(), predictions[:, 0], predictions[:, 2], color="blue", alpha=0.2, label=f"CI of {ci}%"
)
# plot cutoff
ax.vlines(cutoff, 0, 1, linewidth=0.5, color="k", label="cutoff", linestyles="dashed")
ax.set_ylim(0, 1)
ax.set_xlim(0, 50)
ax.legend(loc="lower right")
ax.set_xlabel("t")
ax.set_ylabel("y")
return fig
prior = lcpfn.sample_from_prior(np.random)
curve, _ = prior()
examples = []
for _ in range(14):
prior = lcpfn.sample_from_prior(np.random)
curve, _ = prior()
if np.random.rand() < 0.5:
curve = _
df = pd.DataFrame.from_records(curve[:50][..., np.newaxis], columns=["y"])
df["t"] = [i for i in range(1, 50 + 1)]
examples.append([df[["t", "y"]], 10])
with gr.Column() as components:
gr.Number(value=10)
gr.Number(value=10)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
dataform = gr.Dataframe(
value=examples[0][0],
headers=["t", "y"],
datatype=["number", "number"],
row_count=(50, "fixed"),
col_count=(2, "fixed"),
type="pandas",
)
with gr.Row():
cutoffform = gr.Number(label="cutoff", value=10)
ci_form = gr.Dropdown(label="Confidence Interval", choices=[
("90%", 90),
("95%", 95),
("99%", 99)
], value=90)
outputform = gr.Plot()
gr.Examples(examples, inputs=[dataform], label="Examples of synthetic learning curves", examples_per_page=14)
dataform.change(fn=line_plot_fn, inputs=[dataform, cutoffform, ci_form], outputs=outputform)
cutoffform.change(fn=line_plot_fn, inputs=[dataform, cutoffform, ci_form], outputs=outputform)
ci_form.change(fn=line_plot_fn, inputs=[dataform, cutoffform, ci_form], outputs=outputform)
demo.load(fn=line_plot_fn, inputs=[dataform, cutoffform, ci_form], outputs=outputform)
if __name__ == "__main__":
demo.launch()