A2C playing BreakoutNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b
233c511
import dataclasses | |
import gc | |
import inspect | |
import logging | |
import os | |
from dataclasses import asdict, dataclass | |
from typing import Callable, List, NamedTuple, Optional, Sequence, Union | |
import numpy as np | |
import optuna | |
import torch | |
from optuna.pruners import HyperbandPruner | |
from optuna.samplers import TPESampler | |
from optuna.visualization import plot_optimization_history, plot_param_importances | |
from torch.utils.tensorboard.writer import SummaryWriter | |
import wandb | |
from rl_algo_impls.a2c.optimize import sample_params as a2c_sample_params | |
from rl_algo_impls.runner.config import Config, EnvHyperparams, RunArgs | |
from rl_algo_impls.runner.running_utils import ( | |
ALGOS, | |
base_parser, | |
get_device, | |
hparam_dict, | |
load_hyperparams, | |
make_policy, | |
set_seeds, | |
) | |
from rl_algo_impls.shared.callbacks import Callback | |
from rl_algo_impls.shared.callbacks.microrts_reward_decay_callback import ( | |
MicrortsRewardDecayCallback, | |
) | |
from rl_algo_impls.shared.callbacks.optimize_callback import ( | |
Evaluation, | |
OptimizeCallback, | |
evaluation, | |
) | |
from rl_algo_impls.shared.callbacks.self_play_callback import SelfPlayCallback | |
from rl_algo_impls.shared.stats import EpisodesStats | |
from rl_algo_impls.shared.vec_env import make_env, make_eval_env | |
from rl_algo_impls.wrappers.self_play_wrapper import SelfPlayWrapper | |
from rl_algo_impls.wrappers.vectorable_wrapper import find_wrapper | |
class StudyArgs: | |
load_study: bool | |
study_name: Optional[str] = None | |
storage_path: Optional[str] = None | |
n_trials: int = 100 | |
n_jobs: int = 1 | |
n_evaluations: int = 4 | |
n_eval_envs: int = 8 | |
n_eval_episodes: int = 16 | |
timeout: Union[int, float, None] = None | |
wandb_project_name: Optional[str] = None | |
wandb_entity: Optional[str] = None | |
wandb_tags: Sequence[str] = dataclasses.field(default_factory=list) | |
wandb_group: Optional[str] = None | |
virtual_display: bool = False | |
class Args(NamedTuple): | |
train_args: Sequence[RunArgs] | |
study_args: StudyArgs | |
def parse_args() -> Args: | |
parser = base_parser() | |
parser.add_argument( | |
"--load-study", | |
action="store_true", | |
help="Load a preexisting study, useful for parallelization", | |
) | |
parser.add_argument("--study-name", type=str, help="Optuna study name") | |
parser.add_argument( | |
"--storage-path", | |
type=str, | |
help="Path of database for Optuna to persist to", | |
) | |
parser.add_argument( | |
"--wandb-project-name", | |
type=str, | |
default="rl-algo-impls-tuning", | |
help="WandB project name to upload tuning data to. If none, won't upload", | |
) | |
parser.add_argument( | |
"--wandb-entity", | |
type=str, | |
help="WandB team. None uses the default entity", | |
) | |
parser.add_argument( | |
"--wandb-tags", type=str, nargs="*", help="WandB tags to add to run" | |
) | |
parser.add_argument( | |
"--wandb-group", type=str, help="WandB group to group trials under" | |
) | |
parser.add_argument( | |
"--n-trials", type=int, default=100, help="Maximum number of trials" | |
) | |
parser.add_argument( | |
"--n-jobs", type=int, default=1, help="Number of jobs to run in parallel" | |
) | |
parser.add_argument( | |
"--n-evaluations", | |
type=int, | |
default=4, | |
help="Number of evaluations during the training", | |
) | |
parser.add_argument( | |
"--n-eval-envs", | |
type=int, | |
default=8, | |
help="Number of envs in vectorized eval environment", | |
) | |
parser.add_argument( | |
"--n-eval-episodes", | |
type=int, | |
default=16, | |
help="Number of episodes to complete for evaluation", | |
) | |
parser.add_argument("--timeout", type=int, help="Seconds to timeout optimization") | |
parser.add_argument( | |
"--virtual-display", action="store_true", help="Use headless virtual display" | |
) | |
# parser.set_defaults( | |
# algo=["a2c"], | |
# env=["CartPole-v1"], | |
# seed=[100, 200, 300], | |
# n_trials=5, | |
# virtual_display=True, | |
# ) | |
train_dict, study_dict = {}, {} | |
for k, v in vars(parser.parse_args()).items(): | |
if k in inspect.signature(StudyArgs).parameters: | |
study_dict[k] = v | |
else: | |
train_dict[k] = v | |
study_args = StudyArgs(**study_dict) | |
# Hyperparameter tuning across algos and envs not supported | |
assert len(train_dict["algo"]) == 1 | |
assert len(train_dict["env"]) == 1 | |
train_args = RunArgs.expand_from_dict(train_dict) | |
if not all((study_args.study_name, study_args.storage_path)): | |
hyperparams = load_hyperparams(train_args[0].algo, train_args[0].env) | |
config = Config(train_args[0], hyperparams, os.getcwd()) | |
if study_args.study_name is None: | |
study_args.study_name = config.run_name(include_seed=False) | |
if study_args.storage_path is None: | |
study_args.storage_path = ( | |
f"sqlite:///{os.path.join(config.runs_dir, 'tuning.db')}" | |
) | |
# Default set group name to study name | |
study_args.wandb_group = study_args.wandb_group or study_args.study_name | |
return Args(train_args, study_args) | |
def objective_fn( | |
args: Sequence[RunArgs], study_args: StudyArgs | |
) -> Callable[[optuna.Trial], float]: | |
def objective(trial: optuna.Trial) -> float: | |
if len(args) == 1: | |
return simple_optimize(trial, args[0], study_args) | |
else: | |
return stepwise_optimize(trial, args, study_args) | |
return objective | |
def simple_optimize(trial: optuna.Trial, args: RunArgs, study_args: StudyArgs) -> float: | |
base_hyperparams = load_hyperparams(args.algo, args.env) | |
base_config = Config(args, base_hyperparams, os.getcwd()) | |
if args.algo == "a2c": | |
hyperparams = a2c_sample_params(trial, base_hyperparams, base_config) | |
else: | |
raise ValueError(f"Optimizing {args.algo} isn't supported") | |
config = Config(args, hyperparams, os.getcwd()) | |
wandb_enabled = bool(study_args.wandb_project_name) | |
if wandb_enabled: | |
wandb.init( | |
project=study_args.wandb_project_name, | |
entity=study_args.wandb_entity, | |
config=asdict(hyperparams), | |
name=f"{config.model_name()}-{str(trial.number)}", | |
tags=study_args.wandb_tags, | |
group=study_args.wandb_group, | |
sync_tensorboard=True, | |
monitor_gym=True, | |
save_code=True, | |
reinit=True, | |
) | |
wandb.config.update(args) | |
tb_writer = SummaryWriter(config.tensorboard_summary_path) | |
set_seeds(args.seed, args.use_deterministic_algorithms) | |
env = make_env( | |
config, EnvHyperparams(**config.env_hyperparams), tb_writer=tb_writer | |
) | |
device = get_device(config, env) | |
policy_factory = lambda: make_policy( | |
args.algo, env, device, **config.policy_hyperparams | |
) | |
policy = policy_factory() | |
algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams) | |
eval_env = make_eval_env( | |
config, | |
EnvHyperparams(**config.env_hyperparams), | |
override_hparams={"n_envs": study_args.n_eval_envs}, | |
) | |
optimize_callback = OptimizeCallback( | |
policy, | |
eval_env, | |
trial, | |
tb_writer, | |
step_freq=config.n_timesteps // study_args.n_evaluations, | |
n_episodes=study_args.n_eval_episodes, | |
deterministic=config.eval_hyperparams.get("deterministic", True), | |
) | |
callbacks: List[Callback] = [optimize_callback] | |
if config.hyperparams.microrts_reward_decay_callback: | |
callbacks.append(MicrortsRewardDecayCallback(config, env)) | |
selfPlayWrapper = find_wrapper(env, SelfPlayWrapper) | |
if selfPlayWrapper: | |
callbacks.append(SelfPlayCallback(policy, policy_factory, selfPlayWrapper)) | |
try: | |
algo.learn(config.n_timesteps, callbacks=callbacks) | |
if not optimize_callback.is_pruned: | |
optimize_callback.evaluate() | |
if not optimize_callback.is_pruned: | |
policy.save(config.model_dir_path(best=False)) | |
eval_stat: EpisodesStats = callback.last_eval_stat # type: ignore | |
train_stat: EpisodesStats = callback.last_train_stat # type: ignore | |
tb_writer.add_hparams( | |
hparam_dict(hyperparams, vars(args)), | |
{ | |
"hparam/last_mean": eval_stat.score.mean, | |
"hparam/last_result": eval_stat.score.mean - eval_stat.score.std, | |
"hparam/train_mean": train_stat.score.mean, | |
"hparam/train_result": train_stat.score.mean - train_stat.score.std, | |
"hparam/score": optimize_callback.last_score, | |
"hparam/is_pruned": optimize_callback.is_pruned, | |
}, | |
None, | |
config.run_name(), | |
) | |
tb_writer.close() | |
if wandb_enabled: | |
wandb.run.summary["state"] = ( # type: ignore | |
"Pruned" if optimize_callback.is_pruned else "Complete" | |
) | |
wandb.finish(quiet=True) | |
if optimize_callback.is_pruned: | |
raise optuna.exceptions.TrialPruned() | |
return optimize_callback.last_score | |
except AssertionError as e: | |
logging.warning(e) | |
return np.nan | |
finally: | |
env.close() | |
eval_env.close() | |
gc.collect() | |
torch.cuda.empty_cache() | |
def stepwise_optimize( | |
trial: optuna.Trial, args: Sequence[RunArgs], study_args: StudyArgs | |
) -> float: | |
algo = args[0].algo | |
env_id = args[0].env | |
base_hyperparams = load_hyperparams(algo, env_id) | |
base_config = Config(args[0], base_hyperparams, os.getcwd()) | |
if algo == "a2c": | |
hyperparams = a2c_sample_params(trial, base_hyperparams, base_config) | |
else: | |
raise ValueError(f"Optimizing {algo} isn't supported") | |
wandb_enabled = bool(study_args.wandb_project_name) | |
if wandb_enabled: | |
wandb.init( | |
project=study_args.wandb_project_name, | |
entity=study_args.wandb_entity, | |
config=asdict(hyperparams), | |
name=f"{str(trial.number)}-S{base_config.seed()}", | |
tags=study_args.wandb_tags, | |
group=study_args.wandb_group, | |
save_code=True, | |
reinit=True, | |
) | |
score = -np.inf | |
for i in range(study_args.n_evaluations): | |
evaluations: List[Evaluation] = [] | |
for arg in args: | |
config = Config(arg, hyperparams, os.getcwd()) | |
tb_writer = SummaryWriter(config.tensorboard_summary_path) | |
set_seeds(arg.seed, arg.use_deterministic_algorithms) | |
env = make_env( | |
config, | |
EnvHyperparams(**config.env_hyperparams), | |
normalize_load_path=config.model_dir_path() if i > 0 else None, | |
tb_writer=tb_writer, | |
) | |
device = get_device(config, env) | |
policy_factory = lambda: make_policy( | |
arg.algo, env, device, **config.policy_hyperparams | |
) | |
policy = policy_factory() | |
if i > 0: | |
policy.load(config.model_dir_path()) | |
algo = ALGOS[arg.algo]( | |
policy, env, device, tb_writer, **config.algo_hyperparams | |
) | |
eval_env = make_eval_env( | |
config, | |
EnvHyperparams(**config.env_hyperparams), | |
normalize_load_path=config.model_dir_path() if i > 0 else None, | |
override_hparams={"n_envs": study_args.n_eval_envs}, | |
) | |
start_timesteps = int(i * config.n_timesteps / study_args.n_evaluations) | |
train_timesteps = ( | |
int((i + 1) * config.n_timesteps / study_args.n_evaluations) | |
- start_timesteps | |
) | |
callbacks = [] | |
if config.hyperparams.microrts_reward_decay_callback: | |
callbacks.append( | |
MicrortsRewardDecayCallback( | |
config, env, start_timesteps=start_timesteps | |
) | |
) | |
selfPlayWrapper = find_wrapper(env, SelfPlayWrapper) | |
if selfPlayWrapper: | |
callbacks.append( | |
SelfPlayCallback(policy, policy_factory, selfPlayWrapper) | |
) | |
try: | |
algo.learn( | |
train_timesteps, | |
callbacks=callbacks, | |
total_timesteps=config.n_timesteps, | |
start_timesteps=start_timesteps, | |
) | |
evaluations.append( | |
evaluation( | |
policy, | |
eval_env, | |
tb_writer, | |
study_args.n_eval_episodes, | |
config.eval_hyperparams.get("deterministic", True), | |
start_timesteps + train_timesteps, | |
) | |
) | |
policy.save(config.model_dir_path()) | |
tb_writer.close() | |
except AssertionError as e: | |
logging.warning(e) | |
if wandb_enabled: | |
wandb_finish("Error") | |
return np.nan | |
finally: | |
env.close() | |
eval_env.close() | |
gc.collect() | |
torch.cuda.empty_cache() | |
d = {} | |
for idx, e in enumerate(evaluations): | |
d[f"{idx}/eval_mean"] = e.eval_stat.score.mean | |
d[f"{idx}/train_mean"] = e.train_stat.score.mean | |
d[f"{idx}/score"] = e.score | |
d["eval"] = np.mean([e.eval_stat.score.mean for e in evaluations]).item() | |
d["train"] = np.mean([e.train_stat.score.mean for e in evaluations]).item() | |
score = np.mean([e.score for e in evaluations]).item() | |
d["score"] = score | |
step = i + 1 | |
wandb.log(d, step=step) | |
print(f"Trial #{trial.number} Step {step} Score: {round(score, 2)}") | |
trial.report(score, step) | |
if trial.should_prune(): | |
if wandb_enabled: | |
wandb_finish("Pruned") | |
raise optuna.exceptions.TrialPruned() | |
if wandb_enabled: | |
wandb_finish("Complete") | |
return score | |
def wandb_finish(state: str) -> None: | |
wandb.run.summary["state"] = state # type: ignore | |
wandb.finish(quiet=True) | |
def optimize() -> None: | |
from pyvirtualdisplay.display import Display | |
train_args, study_args = parse_args() | |
if study_args.virtual_display: | |
virtual_display = Display(visible=False, size=(1400, 900)) | |
virtual_display.start() | |
sampler = TPESampler(**TPESampler.hyperopt_parameters()) | |
pruner = HyperbandPruner() | |
if study_args.load_study: | |
assert study_args.study_name | |
assert study_args.storage_path | |
study = optuna.load_study( | |
study_name=study_args.study_name, | |
storage=study_args.storage_path, | |
sampler=sampler, | |
pruner=pruner, | |
) | |
else: | |
study = optuna.create_study( | |
study_name=study_args.study_name, | |
storage=study_args.storage_path, | |
sampler=sampler, | |
pruner=pruner, | |
direction="maximize", | |
) | |
try: | |
study.optimize( | |
objective_fn(train_args, study_args), | |
n_trials=study_args.n_trials, | |
n_jobs=study_args.n_jobs, | |
timeout=study_args.timeout, | |
) | |
except KeyboardInterrupt: | |
pass | |
best = study.best_trial | |
print(f"Best Trial Value: {best.value}") | |
print("Attributes:") | |
for key, value in list(best.params.items()) + list(best.user_attrs.items()): | |
print(f" {key}: {value}") | |
df = study.trials_dataframe() | |
df = df[df.state == "COMPLETE"].sort_values(by=["value"], ascending=False) | |
print(df.to_markdown(index=False)) | |
fig1 = plot_optimization_history(study) | |
fig1.write_image("opt_history.png") | |
fig2 = plot_param_importances(study) | |
fig2.write_image("param_importances.png") | |
if __name__ == "__main__": | |
optimize() | |