File size: 4,441 Bytes
2a8bf2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import copy
import dataclasses
import os
import shutil
from dataclasses import dataclass
from typing import List, NamedTuple, Optional

import numpy as np

import wandb
from rl_algo_impls.runner.config import Config, EnvHyperparams, Hyperparams, RunArgs
from rl_algo_impls.runner.evaluate import Evaluation
from rl_algo_impls.runner.running_utils import (
    get_device,
    load_hyperparams,
    make_policy,
    set_seeds,
)
from rl_algo_impls.shared.callbacks.eval_callback import evaluate
from rl_algo_impls.shared.vec_env import make_eval_env
from rl_algo_impls.wrappers.vec_episode_recorder import VecEpisodeRecorder


@dataclass
class SelfplayEvalArgs(RunArgs):
    # Either wandb_run_paths or model_file_paths must have 2 elements in it.
    wandb_run_paths: List[str] = dataclasses.field(default_factory=list)
    model_file_paths: List[str] = dataclasses.field(default_factory=list)
    render: bool = False
    best: bool = True
    n_envs: int = 1
    n_episodes: int = 1
    deterministic_eval: Optional[bool] = None
    no_print_returns: bool = False
    video_path: Optional[str] = None


def selfplay_evaluate(args: SelfplayEvalArgs, root_dir: str) -> Evaluation:
    if args.wandb_run_paths:
        api = wandb.Api()
        args, config, player_1_model_path = load_player(
            api, args.wandb_run_paths[0], args, root_dir
        )
        _, _, player_2_model_path = load_player(
            api, args.wandb_run_paths[1], args, root_dir
        )
    elif args.model_file_paths:
        hyperparams = load_hyperparams(args.algo, args.env)

        config = Config(args, hyperparams, root_dir)
        player_1_model_path, player_2_model_path = args.model_file_paths
    else:
        raise ValueError("Must specify 2 wandb_run_paths or 2 model_file_paths")

    print(args)

    set_seeds(args.seed, args.use_deterministic_algorithms)

    env_make_kwargs = (
        config.eval_hyperparams.get("env_overrides", {}).get("make_kwargs", {}).copy()
    )
    env_make_kwargs["num_selfplay_envs"] = args.n_envs * 2
    env = make_eval_env(
        config,
        EnvHyperparams(**config.env_hyperparams),
        override_hparams={
            "n_envs": args.n_envs,
            "selfplay_bots": {
                player_2_model_path: args.n_envs,
            },
            "self_play_kwargs": {
                "num_old_policies": 0,
                "save_steps": np.inf,
                "swap_steps": np.inf,
                "bot_always_player_2": True,
            },
            "bots": None,
            "make_kwargs": env_make_kwargs,
        },
        render=args.render,
        normalize_load_path=player_1_model_path,
    )
    if args.video_path:
        env = VecEpisodeRecorder(
            env, args.video_path, max_video_length=18000, num_episodes=args.n_episodes
        )
    device = get_device(config, env)
    policy = make_policy(
        args.algo,
        env,
        device,
        load_path=player_1_model_path,
        **config.policy_hyperparams,
    ).eval()

    deterministic = (
        args.deterministic_eval
        if args.deterministic_eval is not None
        else config.eval_hyperparams.get("deterministic", True)
    )
    return Evaluation(
        policy,
        evaluate(
            env,
            policy,
            args.n_episodes,
            render=args.render,
            deterministic=deterministic,
            print_returns=not args.no_print_returns,
        ),
        config,
    )


class PlayerData(NamedTuple):
    args: SelfplayEvalArgs
    config: Config
    model_path: str


def load_player(
    api: wandb.Api, run_path: str, args: SelfplayEvalArgs, root_dir: str
) -> PlayerData:
    args = copy.copy(args)

    run = api.run(run_path)
    params = run.config
    args.algo = params["algo"]
    args.env = params["env"]
    args.seed = params.get("seed", None)
    args.use_deterministic_algorithms = params.get("use_deterministic_algorithms", True)
    config = Config(args, Hyperparams.from_dict_with_extra_fields(params), root_dir)
    model_path = config.model_dir_path(best=args.best, downloaded=True)

    model_archive_name = config.model_dir_name(best=args.best, extension=".zip")
    run.file(model_archive_name).download()
    if os.path.isdir(model_path):
        shutil.rmtree(model_path)
    shutil.unpack_archive(model_archive_name, model_path)
    os.remove(model_archive_name)

    return PlayerData(args, config, model_path)