import dataclasses import gc import inspect import logging import os from dataclasses import asdict, dataclass from typing import Callable, List, NamedTuple, Optional, Sequence, Union import numpy as np import optuna import torch from optuna.pruners import HyperbandPruner from optuna.samplers import TPESampler from optuna.visualization import plot_optimization_history, plot_param_importances from torch.utils.tensorboard.writer import SummaryWriter import wandb from rl_algo_impls.a2c.optimize import sample_params as a2c_sample_params from rl_algo_impls.runner.config import Config, EnvHyperparams, RunArgs from rl_algo_impls.runner.running_utils import ( ALGOS, base_parser, get_device, hparam_dict, load_hyperparams, make_policy, set_seeds, ) from rl_algo_impls.shared.callbacks import Callback from rl_algo_impls.shared.callbacks.microrts_reward_decay_callback import ( MicrortsRewardDecayCallback, ) from rl_algo_impls.shared.callbacks.optimize_callback import ( Evaluation, OptimizeCallback, evaluation, ) from rl_algo_impls.shared.callbacks.self_play_callback import SelfPlayCallback from rl_algo_impls.shared.stats import EpisodesStats from rl_algo_impls.shared.vec_env import make_env, make_eval_env from rl_algo_impls.wrappers.self_play_wrapper import SelfPlayWrapper from rl_algo_impls.wrappers.vectorable_wrapper import find_wrapper @dataclass class StudyArgs: load_study: bool study_name: Optional[str] = None storage_path: Optional[str] = None n_trials: int = 100 n_jobs: int = 1 n_evaluations: int = 4 n_eval_envs: int = 8 n_eval_episodes: int = 16 timeout: Union[int, float, None] = None wandb_project_name: Optional[str] = None wandb_entity: Optional[str] = None wandb_tags: Sequence[str] = dataclasses.field(default_factory=list) wandb_group: Optional[str] = None virtual_display: bool = False class Args(NamedTuple): train_args: Sequence[RunArgs] study_args: StudyArgs def parse_args() -> Args: parser = base_parser() parser.add_argument( "--load-study", action="store_true", help="Load a preexisting study, useful for parallelization", ) parser.add_argument("--study-name", type=str, help="Optuna study name") parser.add_argument( "--storage-path", type=str, help="Path of database for Optuna to persist to", ) parser.add_argument( "--wandb-project-name", type=str, default="rl-algo-impls-tuning", help="WandB project name to upload tuning data to. If none, won't upload", ) parser.add_argument( "--wandb-entity", type=str, help="WandB team. None uses the default entity", ) parser.add_argument( "--wandb-tags", type=str, nargs="*", help="WandB tags to add to run" ) parser.add_argument( "--wandb-group", type=str, help="WandB group to group trials under" ) parser.add_argument( "--n-trials", type=int, default=100, help="Maximum number of trials" ) parser.add_argument( "--n-jobs", type=int, default=1, help="Number of jobs to run in parallel" ) parser.add_argument( "--n-evaluations", type=int, default=4, help="Number of evaluations during the training", ) parser.add_argument( "--n-eval-envs", type=int, default=8, help="Number of envs in vectorized eval environment", ) parser.add_argument( "--n-eval-episodes", type=int, default=16, help="Number of episodes to complete for evaluation", ) parser.add_argument("--timeout", type=int, help="Seconds to timeout optimization") parser.add_argument( "--virtual-display", action="store_true", help="Use headless virtual display" ) # parser.set_defaults( # algo=["a2c"], # env=["CartPole-v1"], # seed=[100, 200, 300], # n_trials=5, # virtual_display=True, # ) train_dict, study_dict = {}, {} for k, v in vars(parser.parse_args()).items(): if k in inspect.signature(StudyArgs).parameters: study_dict[k] = v else: train_dict[k] = v study_args = StudyArgs(**study_dict) # Hyperparameter tuning across algos and envs not supported assert len(train_dict["algo"]) == 1 assert len(train_dict["env"]) == 1 train_args = RunArgs.expand_from_dict(train_dict) if not all((study_args.study_name, study_args.storage_path)): hyperparams = load_hyperparams(train_args[0].algo, train_args[0].env) config = Config(train_args[0], hyperparams, os.getcwd()) if study_args.study_name is None: study_args.study_name = config.run_name(include_seed=False) if study_args.storage_path is None: study_args.storage_path = ( f"sqlite:///{os.path.join(config.runs_dir, 'tuning.db')}" ) # Default set group name to study name study_args.wandb_group = study_args.wandb_group or study_args.study_name return Args(train_args, study_args) def objective_fn( args: Sequence[RunArgs], study_args: StudyArgs ) -> Callable[[optuna.Trial], float]: def objective(trial: optuna.Trial) -> float: if len(args) == 1: return simple_optimize(trial, args[0], study_args) else: return stepwise_optimize(trial, args, study_args) return objective def simple_optimize(trial: optuna.Trial, args: RunArgs, study_args: StudyArgs) -> float: base_hyperparams = load_hyperparams(args.algo, args.env) base_config = Config(args, base_hyperparams, os.getcwd()) if args.algo == "a2c": hyperparams = a2c_sample_params(trial, base_hyperparams, base_config) else: raise ValueError(f"Optimizing {args.algo} isn't supported") config = Config(args, hyperparams, os.getcwd()) wandb_enabled = bool(study_args.wandb_project_name) if wandb_enabled: wandb.init( project=study_args.wandb_project_name, entity=study_args.wandb_entity, config=asdict(hyperparams), name=f"{config.model_name()}-{str(trial.number)}", tags=study_args.wandb_tags, group=study_args.wandb_group, sync_tensorboard=True, monitor_gym=True, save_code=True, reinit=True, ) wandb.config.update(args) tb_writer = SummaryWriter(config.tensorboard_summary_path) set_seeds(args.seed, args.use_deterministic_algorithms) env = make_env( config, EnvHyperparams(**config.env_hyperparams), tb_writer=tb_writer ) device = get_device(config, env) policy_factory = lambda: make_policy( args.algo, env, device, **config.policy_hyperparams ) policy = policy_factory() algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams) eval_env = make_eval_env( config, EnvHyperparams(**config.env_hyperparams), override_hparams={"n_envs": study_args.n_eval_envs}, ) optimize_callback = OptimizeCallback( policy, eval_env, trial, tb_writer, step_freq=config.n_timesteps // study_args.n_evaluations, n_episodes=study_args.n_eval_episodes, deterministic=config.eval_hyperparams.get("deterministic", True), ) callbacks: List[Callback] = [optimize_callback] if config.hyperparams.microrts_reward_decay_callback: callbacks.append(MicrortsRewardDecayCallback(config, env)) selfPlayWrapper = find_wrapper(env, SelfPlayWrapper) if selfPlayWrapper: callbacks.append(SelfPlayCallback(policy, policy_factory, selfPlayWrapper)) try: algo.learn(config.n_timesteps, callbacks=callbacks) if not optimize_callback.is_pruned: optimize_callback.evaluate() if not optimize_callback.is_pruned: policy.save(config.model_dir_path(best=False)) eval_stat: EpisodesStats = callback.last_eval_stat # type: ignore train_stat: EpisodesStats = callback.last_train_stat # type: ignore tb_writer.add_hparams( hparam_dict(hyperparams, vars(args)), { "hparam/last_mean": eval_stat.score.mean, "hparam/last_result": eval_stat.score.mean - eval_stat.score.std, "hparam/train_mean": train_stat.score.mean, "hparam/train_result": train_stat.score.mean - train_stat.score.std, "hparam/score": optimize_callback.last_score, "hparam/is_pruned": optimize_callback.is_pruned, }, None, config.run_name(), ) tb_writer.close() if wandb_enabled: wandb.run.summary["state"] = ( # type: ignore "Pruned" if optimize_callback.is_pruned else "Complete" ) wandb.finish(quiet=True) if optimize_callback.is_pruned: raise optuna.exceptions.TrialPruned() return optimize_callback.last_score except AssertionError as e: logging.warning(e) return np.nan finally: env.close() eval_env.close() gc.collect() torch.cuda.empty_cache() def stepwise_optimize( trial: optuna.Trial, args: Sequence[RunArgs], study_args: StudyArgs ) -> float: algo = args[0].algo env_id = args[0].env base_hyperparams = load_hyperparams(algo, env_id) base_config = Config(args[0], base_hyperparams, os.getcwd()) if algo == "a2c": hyperparams = a2c_sample_params(trial, base_hyperparams, base_config) else: raise ValueError(f"Optimizing {algo} isn't supported") wandb_enabled = bool(study_args.wandb_project_name) if wandb_enabled: wandb.init( project=study_args.wandb_project_name, entity=study_args.wandb_entity, config=asdict(hyperparams), name=f"{str(trial.number)}-S{base_config.seed()}", tags=study_args.wandb_tags, group=study_args.wandb_group, save_code=True, reinit=True, ) score = -np.inf for i in range(study_args.n_evaluations): evaluations: List[Evaluation] = [] for arg in args: config = Config(arg, hyperparams, os.getcwd()) tb_writer = SummaryWriter(config.tensorboard_summary_path) set_seeds(arg.seed, arg.use_deterministic_algorithms) env = make_env( config, EnvHyperparams(**config.env_hyperparams), normalize_load_path=config.model_dir_path() if i > 0 else None, tb_writer=tb_writer, ) device = get_device(config, env) policy_factory = lambda: make_policy( arg.algo, env, device, **config.policy_hyperparams ) policy = policy_factory() if i > 0: policy.load(config.model_dir_path()) algo = ALGOS[arg.algo]( policy, env, device, tb_writer, **config.algo_hyperparams ) eval_env = make_eval_env( config, EnvHyperparams(**config.env_hyperparams), normalize_load_path=config.model_dir_path() if i > 0 else None, override_hparams={"n_envs": study_args.n_eval_envs}, ) start_timesteps = int(i * config.n_timesteps / study_args.n_evaluations) train_timesteps = ( int((i + 1) * config.n_timesteps / study_args.n_evaluations) - start_timesteps ) callbacks = [] if config.hyperparams.microrts_reward_decay_callback: callbacks.append( MicrortsRewardDecayCallback( config, env, start_timesteps=start_timesteps ) ) selfPlayWrapper = find_wrapper(env, SelfPlayWrapper) if selfPlayWrapper: callbacks.append( SelfPlayCallback(policy, policy_factory, selfPlayWrapper) ) try: algo.learn( train_timesteps, callbacks=callbacks, total_timesteps=config.n_timesteps, start_timesteps=start_timesteps, ) evaluations.append( evaluation( policy, eval_env, tb_writer, study_args.n_eval_episodes, config.eval_hyperparams.get("deterministic", True), start_timesteps + train_timesteps, ) ) policy.save(config.model_dir_path()) tb_writer.close() except AssertionError as e: logging.warning(e) if wandb_enabled: wandb_finish("Error") return np.nan finally: env.close() eval_env.close() gc.collect() torch.cuda.empty_cache() d = {} for idx, e in enumerate(evaluations): d[f"{idx}/eval_mean"] = e.eval_stat.score.mean d[f"{idx}/train_mean"] = e.train_stat.score.mean d[f"{idx}/score"] = e.score d["eval"] = np.mean([e.eval_stat.score.mean for e in evaluations]).item() d["train"] = np.mean([e.train_stat.score.mean for e in evaluations]).item() score = np.mean([e.score for e in evaluations]).item() d["score"] = score step = i + 1 wandb.log(d, step=step) print(f"Trial #{trial.number} Step {step} Score: {round(score, 2)}") trial.report(score, step) if trial.should_prune(): if wandb_enabled: wandb_finish("Pruned") raise optuna.exceptions.TrialPruned() if wandb_enabled: wandb_finish("Complete") return score def wandb_finish(state: str) -> None: wandb.run.summary["state"] = state # type: ignore wandb.finish(quiet=True) def optimize() -> None: from pyvirtualdisplay.display import Display train_args, study_args = parse_args() if study_args.virtual_display: virtual_display = Display(visible=False, size=(1400, 900)) virtual_display.start() sampler = TPESampler(**TPESampler.hyperopt_parameters()) pruner = HyperbandPruner() if study_args.load_study: assert study_args.study_name assert study_args.storage_path study = optuna.load_study( study_name=study_args.study_name, storage=study_args.storage_path, sampler=sampler, pruner=pruner, ) else: study = optuna.create_study( study_name=study_args.study_name, storage=study_args.storage_path, sampler=sampler, pruner=pruner, direction="maximize", ) try: study.optimize( objective_fn(train_args, study_args), n_trials=study_args.n_trials, n_jobs=study_args.n_jobs, timeout=study_args.timeout, ) except KeyboardInterrupt: pass best = study.best_trial print(f"Best Trial Value: {best.value}") print("Attributes:") for key, value in list(best.params.items()) + list(best.user_attrs.items()): print(f" {key}: {value}") df = study.trials_dataframe() df = df[df.state == "COMPLETE"].sort_values(by=["value"], ascending=False) print(df.to_markdown(index=False)) fig1 = plot_optimization_history(study) fig1.write_image("opt_history.png") fig2 = plot_param_importances(study) fig2.write_image("param_importances.png") if __name__ == "__main__": optimize()