Quentin Gallouédec commited on
Commit
0660028
·
1 Parent(s): 02fb3fc
Files changed (1) hide show
  1. app.py +60 -74
app.py CHANGED
@@ -19,7 +19,7 @@ logging.getLogger("absl").setLevel(logging.WARNING)
19
 
20
  API = HfApi(token=os.environ.get("TOKEN"))
21
  RESULTS_REPO = "open-rl-leaderboard/results"
22
- REFRESH_RATE = 0.5 * 60 # 5 minutes
23
  ALL_ENV_IDS = {
24
  "Atari": [
25
  "AdventureNoFrameskip-v4",
@@ -126,6 +126,7 @@ def iqm(x):
126
 
127
 
128
  def get_leaderboard_df():
 
129
  dir_path = API.snapshot_download(repo_id=RESULTS_REPO, repo_type="dataset")
130
  pattern = os.path.join(dir_path, "**", "results_*.json")
131
  filenames = glob.glob(pattern, recursive=True)
@@ -169,70 +170,43 @@ def format_df(df: pd.DataFrame):
169
  return df.values.tolist()
170
 
171
 
172
- def refresh_hidden_df():
173
- df = get_leaderboard_df()
174
- return df
 
 
 
 
 
 
 
 
175
 
176
 
177
- def refresh_dataframes(df):
178
- all_dfs = [format_df(select_env(df, env_id)) for env_id in all_env_ids]
179
- return all_dfs
180
-
181
-
182
- def refresh_videos(df):
183
- outputs = []
184
- for env_id in all_env_ids:
185
- env_df = select_env(df, env_id)
186
- if not env_df.empty:
187
- user_id = env_df.iloc[0]["user_id"]
188
- model_id = env_df.iloc[0]["model_id"]
189
- model_sha = env_df.iloc[0]["model_sha"]
190
- repo_id = f"{user_id}/{model_id}"
191
- video_path = API.hf_hub_download(repo_id=repo_id, filename="replay.mp4", revision=model_sha, repo_type="model")
192
- outputs.append(video_path)
193
- else:
194
- outputs.append(None)
195
- return outputs
196
-
197
-
198
- def refresh_video(env_id):
199
- def func(df):
200
- env_df = select_env(df, env_id)
201
- if not env_df.empty:
202
- user_id = env_df.iloc[0]["user_id"]
203
- model_id = env_df.iloc[0]["model_id"]
204
- model_sha = env_df.iloc[0]["model_sha"]
205
- repo_id = f"{user_id}/{model_id}"
206
- video_path = API.hf_hub_download(repo_id=repo_id, filename="replay.mp4", revision=model_sha, repo_type="model")
207
- return video_path
208
- return None
209
 
210
- return func
211
 
212
 
213
- def refresh_winners(df):
214
- outputs = []
215
- for env_id in all_env_ids:
216
- env_df = select_env(df, env_id)
217
- if not env_df.empty:
218
- user_id = env_df.iloc[0]["user_id"]
219
- model_id = env_df.iloc[0]["model_id"]
220
- url = f"https://huggingface.co/{user_id}/{model_id}"
221
- outputs.append(
222
- f"""## {env_id}
223
 
224
  ### 🏆 [Best model]({url}) 🏆"""
225
- )
226
- else:
227
- outputs.append(
228
- f"""## {env_id}
229
 
230
  This leaderboard is quite empty... 😢
231
 
232
  Be the first to [submit your model]()!
233
  """
234
- )
235
- return outputs
236
 
237
 
238
  HEADING = """
@@ -332,61 +306,73 @@ h3 {
332
 
333
  """
334
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
  with gr.Blocks(css=css) as demo:
336
  gr.Markdown(HEADING)
337
  with gr.Tabs(elem_classes="tab-buttons") as tabs:
338
  with gr.TabItem("🏅 Leaderboard"):
339
- df = get_leaderboard_df()
340
- hidden_df = gr.components.Dataframe(df, visible=False)
341
- all_env_ids = []
342
- all_gr_dfs = []
343
- all_gr_videos = []
344
- all_gr_winners = []
345
  for env_domain, env_ids in ALL_ENV_IDS.items():
346
  with gr.TabItem(env_domain):
347
  for env_id in env_ids:
348
- # If the env_id envs with "NoFrameskip-v4", we remove it
349
  tab_env_id = env_id[: -len("NoFrameskip-v4")] if env_id.endswith("NoFrameskip-v4") else env_id
350
  with gr.TabItem(tab_env_id) as tab:
351
  logger.info(f"Creating tab for {env_id}")
352
  with gr.Row(equal_height=False):
353
  with gr.Column(scale=3):
354
- # Display the leaderboard
355
  gr_df = gr.components.Dataframe(
356
  headers=["🏆", "🧑 User", "🤖 Model id", "📊 IQM episodic return"],
357
  datatype=["number", "markdown", "markdown", "number"],
358
- row_count=(20, "fixed"),
359
  )
360
  with gr.Column(scale=1):
361
  with gr.Row(): # Display the env_id and the winner
362
  gr_winner = gr.Markdown()
363
  with gr.Row(): # Play the video of the best model
364
- gr_video = gr.PlayableVideo( # Doesn't loop for the moment, see https://github.com/gradio-app/gradio/issues/7689
365
  min_width=50,
366
- autoplay=True,
367
  show_download_button=False,
368
  show_share_button=False,
369
  show_label=False,
 
370
  )
371
 
372
- all_env_ids.append(env_id)
373
- all_gr_dfs.append(gr_df)
374
- all_gr_winners.append(gr_winner)
375
- all_gr_videos.append(gr_video)
 
376
 
377
- tab.select(refresh_video(env_id), inputs=hidden_df, outputs=gr_video)
 
378
 
379
  with gr.TabItem("📝 About"):
380
  gr.Markdown(ABOUT_TEXT)
381
 
382
- demo.load(refresh_hidden_df, outputs=hidden_df, every=REFRESH_RATE)
383
- demo.load(refresh_dataframes, inputs=hidden_df, outputs=all_gr_dfs, every=REFRESH_RATE)
384
- demo.load(refresh_videos, inputs=hidden_df, outputs=all_gr_videos, every=REFRESH_RATE)
385
- demo.load(refresh_winners, inputs=hidden_df, outputs=all_gr_winners, every=REFRESH_RATE)
386
-
387
 
388
  scheduler = BackgroundScheduler()
389
  scheduler.add_job(func=backend_routine, trigger="interval", seconds=REFRESH_RATE, max_instances=1)
 
390
  scheduler.start()
391
 
392
 
 
19
 
20
  API = HfApi(token=os.environ.get("TOKEN"))
21
  RESULTS_REPO = "open-rl-leaderboard/results"
22
+ REFRESH_RATE = 5 * 60 # 5 minutes
23
  ALL_ENV_IDS = {
24
  "Atari": [
25
  "AdventureNoFrameskip-v4",
 
126
 
127
 
128
  def get_leaderboard_df():
129
+ logger.info("Downloading results")
130
  dir_path = API.snapshot_download(repo_id=RESULTS_REPO, repo_type="dataset")
131
  pattern = os.path.join(dir_path, "**", "results_*.json")
132
  filenames = glob.glob(pattern, recursive=True)
 
170
  return df.values.tolist()
171
 
172
 
173
+ def refresh_video(df, env_id):
174
+ env_df = select_env(df, env_id)
175
+ if not env_df.empty:
176
+ user_id = env_df.iloc[0]["user_id"]
177
+ model_id = env_df.iloc[0]["model_id"]
178
+ model_sha = env_df.iloc[0]["model_sha"]
179
+ repo_id = f"{user_id}/{model_id}"
180
+ video_path = API.hf_hub_download(repo_id=repo_id, filename="replay.mp4", revision=model_sha, repo_type="model")
181
+ return video_path
182
+ else:
183
+ return None
184
 
185
 
186
+ def refresh_one_video(df, env_id):
187
+ def inner():
188
+ return refresh_video(df, env_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
+ return inner
191
 
192
 
193
+ def refresh_winner(df, env_id):
194
+ # print("Refreshing winners")
195
+ env_df = select_env(df, env_id)
196
+ if not env_df.empty:
197
+ user_id = env_df.iloc[0]["user_id"]
198
+ model_id = env_df.iloc[0]["model_id"]
199
+ url = f"https://huggingface.co/{user_id}/{model_id}"
200
+ return f"""## {env_id}
 
 
201
 
202
  ### 🏆 [Best model]({url}) 🏆"""
203
+ else:
204
+ return f"""## {env_id}
 
 
205
 
206
  This leaderboard is quite empty... 😢
207
 
208
  Be the first to [submit your model]()!
209
  """
 
 
210
 
211
 
212
  HEADING = """
 
306
 
307
  """
308
 
309
+
310
+ def update_globals():
311
+ global dataframes, winner_texts, video_pathes, df
312
+ df = get_leaderboard_df()
313
+ all_env_ids = [env_id for env_ids in ALL_ENV_IDS.values() for env_id in env_ids]
314
+ dataframes = {env_id: format_df(select_env(df, env_id)) for env_id in all_env_ids}
315
+ winner_texts = {env_id: refresh_winner(df, env_id) for env_id in all_env_ids}
316
+ video_pathes = {env_id: refresh_video(df, env_id) for env_id in all_env_ids}
317
+
318
+
319
+ update_globals()
320
+
321
+
322
+ def refresh():
323
+ global dataframes, winner_texts, video_pathes
324
+ return list(dataframes.values()) + list(winner_texts.values()) + [list(video_pathes.values())[0]]
325
+
326
+
327
  with gr.Blocks(css=css) as demo:
328
  gr.Markdown(HEADING)
329
  with gr.Tabs(elem_classes="tab-buttons") as tabs:
330
  with gr.TabItem("🏅 Leaderboard"):
331
+ all_gr_dfs = {}
332
+ all_gr_winners = {}
333
+ all_gr_videos = {}
 
 
 
334
  for env_domain, env_ids in ALL_ENV_IDS.items():
335
  with gr.TabItem(env_domain):
336
  for env_id in env_ids:
337
+ # If the env_id envs with "NoFrameskip-v4", we remove it to improve readability
338
  tab_env_id = env_id[: -len("NoFrameskip-v4")] if env_id.endswith("NoFrameskip-v4") else env_id
339
  with gr.TabItem(tab_env_id) as tab:
340
  logger.info(f"Creating tab for {env_id}")
341
  with gr.Row(equal_height=False):
342
  with gr.Column(scale=3):
 
343
  gr_df = gr.components.Dataframe(
344
  headers=["🏆", "🧑 User", "🤖 Model id", "📊 IQM episodic return"],
345
  datatype=["number", "markdown", "markdown", "number"],
 
346
  )
347
  with gr.Column(scale=1):
348
  with gr.Row(): # Display the env_id and the winner
349
  gr_winner = gr.Markdown()
350
  with gr.Row(): # Play the video of the best model
351
+ gr_video = gr.PlayableVideo( # Doesn't loop for the moment, see https://github.com/gradio-app/gradio/issues/7689,
352
  min_width=50,
 
353
  show_download_button=False,
354
  show_share_button=False,
355
  show_label=False,
356
+ interactive=False,
357
  )
358
 
359
+ all_gr_dfs[env_id] = gr_df
360
+ all_gr_winners[env_id] = gr_winner
361
+ all_gr_videos[env_id] = gr_video
362
+
363
+ tab.select(refresh_one_video(df, env_id), outputs=[gr_video])
364
 
365
+ # Load the first video of the first environment
366
+ demo.load(refresh_one_video(df, env_ids[0]), outputs=[all_gr_videos[env_ids[0]]])
367
 
368
  with gr.TabItem("📝 About"):
369
  gr.Markdown(ABOUT_TEXT)
370
 
371
+ demo.load(refresh, outputs=list(all_gr_dfs.values()) + list(all_gr_winners.values()))
 
 
 
 
372
 
373
  scheduler = BackgroundScheduler()
374
  scheduler.add_job(func=backend_routine, trigger="interval", seconds=REFRESH_RATE, max_instances=1)
375
+ scheduler.add_job(func=update_globals, trigger="interval", seconds=REFRESH_RATE, max_instances=1)
376
  scheduler.start()
377
 
378