Gregor Betz commited on
Commit
c28665f
1 Parent(s): ca2e2c2

use query param only with demo.load

Browse files
Files changed (1) hide show
  1. app.py +21 -14
app.py CHANGED
@@ -23,12 +23,16 @@ except Exception:
23
  restart_space()
24
 
25
 
26
- def plot_evals(model_id, plotly_mode, request: gr.Request):
27
- df = df_cot_err.copy()
28
  if request and "model" in request.query_params:
29
  model_param = request.query_params["model"]
30
- if model_param in df.model.to_list():
31
  model_id = model_param
 
 
 
 
 
32
  df["selected"] = df_cot_err.model.apply(lambda x: "selected" if x==model_id else "-")
33
  #df.sort_values(["selected", "model"], inplace=True, ascending=True) # has currently no effect with px.scatter
34
  template = "plotly_dark" if plotly_mode=="dark" else "plotly"
@@ -45,7 +49,16 @@ def plot_evals(model_id, plotly_mode, request: gr.Request):
45
  )
46
  return fig, model_id
47
 
48
- def get_model_table(model_id):
 
 
 
 
 
 
 
 
 
49
 
50
  def make_pretty(styler):
51
  styler.hide(axis="index")
@@ -79,13 +92,6 @@ def get_model_table(model_id):
79
 
80
  return df_cot_model.style.pipe(make_pretty)
81
 
82
- def styled_model_table(model_id, request: gr.Request):
83
- if request and "model" in request.query_params:
84
- model_param = request.query_params["model"]
85
- if model_param in df_cot_regimes.model.to_list():
86
- model_id = model_param
87
- return get_model_table(model_id)
88
-
89
 
90
  demo = gr.Blocks()
91
 
@@ -95,7 +101,7 @@ with demo:
95
  gr.Markdown(INTRODUCTION_TEXT)
96
  with gr.Row():
97
  model_list = gr.Dropdown(list(df_cot_err.model.unique()), value="allenai/tulu-2-70b", label="Model", scale=2)
98
- plotly_mode = gr.Radio(["dark","light"], value="dark", label="Plot theme", scale=1)
99
  submit = gr.Button("Update", scale=1)
100
  table = gr.DataFrame()
101
  plot = gr.Plot(label="evals")
@@ -103,7 +109,8 @@ with demo:
103
 
104
  submit.click(plot_evals, [model_list, plotly_mode], [plot, model_list])
105
  submit.click(styled_model_table, model_list, table)
106
- demo.load(plot_evals, [model_list, plotly_mode], [plot, model_list])
107
- demo.load(styled_model_table, model_list, table)
 
108
 
109
  demo.launch()
 
23
  restart_space()
24
 
25
 
26
+ def plot_evals_init(model_id, plotly_mode, request: gr.Request):
 
27
  if request and "model" in request.query_params:
28
  model_param = request.query_params["model"]
29
+ if model_param in df_cot_err.model.to_list():
30
  model_id = model_param
31
+ return plot_evals(model_id, plotly_mode)
32
+
33
+
34
+ def plot_evals(model_id, plotly_mode):
35
+ df = df_cot_err.copy()
36
  df["selected"] = df_cot_err.model.apply(lambda x: "selected" if x==model_id else "-")
37
  #df.sort_values(["selected", "model"], inplace=True, ascending=True) # has currently no effect with px.scatter
38
  template = "plotly_dark" if plotly_mode=="dark" else "plotly"
 
49
  )
50
  return fig, model_id
51
 
52
+
53
+ def styled_model_table_init(model_id, request: gr.Request):
54
+ if request and "model" in request.query_params:
55
+ model_param = request.query_params["model"]
56
+ if model_param in df_cot_regimes.model.to_list():
57
+ model_id = model_param
58
+ return styled_model_table(model_id)
59
+
60
+
61
+ def styled_model_table(model_id):
62
 
63
  def make_pretty(styler):
64
  styler.hide(axis="index")
 
92
 
93
  return df_cot_model.style.pipe(make_pretty)
94
 
 
 
 
 
 
 
 
95
 
96
  demo = gr.Blocks()
97
 
 
101
  gr.Markdown(INTRODUCTION_TEXT)
102
  with gr.Row():
103
  model_list = gr.Dropdown(list(df_cot_err.model.unique()), value="allenai/tulu-2-70b", label="Model", scale=2)
104
+ plotly_mode = gr.Radio(["dark","light"], value="light", label="Plot theme", scale=1)
105
  submit = gr.Button("Update", scale=1)
106
  table = gr.DataFrame()
107
  plot = gr.Plot(label="evals")
 
109
 
110
  submit.click(plot_evals, [model_list, plotly_mode], [plot, model_list])
111
  submit.click(styled_model_table, model_list, table)
112
+
113
+ demo.load(plot_evals_init, [model_list, plotly_mode], [plot, model_list])
114
+ demo.load(styled_model_table_init, model_list, table)
115
 
116
  demo.launch()