import argparse import subprocess import wandb import wandb.apis.public from collections import defaultdict from multiprocessing.pool import ThreadPool from typing import List, NamedTuple class RunGroup(NamedTuple): algo: str env_id: str def benchmark_publish() -> None: parser = argparse.ArgumentParser() parser.add_argument( "--wandb-project-name", type=str, default="rl-algo-impls-benchmarks", help="WandB project name to load runs from", ) parser.add_argument( "--wandb-entity", type=str, default=None, help="WandB team of project. None uses default entity", ) parser.add_argument("--wandb-tags", type=str, nargs="+", help="WandB tags") parser.add_argument("--wandb-report-url", type=str, help="Link to WandB report") parser.add_argument( "--envs", type=str, nargs="*", help="Optional filter down to these envs" ) parser.add_argument( "--exclude-envs", type=str, nargs="*", help="Environments to exclude from publishing", ) parser.add_argument( "--huggingface-user", type=str, default=None, help="Huggingface user or team to upload model cards. Defaults to huggingface-cli login user", ) parser.add_argument( "--pool-size", type=int, default=3, help="How many publish jobs can run in parallel", ) parser.add_argument( "--virtual-display", action="store_true", help="Use headless virtual display" ) # parser.set_defaults( # wandb_tags=["benchmark_e47a44c", "host_129-146-2-230"], # wandb_report_url="https://api.wandb.ai/links/sgoodfriend/v4wd7cp5", # envs=[], # exclude_envs=[], # ) args = parser.parse_args() print(args) api = wandb.Api() all_runs = api.runs( f"{args.wandb_entity or api.default_entity}/{args.wandb_project_name}" ) required_tags = set(args.wandb_tags) runs: List[wandb.apis.public.Run] = [ r for r in all_runs if required_tags.issubset(set(r.config.get("wandb_tags", []))) ] runs_paths_by_group = defaultdict(list) for r in runs: if r.state != "finished": continue algo = r.config["algo"] env = r.config["env"] if args.envs and env not in args.envs: continue if args.exclude_envs and env in args.exclude_envs: continue run_group = RunGroup(algo, env) runs_paths_by_group[run_group].append("/".join(r.path)) def run(run_paths: List[str]) -> None: publish_args = ["python", "huggingface_publish.py"] publish_args.append("--wandb-run-paths") publish_args.extend(run_paths) publish_args.append("--wandb-report-url") publish_args.append(args.wandb_report_url) if args.huggingface_user: publish_args.append("--huggingface-user") publish_args.append(args.huggingface_user) if args.virtual_display: publish_args.append("--virtual-display") subprocess.run(publish_args) tp = ThreadPool(args.pool_size) for run_paths in runs_paths_by_group.values(): tp.apply_async(run, (run_paths,)) tp.close() tp.join() if __name__ == "__main__": benchmark_publish()