Spaces:
Sleeping
Sleeping
import fnmatch | |
import os | |
import random | |
import time | |
import pybullet_envs_gymnasium # noqa: F401 pylint: disable=unused-import | |
from datasets import load_dataset | |
from huggingface_hub import HfApi | |
from src.evaluation import evaluate | |
from src.logging import setup_logger | |
logger = setup_logger(__name__) | |
API = HfApi(token=os.environ.get("TOKEN")) | |
RESULTS_REPO = "open-rl-leaderboard/results_v2" | |
ALL_ENV_IDS = [ | |
"AdventureNoFrameskip-v4", | |
"AirRaidNoFrameskip-v4", | |
"AlienNoFrameskip-v4", | |
"AmidarNoFrameskip-v4", | |
"AssaultNoFrameskip-v4", | |
"AsterixNoFrameskip-v4", | |
"AsteroidsNoFrameskip-v4", | |
"AtlantisNoFrameskip-v4", | |
"BankHeistNoFrameskip-v4", | |
"BattleZoneNoFrameskip-v4", | |
"BeamRiderNoFrameskip-v4", | |
"BerzerkNoFrameskip-v4", | |
"BowlingNoFrameskip-v4", | |
"BoxingNoFrameskip-v4", | |
"BreakoutNoFrameskip-v4", | |
"CarnivalNoFrameskip-v4", | |
"CentipedeNoFrameskip-v4", | |
"ChopperCommandNoFrameskip-v4", | |
"CrazyClimberNoFrameskip-v4", | |
"DefenderNoFrameskip-v4", | |
"DemonAttackNoFrameskip-v4", | |
"DoubleDunkNoFrameskip-v4", | |
"ElevatorActionNoFrameskip-v4", | |
"EnduroNoFrameskip-v4", | |
"FishingDerbyNoFrameskip-v4", | |
"FreewayNoFrameskip-v4", | |
"FrostbiteNoFrameskip-v4", | |
"GopherNoFrameskip-v4", | |
"GravitarNoFrameskip-v4", | |
"HeroNoFrameskip-v4", | |
"IceHockeyNoFrameskip-v4", | |
"JamesbondNoFrameskip-v4", | |
"JourneyEscapeNoFrameskip-v4", | |
"KangarooNoFrameskip-v4", | |
"KrullNoFrameskip-v4", | |
"KungFuMasterNoFrameskip-v4", | |
"MontezumaRevengeNoFrameskip-v4", | |
"MsPacmanNoFrameskip-v4", | |
"NameThisGameNoFrameskip-v4", | |
"PhoenixNoFrameskip-v4", | |
"PitfallNoFrameskip-v4", | |
"PongNoFrameskip-v4", | |
"PooyanNoFrameskip-v4", | |
"PrivateEyeNoFrameskip-v4", | |
"QbertNoFrameskip-v4", | |
"RiverraidNoFrameskip-v4", | |
"RoadRunnerNoFrameskip-v4", | |
"RobotankNoFrameskip-v4", | |
"SeaquestNoFrameskip-v4", | |
"SkiingNoFrameskip-v4", | |
"SolarisNoFrameskip-v4", | |
"SpaceInvadersNoFrameskip-v4", | |
"StarGunnerNoFrameskip-v4", | |
"TennisNoFrameskip-v4", | |
"TimePilotNoFrameskip-v4", | |
"TutankhamNoFrameskip-v4", | |
"UpNDownNoFrameskip-v4", | |
"VentureNoFrameskip-v4", | |
"VideoPinballNoFrameskip-v4", | |
"WizardOfWorNoFrameskip-v4", | |
"YarsRevengeNoFrameskip-v4", | |
"ZaxxonNoFrameskip-v4", | |
# Box2D | |
"BipedalWalker-v3", | |
"BipedalWalkerHardcore-v3", | |
"CarRacing-v2", | |
"LunarLander-v2", | |
"LunarLanderContinuous-v2", | |
# Toy text | |
"Blackjack-v1", | |
"CliffWalking-v0", | |
"FrozenLake-v1", | |
"FrozenLake8x8-v1", | |
# Classic control | |
"Acrobot-v1", | |
"CartPole-v1", | |
"MountainCar-v0", | |
"MountainCarContinuous-v0", | |
"Pendulum-v1", | |
# MuJoCo | |
"Ant-v4", | |
"HalfCheetah-v4", | |
"Hopper-v4", | |
"Humanoid-v4", | |
"HumanoidStandup-v4", | |
"InvertedDoublePendulum-v4", | |
"InvertedPendulum-v4", | |
"Pusher-v4", | |
"Reacher-v4", | |
"Swimmer-v4", | |
"Walker2d-v4", | |
# PyBullet | |
"AntBulletEnv-v0", | |
"HalfCheetahBulletEnv-v0", | |
"HopperBulletEnv-v0", | |
"HumanoidBulletEnv-v0", | |
"InvertedDoublePendulumBulletEnv-v0", | |
"InvertedPendulumSwingupBulletEnv-v0", | |
"MinitaurBulletEnv-v0", | |
"ReacherBulletEnv-v0", | |
"Walker2DBulletEnv-v0", | |
] | |
def pattern_match(patterns, source_list): | |
if isinstance(patterns, str): | |
patterns = [patterns] | |
env_ids = set() | |
for pattern in patterns: | |
for matching in fnmatch.filter(source_list, pattern): | |
env_ids.add(matching) | |
return sorted(list(env_ids)) | |
def _backend_routine(): | |
# List only the text classification models | |
rl_models = list(API.list_models(filter=["reinforcement-learning"])) | |
logger.info(f"Found {len(rl_models)} RL models") | |
compatible_models = [] | |
for model in rl_models: | |
filenames = [sib.rfilename for sib in model.siblings] | |
if "agent.pt" in filenames: | |
compatible_models.append((model.modelId, model.sha)) | |
logger.info(f"Found {len(compatible_models)} compatible models") | |
dataset = load_dataset(RESULTS_REPO, split="train", download_mode="force_redownload", verification_mode="no_checks") | |
evaluated_models = [("/".join([x["user_id"], x["model_id"]]), x["sha"]) for x in dataset] | |
pending_models = list(set(compatible_models) - set(evaluated_models)) | |
logger.info(f"Found {len(pending_models)} pending models") | |
if len(pending_models) == 0: | |
return None | |
# Shuffle the dataset | |
random.shuffle(pending_models) | |
# Select a random model | |
repo_id, sha = pending_models.pop() | |
user_id, model_id = repo_id.split("/") | |
row = {"model_id": model_id, "user_id": user_id, "sha": sha} | |
# Run an evaluation on the models | |
model_info = API.model_info(repo_id, revision=sha) | |
# Extract the environment IDs from the tags (usually only one) | |
env_ids = pattern_match(model_info.tags, ALL_ENV_IDS) | |
if len(env_ids) > 0: | |
env_id = env_ids[0] | |
logger.info(f"Running evaluation on {user_id}/{model_id}") | |
try: | |
episodic_returns = evaluate(repo_id, sha, env_id) | |
row["status"] = "DONE" | |
row["env_id"] = env_id | |
row["episodic_returns"] = episodic_returns | |
except Exception as e: | |
logger.error(f"Error evaluating {repo_id}: {e}") | |
logger.exception(e) | |
row["status"] = "FAILED" | |
else: | |
logger.error(f"No environment found for {model_id}") | |
row["status"] = "FAILED" | |
# load the last version of the dataset | |
dataset = load_dataset( | |
RESULTS_REPO, split="train", download_mode="force_redownload", verification_mode="no_checks" | |
) | |
dataset.add_item(row) | |
dataset.push_to_hub(RESULTS_REPO, split="train", token=API.token) | |
time.sleep(60) # Sleep for 1 minute to avoid rate limiting | |
def backend_routine(): | |
try: | |
_backend_routine() | |
except Exception as e: | |
logger.error(f"{e.__class__.__name__}: {str(e)}") | |
logger.exception(e) | |
if __name__ == "__main__": | |
backend_routine() | |