Spaces:
Runtime error
Runtime error
Quentin Gallouédec
commited on
Commit
·
0660028
1
Parent(s):
02fb3fc
optimize
Browse files
app.py
CHANGED
@@ -19,7 +19,7 @@ logging.getLogger("absl").setLevel(logging.WARNING)
|
|
19 |
|
20 |
API = HfApi(token=os.environ.get("TOKEN"))
|
21 |
RESULTS_REPO = "open-rl-leaderboard/results"
|
22 |
-
REFRESH_RATE =
|
23 |
ALL_ENV_IDS = {
|
24 |
"Atari": [
|
25 |
"AdventureNoFrameskip-v4",
|
@@ -126,6 +126,7 @@ def iqm(x):
|
|
126 |
|
127 |
|
128 |
def get_leaderboard_df():
|
|
|
129 |
dir_path = API.snapshot_download(repo_id=RESULTS_REPO, repo_type="dataset")
|
130 |
pattern = os.path.join(dir_path, "**", "results_*.json")
|
131 |
filenames = glob.glob(pattern, recursive=True)
|
@@ -169,70 +170,43 @@ def format_df(df: pd.DataFrame):
|
|
169 |
return df.values.tolist()
|
170 |
|
171 |
|
172 |
-
def
|
173 |
-
|
174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
|
176 |
|
177 |
-
def
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
def refresh_videos(df):
|
183 |
-
outputs = []
|
184 |
-
for env_id in all_env_ids:
|
185 |
-
env_df = select_env(df, env_id)
|
186 |
-
if not env_df.empty:
|
187 |
-
user_id = env_df.iloc[0]["user_id"]
|
188 |
-
model_id = env_df.iloc[0]["model_id"]
|
189 |
-
model_sha = env_df.iloc[0]["model_sha"]
|
190 |
-
repo_id = f"{user_id}/{model_id}"
|
191 |
-
video_path = API.hf_hub_download(repo_id=repo_id, filename="replay.mp4", revision=model_sha, repo_type="model")
|
192 |
-
outputs.append(video_path)
|
193 |
-
else:
|
194 |
-
outputs.append(None)
|
195 |
-
return outputs
|
196 |
-
|
197 |
-
|
198 |
-
def refresh_video(env_id):
|
199 |
-
def func(df):
|
200 |
-
env_df = select_env(df, env_id)
|
201 |
-
if not env_df.empty:
|
202 |
-
user_id = env_df.iloc[0]["user_id"]
|
203 |
-
model_id = env_df.iloc[0]["model_id"]
|
204 |
-
model_sha = env_df.iloc[0]["model_sha"]
|
205 |
-
repo_id = f"{user_id}/{model_id}"
|
206 |
-
video_path = API.hf_hub_download(repo_id=repo_id, filename="replay.mp4", revision=model_sha, repo_type="model")
|
207 |
-
return video_path
|
208 |
-
return None
|
209 |
|
210 |
-
return
|
211 |
|
212 |
|
213 |
-
def
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
outputs.append(
|
222 |
-
f"""## {env_id}
|
223 |
|
224 |
### 🏆 [Best model]({url}) 🏆"""
|
225 |
-
|
226 |
-
|
227 |
-
outputs.append(
|
228 |
-
f"""## {env_id}
|
229 |
|
230 |
This leaderboard is quite empty... 😢
|
231 |
|
232 |
Be the first to [submit your model]()!
|
233 |
"""
|
234 |
-
)
|
235 |
-
return outputs
|
236 |
|
237 |
|
238 |
HEADING = """
|
@@ -332,61 +306,73 @@ h3 {
|
|
332 |
|
333 |
"""
|
334 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
335 |
with gr.Blocks(css=css) as demo:
|
336 |
gr.Markdown(HEADING)
|
337 |
with gr.Tabs(elem_classes="tab-buttons") as tabs:
|
338 |
with gr.TabItem("🏅 Leaderboard"):
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
all_gr_dfs = []
|
343 |
-
all_gr_videos = []
|
344 |
-
all_gr_winners = []
|
345 |
for env_domain, env_ids in ALL_ENV_IDS.items():
|
346 |
with gr.TabItem(env_domain):
|
347 |
for env_id in env_ids:
|
348 |
-
# If the env_id envs with "NoFrameskip-v4", we remove it
|
349 |
tab_env_id = env_id[: -len("NoFrameskip-v4")] if env_id.endswith("NoFrameskip-v4") else env_id
|
350 |
with gr.TabItem(tab_env_id) as tab:
|
351 |
logger.info(f"Creating tab for {env_id}")
|
352 |
with gr.Row(equal_height=False):
|
353 |
with gr.Column(scale=3):
|
354 |
-
# Display the leaderboard
|
355 |
gr_df = gr.components.Dataframe(
|
356 |
headers=["🏆", "🧑 User", "🤖 Model id", "📊 IQM episodic return"],
|
357 |
datatype=["number", "markdown", "markdown", "number"],
|
358 |
-
row_count=(20, "fixed"),
|
359 |
)
|
360 |
with gr.Column(scale=1):
|
361 |
with gr.Row(): # Display the env_id and the winner
|
362 |
gr_winner = gr.Markdown()
|
363 |
with gr.Row(): # Play the video of the best model
|
364 |
-
gr_video = gr.PlayableVideo( # Doesn't loop for the moment, see https://github.com/gradio-app/gradio/issues/7689
|
365 |
min_width=50,
|
366 |
-
autoplay=True,
|
367 |
show_download_button=False,
|
368 |
show_share_button=False,
|
369 |
show_label=False,
|
|
|
370 |
)
|
371 |
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
|
|
376 |
|
377 |
-
|
|
|
378 |
|
379 |
with gr.TabItem("📝 About"):
|
380 |
gr.Markdown(ABOUT_TEXT)
|
381 |
|
382 |
-
demo.load(
|
383 |
-
demo.load(refresh_dataframes, inputs=hidden_df, outputs=all_gr_dfs, every=REFRESH_RATE)
|
384 |
-
demo.load(refresh_videos, inputs=hidden_df, outputs=all_gr_videos, every=REFRESH_RATE)
|
385 |
-
demo.load(refresh_winners, inputs=hidden_df, outputs=all_gr_winners, every=REFRESH_RATE)
|
386 |
-
|
387 |
|
388 |
scheduler = BackgroundScheduler()
|
389 |
scheduler.add_job(func=backend_routine, trigger="interval", seconds=REFRESH_RATE, max_instances=1)
|
|
|
390 |
scheduler.start()
|
391 |
|
392 |
|
|
|
19 |
|
20 |
API = HfApi(token=os.environ.get("TOKEN"))
|
21 |
RESULTS_REPO = "open-rl-leaderboard/results"
|
22 |
+
REFRESH_RATE = 5 * 60 # 5 minutes
|
23 |
ALL_ENV_IDS = {
|
24 |
"Atari": [
|
25 |
"AdventureNoFrameskip-v4",
|
|
|
126 |
|
127 |
|
128 |
def get_leaderboard_df():
|
129 |
+
logger.info("Downloading results")
|
130 |
dir_path = API.snapshot_download(repo_id=RESULTS_REPO, repo_type="dataset")
|
131 |
pattern = os.path.join(dir_path, "**", "results_*.json")
|
132 |
filenames = glob.glob(pattern, recursive=True)
|
|
|
170 |
return df.values.tolist()
|
171 |
|
172 |
|
173 |
+
def refresh_video(df, env_id):
|
174 |
+
env_df = select_env(df, env_id)
|
175 |
+
if not env_df.empty:
|
176 |
+
user_id = env_df.iloc[0]["user_id"]
|
177 |
+
model_id = env_df.iloc[0]["model_id"]
|
178 |
+
model_sha = env_df.iloc[0]["model_sha"]
|
179 |
+
repo_id = f"{user_id}/{model_id}"
|
180 |
+
video_path = API.hf_hub_download(repo_id=repo_id, filename="replay.mp4", revision=model_sha, repo_type="model")
|
181 |
+
return video_path
|
182 |
+
else:
|
183 |
+
return None
|
184 |
|
185 |
|
186 |
+
def refresh_one_video(df, env_id):
|
187 |
+
def inner():
|
188 |
+
return refresh_video(df, env_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
|
190 |
+
return inner
|
191 |
|
192 |
|
193 |
+
def refresh_winner(df, env_id):
|
194 |
+
# print("Refreshing winners")
|
195 |
+
env_df = select_env(df, env_id)
|
196 |
+
if not env_df.empty:
|
197 |
+
user_id = env_df.iloc[0]["user_id"]
|
198 |
+
model_id = env_df.iloc[0]["model_id"]
|
199 |
+
url = f"https://huggingface.co/{user_id}/{model_id}"
|
200 |
+
return f"""## {env_id}
|
|
|
|
|
201 |
|
202 |
### 🏆 [Best model]({url}) 🏆"""
|
203 |
+
else:
|
204 |
+
return f"""## {env_id}
|
|
|
|
|
205 |
|
206 |
This leaderboard is quite empty... 😢
|
207 |
|
208 |
Be the first to [submit your model]()!
|
209 |
"""
|
|
|
|
|
210 |
|
211 |
|
212 |
HEADING = """
|
|
|
306 |
|
307 |
"""
|
308 |
|
309 |
+
|
310 |
+
def update_globals():
|
311 |
+
global dataframes, winner_texts, video_pathes, df
|
312 |
+
df = get_leaderboard_df()
|
313 |
+
all_env_ids = [env_id for env_ids in ALL_ENV_IDS.values() for env_id in env_ids]
|
314 |
+
dataframes = {env_id: format_df(select_env(df, env_id)) for env_id in all_env_ids}
|
315 |
+
winner_texts = {env_id: refresh_winner(df, env_id) for env_id in all_env_ids}
|
316 |
+
video_pathes = {env_id: refresh_video(df, env_id) for env_id in all_env_ids}
|
317 |
+
|
318 |
+
|
319 |
+
update_globals()
|
320 |
+
|
321 |
+
|
322 |
+
def refresh():
|
323 |
+
global dataframes, winner_texts, video_pathes
|
324 |
+
return list(dataframes.values()) + list(winner_texts.values()) + [list(video_pathes.values())[0]]
|
325 |
+
|
326 |
+
|
327 |
with gr.Blocks(css=css) as demo:
|
328 |
gr.Markdown(HEADING)
|
329 |
with gr.Tabs(elem_classes="tab-buttons") as tabs:
|
330 |
with gr.TabItem("🏅 Leaderboard"):
|
331 |
+
all_gr_dfs = {}
|
332 |
+
all_gr_winners = {}
|
333 |
+
all_gr_videos = {}
|
|
|
|
|
|
|
334 |
for env_domain, env_ids in ALL_ENV_IDS.items():
|
335 |
with gr.TabItem(env_domain):
|
336 |
for env_id in env_ids:
|
337 |
+
# If the env_id envs with "NoFrameskip-v4", we remove it to improve readability
|
338 |
tab_env_id = env_id[: -len("NoFrameskip-v4")] if env_id.endswith("NoFrameskip-v4") else env_id
|
339 |
with gr.TabItem(tab_env_id) as tab:
|
340 |
logger.info(f"Creating tab for {env_id}")
|
341 |
with gr.Row(equal_height=False):
|
342 |
with gr.Column(scale=3):
|
|
|
343 |
gr_df = gr.components.Dataframe(
|
344 |
headers=["🏆", "🧑 User", "🤖 Model id", "📊 IQM episodic return"],
|
345 |
datatype=["number", "markdown", "markdown", "number"],
|
|
|
346 |
)
|
347 |
with gr.Column(scale=1):
|
348 |
with gr.Row(): # Display the env_id and the winner
|
349 |
gr_winner = gr.Markdown()
|
350 |
with gr.Row(): # Play the video of the best model
|
351 |
+
gr_video = gr.PlayableVideo( # Doesn't loop for the moment, see https://github.com/gradio-app/gradio/issues/7689,
|
352 |
min_width=50,
|
|
|
353 |
show_download_button=False,
|
354 |
show_share_button=False,
|
355 |
show_label=False,
|
356 |
+
interactive=False,
|
357 |
)
|
358 |
|
359 |
+
all_gr_dfs[env_id] = gr_df
|
360 |
+
all_gr_winners[env_id] = gr_winner
|
361 |
+
all_gr_videos[env_id] = gr_video
|
362 |
+
|
363 |
+
tab.select(refresh_one_video(df, env_id), outputs=[gr_video])
|
364 |
|
365 |
+
# Load the first video of the first environment
|
366 |
+
demo.load(refresh_one_video(df, env_ids[0]), outputs=[all_gr_videos[env_ids[0]]])
|
367 |
|
368 |
with gr.TabItem("📝 About"):
|
369 |
gr.Markdown(ABOUT_TEXT)
|
370 |
|
371 |
+
demo.load(refresh, outputs=list(all_gr_dfs.values()) + list(all_gr_winners.values()))
|
|
|
|
|
|
|
|
|
372 |
|
373 |
scheduler = BackgroundScheduler()
|
374 |
scheduler.add_job(func=backend_routine, trigger="interval", seconds=REFRESH_RATE, max_instances=1)
|
375 |
+
scheduler.add_job(func=update_globals, trigger="interval", seconds=REFRESH_RATE, max_instances=1)
|
376 |
scheduler.start()
|
377 |
|
378 |
|