File size: 4,441 Bytes
1cde088 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import copy
import dataclasses
import os
import shutil
from dataclasses import dataclass
from typing import List, NamedTuple, Optional
import numpy as np
import wandb
from rl_algo_impls.runner.config import Config, EnvHyperparams, Hyperparams, RunArgs
from rl_algo_impls.runner.evaluate import Evaluation
from rl_algo_impls.runner.running_utils import (
get_device,
load_hyperparams,
make_policy,
set_seeds,
)
from rl_algo_impls.shared.callbacks.eval_callback import evaluate
from rl_algo_impls.shared.vec_env import make_eval_env
from rl_algo_impls.wrappers.vec_episode_recorder import VecEpisodeRecorder
@dataclass
class SelfplayEvalArgs(RunArgs):
# Either wandb_run_paths or model_file_paths must have 2 elements in it.
wandb_run_paths: List[str] = dataclasses.field(default_factory=list)
model_file_paths: List[str] = dataclasses.field(default_factory=list)
render: bool = False
best: bool = True
n_envs: int = 1
n_episodes: int = 1
deterministic_eval: Optional[bool] = None
no_print_returns: bool = False
video_path: Optional[str] = None
def selfplay_evaluate(args: SelfplayEvalArgs, root_dir: str) -> Evaluation:
if args.wandb_run_paths:
api = wandb.Api()
args, config, player_1_model_path = load_player(
api, args.wandb_run_paths[0], args, root_dir
)
_, _, player_2_model_path = load_player(
api, args.wandb_run_paths[1], args, root_dir
)
elif args.model_file_paths:
hyperparams = load_hyperparams(args.algo, args.env)
config = Config(args, hyperparams, root_dir)
player_1_model_path, player_2_model_path = args.model_file_paths
else:
raise ValueError("Must specify 2 wandb_run_paths or 2 model_file_paths")
print(args)
set_seeds(args.seed, args.use_deterministic_algorithms)
env_make_kwargs = (
config.eval_hyperparams.get("env_overrides", {}).get("make_kwargs", {}).copy()
)
env_make_kwargs["num_selfplay_envs"] = args.n_envs * 2
env = make_eval_env(
config,
EnvHyperparams(**config.env_hyperparams),
override_hparams={
"n_envs": args.n_envs,
"selfplay_bots": {
player_2_model_path: args.n_envs,
},
"self_play_kwargs": {
"num_old_policies": 0,
"save_steps": np.inf,
"swap_steps": np.inf,
"bot_always_player_2": True,
},
"bots": None,
"make_kwargs": env_make_kwargs,
},
render=args.render,
normalize_load_path=player_1_model_path,
)
if args.video_path:
env = VecEpisodeRecorder(
env, args.video_path, max_video_length=18000, num_episodes=args.n_episodes
)
device = get_device(config, env)
policy = make_policy(
args.algo,
env,
device,
load_path=player_1_model_path,
**config.policy_hyperparams,
).eval()
deterministic = (
args.deterministic_eval
if args.deterministic_eval is not None
else config.eval_hyperparams.get("deterministic", True)
)
return Evaluation(
policy,
evaluate(
env,
policy,
args.n_episodes,
render=args.render,
deterministic=deterministic,
print_returns=not args.no_print_returns,
),
config,
)
class PlayerData(NamedTuple):
args: SelfplayEvalArgs
config: Config
model_path: str
def load_player(
api: wandb.Api, run_path: str, args: SelfplayEvalArgs, root_dir: str
) -> PlayerData:
args = copy.copy(args)
run = api.run(run_path)
params = run.config
args.algo = params["algo"]
args.env = params["env"]
args.seed = params.get("seed", None)
args.use_deterministic_algorithms = params.get("use_deterministic_algorithms", True)
config = Config(args, Hyperparams.from_dict_with_extra_fields(params), root_dir)
model_path = config.model_dir_path(best=args.best, downloaded=True)
model_archive_name = config.model_dir_name(best=args.best, extension=".zip")
run.file(model_archive_name).download()
if os.path.isdir(model_path):
shutil.rmtree(model_path)
shutil.unpack_archive(model_archive_name, model_path)
os.remove(model_archive_name)
return PlayerData(args, config, model_path)
|