Gregor Betz commited on
Commit
f2d4743
1 Parent(s): 2c4b95c

add regex filter

Browse files
Files changed (1) hide show
  1. app.py +20 -17
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr # type: ignore
2
  import plotly.express as px # type: ignore
3
 
4
- from backend.data import load_cot_data, is_visible_model
5
  from backend.envs import API, REPO_ID, TOKEN
6
 
7
  logo1_url = "https://raw.githubusercontent.com/logikon-ai/cot-eval/main/assets/AI2_Logo_Square.png"
@@ -27,18 +27,26 @@ except Exception as err:
27
  restart_space()
28
 
29
 
30
- def plot_evals_init(model_id, plotly_mode, request: gr.Request):
31
  if request and "model" in request.query_params:
32
  model_param = request.query_params["model"]
33
  if model_param in df_cot_err.model.to_list():
34
  model_id = model_param
35
- return plot_evals(model_id, plotly_mode)
36
 
37
 
38
- def plot_evals(model_id, plotly_mode):
39
  df = df_cot_err.copy()
40
  df["selected"] = df_cot_err.model.apply(lambda x: "selected" if x==model_id else "-")
41
- df["visibility"] = df_cot_err.model.apply(is_visible_model) | df.selected.eq("selected")
 
 
 
 
 
 
 
 
42
  #df.sort_values(["selected", "model"], inplace=True, ascending=True) # has currently no effect with px.scatter
43
  template = "plotly_dark" if plotly_mode=="dark" else "plotly"
44
  fig = px.scatter(df, x="base accuracy", y="marginal acc. gain", color="selected", symbol="model",
@@ -49,13 +57,7 @@ def plot_evals(model_id, plotly_mode):
49
  error_y="acc_gain-err", hover_data=['model', "cot accuracy"],
50
  custom_data=['visibility'],
51
  width=1200, height=700)
52
-
53
- # TODO: doesn't work, needs to be fixed
54
- fig.update_traces(
55
- visible="legendonly",
56
- selector=dict(visibility=False)
57
- )
58
-
59
  fig.update_layout(
60
  title={"automargin": True},
61
  )
@@ -112,17 +114,18 @@ with demo:
112
  gr.HTML(TITLE)
113
  gr.Markdown(INTRODUCTION_TEXT)
114
  with gr.Row():
115
- model_list = gr.Dropdown(list(df_cot_err.model.unique()), value="allenai/tulu-2-70b", label="Model", scale=2)
 
116
  plotly_mode = gr.Radio(["dark","light"], value="light", label="Plot theme", scale=1)
117
  submit = gr.Button("Update", scale=1)
118
  table = gr.DataFrame()
119
  plot = gr.Plot(label="evals")
120
 
121
 
122
- submit.click(plot_evals, [model_list, plotly_mode], [plot, model_list])
123
- submit.click(styled_model_table, model_list, table)
124
 
125
- demo.load(plot_evals_init, [model_list, plotly_mode], [plot, model_list])
126
- demo.load(styled_model_table_init, model_list, table)
127
 
128
  demo.launch()
 
1
  import gradio as gr # type: ignore
2
  import plotly.express as px # type: ignore
3
 
4
+ from backend.data import load_cot_data
5
  from backend.envs import API, REPO_ID, TOKEN
6
 
7
  logo1_url = "https://raw.githubusercontent.com/logikon-ai/cot-eval/main/assets/AI2_Logo_Square.png"
 
27
  restart_space()
28
 
29
 
30
+ def plot_evals_init(model_id, regex_model_filter, plotly_mode, request: gr.Request):
31
  if request and "model" in request.query_params:
32
  model_param = request.query_params["model"]
33
  if model_param in df_cot_err.model.to_list():
34
  model_id = model_param
35
+ return plot_evals(model_id, regex_model_filter, plotly_mode)
36
 
37
 
38
+ def plot_evals(model_id, regex_model_filter, plotly_mode):
39
  df = df_cot_err.copy()
40
  df["selected"] = df_cot_err.model.apply(lambda x: "selected" if x==model_id else "-")
41
+
42
+ try:
43
+ df_filter = df.model.str.contains(regex_model_filter)
44
+ except Exception as err:
45
+ gr.Warning("Failed to apply regex filter", duration=4)
46
+ print("Failed to apply regex filter" + err)
47
+ df_filter = df.model.str.contains(".*")
48
+ df = df[df_filter | df.selected.eq("selected")]
49
+
50
  #df.sort_values(["selected", "model"], inplace=True, ascending=True) # has currently no effect with px.scatter
51
  template = "plotly_dark" if plotly_mode=="dark" else "plotly"
52
  fig = px.scatter(df, x="base accuracy", y="marginal acc. gain", color="selected", symbol="model",
 
57
  error_y="acc_gain-err", hover_data=['model', "cot accuracy"],
58
  custom_data=['visibility'],
59
  width=1200, height=700)
60
+
 
 
 
 
 
 
61
  fig.update_layout(
62
  title={"automargin": True},
63
  )
 
114
  gr.HTML(TITLE)
115
  gr.Markdown(INTRODUCTION_TEXT)
116
  with gr.Row():
117
+ selected_model = gr.Dropdown(list(df_cot_err.model.unique()), value="allenai/tulu-2-70b", label="Model", scale=2)
118
+ regex_model_filter = gr.Textbox(".*", label="Regex", info="to filter models shown in plots", scale=2)
119
  plotly_mode = gr.Radio(["dark","light"], value="light", label="Plot theme", scale=1)
120
  submit = gr.Button("Update", scale=1)
121
  table = gr.DataFrame()
122
  plot = gr.Plot(label="evals")
123
 
124
 
125
+ submit.click(plot_evals, [selected_model, regex_model_filter, plotly_mode], [plot, selected_model])
126
+ submit.click(styled_model_table, selected_model, table)
127
 
128
+ demo.load(plot_evals_init, [selected_model, regex_model_filter, plotly_mode], [plot, selected_model])
129
+ demo.load(styled_model_table_init, selected_model, table)
130
 
131
  demo.launch()