Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import logging | |
import os | |
import gradio as gr | |
import numpy as np | |
import pandas as pd | |
import scipy.stats | |
from apscheduler.schedulers.background import BackgroundScheduler | |
from datasets import load_dataset | |
from huggingface_hub import HfApi | |
# Set up logging | |
logger = logging.getLogger("app") | |
logger.setLevel(logging.INFO) | |
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") | |
ch = logging.StreamHandler() | |
ch.setFormatter(formatter) | |
logger.addHandler(ch) | |
# Disable the absl logger (annoying) | |
logging.getLogger("absl").setLevel(logging.WARNING) | |
API = HfApi(token=os.environ.get("TOKEN")) | |
RESULTS_REPO = "open-rl-leaderboard/results_v2" | |
REFRESH_RATE = 5 * 60 # 5 minutes | |
ALL_ENV_IDS = { | |
"Atari": [ | |
"AdventureNoFrameskip-v4", | |
"AirRaidNoFrameskip-v4", | |
"AlienNoFrameskip-v4", | |
"AmidarNoFrameskip-v4", | |
"AssaultNoFrameskip-v4", | |
"AsterixNoFrameskip-v4", | |
"AsteroidsNoFrameskip-v4", | |
"AtlantisNoFrameskip-v4", | |
"BankHeistNoFrameskip-v4", | |
"BattleZoneNoFrameskip-v4", | |
"BeamRiderNoFrameskip-v4", | |
"BerzerkNoFrameskip-v4", | |
"BowlingNoFrameskip-v4", | |
"BoxingNoFrameskip-v4", | |
"BreakoutNoFrameskip-v4", | |
"CarnivalNoFrameskip-v4", | |
"CentipedeNoFrameskip-v4", | |
"ChopperCommandNoFrameskip-v4", | |
"CrazyClimberNoFrameskip-v4", | |
"DefenderNoFrameskip-v4", | |
"DemonAttackNoFrameskip-v4", | |
"DoubleDunkNoFrameskip-v4", | |
"ElevatorActionNoFrameskip-v4", | |
"EnduroNoFrameskip-v4", | |
"FishingDerbyNoFrameskip-v4", | |
"FreewayNoFrameskip-v4", | |
"FrostbiteNoFrameskip-v4", | |
"GopherNoFrameskip-v4", | |
"GravitarNoFrameskip-v4", | |
"HeroNoFrameskip-v4", | |
"IceHockeyNoFrameskip-v4", | |
"JamesbondNoFrameskip-v4", | |
"JourneyEscapeNoFrameskip-v4", | |
"KangarooNoFrameskip-v4", | |
"KrullNoFrameskip-v4", | |
"KungFuMasterNoFrameskip-v4", | |
"MontezumaRevengeNoFrameskip-v4", | |
"MsPacmanNoFrameskip-v4", | |
"NameThisGameNoFrameskip-v4", | |
"PhoenixNoFrameskip-v4", | |
"PitfallNoFrameskip-v4", | |
"PongNoFrameskip-v4", | |
"PooyanNoFrameskip-v4", | |
"PrivateEyeNoFrameskip-v4", | |
"QbertNoFrameskip-v4", | |
"RiverraidNoFrameskip-v4", | |
"RoadRunnerNoFrameskip-v4", | |
"RobotankNoFrameskip-v4", | |
"SeaquestNoFrameskip-v4", | |
"SkiingNoFrameskip-v4", | |
"SolarisNoFrameskip-v4", | |
"SpaceInvadersNoFrameskip-v4", | |
"StarGunnerNoFrameskip-v4", | |
"TennisNoFrameskip-v4", | |
"TimePilotNoFrameskip-v4", | |
"TutankhamNoFrameskip-v4", | |
"UpNDownNoFrameskip-v4", | |
"VentureNoFrameskip-v4", | |
"VideoPinballNoFrameskip-v4", | |
"WizardOfWorNoFrameskip-v4", | |
"YarsRevengeNoFrameskip-v4", | |
"ZaxxonNoFrameskip-v4", | |
], | |
"Box2D": [ | |
"BipedalWalker-v3", | |
"BipedalWalkerHardcore-v3", | |
"CarRacing-v2", | |
"LunarLander-v2", | |
"LunarLanderContinuous-v2", | |
], | |
"Toy text": [ | |
"Blackjack-v1", | |
"CliffWalking-v0", | |
"FrozenLake-v1", | |
"FrozenLake8x8-v1", | |
], | |
"Classic control": [ | |
"Acrobot-v1", | |
"CartPole-v1", | |
"MountainCar-v0", | |
"MountainCarContinuous-v0", | |
"Pendulum-v1", | |
], | |
"MuJoCo": [ | |
"Ant-v4", | |
"HalfCheetah-v4", | |
"Hopper-v4", | |
"Humanoid-v4", | |
"HumanoidStandup-v4", | |
"InvertedDoublePendulum-v4", | |
"InvertedPendulum-v4", | |
"Pusher-v4", | |
"Reacher-v4", | |
"Swimmer-v4", | |
"Walker2d-v4", | |
], | |
"PyBullet": [ | |
"AntBulletEnv-v0", | |
"HalfCheetahBulletEnv-v0", | |
"HopperBulletEnv-v0", | |
"HumanoidBulletEnv-v0", | |
"InvertedDoublePendulumBulletEnv-v0", | |
"InvertedPendulumSwingupBulletEnv-v0", | |
"MinitaurBulletEnv-v0", | |
"ReacherBulletEnv-v0", | |
"Walker2DBulletEnv-v0", | |
], | |
} | |
def iqm(x): | |
return scipy.stats.trim_mean(x, proportiontocut=0.25, axis=None) | |
def get_leaderboard_df(): | |
logger.info("Downloading results") | |
dataset = load_dataset(RESULTS_REPO, split="train") # split is not important, but we need to use "train") | |
df = dataset.to_pandas() # convert to pandas dataframe | |
df = df[df["status"] == "DONE"] # keep only the models that are done | |
df["iqm_episodic_return"] = df["episodic_returns"].apply(iqm) | |
logger.debug("Results downloaded") | |
return df | |
def select_env(df: pd.DataFrame, env_id: str): | |
df = df[df["env_id"] == env_id] | |
df = df.sort_values("iqm_episodic_return", ascending=False) | |
df["ranking"] = np.arange(1, len(df) + 1) | |
return df | |
def format_df(df: pd.DataFrame): | |
# Add hyperlinks | |
df = df.copy() | |
for index, row in df.iterrows(): | |
user_id = row["user_id"] | |
model_id = row["model_id"] | |
df.loc[index, "user_id"] = f"[{user_id}](https://huggingface.co/{user_id})" | |
df.loc[index, "model_id"] = f"[{model_id}](https://huggingface.co/{user_id}/{model_id})" | |
# Keep only the relevant columns | |
df = df[["ranking", "user_id", "model_id", "iqm_episodic_return"]] | |
return df.values.tolist() | |
def refresh_video(df, env_id): | |
env_df = select_env(df, env_id) | |
if not env_df.empty: | |
user_id = env_df.iloc[0]["user_id"] | |
model_id = env_df.iloc[0]["model_id"] | |
sha = env_df.iloc[0]["sha"] | |
repo_id = f"{user_id}/{model_id}" | |
try: | |
video_path = API.hf_hub_download(repo_id=repo_id, filename="replay.mp4", revision=sha, repo_type="model") | |
return video_path | |
except Exception as e: | |
logger.error(f"Error while downloading video for {env_id}: {e}") | |
return None | |
else: | |
return None | |
def refresh_one_video(df, env_id): | |
def inner(): | |
return refresh_video(df, env_id) | |
return inner | |
def refresh_winner(df, env_id): | |
# print("Refreshing winners") | |
env_df = select_env(df, env_id) | |
if not env_df.empty: | |
user_id = env_df.iloc[0]["user_id"] | |
model_id = env_df.iloc[0]["model_id"] | |
url = f"https://huggingface.co/{user_id}/{model_id}" | |
return f"""## {env_id} | |
### 🏆 [Best model]({url}) 🏆""" | |
else: | |
return f"""## {env_id} | |
This leaderboard is quite empty... 😢 | |
Be the first to submit your model! | |
Check the tab "🚀 Getting my agent evaluated" | |
""" | |
def refresh_num_models(df): | |
return f"The leaderboard currently contains {len(df):,} models." | |
css = """ | |
.generating { | |
border: none; | |
} | |
h2 { | |
text-align: center; | |
} | |
h3 { | |
text-align: center; | |
} | |
""" | |
def update_globals(): | |
global dataframes, winner_texts, video_pathes, num_models_str, df | |
df = get_leaderboard_df() | |
all_env_ids = [env_id for env_ids in ALL_ENV_IDS.values() for env_id in env_ids] | |
dataframes = {env_id: format_df(select_env(df, env_id)) for env_id in all_env_ids} | |
winner_texts = {env_id: refresh_winner(df, env_id) for env_id in all_env_ids} | |
video_pathes = {env_id: refresh_video(df, env_id) for env_id in all_env_ids} | |
num_models_str = refresh_num_models(df) | |
update_globals() | |
def refresh(): | |
global dataframes, winner_texts, num_models_str | |
return list(dataframes.values()) + list(winner_texts.values()) + [num_models_str] | |
with gr.Blocks(css=css) as demo: | |
with open("texts/heading.md") as fp: | |
gr.Markdown(fp.read()) | |
num_models_md = gr.Markdown() | |
with gr.Tabs(elem_classes="tab-buttons") as tabs: | |
with gr.TabItem("🏅 Leaderboard"): | |
all_gr_dfs = {} | |
all_gr_winners = {} | |
all_gr_videos = {} | |
for env_domain, env_ids in ALL_ENV_IDS.items(): | |
with gr.TabItem(env_domain): | |
for env_id in env_ids: | |
# If the env_id envs with "NoFrameskip-v4", we remove it to improve readability | |
tab_env_id = env_id[: -len("NoFrameskip-v4")] if env_id.endswith("NoFrameskip-v4") else env_id | |
with gr.TabItem(tab_env_id) as tab: | |
logger.debug(f"Creating tab for {env_id}") | |
with gr.Row(equal_height=False): | |
with gr.Column(scale=3): | |
gr_df = gr.components.Dataframe( | |
headers=["🏆", "🧑 User", "🤖 Model id", "📊 IQM episodic return"], | |
datatype=["number", "markdown", "markdown", "number"], | |
) | |
with gr.Column(scale=1): | |
with gr.Row(): # Display the env_id and the winner | |
gr_winner = gr.Markdown() | |
with gr.Row(): # Play the video of the best model | |
gr_video = gr.PlayableVideo( # Doesn't loop for the moment, see https://github.com/gradio-app/gradio/issues/7689, | |
min_width=50, | |
show_download_button=False, | |
show_share_button=False, | |
show_label=False, | |
interactive=False, | |
) | |
all_gr_dfs[env_id] = gr_df | |
all_gr_winners[env_id] = gr_winner | |
all_gr_videos[env_id] = gr_video | |
tab.select(refresh_one_video(df, env_id), outputs=[gr_video]) | |
# Load the first video of the first environment | |
demo.load(refresh_one_video(df, env_ids[0]), outputs=[all_gr_videos[env_ids[0]]]) | |
with gr.TabItem("🚀 Getting my agent evaluated"): | |
with open("texts/getting_my_agent_evaluated.md") as fp: | |
gr.Markdown(fp.read()) | |
with gr.TabItem("📝 About"): | |
with open("texts/about.md") as fp: | |
gr.Markdown(fp.read()) | |
demo.load(refresh, outputs=list(all_gr_dfs.values()) + list(all_gr_winners.values()) + [num_models_md]) | |
scheduler = BackgroundScheduler() | |
scheduler.add_job(func=update_globals, trigger="interval", seconds=REFRESH_RATE, max_instances=1) | |
scheduler.start() | |
if __name__ == "__main__": | |
demo.queue().launch() | |