Spaces:
Sleeping
Sleeping
import fnmatch | |
import importlib | |
import json | |
import os | |
import random | |
import shutil | |
import sys | |
import time | |
import zipfile | |
from pathlib import Path | |
from typing import Optional | |
import numpy as np | |
import rl_zoo3.import_envs # noqa: F401 pylint: disable=unused-import | |
import torch as th | |
import yaml | |
from datasets import load_dataset | |
from huggingface_hub import HfApi | |
from huggingface_hub.utils import EntryNotFoundError | |
from huggingface_sb3 import EnvironmentName, ModelName, ModelRepoId, load_from_hub | |
from requests.exceptions import HTTPError | |
from rl_zoo3 import ALGOS, create_test_env, get_latest_run_id, get_saved_hyperparams | |
from rl_zoo3.exp_manager import ExperimentManager | |
from rl_zoo3.utils import get_model_path | |
from stable_baselines3.common.utils import set_random_seed | |
from src.logging import setup_logger | |
ALL_ENV_IDS = [ | |
"AdventureNoFrameskip-v4", | |
"AirRaidNoFrameskip-v4", | |
"AlienNoFrameskip-v4", | |
"AmidarNoFrameskip-v4", | |
"AssaultNoFrameskip-v4", | |
"AsterixNoFrameskip-v4", | |
"AsteroidsNoFrameskip-v4", | |
"AtlantisNoFrameskip-v4", | |
"BankHeistNoFrameskip-v4", | |
"BattleZoneNoFrameskip-v4", | |
"BeamRiderNoFrameskip-v4", | |
"BerzerkNoFrameskip-v4", | |
"BowlingNoFrameskip-v4", | |
"BoxingNoFrameskip-v4", | |
"BreakoutNoFrameskip-v4", | |
"CarnivalNoFrameskip-v4", | |
"CentipedeNoFrameskip-v4", | |
"ChopperCommandNoFrameskip-v4", | |
"CrazyClimberNoFrameskip-v4", | |
"DefenderNoFrameskip-v4", | |
"DemonAttackNoFrameskip-v4", | |
"DoubleDunkNoFrameskip-v4", | |
"ElevatorActionNoFrameskip-v4", | |
"EnduroNoFrameskip-v4", | |
"FishingDerbyNoFrameskip-v4", | |
"FreewayNoFrameskip-v4", | |
"FrostbiteNoFrameskip-v4", | |
"GopherNoFrameskip-v4", | |
"GravitarNoFrameskip-v4", | |
"HeroNoFrameskip-v4", | |
"IceHockeyNoFrameskip-v4", | |
"JamesbondNoFrameskip-v4", | |
"JourneyEscapeNoFrameskip-v4", | |
"KangarooNoFrameskip-v4", | |
"KrullNoFrameskip-v4", | |
"KungFuMasterNoFrameskip-v4", | |
"MontezumaRevengeNoFrameskip-v4", | |
"MsPacmanNoFrameskip-v4", | |
"NameThisGameNoFrameskip-v4", | |
"PhoenixNoFrameskip-v4", | |
"PitfallNoFrameskip-v4", | |
"PongNoFrameskip-v4", | |
"PooyanNoFrameskip-v4", | |
"PrivateEyeNoFrameskip-v4", | |
"QbertNoFrameskip-v4", | |
"RiverraidNoFrameskip-v4", | |
"RoadRunnerNoFrameskip-v4", | |
"RobotankNoFrameskip-v4", | |
"SeaquestNoFrameskip-v4", | |
"SkiingNoFrameskip-v4", | |
"SolarisNoFrameskip-v4", | |
"SpaceInvadersNoFrameskip-v4", | |
"StarGunnerNoFrameskip-v4", | |
"TennisNoFrameskip-v4", | |
"TimePilotNoFrameskip-v4", | |
"TutankhamNoFrameskip-v4", | |
"UpNDownNoFrameskip-v4", | |
"VentureNoFrameskip-v4", | |
"VideoPinballNoFrameskip-v4", | |
"WizardOfWorNoFrameskip-v4", | |
"YarsRevengeNoFrameskip-v4", | |
"ZaxxonNoFrameskip-v4", | |
# Box2D | |
"BipedalWalker-v3", | |
"BipedalWalkerHardcore-v3", | |
"CarRacing-v2", | |
"LunarLander-v2", | |
"LunarLanderContinuous-v2", | |
# Toy text | |
"Blackjack-v1", | |
"CliffWalking-v0", | |
"FrozenLake-v1", | |
"FrozenLake8x8-v1", | |
# Classic control | |
"Acrobot-v1", | |
"CartPole-v1", | |
"MountainCar-v0", | |
"MountainCarContinuous-v0", | |
"Pendulum-v1", | |
# MuJoCo | |
"Ant-v4", | |
"HalfCheetah-v4", | |
"Hopper-v4", | |
"Humanoid-v4", | |
"HumanoidStandup-v4", | |
"InvertedDoublePendulum-v4", | |
"InvertedPendulum-v4", | |
"Pusher-v4", | |
"Reacher-v4", | |
"Swimmer-v4", | |
"Walker2d-v4", | |
# PyBullet | |
"AntBulletEnv-v0", | |
"HalfCheetahBulletEnv-v0", | |
"HopperBulletEnv-v0", | |
"HumanoidBulletEnv-v0", | |
"InvertedDoublePendulumBulletEnv-v0", | |
"InvertedPendulumSwingupBulletEnv-v0", | |
"MinitaurBulletEnv-v0", | |
"ReacherBulletEnv-v0", | |
"Walker2DBulletEnv-v0", | |
] | |
def download_from_hub( | |
algo: str, | |
env_name: EnvironmentName, | |
exp_id: int, | |
folder: str, | |
organization: str, | |
repo_name: Optional[str] = None, | |
force: bool = False, | |
) -> None: | |
""" | |
Try to load a model from the Huggingface hub | |
and save it following the RL Zoo structure. | |
Default repo name is {organization}/{algo}-{env_id} | |
where repo_name = {algo}-{env_id} | |
:param algo: Algorithm | |
:param env_name: Environment name | |
:param exp_id: Experiment id | |
:param folder: Log folder | |
:param organization: Huggingface organization | |
:param repo_name: Overwrite default repository name | |
:param force: Allow overwritting the folder | |
if it already exists. | |
""" | |
model_name = ModelName(algo, env_name) | |
if repo_name is None: | |
repo_name = model_name # Note: model name is {algo}-{env_name} | |
# Note: repo id is {organization}/{repo_name} | |
repo_id = ModelRepoId(organization, repo_name) | |
logger.info(f"Downloading from https://huggingface.co/{repo_id}") | |
checkpoint = load_from_hub(repo_id, model_name.filename) | |
try: | |
config_path = load_from_hub(repo_id, "config.yml") | |
except EntryNotFoundError: # hotfix for old models | |
config_path = load_from_hub(repo_id, "config.json") | |
with open(config_path, "r") as f: | |
config = json.load(f) | |
config_path = config_path.replace(".json", ".yml") | |
with open(config_path, "w") as f: | |
yaml.dump(config, f) | |
# If VecNormalize, download | |
try: | |
vec_normalize_stats = load_from_hub(repo_id, "vec_normalize.pkl") | |
except HTTPError: | |
logger.info("No normalization file") | |
vec_normalize_stats = None | |
try: | |
saved_args = load_from_hub(repo_id, "args.yml") | |
except EntryNotFoundError: | |
logger.info("No args file") | |
saved_args = None | |
try: | |
env_kwargs = load_from_hub(repo_id, "env_kwargs.yml") | |
except EntryNotFoundError: | |
logger.info("No env_kwargs file") | |
env_kwargs = None | |
try: | |
train_eval_metrics = load_from_hub(repo_id, "train_eval_metrics.zip") | |
except EntryNotFoundError: | |
logger.info("No train_eval_metrics file") | |
train_eval_metrics = None | |
if exp_id == 0: | |
exp_id = get_latest_run_id(os.path.join(folder, algo), env_name) + 1 | |
# Sanity checks | |
if exp_id > 0: | |
log_path = os.path.join(folder, algo, f"{env_name}_{exp_id}") | |
else: | |
log_path = os.path.join(folder, algo) | |
# Check that the folder does not exist | |
log_folder = Path(log_path) | |
if log_folder.is_dir(): | |
if force: | |
logger.info(f"The folder {log_path} already exists, overwritting") | |
# Delete the current one to avoid errors | |
shutil.rmtree(log_path) | |
else: | |
raise ValueError( | |
f"The folder {log_path} already exists, use --force to overwrite it, " "or choose '--exp-id 0' to create a new folder" | |
) | |
logger.info(f"Saving to {log_path}") | |
# Create folder structure | |
os.makedirs(log_path, exist_ok=True) | |
config_folder = os.path.join(log_path, env_name) | |
os.makedirs(config_folder, exist_ok=True) | |
# Copy config files and saved stats | |
shutil.copy(checkpoint, os.path.join(log_path, f"{env_name}.zip")) | |
if saved_args is not None: | |
shutil.copy(saved_args, os.path.join(config_folder, "args.yml")) | |
shutil.copy(config_path, os.path.join(config_folder, "config.yml")) | |
if env_kwargs is not None: | |
shutil.copy(env_kwargs, os.path.join(config_folder, "env_kwargs.yml")) | |
if vec_normalize_stats is not None: | |
shutil.copy(vec_normalize_stats, os.path.join(config_folder, "vecnormalize.pkl")) | |
# Extract monitor file and evaluation file | |
if train_eval_metrics is not None: | |
with zipfile.ZipFile(train_eval_metrics, "r") as zip_ref: | |
zip_ref.extractall(log_path) | |
def pattern_match(patterns, source_list): | |
if isinstance(patterns, str): | |
patterns = [patterns] | |
env_ids = set() | |
for pattern in patterns: | |
for matching in fnmatch.filter(source_list, pattern): | |
env_ids.add(matching) | |
return sorted(list(env_ids)) | |
def evaluate( | |
user_id, | |
repo_name, | |
env="CartPole-v1", | |
folder="rl-trained-agents", | |
algo="ppo", | |
# n_timesteps=1000, | |
n_episodes=50, | |
num_threads=-1, | |
n_envs=1, | |
exp_id=0, | |
verbose=1, | |
no_render=False, | |
deterministic=False, | |
device="auto", | |
load_best=False, | |
load_checkpoint=None, | |
load_last_checkpoint=False, | |
stochastic=False, | |
norm_reward=False, | |
seed=0, | |
reward_log="", | |
gym_packages=[], | |
env_kwargs=None, | |
custom_objects=False, | |
progress=False, | |
): | |
""" | |
Enjoy trained agent | |
:param env: (str) Environment ID | |
:param folder: (str) Log folder | |
:param algo: (str) RL Algorithm | |
:param n_timesteps: (int) Number of timesteps | |
:param num_threads: (int) Number of threads for PyTorch | |
:param n_envs: (int) Number of environments | |
:param exp_id: (int) Experiment ID (default: 0: latest, -1: no exp folder) | |
:param verbose: (int) Verbose mode (0: no output, 1: INFO) | |
:param no_render: (bool) Do not render the environment (useful for tests) | |
:param deterministic: (bool) Use deterministic actions | |
:param device: (str) PyTorch device to be use (ex: cpu, cuda...) | |
:param load_best: (bool) Load best model instead of last model if available | |
:param load_checkpoint: (int) Load checkpoint instead of last model if available | |
:param load_last_checkpoint: (bool) Load last checkpoint instead of last model if available | |
:param stochastic: (bool) Use stochastic actions | |
:param norm_reward: (bool) Normalize reward if applicable (trained with VecNormalize) | |
:param seed: (int) Random generator seed | |
:param reward_log: (str) Where to log reward | |
:param gym_packages: (List[str]) Additional external Gym environment package modules to import | |
:param env_kwargs: (Dict[str, Any]) Optional keyword argument to pass to the env constructor | |
:param custom_objects: (bool) Use custom objects to solve loading issues | |
:param progress: (bool) if toggled, display a progress bar using tqdm and rich | |
""" | |
# Going through custom gym packages to let them register in the global registory | |
for env_module in gym_packages: | |
importlib.import_module(env_module) | |
env_name = EnvironmentName(env) | |
# try: | |
# _, model_path, log_path = get_model_path( | |
# exp_id, | |
# folder, | |
# algo, | |
# env_name, | |
# load_best, | |
# load_checkpoint, | |
# load_last_checkpoint, | |
# ) | |
# except (AssertionError, ValueError) as e: | |
# # Special case for rl-trained agents | |
# # auto-download from the hub | |
# if "rl-trained-agents" not in folder: | |
# raise e | |
# else: | |
# logger.info("Pretrained model not found, trying to download it from sb3 Huggingface hub: https://huggingface.co/sb3") | |
# Auto-download | |
download_from_hub( | |
algo=algo, | |
env_name=env_name, | |
exp_id=exp_id, | |
folder=folder, | |
# organization="sb3", | |
organization=user_id, | |
# repo_name=None, | |
repo_name=repo_name, | |
force=False, | |
) | |
# Try again | |
_, model_path, log_path = get_model_path( | |
exp_id, | |
folder, | |
algo, | |
env_name, | |
load_best, | |
load_checkpoint, | |
load_last_checkpoint, | |
) | |
logger.info(f"Loading {model_path}") | |
# Off-policy algorithm only support one env for now | |
off_policy_algos = ["qrdqn", "dqn", "ddpg", "sac", "her", "td3", "tqc"] | |
set_random_seed(seed) | |
if num_threads > 0: | |
if verbose > 1: | |
logger.info(f"Setting torch.num_threads to {num_threads}") | |
th.set_num_threads(num_threads) | |
is_atari = ExperimentManager.is_atari(env_name.gym_id) | |
is_minigrid = ExperimentManager.is_minigrid(env_name.gym_id) | |
stats_path = os.path.join(log_path, env_name) | |
hyperparams, maybe_stats_path = get_saved_hyperparams(stats_path, norm_reward=norm_reward, test_mode=True) | |
# load env_kwargs if existing | |
env_kwargs = {} | |
args_path = os.path.join(log_path, env_name, "args.yml") | |
if os.path.isfile(args_path): | |
with open(args_path) as f: | |
loaded_args = yaml.load(f, Loader=yaml.UnsafeLoader) | |
if loaded_args["env_kwargs"] is not None: | |
env_kwargs = loaded_args["env_kwargs"] | |
# overwrite with command line arguments | |
if env_kwargs is not None: | |
env_kwargs.update(env_kwargs) | |
log_dir = reward_log if reward_log != "" else None | |
env = create_test_env( | |
env_name.gym_id, | |
n_envs=n_envs, | |
stats_path=maybe_stats_path, | |
seed=seed, | |
log_dir=log_dir, | |
should_render=not no_render, | |
hyperparams=hyperparams, | |
env_kwargs=env_kwargs, | |
) | |
kwargs = dict(seed=seed) | |
if algo in off_policy_algos: | |
# Dummy buffer size as we don't need memory to enjoy the trained agent | |
kwargs.update(dict(buffer_size=1)) | |
# Hack due to breaking change in v1.6 | |
# handle_timeout_termination cannot be at the same time | |
# with optimize_memory_usage | |
if "optimize_memory_usage" in hyperparams: | |
kwargs.update(optimize_memory_usage=False) | |
# Check if we are running python 3.8+ | |
# we need to patch saved model under python 3.6/3.7 to load them | |
newer_python_version = sys.version_info.major == 3 and sys.version_info.minor >= 8 | |
custom_objects = {} | |
if newer_python_version or custom_objects: | |
custom_objects = { | |
"learning_rate": 0.0, | |
"lr_schedule": lambda _: 0.0, | |
"clip_range": lambda _: 0.0, | |
} | |
if "HerReplayBuffer" in hyperparams.get("replay_buffer_class", ""): | |
kwargs["env"] = env | |
model = ALGOS[algo].load(model_path, custom_objects=custom_objects, device=device, **kwargs) | |
obs = env.reset() | |
# Deterministic by default except for atari games | |
stochastic = stochastic or (is_atari or is_minigrid) and not deterministic | |
deterministic = not stochastic | |
episode_reward = 0.0 | |
episode_rewards, episode_lengths = [], [] | |
ep_len = 0 | |
# For HER, monitor success rate | |
successes = [] | |
lstm_states = None | |
episode_start = np.ones((env.num_envs,), dtype=bool) | |
# generator = range(n_timesteps) | |
# if progress: | |
# if tqdm is None: | |
# raise ImportError("Please install tqdm and rich to use the progress bar") | |
# generator = tqdm(generator) | |
try: | |
# for _ in generator: | |
while len(episode_rewards) < n_episodes: | |
action, lstm_states = model.predict( | |
obs, # type: ignore[arg-type] | |
state=lstm_states, | |
episode_start=episode_start, | |
deterministic=deterministic, | |
) | |
obs, reward, done, infos = env.step(action) | |
episode_start = done | |
if not no_render: | |
env.render("human") | |
episode_reward += reward[0] | |
ep_len += 1 | |
if n_envs == 1: | |
# For atari the return reward is not the atari score | |
# so we have to get it from the infos dict | |
if is_atari and infos is not None and verbose >= 1: | |
episode_infos = infos[0].get("episode") | |
if episode_infos is not None: | |
logger.info(f"Atari Episode Score: {episode_infos['r']:.2f}") | |
logger.info(f"Atari Episode Length {episode_infos['l']}") | |
episode_rewards.append(episode_infos["r"]) | |
episode_lengths.append(episode_infos["l"]) | |
if done and not is_atari and verbose > 0: | |
# NOTE: for env using VecNormalize, the mean reward | |
# is a normalized reward when `--norm_reward` flag is passed | |
logger.info(f"Episode Reward: {episode_reward:.2f}") | |
logger.info(f"Episode Length {ep_len}") | |
episode_rewards.append(episode_reward) | |
episode_lengths.append(ep_len) | |
episode_reward = 0.0 | |
ep_len = 0 | |
# Reset also when the goal is achieved when using HER | |
if done and infos[0].get("is_success") is not None: | |
if verbose > 1: | |
logger.info(f"Success? {infos[0].get('is_success', False)}") | |
if infos[0].get("is_success") is not None: | |
successes.append(infos[0].get("is_success", False)) | |
episode_reward, ep_len = 0.0, 0 | |
except KeyboardInterrupt: | |
pass | |
if verbose > 0 and len(successes) > 0: | |
logger.info(f"Success rate: {100 * np.mean(successes):.2f}%") | |
if verbose > 0 and len(episode_rewards) > 0: | |
logger.info(f"{len(episode_rewards)} Episodes") | |
logger.info(f"Mean reward: {np.mean(episode_rewards):.2f} +/- {np.std(episode_rewards):.2f}") | |
if verbose > 0 and len(episode_lengths) > 0: | |
logger.info(f"Mean episode length: {np.mean(episode_lengths):.2f} +/- {np.std(episode_lengths):.2f}") | |
env.close() | |
return episode_rewards | |
logger = setup_logger(__name__) | |
API = HfApi(token=os.environ.get("TOKEN")) | |
RESULTS_REPO = "open-rl-leaderboard/results_v2" | |
def _backend_routine(): | |
# List only the text classification models | |
sb3_models = [(model.modelId, model.sha) for model in API.list_models(filter=["reinforcement-learning", "stable-baselines3"])] | |
logger.info(f"Found {len(sb3_models)} SB3 models") | |
dataset = load_dataset(RESULTS_REPO, split="train", download_mode="force_redownload", verification_mode="no_checks") | |
evaluated_models = [("/".join([x["user_id"], x["model_id"]]), x["sha"]) for x in dataset] | |
pending_models = list(set(sb3_models) - set(evaluated_models)) | |
logger.info(f"Found {len(pending_models)} pending models") | |
if len(pending_models) == 0: | |
return None | |
# Shuffle the dataset | |
random.shuffle(pending_models) | |
# Select a random model | |
repo_id, sha = pending_models.pop() | |
user_id, model_id = repo_id.split("/") | |
row = {"model_id": model_id, "user_id": user_id, "sha": sha} | |
# Run an evaluation on the models | |
try: | |
model_info = API.model_info(repo_id, revision=sha) | |
# Extract the environment IDs from the tags (usually only one) | |
env_ids = pattern_match(model_info.tags, ALL_ENV_IDS) | |
except Exception as e: | |
logger.error(f"Error fetching model info for {repo_id}: {e}") | |
logger.exception(e) | |
env_ids = [] | |
if len(env_ids) > 0: | |
env = env_ids[0] | |
logger.info(f"Running evaluation on {user_id}/{model_id}") | |
algo = model_info.model_index[0]["name"].lower() | |
try: | |
episodic_returns = evaluate(user_id, model_id, env, "rl-trained-agents", algo, no_render=True, verbose=1) | |
row["status"] = "DONE" | |
row["env_id"] = env | |
row["episodic_returns"] = episodic_returns | |
except Exception as e: | |
logger.error(f"Error evaluating {model_id}: {e}") | |
logger.exception(e) | |
row["status"] = "FAILED" | |
else: | |
logger.error(f"No environment found for {model_id}") | |
row["status"] = "FAILED" | |
# load the last version of the dataset | |
dataset = load_dataset(RESULTS_REPO, split="train", download_mode="force_redownload", verification_mode="no_checks") | |
dataset = dataset.add_item(row) | |
dataset.push_to_hub(RESULTS_REPO, split="train", token=API.token) | |
time.sleep(60) # Sleep for 1 minute to avoid rate limiting | |
def backend_routine(): | |
try: | |
_backend_routine() | |
except Exception as e: | |
logger.error(f"{e.__class__.__name__}: {str(e)}") | |
logger.exception(e) | |
if __name__ == "__main__": | |
while True: | |
backend_routine() | |