|
|
import os |
|
|
import shutil |
|
|
|
|
|
from dataclasses import dataclass |
|
|
from typing import NamedTuple, Optional |
|
|
|
|
|
from runner.env import make_eval_env |
|
|
from runner.config import Config, EnvHyperparams, RunArgs |
|
|
from runner.running_utils import ( |
|
|
load_hyperparams, |
|
|
set_seeds, |
|
|
get_device, |
|
|
make_policy, |
|
|
) |
|
|
from shared.callbacks.eval_callback import evaluate |
|
|
from shared.policy.policy import Policy |
|
|
from shared.stats import EpisodesStats |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class EvalArgs(RunArgs): |
|
|
render: bool = True |
|
|
best: bool = True |
|
|
n_envs: Optional[int] = 1 |
|
|
n_episodes: int = 3 |
|
|
deterministic_eval: Optional[bool] = None |
|
|
no_print_returns: bool = False |
|
|
wandb_run_path: Optional[str] = None |
|
|
|
|
|
|
|
|
class Evaluation(NamedTuple): |
|
|
policy: Policy |
|
|
stats: EpisodesStats |
|
|
config: Config |
|
|
|
|
|
|
|
|
def evaluate_model(args: EvalArgs, root_dir: str) -> Evaluation: |
|
|
if args.wandb_run_path: |
|
|
import wandb |
|
|
|
|
|
api = wandb.Api() |
|
|
run = api.run(args.wandb_run_path) |
|
|
hyperparams = run.config |
|
|
|
|
|
args.algo = hyperparams["algo"] |
|
|
args.env = hyperparams["env"] |
|
|
args.seed = hyperparams.get("seed", None) |
|
|
args.use_deterministic_algorithms = hyperparams.get( |
|
|
"use_deterministic_algorithms", True |
|
|
) |
|
|
|
|
|
config = Config(args, hyperparams, root_dir) |
|
|
model_path = config.model_dir_path(best=args.best, downloaded=True) |
|
|
|
|
|
model_archive_name = config.model_dir_name(best=args.best, extension=".zip") |
|
|
run.file(model_archive_name).download() |
|
|
if os.path.isdir(model_path): |
|
|
shutil.rmtree(model_path) |
|
|
shutil.unpack_archive(model_archive_name, model_path) |
|
|
os.remove(model_archive_name) |
|
|
else: |
|
|
hyperparams = load_hyperparams(args.algo, args.env, root_dir) |
|
|
|
|
|
config = Config(args, hyperparams, root_dir) |
|
|
model_path = config.model_dir_path(best=args.best) |
|
|
|
|
|
print(args) |
|
|
|
|
|
set_seeds(args.seed, args.use_deterministic_algorithms) |
|
|
|
|
|
env = make_eval_env( |
|
|
config, |
|
|
EnvHyperparams(**config.env_hyperparams), |
|
|
override_n_envs=args.n_envs, |
|
|
render=args.render, |
|
|
normalize_load_path=model_path, |
|
|
) |
|
|
device = get_device(config.device, env) |
|
|
policy = make_policy( |
|
|
args.algo, |
|
|
env, |
|
|
device, |
|
|
load_path=model_path, |
|
|
**config.policy_hyperparams, |
|
|
).eval() |
|
|
|
|
|
deterministic = ( |
|
|
args.deterministic_eval |
|
|
if args.deterministic_eval is not None |
|
|
else config.eval_params.get("deterministic", True) |
|
|
) |
|
|
return Evaluation( |
|
|
policy, |
|
|
evaluate( |
|
|
env, |
|
|
policy, |
|
|
args.n_episodes, |
|
|
render=args.render, |
|
|
deterministic=deterministic, |
|
|
print_returns=not args.no_print_returns, |
|
|
), |
|
|
config, |
|
|
) |
|
|
|