ppo-BreakoutNoFrameskip-v4 / rl_algo_impls /huggingface_publish.py
sgoodfriend's picture
PPO playing BreakoutNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/2067e21d62fff5db60168687e7d9e89019a8bfc0
20d9758
raw
history blame
6.28 kB
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import argparse
import requests
import shutil
import subprocess
import tempfile
import wandb
import wandb.apis.public
from typing import List, Optional
from huggingface_hub.hf_api import HfApi, upload_folder
from huggingface_hub.repocard import metadata_save
from pyvirtualdisplay.display import Display
from rl_algo_impls.publish.markdown_format import EvalTableData, model_card_text
from rl_algo_impls.runner.config import EnvHyperparams
from rl_algo_impls.runner.evaluate import EvalArgs, evaluate_model
from rl_algo_impls.runner.env import make_eval_env
from rl_algo_impls.shared.callbacks.eval_callback import evaluate
from rl_algo_impls.wrappers.vec_episode_recorder import VecEpisodeRecorder
def publish(
wandb_run_paths: List[str],
wandb_report_url: str,
huggingface_user: Optional[str] = None,
huggingface_token: Optional[str] = None,
virtual_display: bool = False,
) -> None:
if virtual_display:
display = Display(visible=False, size=(1400, 900))
display.start()
api = wandb.Api()
runs = [api.run(rp) for rp in wandb_run_paths]
algo = runs[0].config["algo"]
hyperparam_id = runs[0].config["env"]
evaluations = [
evaluate_model(
EvalArgs(
algo,
hyperparam_id,
seed=r.config.get("seed", None),
render=False,
best=True,
n_envs=None,
n_episodes=10,
no_print_returns=True,
wandb_run_path="/".join(r.path),
),
os.getcwd(),
)
for r in runs
]
run_metadata = requests.get(runs[0].file("wandb-metadata.json").url).json()
table_data = list(EvalTableData(r, e) for r, e in zip(runs, evaluations))
best_eval = sorted(
table_data, key=lambda d: d.evaluation.stats.score, reverse=True
)[0]
with tempfile.TemporaryDirectory() as tmpdirname:
_, (policy, stats, config) = best_eval
repo_name = config.model_name(include_seed=False)
repo_dir_path = os.path.join(tmpdirname, repo_name)
# Locally clone this repo to a temp directory
subprocess.run(["git", "clone", ".", repo_dir_path])
shutil.rmtree(os.path.join(repo_dir_path, ".git"))
model_path = config.model_dir_path(best=True, downloaded=True)
shutil.copytree(
model_path,
os.path.join(
repo_dir_path, "saved_models", config.model_dir_name(best=True)
),
)
github_url = "https://github.com/sgoodfriend/rl-algo-impls"
commit_hash = run_metadata.get("git", {}).get("commit", None)
env_id = runs[0].config.get("env_id") or runs[0].config["env"]
card_text = model_card_text(
algo,
env_id,
github_url,
commit_hash,
wandb_report_url,
table_data,
best_eval,
)
readme_filepath = os.path.join(repo_dir_path, "README.md")
os.remove(readme_filepath)
with open(readme_filepath, "w") as f:
f.write(card_text)
metadata = {
"library_name": "rl-algo-impls",
"tags": [
env_id,
algo,
"deep-reinforcement-learning",
"reinforcement-learning",
],
"model-index": [
{
"name": algo,
"results": [
{
"metrics": [
{
"type": "mean_reward",
"value": str(stats.score),
"name": "mean_reward",
}
],
"task": {
"type": "reinforcement-learning",
"name": "reinforcement-learning",
},
"dataset": {
"name": env_id,
"type": env_id,
},
}
],
}
],
}
metadata_save(readme_filepath, metadata)
video_env = VecEpisodeRecorder(
make_eval_env(
config,
EnvHyperparams(**config.env_hyperparams),
override_n_envs=1,
normalize_load_path=model_path,
),
os.path.join(repo_dir_path, "replay"),
max_video_length=3600,
)
evaluate(
video_env,
policy,
1,
deterministic=config.eval_params.get("deterministic", True),
)
api = HfApi()
huggingface_user = huggingface_user or api.whoami()["name"]
huggingface_repo = f"{huggingface_user}/{repo_name}"
api.create_repo(
token=huggingface_token,
repo_id=huggingface_repo,
private=False,
exist_ok=True,
)
repo_url = upload_folder(
repo_id=huggingface_repo,
folder_path=repo_dir_path,
path_in_repo="",
commit_message=f"{algo.upper()} playing {env_id} from {github_url}/tree/{commit_hash}",
token=huggingface_token,
)
print(f"Pushed model to the hub: {repo_url}")
def huggingface_publish():
parser = argparse.ArgumentParser()
parser.add_argument(
"--wandb-run-paths",
type=str,
nargs="+",
help="Run paths of the form entity/project/run_id",
)
parser.add_argument("--wandb-report-url", type=str, help="Link to WandB report")
parser.add_argument(
"--huggingface-user",
type=str,
help="Huggingface user or team to upload model cards",
default=None,
)
parser.add_argument(
"--virtual-display", action="store_true", help="Use headless virtual display"
)
args = parser.parse_args()
print(args)
publish(**vars(args))
if __name__ == "__main__":
huggingface_publish()