leaderboard / app.py
Quentin Gallouédec
pybullet
74c08c9
raw
history blame contribute delete
No virus
10.4 kB
import logging
import os
import gradio as gr
import numpy as np
import pandas as pd
import scipy.stats
from apscheduler.schedulers.background import BackgroundScheduler
from datasets import load_dataset
from huggingface_hub import HfApi
# Set up logging
logger = logging.getLogger("app")
logger.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
ch = logging.StreamHandler()
ch.setFormatter(formatter)
logger.addHandler(ch)
# Disable the absl logger (annoying)
logging.getLogger("absl").setLevel(logging.WARNING)
API = HfApi(token=os.environ.get("TOKEN"))
RESULTS_REPO = "open-rl-leaderboard/results_v2"
REFRESH_RATE = 5 * 60 # 5 minutes
ALL_ENV_IDS = {
"Atari": [
"AdventureNoFrameskip-v4",
"AirRaidNoFrameskip-v4",
"AlienNoFrameskip-v4",
"AmidarNoFrameskip-v4",
"AssaultNoFrameskip-v4",
"AsterixNoFrameskip-v4",
"AsteroidsNoFrameskip-v4",
"AtlantisNoFrameskip-v4",
"BankHeistNoFrameskip-v4",
"BattleZoneNoFrameskip-v4",
"BeamRiderNoFrameskip-v4",
"BerzerkNoFrameskip-v4",
"BowlingNoFrameskip-v4",
"BoxingNoFrameskip-v4",
"BreakoutNoFrameskip-v4",
"CarnivalNoFrameskip-v4",
"CentipedeNoFrameskip-v4",
"ChopperCommandNoFrameskip-v4",
"CrazyClimberNoFrameskip-v4",
"DefenderNoFrameskip-v4",
"DemonAttackNoFrameskip-v4",
"DoubleDunkNoFrameskip-v4",
"ElevatorActionNoFrameskip-v4",
"EnduroNoFrameskip-v4",
"FishingDerbyNoFrameskip-v4",
"FreewayNoFrameskip-v4",
"FrostbiteNoFrameskip-v4",
"GopherNoFrameskip-v4",
"GravitarNoFrameskip-v4",
"HeroNoFrameskip-v4",
"IceHockeyNoFrameskip-v4",
"JamesbondNoFrameskip-v4",
"JourneyEscapeNoFrameskip-v4",
"KangarooNoFrameskip-v4",
"KrullNoFrameskip-v4",
"KungFuMasterNoFrameskip-v4",
"MontezumaRevengeNoFrameskip-v4",
"MsPacmanNoFrameskip-v4",
"NameThisGameNoFrameskip-v4",
"PhoenixNoFrameskip-v4",
"PitfallNoFrameskip-v4",
"PongNoFrameskip-v4",
"PooyanNoFrameskip-v4",
"PrivateEyeNoFrameskip-v4",
"QbertNoFrameskip-v4",
"RiverraidNoFrameskip-v4",
"RoadRunnerNoFrameskip-v4",
"RobotankNoFrameskip-v4",
"SeaquestNoFrameskip-v4",
"SkiingNoFrameskip-v4",
"SolarisNoFrameskip-v4",
"SpaceInvadersNoFrameskip-v4",
"StarGunnerNoFrameskip-v4",
"TennisNoFrameskip-v4",
"TimePilotNoFrameskip-v4",
"TutankhamNoFrameskip-v4",
"UpNDownNoFrameskip-v4",
"VentureNoFrameskip-v4",
"VideoPinballNoFrameskip-v4",
"WizardOfWorNoFrameskip-v4",
"YarsRevengeNoFrameskip-v4",
"ZaxxonNoFrameskip-v4",
],
"Box2D": [
"BipedalWalker-v3",
"BipedalWalkerHardcore-v3",
"CarRacing-v2",
"LunarLander-v2",
"LunarLanderContinuous-v2",
],
"Toy text": [
"Blackjack-v1",
"CliffWalking-v0",
"FrozenLake-v1",
"FrozenLake8x8-v1",
],
"Classic control": [
"Acrobot-v1",
"CartPole-v1",
"MountainCar-v0",
"MountainCarContinuous-v0",
"Pendulum-v1",
],
"MuJoCo": [
"Ant-v4",
"HalfCheetah-v4",
"Hopper-v4",
"Humanoid-v4",
"HumanoidStandup-v4",
"InvertedDoublePendulum-v4",
"InvertedPendulum-v4",
"Pusher-v4",
"Reacher-v4",
"Swimmer-v4",
"Walker2d-v4",
],
"PyBullet": [
"AntBulletEnv-v0",
"HalfCheetahBulletEnv-v0",
"HopperBulletEnv-v0",
"HumanoidBulletEnv-v0",
"InvertedDoublePendulumBulletEnv-v0",
"InvertedPendulumSwingupBulletEnv-v0",
"MinitaurBulletEnv-v0",
"ReacherBulletEnv-v0",
"Walker2DBulletEnv-v0",
],
}
def iqm(x):
return scipy.stats.trim_mean(x, proportiontocut=0.25, axis=None)
def get_leaderboard_df():
logger.info("Downloading results")
dataset = load_dataset(RESULTS_REPO, split="train") # split is not important, but we need to use "train")
df = dataset.to_pandas() # convert to pandas dataframe
df = df[df["status"] == "DONE"] # keep only the models that are done
df["iqm_episodic_return"] = df["episodic_returns"].apply(iqm)
logger.debug("Results downloaded")
return df
def select_env(df: pd.DataFrame, env_id: str):
df = df[df["env_id"] == env_id]
df = df.sort_values("iqm_episodic_return", ascending=False)
df["ranking"] = np.arange(1, len(df) + 1)
return df
def format_df(df: pd.DataFrame):
# Add hyperlinks
df = df.copy()
for index, row in df.iterrows():
user_id = row["user_id"]
model_id = row["model_id"]
df.loc[index, "user_id"] = f"[{user_id}](https://huggingface.co/{user_id})"
df.loc[index, "model_id"] = f"[{model_id}](https://huggingface.co/{user_id}/{model_id})"
# Keep only the relevant columns
df = df[["ranking", "user_id", "model_id", "iqm_episodic_return"]]
return df.values.tolist()
def refresh_video(df, env_id):
env_df = select_env(df, env_id)
if not env_df.empty:
user_id = env_df.iloc[0]["user_id"]
model_id = env_df.iloc[0]["model_id"]
sha = env_df.iloc[0]["sha"]
repo_id = f"{user_id}/{model_id}"
try:
video_path = API.hf_hub_download(repo_id=repo_id, filename="replay.mp4", revision=sha, repo_type="model")
return video_path
except Exception as e:
logger.error(f"Error while downloading video for {env_id}: {e}")
return None
else:
return None
def refresh_one_video(df, env_id):
def inner():
return refresh_video(df, env_id)
return inner
def refresh_winner(df, env_id):
# print("Refreshing winners")
env_df = select_env(df, env_id)
if not env_df.empty:
user_id = env_df.iloc[0]["user_id"]
model_id = env_df.iloc[0]["model_id"]
url = f"https://huggingface.co/{user_id}/{model_id}"
return f"""## {env_id}
### 🏆 [Best model]({url}) 🏆"""
else:
return f"""## {env_id}
This leaderboard is quite empty... 😢
Be the first to submit your model!
Check the tab "🚀 Getting my agent evaluated"
"""
def refresh_num_models(df):
return f"The leaderboard currently contains {len(df):,} models."
css = """
.generating {
border: none;
}
h2 {
text-align: center;
}
h3 {
text-align: center;
}
"""
def update_globals():
global dataframes, winner_texts, video_pathes, num_models_str, df
df = get_leaderboard_df()
all_env_ids = [env_id for env_ids in ALL_ENV_IDS.values() for env_id in env_ids]
dataframes = {env_id: format_df(select_env(df, env_id)) for env_id in all_env_ids}
winner_texts = {env_id: refresh_winner(df, env_id) for env_id in all_env_ids}
video_pathes = {env_id: refresh_video(df, env_id) for env_id in all_env_ids}
num_models_str = refresh_num_models(df)
update_globals()
def refresh():
global dataframes, winner_texts, num_models_str
return list(dataframes.values()) + list(winner_texts.values()) + [num_models_str]
with gr.Blocks(css=css) as demo:
with open("texts/heading.md") as fp:
gr.Markdown(fp.read())
num_models_md = gr.Markdown()
with gr.Tabs(elem_classes="tab-buttons") as tabs:
with gr.TabItem("🏅 Leaderboard"):
all_gr_dfs = {}
all_gr_winners = {}
all_gr_videos = {}
for env_domain, env_ids in ALL_ENV_IDS.items():
with gr.TabItem(env_domain):
for env_id in env_ids:
# If the env_id envs with "NoFrameskip-v4", we remove it to improve readability
tab_env_id = env_id[: -len("NoFrameskip-v4")] if env_id.endswith("NoFrameskip-v4") else env_id
with gr.TabItem(tab_env_id) as tab:
logger.debug(f"Creating tab for {env_id}")
with gr.Row(equal_height=False):
with gr.Column(scale=3):
gr_df = gr.components.Dataframe(
headers=["🏆", "🧑 User", "🤖 Model id", "📊 IQM episodic return"],
datatype=["number", "markdown", "markdown", "number"],
)
with gr.Column(scale=1):
with gr.Row(): # Display the env_id and the winner
gr_winner = gr.Markdown()
with gr.Row(): # Play the video of the best model
gr_video = gr.PlayableVideo( # Doesn't loop for the moment, see https://github.com/gradio-app/gradio/issues/7689,
min_width=50,
show_download_button=False,
show_share_button=False,
show_label=False,
interactive=False,
)
all_gr_dfs[env_id] = gr_df
all_gr_winners[env_id] = gr_winner
all_gr_videos[env_id] = gr_video
tab.select(refresh_one_video(df, env_id), outputs=[gr_video])
# Load the first video of the first environment
demo.load(refresh_one_video(df, env_ids[0]), outputs=[all_gr_videos[env_ids[0]]])
with gr.TabItem("🚀 Getting my agent evaluated"):
with open("texts/getting_my_agent_evaluated.md") as fp:
gr.Markdown(fp.read())
with gr.TabItem("📝 About"):
with open("texts/about.md") as fp:
gr.Markdown(fp.read())
demo.load(refresh, outputs=list(all_gr_dfs.values()) + list(all_gr_winners.values()) + [num_models_md])
scheduler = BackgroundScheduler()
scheduler.add_job(func=update_globals, trigger="interval", seconds=REFRESH_RATE, max_instances=1)
scheduler.start()
if __name__ == "__main__":
demo.queue().launch()