File size: 3,365 Bytes
b18ddcc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e1c086
 
b18ddcc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import argparse
import subprocess
import wandb
import wandb.apis.public

from collections import defaultdict
from multiprocessing.pool import ThreadPool
from typing import List, NamedTuple


class RunGroup(NamedTuple):
    algo: str
    env_id: str


def benchmark_publish() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--wandb-project-name",
        type=str,
        default="rl-algo-impls-benchmarks",
        help="WandB project name to load runs from",
    )
    parser.add_argument(
        "--wandb-entity",
        type=str,
        default=None,
        help="WandB team of project. None uses default entity",
    )
    parser.add_argument("--wandb-tags", type=str, nargs="+", help="WandB tags")
    parser.add_argument("--wandb-report-url", type=str, help="Link to WandB report")
    parser.add_argument(
        "--envs", type=str, nargs="*", help="Optional filter down to these envs"
    )
    parser.add_argument(
        "--exclude-envs",
        type=str,
        nargs="*",
        help="Environments to exclude from publishing",
    )
    parser.add_argument(
        "--huggingface-user",
        type=str,
        default=None,
        help="Huggingface user or team to upload model cards. Defaults to huggingface-cli login user",
    )
    parser.add_argument(
        "--pool-size",
        type=int,
        default=3,
        help="How many publish jobs can run in parallel",
    )
    parser.add_argument(
        "--virtual-display", action="store_true", help="Use headless virtual display"
    )
    # parser.set_defaults(
    #     wandb_tags=["benchmark_e47a44c", "host_129-146-2-230"],
    #     wandb_report_url="https://api.wandb.ai/links/sgoodfriend/v4wd7cp5",
    #     envs=[],
    #     exclude_envs=[],
    # )
    args = parser.parse_args()
    print(args)

    api = wandb.Api()
    all_runs = api.runs(
        f"{args.wandb_entity or api.default_entity}/{args.wandb_project_name}"
    )

    required_tags = set(args.wandb_tags)
    runs: List[wandb.apis.public.Run] = [
        r
        for r in all_runs
        if required_tags.issubset(set(r.config.get("wandb_tags", [])))
    ]

    runs_paths_by_group = defaultdict(list)
    for r in runs:
        if r.state != "finished":
            continue
        algo = r.config["algo"]
        env = r.config["env"]
        if args.envs and env not in args.envs:
            continue
        if args.exclude_envs and env in args.exclude_envs:
            continue
        run_group = RunGroup(algo, env)
        runs_paths_by_group[run_group].append("/".join(r.path))

    def run(run_paths: List[str]) -> None:
        publish_args = ["python", "huggingface_publish.py"]
        publish_args.append("--wandb-run-paths")
        publish_args.extend(run_paths)
        publish_args.append("--wandb-report-url")
        publish_args.append(args.wandb_report_url)
        if args.huggingface_user:
            publish_args.append("--huggingface-user")
            publish_args.append(args.huggingface_user)
        if args.virtual_display:
            publish_args.append("--virtual-display")
        subprocess.run(publish_args)

    tp = ThreadPool(args.pool_size)
    for run_paths in runs_paths_by_group.values():
        tp.apply_async(run, (run_paths,))
    tp.close()
    tp.join()


if __name__ == "__main__":
    benchmark_publish()