PPO playing MicrortsAttackShapedReward-v1 from https://github.com/sgoodfriend/rl-algo-impls/tree/fb34ab86707f5e2db85e821ff7dbdc624072d640
85ca419
| import os | |
| import shutil | |
| from dataclasses import dataclass | |
| from typing import NamedTuple, Optional | |
| from rl_algo_impls.runner.env import make_eval_env | |
| from rl_algo_impls.runner.config import Config, EnvHyperparams, Hyperparams, RunArgs | |
| from rl_algo_impls.runner.running_utils import ( | |
| load_hyperparams, | |
| set_seeds, | |
| get_device, | |
| make_policy, | |
| ) | |
| from rl_algo_impls.shared.callbacks.eval_callback import evaluate | |
| from rl_algo_impls.shared.policy.policy import Policy | |
| from rl_algo_impls.shared.stats import EpisodesStats | |
| class EvalArgs(RunArgs): | |
| render: bool = True | |
| best: bool = True | |
| n_envs: Optional[int] = 1 | |
| n_episodes: int = 3 | |
| deterministic_eval: Optional[bool] = None | |
| no_print_returns: bool = False | |
| wandb_run_path: Optional[str] = None | |
| class Evaluation(NamedTuple): | |
| policy: Policy | |
| stats: EpisodesStats | |
| config: Config | |
| def evaluate_model(args: EvalArgs, root_dir: str) -> Evaluation: | |
| if args.wandb_run_path: | |
| import wandb | |
| api = wandb.Api() | |
| run = api.run(args.wandb_run_path) | |
| params = run.config | |
| args.algo = params["algo"] | |
| args.env = params["env"] | |
| args.seed = params.get("seed", None) | |
| args.use_deterministic_algorithms = params.get( | |
| "use_deterministic_algorithms", True | |
| ) | |
| config = Config(args, Hyperparams.from_dict_with_extra_fields(params), root_dir) | |
| model_path = config.model_dir_path(best=args.best, downloaded=True) | |
| model_archive_name = config.model_dir_name(best=args.best, extension=".zip") | |
| run.file(model_archive_name).download() | |
| if os.path.isdir(model_path): | |
| shutil.rmtree(model_path) | |
| shutil.unpack_archive(model_archive_name, model_path) | |
| os.remove(model_archive_name) | |
| else: | |
| hyperparams = load_hyperparams(args.algo, args.env) | |
| config = Config(args, hyperparams, root_dir) | |
| model_path = config.model_dir_path(best=args.best) | |
| print(args) | |
| set_seeds(args.seed, args.use_deterministic_algorithms) | |
| env = make_eval_env( | |
| config, | |
| EnvHyperparams(**config.env_hyperparams), | |
| override_n_envs=args.n_envs, | |
| render=args.render, | |
| normalize_load_path=model_path, | |
| ) | |
| device = get_device(config, env) | |
| policy = make_policy( | |
| args.algo, | |
| env, | |
| device, | |
| load_path=model_path, | |
| **config.policy_hyperparams, | |
| ).eval() | |
| deterministic = ( | |
| args.deterministic_eval | |
| if args.deterministic_eval is not None | |
| else config.eval_params.get("deterministic", True) | |
| ) | |
| return Evaluation( | |
| policy, | |
| evaluate( | |
| env, | |
| policy, | |
| args.n_episodes, | |
| render=args.render, | |
| deterministic=deterministic, | |
| print_returns=not args.no_print_returns, | |
| ), | |
| config, | |
| ) | |