backend / src /backend.py
qgallouedec's picture
qgallouedec HF staff
fix generator
5028d3c
raw
history blame
No virus
6.01 kB
import fnmatch
import os
import random
import time
import pybullet_envs_gymnasium # noqa: F401 pylint: disable=unused-import
from datasets import load_dataset
from huggingface_hub import HfApi
from src.evaluation import evaluate
from src.logging import setup_logger
logger = setup_logger(__name__)
API = HfApi(token=os.environ.get("TOKEN"))
RESULTS_REPO = "open-rl-leaderboard/results_v2"
ALL_ENV_IDS = [
"AdventureNoFrameskip-v4",
"AirRaidNoFrameskip-v4",
"AlienNoFrameskip-v4",
"AmidarNoFrameskip-v4",
"AssaultNoFrameskip-v4",
"AsterixNoFrameskip-v4",
"AsteroidsNoFrameskip-v4",
"AtlantisNoFrameskip-v4",
"BankHeistNoFrameskip-v4",
"BattleZoneNoFrameskip-v4",
"BeamRiderNoFrameskip-v4",
"BerzerkNoFrameskip-v4",
"BowlingNoFrameskip-v4",
"BoxingNoFrameskip-v4",
"BreakoutNoFrameskip-v4",
"CarnivalNoFrameskip-v4",
"CentipedeNoFrameskip-v4",
"ChopperCommandNoFrameskip-v4",
"CrazyClimberNoFrameskip-v4",
"DefenderNoFrameskip-v4",
"DemonAttackNoFrameskip-v4",
"DoubleDunkNoFrameskip-v4",
"ElevatorActionNoFrameskip-v4",
"EnduroNoFrameskip-v4",
"FishingDerbyNoFrameskip-v4",
"FreewayNoFrameskip-v4",
"FrostbiteNoFrameskip-v4",
"GopherNoFrameskip-v4",
"GravitarNoFrameskip-v4",
"HeroNoFrameskip-v4",
"IceHockeyNoFrameskip-v4",
"JamesbondNoFrameskip-v4",
"JourneyEscapeNoFrameskip-v4",
"KangarooNoFrameskip-v4",
"KrullNoFrameskip-v4",
"KungFuMasterNoFrameskip-v4",
"MontezumaRevengeNoFrameskip-v4",
"MsPacmanNoFrameskip-v4",
"NameThisGameNoFrameskip-v4",
"PhoenixNoFrameskip-v4",
"PitfallNoFrameskip-v4",
"PongNoFrameskip-v4",
"PooyanNoFrameskip-v4",
"PrivateEyeNoFrameskip-v4",
"QbertNoFrameskip-v4",
"RiverraidNoFrameskip-v4",
"RoadRunnerNoFrameskip-v4",
"RobotankNoFrameskip-v4",
"SeaquestNoFrameskip-v4",
"SkiingNoFrameskip-v4",
"SolarisNoFrameskip-v4",
"SpaceInvadersNoFrameskip-v4",
"StarGunnerNoFrameskip-v4",
"TennisNoFrameskip-v4",
"TimePilotNoFrameskip-v4",
"TutankhamNoFrameskip-v4",
"UpNDownNoFrameskip-v4",
"VentureNoFrameskip-v4",
"VideoPinballNoFrameskip-v4",
"WizardOfWorNoFrameskip-v4",
"YarsRevengeNoFrameskip-v4",
"ZaxxonNoFrameskip-v4",
# Box2D
"BipedalWalker-v3",
"BipedalWalkerHardcore-v3",
"CarRacing-v2",
"LunarLander-v2",
"LunarLanderContinuous-v2",
# Toy text
"Blackjack-v1",
"CliffWalking-v0",
"FrozenLake-v1",
"FrozenLake8x8-v1",
# Classic control
"Acrobot-v1",
"CartPole-v1",
"MountainCar-v0",
"MountainCarContinuous-v0",
"Pendulum-v1",
# MuJoCo
"Ant-v4",
"HalfCheetah-v4",
"Hopper-v4",
"Humanoid-v4",
"HumanoidStandup-v4",
"InvertedDoublePendulum-v4",
"InvertedPendulum-v4",
"Pusher-v4",
"Reacher-v4",
"Swimmer-v4",
"Walker2d-v4",
# PyBullet
"AntBulletEnv-v0",
"HalfCheetahBulletEnv-v0",
"HopperBulletEnv-v0",
"HumanoidBulletEnv-v0",
"InvertedDoublePendulumBulletEnv-v0",
"InvertedPendulumSwingupBulletEnv-v0",
"MinitaurBulletEnv-v0",
"ReacherBulletEnv-v0",
"Walker2DBulletEnv-v0",
]
def pattern_match(patterns, source_list):
if isinstance(patterns, str):
patterns = [patterns]
env_ids = set()
for pattern in patterns:
for matching in fnmatch.filter(source_list, pattern):
env_ids.add(matching)
return sorted(list(env_ids))
def _backend_routine():
# List only the text classification models
rl_models = list(API.list_models(filter=["reinforcement-learning"]))
logger.info(f"Found {len(rl_models)} RL models")
compatible_models = []
for model in rl_models:
filenames = [sib.rfilename for sib in model.siblings]
if "agent.pt" in filenames:
compatible_models.append((model.modelId, model.sha))
logger.info(f"Found {len(compatible_models)} compatible models")
dataset = load_dataset(RESULTS_REPO, split="train", download_mode="force_redownload", verification_mode="no_checks")
evaluated_models = [("/".join([x["user_id"], x["model_id"]]), x["sha"]) for x in dataset]
pending_models = list(set(compatible_models) - set(evaluated_models))
logger.info(f"Found {len(pending_models)} pending models")
if len(pending_models) == 0:
return None
# Shuffle the dataset
random.shuffle(pending_models)
# Select a random model
repo_id, sha = pending_models.pop()
user_id, model_id = repo_id.split("/")
row = {"model_id": model_id, "user_id": user_id, "sha": sha}
# Run an evaluation on the models
model_info = API.model_info(repo_id, revision=sha)
# Extract the environment IDs from the tags (usually only one)
env_ids = pattern_match(model_info.tags, ALL_ENV_IDS)
if len(env_ids) > 0:
env_id = env_ids[0]
logger.info(f"Running evaluation on {user_id}/{model_id}")
try:
episodic_returns = evaluate(repo_id, sha, env_id)
row["status"] = "DONE"
row["env_id"] = env_id
row["episodic_returns"] = episodic_returns
except Exception as e:
logger.error(f"Error evaluating {repo_id}: {e}")
logger.exception(e)
row["status"] = "FAILED"
else:
logger.error(f"No environment found for {model_id}")
row["status"] = "FAILED"
# load the last version of the dataset
dataset = load_dataset(
RESULTS_REPO, split="train", download_mode="force_redownload", verification_mode="no_checks"
)
dataset.add_item(row)
dataset.push_to_hub(RESULTS_REPO, split="train", token=API.token)
time.sleep(60) # Sleep for 1 minute to avoid rate limiting
def backend_routine():
try:
_backend_routine()
except Exception as e:
logger.error(f"{e.__class__.__name__}: {str(e)}")
logger.exception(e)
if __name__ == "__main__":
backend_routine()