Gregor Betz
commited on
Commit
•
f2d4743
1
Parent(s):
2c4b95c
add regex filter
Browse files
app.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import gradio as gr # type: ignore
|
2 |
import plotly.express as px # type: ignore
|
3 |
|
4 |
-
from backend.data import load_cot_data
|
5 |
from backend.envs import API, REPO_ID, TOKEN
|
6 |
|
7 |
logo1_url = "https://raw.githubusercontent.com/logikon-ai/cot-eval/main/assets/AI2_Logo_Square.png"
|
@@ -27,18 +27,26 @@ except Exception as err:
|
|
27 |
restart_space()
|
28 |
|
29 |
|
30 |
-
def plot_evals_init(model_id, plotly_mode, request: gr.Request):
|
31 |
if request and "model" in request.query_params:
|
32 |
model_param = request.query_params["model"]
|
33 |
if model_param in df_cot_err.model.to_list():
|
34 |
model_id = model_param
|
35 |
-
return plot_evals(model_id, plotly_mode)
|
36 |
|
37 |
|
38 |
-
def plot_evals(model_id, plotly_mode):
|
39 |
df = df_cot_err.copy()
|
40 |
df["selected"] = df_cot_err.model.apply(lambda x: "selected" if x==model_id else "-")
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
#df.sort_values(["selected", "model"], inplace=True, ascending=True) # has currently no effect with px.scatter
|
43 |
template = "plotly_dark" if plotly_mode=="dark" else "plotly"
|
44 |
fig = px.scatter(df, x="base accuracy", y="marginal acc. gain", color="selected", symbol="model",
|
@@ -49,13 +57,7 @@ def plot_evals(model_id, plotly_mode):
|
|
49 |
error_y="acc_gain-err", hover_data=['model', "cot accuracy"],
|
50 |
custom_data=['visibility'],
|
51 |
width=1200, height=700)
|
52 |
-
|
53 |
-
# TODO: doesn't work, needs to be fixed
|
54 |
-
fig.update_traces(
|
55 |
-
visible="legendonly",
|
56 |
-
selector=dict(visibility=False)
|
57 |
-
)
|
58 |
-
|
59 |
fig.update_layout(
|
60 |
title={"automargin": True},
|
61 |
)
|
@@ -112,17 +114,18 @@ with demo:
|
|
112 |
gr.HTML(TITLE)
|
113 |
gr.Markdown(INTRODUCTION_TEXT)
|
114 |
with gr.Row():
|
115 |
-
|
|
|
116 |
plotly_mode = gr.Radio(["dark","light"], value="light", label="Plot theme", scale=1)
|
117 |
submit = gr.Button("Update", scale=1)
|
118 |
table = gr.DataFrame()
|
119 |
plot = gr.Plot(label="evals")
|
120 |
|
121 |
|
122 |
-
submit.click(plot_evals, [
|
123 |
-
submit.click(styled_model_table,
|
124 |
|
125 |
-
demo.load(plot_evals_init, [
|
126 |
-
demo.load(styled_model_table_init,
|
127 |
|
128 |
demo.launch()
|
|
|
1 |
import gradio as gr # type: ignore
|
2 |
import plotly.express as px # type: ignore
|
3 |
|
4 |
+
from backend.data import load_cot_data
|
5 |
from backend.envs import API, REPO_ID, TOKEN
|
6 |
|
7 |
logo1_url = "https://raw.githubusercontent.com/logikon-ai/cot-eval/main/assets/AI2_Logo_Square.png"
|
|
|
27 |
restart_space()
|
28 |
|
29 |
|
30 |
+
def plot_evals_init(model_id, regex_model_filter, plotly_mode, request: gr.Request):
|
31 |
if request and "model" in request.query_params:
|
32 |
model_param = request.query_params["model"]
|
33 |
if model_param in df_cot_err.model.to_list():
|
34 |
model_id = model_param
|
35 |
+
return plot_evals(model_id, regex_model_filter, plotly_mode)
|
36 |
|
37 |
|
38 |
+
def plot_evals(model_id, regex_model_filter, plotly_mode):
|
39 |
df = df_cot_err.copy()
|
40 |
df["selected"] = df_cot_err.model.apply(lambda x: "selected" if x==model_id else "-")
|
41 |
+
|
42 |
+
try:
|
43 |
+
df_filter = df.model.str.contains(regex_model_filter)
|
44 |
+
except Exception as err:
|
45 |
+
gr.Warning("Failed to apply regex filter", duration=4)
|
46 |
+
print("Failed to apply regex filter" + err)
|
47 |
+
df_filter = df.model.str.contains(".*")
|
48 |
+
df = df[df_filter | df.selected.eq("selected")]
|
49 |
+
|
50 |
#df.sort_values(["selected", "model"], inplace=True, ascending=True) # has currently no effect with px.scatter
|
51 |
template = "plotly_dark" if plotly_mode=="dark" else "plotly"
|
52 |
fig = px.scatter(df, x="base accuracy", y="marginal acc. gain", color="selected", symbol="model",
|
|
|
57 |
error_y="acc_gain-err", hover_data=['model', "cot accuracy"],
|
58 |
custom_data=['visibility'],
|
59 |
width=1200, height=700)
|
60 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
fig.update_layout(
|
62 |
title={"automargin": True},
|
63 |
)
|
|
|
114 |
gr.HTML(TITLE)
|
115 |
gr.Markdown(INTRODUCTION_TEXT)
|
116 |
with gr.Row():
|
117 |
+
selected_model = gr.Dropdown(list(df_cot_err.model.unique()), value="allenai/tulu-2-70b", label="Model", scale=2)
|
118 |
+
regex_model_filter = gr.Textbox(".*", label="Regex", info="to filter models shown in plots", scale=2)
|
119 |
plotly_mode = gr.Radio(["dark","light"], value="light", label="Plot theme", scale=1)
|
120 |
submit = gr.Button("Update", scale=1)
|
121 |
table = gr.DataFrame()
|
122 |
plot = gr.Plot(label="evals")
|
123 |
|
124 |
|
125 |
+
submit.click(plot_evals, [selected_model, regex_model_filter, plotly_mode], [plot, selected_model])
|
126 |
+
submit.click(styled_model_table, selected_model, table)
|
127 |
|
128 |
+
demo.load(plot_evals_init, [selected_model, regex_model_filter, plotly_mode], [plot, selected_model])
|
129 |
+
demo.load(styled_model_table_init, selected_model, table)
|
130 |
|
131 |
demo.launch()
|