File size: 2,785 Bytes
ff8c6a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import os
import shutil
from dataclasses import dataclass
from typing import NamedTuple, Optional
from runner.env import make_eval_env
from runner.config import Config, RunArgs
from runner.running_utils import (
load_hyperparams,
set_seeds,
get_device,
make_policy,
)
from shared.callbacks.eval_callback import evaluate
from shared.policy.policy import Policy
from shared.stats import EpisodesStats
@dataclass
class EvalArgs(RunArgs):
render: bool = True
best: bool = True
n_envs: Optional[int] = 1
n_episodes: int = 3
deterministic_eval: Optional[bool] = None
no_print_returns: bool = False
wandb_run_path: Optional[str] = None
class Evaluation(NamedTuple):
policy: Policy
stats: EpisodesStats
config: Config
def evaluate_model(args: EvalArgs, root_dir: str) -> Evaluation:
if args.wandb_run_path:
import wandb
api = wandb.Api()
run = api.run(args.wandb_run_path)
hyperparams = run.config
args.algo = hyperparams["algo"]
args.env = hyperparams["env"]
args.seed = hyperparams.get("seed", None)
args.use_deterministic_algorithms = hyperparams.get(
"use_deterministic_algorithms", True
)
config = Config(args, hyperparams, root_dir)
model_path = config.model_dir_path(best=args.best, downloaded=True)
model_archive_name = config.model_dir_name(best=args.best, extension=".zip")
run.file(model_archive_name).download()
if os.path.isdir(model_path):
shutil.rmtree(model_path)
shutil.unpack_archive(model_archive_name, model_path)
os.remove(model_archive_name)
else:
hyperparams = load_hyperparams(args.algo, args.env, root_dir)
config = Config(args, hyperparams, root_dir)
model_path = config.model_dir_path(best=args.best)
print(args)
set_seeds(args.seed, args.use_deterministic_algorithms)
env = make_eval_env(
config,
override_n_envs=args.n_envs,
render=args.render,
normalize_load_path=model_path,
**config.env_hyperparams,
)
device = get_device(config.device, env)
policy = make_policy(
args.algo,
env,
device,
load_path=model_path,
**config.policy_hyperparams,
).eval()
deterministic = (
args.deterministic_eval
if args.deterministic_eval is not None
else config.eval_params.get("deterministic", True)
)
return Evaluation(
policy,
evaluate(
env,
policy,
args.n_episodes,
render=args.render,
deterministic=deterministic,
print_returns=not args.no_print_returns,
),
config,
)
|