leaderboard / app.py
Quentin Gallouรฉdec
pybullet
74c08c9
import logging
import os
import gradio as gr
import numpy as np
import pandas as pd
import scipy.stats
from apscheduler.schedulers.background import BackgroundScheduler
from datasets import load_dataset
from huggingface_hub import HfApi
# Set up logging
logger = logging.getLogger("app")
logger.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
ch = logging.StreamHandler()
ch.setFormatter(formatter)
logger.addHandler(ch)
# Disable the absl logger (annoying)
logging.getLogger("absl").setLevel(logging.WARNING)
API = HfApi(token=os.environ.get("TOKEN"))
RESULTS_REPO = "open-rl-leaderboard/results_v2"
REFRESH_RATE = 5 * 60 # 5 minutes
ALL_ENV_IDS = {
"Atari": [
"AdventureNoFrameskip-v4",
"AirRaidNoFrameskip-v4",
"AlienNoFrameskip-v4",
"AmidarNoFrameskip-v4",
"AssaultNoFrameskip-v4",
"AsterixNoFrameskip-v4",
"AsteroidsNoFrameskip-v4",
"AtlantisNoFrameskip-v4",
"BankHeistNoFrameskip-v4",
"BattleZoneNoFrameskip-v4",
"BeamRiderNoFrameskip-v4",
"BerzerkNoFrameskip-v4",
"BowlingNoFrameskip-v4",
"BoxingNoFrameskip-v4",
"BreakoutNoFrameskip-v4",
"CarnivalNoFrameskip-v4",
"CentipedeNoFrameskip-v4",
"ChopperCommandNoFrameskip-v4",
"CrazyClimberNoFrameskip-v4",
"DefenderNoFrameskip-v4",
"DemonAttackNoFrameskip-v4",
"DoubleDunkNoFrameskip-v4",
"ElevatorActionNoFrameskip-v4",
"EnduroNoFrameskip-v4",
"FishingDerbyNoFrameskip-v4",
"FreewayNoFrameskip-v4",
"FrostbiteNoFrameskip-v4",
"GopherNoFrameskip-v4",
"GravitarNoFrameskip-v4",
"HeroNoFrameskip-v4",
"IceHockeyNoFrameskip-v4",
"JamesbondNoFrameskip-v4",
"JourneyEscapeNoFrameskip-v4",
"KangarooNoFrameskip-v4",
"KrullNoFrameskip-v4",
"KungFuMasterNoFrameskip-v4",
"MontezumaRevengeNoFrameskip-v4",
"MsPacmanNoFrameskip-v4",
"NameThisGameNoFrameskip-v4",
"PhoenixNoFrameskip-v4",
"PitfallNoFrameskip-v4",
"PongNoFrameskip-v4",
"PooyanNoFrameskip-v4",
"PrivateEyeNoFrameskip-v4",
"QbertNoFrameskip-v4",
"RiverraidNoFrameskip-v4",
"RoadRunnerNoFrameskip-v4",
"RobotankNoFrameskip-v4",
"SeaquestNoFrameskip-v4",
"SkiingNoFrameskip-v4",
"SolarisNoFrameskip-v4",
"SpaceInvadersNoFrameskip-v4",
"StarGunnerNoFrameskip-v4",
"TennisNoFrameskip-v4",
"TimePilotNoFrameskip-v4",
"TutankhamNoFrameskip-v4",
"UpNDownNoFrameskip-v4",
"VentureNoFrameskip-v4",
"VideoPinballNoFrameskip-v4",
"WizardOfWorNoFrameskip-v4",
"YarsRevengeNoFrameskip-v4",
"ZaxxonNoFrameskip-v4",
],
"Box2D": [
"BipedalWalker-v3",
"BipedalWalkerHardcore-v3",
"CarRacing-v2",
"LunarLander-v2",
"LunarLanderContinuous-v2",
],
"Toy text": [
"Blackjack-v1",
"CliffWalking-v0",
"FrozenLake-v1",
"FrozenLake8x8-v1",
],
"Classic control": [
"Acrobot-v1",
"CartPole-v1",
"MountainCar-v0",
"MountainCarContinuous-v0",
"Pendulum-v1",
],
"MuJoCo": [
"Ant-v4",
"HalfCheetah-v4",
"Hopper-v4",
"Humanoid-v4",
"HumanoidStandup-v4",
"InvertedDoublePendulum-v4",
"InvertedPendulum-v4",
"Pusher-v4",
"Reacher-v4",
"Swimmer-v4",
"Walker2d-v4",
],
"PyBullet": [
"AntBulletEnv-v0",
"HalfCheetahBulletEnv-v0",
"HopperBulletEnv-v0",
"HumanoidBulletEnv-v0",
"InvertedDoublePendulumBulletEnv-v0",
"InvertedPendulumSwingupBulletEnv-v0",
"MinitaurBulletEnv-v0",
"ReacherBulletEnv-v0",
"Walker2DBulletEnv-v0",
],
}
def iqm(x):
return scipy.stats.trim_mean(x, proportiontocut=0.25, axis=None)
def get_leaderboard_df():
logger.info("Downloading results")
dataset = load_dataset(RESULTS_REPO, split="train") # split is not important, but we need to use "train")
df = dataset.to_pandas() # convert to pandas dataframe
df = df[df["status"] == "DONE"] # keep only the models that are done
df["iqm_episodic_return"] = df["episodic_returns"].apply(iqm)
logger.debug("Results downloaded")
return df
def select_env(df: pd.DataFrame, env_id: str):
df = df[df["env_id"] == env_id]
df = df.sort_values("iqm_episodic_return", ascending=False)
df["ranking"] = np.arange(1, len(df) + 1)
return df
def format_df(df: pd.DataFrame):
# Add hyperlinks
df = df.copy()
for index, row in df.iterrows():
user_id = row["user_id"]
model_id = row["model_id"]
df.loc[index, "user_id"] = f"[{user_id}](https://huggingface.co/{user_id})"
df.loc[index, "model_id"] = f"[{model_id}](https://huggingface.co/{user_id}/{model_id})"
# Keep only the relevant columns
df = df[["ranking", "user_id", "model_id", "iqm_episodic_return"]]
return df.values.tolist()
def refresh_video(df, env_id):
env_df = select_env(df, env_id)
if not env_df.empty:
user_id = env_df.iloc[0]["user_id"]
model_id = env_df.iloc[0]["model_id"]
sha = env_df.iloc[0]["sha"]
repo_id = f"{user_id}/{model_id}"
try:
video_path = API.hf_hub_download(repo_id=repo_id, filename="replay.mp4", revision=sha, repo_type="model")
return video_path
except Exception as e:
logger.error(f"Error while downloading video for {env_id}: {e}")
return None
else:
return None
def refresh_one_video(df, env_id):
def inner():
return refresh_video(df, env_id)
return inner
def refresh_winner(df, env_id):
# print("Refreshing winners")
env_df = select_env(df, env_id)
if not env_df.empty:
user_id = env_df.iloc[0]["user_id"]
model_id = env_df.iloc[0]["model_id"]
url = f"https://huggingface.co/{user_id}/{model_id}"
return f"""## {env_id}
### ๐Ÿ† [Best model]({url}) ๐Ÿ†"""
else:
return f"""## {env_id}
This leaderboard is quite empty... ๐Ÿ˜ข
Be the first to submit your model!
Check the tab "๐Ÿš€ Getting my agent evaluated"
"""
def refresh_num_models(df):
return f"The leaderboard currently contains {len(df):,} models."
css = """
.generating {
border: none;
}
h2 {
text-align: center;
}
h3 {
text-align: center;
}
"""
def update_globals():
global dataframes, winner_texts, video_pathes, num_models_str, df
df = get_leaderboard_df()
all_env_ids = [env_id for env_ids in ALL_ENV_IDS.values() for env_id in env_ids]
dataframes = {env_id: format_df(select_env(df, env_id)) for env_id in all_env_ids}
winner_texts = {env_id: refresh_winner(df, env_id) for env_id in all_env_ids}
video_pathes = {env_id: refresh_video(df, env_id) for env_id in all_env_ids}
num_models_str = refresh_num_models(df)
update_globals()
def refresh():
global dataframes, winner_texts, num_models_str
return list(dataframes.values()) + list(winner_texts.values()) + [num_models_str]
with gr.Blocks(css=css) as demo:
with open("texts/heading.md") as fp:
gr.Markdown(fp.read())
num_models_md = gr.Markdown()
with gr.Tabs(elem_classes="tab-buttons") as tabs:
with gr.TabItem("๐Ÿ… Leaderboard"):
all_gr_dfs = {}
all_gr_winners = {}
all_gr_videos = {}
for env_domain, env_ids in ALL_ENV_IDS.items():
with gr.TabItem(env_domain):
for env_id in env_ids:
# If the env_id envs with "NoFrameskip-v4", we remove it to improve readability
tab_env_id = env_id[: -len("NoFrameskip-v4")] if env_id.endswith("NoFrameskip-v4") else env_id
with gr.TabItem(tab_env_id) as tab:
logger.debug(f"Creating tab for {env_id}")
with gr.Row(equal_height=False):
with gr.Column(scale=3):
gr_df = gr.components.Dataframe(
headers=["๐Ÿ†", "๐Ÿง‘ User", "๐Ÿค– Model id", "๐Ÿ“Š IQM episodic return"],
datatype=["number", "markdown", "markdown", "number"],
)
with gr.Column(scale=1):
with gr.Row(): # Display the env_id and the winner
gr_winner = gr.Markdown()
with gr.Row(): # Play the video of the best model
gr_video = gr.PlayableVideo( # Doesn't loop for the moment, see https://github.com/gradio-app/gradio/issues/7689,
min_width=50,
show_download_button=False,
show_share_button=False,
show_label=False,
interactive=False,
)
all_gr_dfs[env_id] = gr_df
all_gr_winners[env_id] = gr_winner
all_gr_videos[env_id] = gr_video
tab.select(refresh_one_video(df, env_id), outputs=[gr_video])
# Load the first video of the first environment
demo.load(refresh_one_video(df, env_ids[0]), outputs=[all_gr_videos[env_ids[0]]])
with gr.TabItem("๐Ÿš€ Getting my agent evaluated"):
with open("texts/getting_my_agent_evaluated.md") as fp:
gr.Markdown(fp.read())
with gr.TabItem("๐Ÿ“ About"):
with open("texts/about.md") as fp:
gr.Markdown(fp.read())
demo.load(refresh, outputs=list(all_gr_dfs.values()) + list(all_gr_winners.values()) + [num_models_md])
scheduler = BackgroundScheduler()
scheduler.add_job(func=update_globals, trigger="interval", seconds=REFRESH_RATE, max_instances=1)
scheduler.start()
if __name__ == "__main__":
demo.queue().launch()