a2c-MountainCar-v0 / benchmark_publish.py
sgoodfriend's picture
A2C playing MountainCar-v0 from https://github.com/sgoodfriend/rl-algo-impls/tree/0760ef7d52b17f30219a27c18ba52c8895025ae3
0bbce05
raw history blame
No virus
3.31 kB
import argparse
import subprocess
import wandb
import wandb.apis.public
from collections import defaultdict
from multiprocessing.pool import ThreadPool
from typing import List, NamedTuple
class RunGroup(NamedTuple):
algo: str
env_id: str
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--wandb-project-name",
type=str,
default="rl-algo-impls-benchmarks",
help="WandB project name to load runs from",
)
parser.add_argument(
"--wandb-entity",
type=str,
default=None,
help="WandB team of project. None uses default entity",
)
parser.add_argument("--wandb-tags", type=str, nargs="+", help="WandB tags")
parser.add_argument("--wandb-report-url", type=str, help="Link to WandB report")
parser.add_argument(
"--envs", type=str, nargs="*", help="Optional filter down to these envs"
)
parser.add_argument(
"--exclude-envs",
type=str,
nargs="*",
help="Environments to exclude from publishing",
)
parser.add_argument(
"--huggingface-user",
type=str,
default=None,
help="Huggingface user or team to upload model cards. Defaults to huggingface-cli login user",
)
parser.add_argument(
"--pool-size",
type=int,
default=3,
help="How many publish jobs can run in parallel",
)
parser.add_argument(
"--virtual-display", action="store_true", help="Use headless virtual display"
)
# parser.set_defaults(
# wandb_tags=["benchmark_e47a44c", "host_129-146-2-230"],
# wandb_report_url="https://api.wandb.ai/links/sgoodfriend/v4wd7cp5",
# envs=[],
# exclude_envs=[],
# )
args = parser.parse_args()
print(args)
api = wandb.Api()
all_runs = api.runs(
f"{args.wandb_entity or api.default_entity}/{args.wandb_project_name}"
)
required_tags = set(args.wandb_tags)
runs: List[wandb.apis.public.Run] = [
r
for r in all_runs
if required_tags.issubset(set(r.config.get("wandb_tags", [])))
]
runs_paths_by_group = defaultdict(list)
for r in runs:
if r.state != "finished":
continue
algo = r.config["algo"]
env = r.config["env"]
if args.envs and env not in args.envs:
continue
if args.exclude_envs and env in args.exclude_envs:
continue
run_group = RunGroup(algo, env)
runs_paths_by_group[run_group].append("/".join(r.path))
def run(run_paths: List[str]) -> None:
publish_args = ["python", "huggingface_publish.py"]
publish_args.append("--wandb-run-paths")
publish_args.extend(run_paths)
publish_args.append("--wandb-report-url")
publish_args.append(args.wandb_report_url)
if args.huggingface_user:
publish_args.append("--huggingface-user")
publish_args.append(args.huggingface_user)
if args.virtual_display:
publish_args.append("--virtual-display")
subprocess.run(publish_args)
tp = ThreadPool(args.pool_size)
for run_paths in runs_paths_by_group.values():
tp.apply_async(run, (run_paths,))
tp.close()
tp.join()