import dataclasses import gc import inspect import logging import numpy as np import optuna import os import torch import wandb from dataclasses import asdict, dataclass from optuna.pruners import HyperbandPruner from optuna.samplers import TPESampler from optuna.visualization import plot_optimization_history, plot_param_importances from torch.utils.tensorboard.writer import SummaryWriter from typing import Callable, List, NamedTuple, Optional, Sequence, Union from rl_algo_impls.a2c.optimize import sample_params as a2c_sample_params from rl_algo_impls.runner.config import Config, EnvHyperparams, RunArgs from rl_algo_impls.runner.env import make_env, make_eval_env from rl_algo_impls.runner.running_utils import ( base_parser, load_hyperparams, set_seeds, get_device, make_policy, ALGOS, hparam_dict, ) from rl_algo_impls.shared.callbacks.optimize_callback import ( Evaluation, OptimizeCallback, evaluation, ) from rl_algo_impls.shared.stats import EpisodesStats @dataclass class StudyArgs: load_study: bool study_name: Optional[str] = None storage_path: Optional[str] = None n_trials: int = 100 n_jobs: int = 1 n_evaluations: int = 4 n_eval_envs: int = 8 n_eval_episodes: int = 16 timeout: Union[int, float, None] = None wandb_project_name: Optional[str] = None wandb_entity: Optional[str] = None wandb_tags: Sequence[str] = dataclasses.field(default_factory=list) wandb_group: Optional[str] = None virtual_display: bool = False class Args(NamedTuple): train_args: Sequence[RunArgs] study_args: StudyArgs def parse_args() -> Args: parser = base_parser() parser.add_argument( "--load-study", action="store_true", help="Load a preexisting study, useful for parallelization", ) parser.add_argument("--study-name", type=str, help="Optuna study name") parser.add_argument( "--storage-path", type=str, help="Path of database for Optuna to persist to", ) parser.add_argument( "--wandb-project-name", type=str, default="rl-algo-impls-tuning", help="WandB project name to upload tuning data to. If none, won't upload", ) parser.add_argument( "--wandb-entity", type=str, help="WandB team. None uses the default entity", ) parser.add_argument( "--wandb-tags", type=str, nargs="*", help="WandB tags to add to run" ) parser.add_argument( "--wandb-group", type=str, help="WandB group to group trials under" ) parser.add_argument( "--n-trials", type=int, default=100, help="Maximum number of trials" ) parser.add_argument( "--n-jobs", type=int, default=1, help="Number of jobs to run in parallel" ) parser.add_argument( "--n-evaluations", type=int, default=4, help="Number of evaluations during the training", ) parser.add_argument( "--n-eval-envs", type=int, default=8, help="Number of envs in vectorized eval environment", ) parser.add_argument( "--n-eval-episodes", type=int, default=16, help="Number of episodes to complete for evaluation", ) parser.add_argument("--timeout", type=int, help="Seconds to timeout optimization") parser.add_argument( "--virtual-display", action="store_true", help="Use headless virtual display" ) # parser.set_defaults( # algo=["a2c"], # env=["CartPole-v1"], # seed=[100, 200, 300], # n_trials=5, # virtual_display=True, # ) train_dict, study_dict = {}, {} for k, v in vars(parser.parse_args()).items(): if k in inspect.signature(StudyArgs).parameters: study_dict[k] = v else: train_dict[k] = v study_args = StudyArgs(**study_dict) # Hyperparameter tuning across algos and envs not supported assert len(train_dict["algo"]) == 1 assert len(train_dict["env"]) == 1 train_args = RunArgs.expand_from_dict(train_dict) if not all((study_args.study_name, study_args.storage_path)): hyperparams = load_hyperparams(train_args[0].algo, train_args[0].env) config = Config(train_args[0], hyperparams, os.getcwd()) if study_args.study_name is None: study_args.study_name = config.run_name(include_seed=False) if study_args.storage_path is None: study_args.storage_path = ( f"sqlite:///{os.path.join(config.runs_dir, 'tuning.db')}" ) # Default set group name to study name study_args.wandb_group = study_args.wandb_group or study_args.study_name return Args(train_args, study_args) def objective_fn( args: Sequence[RunArgs], study_args: StudyArgs ) -> Callable[[optuna.Trial], float]: def objective(trial: optuna.Trial) -> float: if len(args) == 1: return simple_optimize(trial, args[0], study_args) else: return stepwise_optimize(trial, args, study_args) return objective def simple_optimize(trial: optuna.Trial, args: RunArgs, study_args: StudyArgs) -> float: base_hyperparams = load_hyperparams(args.algo, args.env) base_config = Config(args, base_hyperparams, os.getcwd()) if args.algo == "a2c": hyperparams = a2c_sample_params(trial, base_hyperparams, base_config) else: raise ValueError(f"Optimizing {args.algo} isn't supported") config = Config(args, hyperparams, os.getcwd()) wandb_enabled = bool(study_args.wandb_project_name) if wandb_enabled: wandb.init( project=study_args.wandb_project_name, entity=study_args.wandb_entity, config=asdict(hyperparams), name=f"{config.model_name()}-{str(trial.number)}", tags=study_args.wandb_tags, group=study_args.wandb_group, sync_tensorboard=True, monitor_gym=True, save_code=True, reinit=True, ) wandb.config.update(args) tb_writer = SummaryWriter(config.tensorboard_summary_path) set_seeds(args.seed, args.use_deterministic_algorithms) env = make_env( config, EnvHyperparams(**config.env_hyperparams), tb_writer=tb_writer ) device = get_device(config.device, env) policy = make_policy(args.algo, env, device, **config.policy_hyperparams) algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams) eval_env = make_eval_env( config, EnvHyperparams(**config.env_hyperparams), override_n_envs=study_args.n_eval_envs, ) callback = OptimizeCallback( policy, eval_env, trial, tb_writer, step_freq=config.n_timesteps // study_args.n_evaluations, n_episodes=study_args.n_eval_episodes, deterministic=config.eval_params.get("deterministic", True), ) try: algo.learn(config.n_timesteps, callback=callback) if not callback.is_pruned: callback.evaluate() if not callback.is_pruned: policy.save(config.model_dir_path(best=False)) eval_stat: EpisodesStats = callback.last_eval_stat # type: ignore train_stat: EpisodesStats = callback.last_train_stat # type: ignore tb_writer.add_hparams( hparam_dict(hyperparams, vars(args)), { "hparam/last_mean": eval_stat.score.mean, "hparam/last_result": eval_stat.score.mean - eval_stat.score.std, "hparam/train_mean": train_stat.score.mean, "hparam/train_result": train_stat.score.mean - train_stat.score.std, "hparam/score": callback.last_score, "hparam/is_pruned": callback.is_pruned, }, None, config.run_name(), ) tb_writer.close() if wandb_enabled: wandb.run.summary["state"] = "Pruned" if callback.is_pruned else "Complete" wandb.finish(quiet=True) if callback.is_pruned: raise optuna.exceptions.TrialPruned() return callback.last_score except AssertionError as e: logging.warning(e) return np.nan finally: env.close() eval_env.close() gc.collect() torch.cuda.empty_cache() def stepwise_optimize( trial: optuna.Trial, args: Sequence[RunArgs], study_args: StudyArgs ) -> float: algo = args[0].algo env_id = args[0].env base_hyperparams = load_hyperparams(algo, env_id) base_config = Config(args[0], base_hyperparams, os.getcwd()) if algo == "a2c": hyperparams = a2c_sample_params(trial, base_hyperparams, base_config) else: raise ValueError(f"Optimizing {algo} isn't supported") wandb_enabled = bool(study_args.wandb_project_name) if wandb_enabled: wandb.init( project=study_args.wandb_project_name, entity=study_args.wandb_entity, config=asdict(hyperparams), name=f"{study_args.study_name}-{str(trial.number)}", tags=study_args.wandb_tags, group=study_args.wandb_group, save_code=True, reinit=True, ) score = -np.inf for i in range(study_args.n_evaluations): evaluations: List[Evaluation] = [] for arg in args: config = Config(arg, hyperparams, os.getcwd()) tb_writer = SummaryWriter(config.tensorboard_summary_path) set_seeds(arg.seed, arg.use_deterministic_algorithms) env = make_env( config, EnvHyperparams(**config.env_hyperparams), normalize_load_path=config.model_dir_path() if i > 0 else None, tb_writer=tb_writer, ) device = get_device(config.device, env) policy = make_policy(arg.algo, env, device, **config.policy_hyperparams) if i > 0: policy.load(config.model_dir_path()) algo = ALGOS[arg.algo]( policy, env, device, tb_writer, **config.algo_hyperparams ) eval_env = make_eval_env( config, EnvHyperparams(**config.env_hyperparams), normalize_load_path=config.model_dir_path() if i > 0 else None, override_n_envs=study_args.n_eval_envs, ) start_timesteps = int(i * config.n_timesteps / study_args.n_evaluations) train_timesteps = ( int((i + 1) * config.n_timesteps / study_args.n_evaluations) - start_timesteps ) try: algo.learn( train_timesteps, callback=None, total_timesteps=config.n_timesteps, start_timesteps=start_timesteps, ) evaluations.append( evaluation( policy, eval_env, tb_writer, study_args.n_eval_episodes, config.eval_params.get("deterministic", True), start_timesteps + train_timesteps, ) ) policy.save(config.model_dir_path()) tb_writer.close() except AssertionError as e: logging.warning(e) if wandb_enabled: wandb_finish("Error") return np.nan finally: env.close() eval_env.close() gc.collect() torch.cuda.empty_cache() d = {} for idx, e in enumerate(evaluations): d[f"{idx}/eval_mean"] = e.eval_stat.score.mean d[f"{idx}/train_mean"] = e.train_stat.score.mean d[f"{idx}/score"] = e.score d["eval"] = np.mean([e.eval_stat.score.mean for e in evaluations]).item() d["train"] = np.mean([e.train_stat.score.mean for e in evaluations]).item() score = np.mean([e.score for e in evaluations]).item() d["score"] = score step = i + 1 wandb.log(d, step=step) print(f"Trial #{trial.number} Step {step} Score: {round(score, 2)}") trial.report(score, step) if trial.should_prune(): if wandb_enabled: wandb_finish("Pruned") raise optuna.exceptions.TrialPruned() if wandb_enabled: wandb_finish("Complete") return score def wandb_finish(state: str) -> None: wandb.run.summary["state"] = state wandb.finish(quiet=True) def optimize() -> None: from pyvirtualdisplay.display import Display train_args, study_args = parse_args() if study_args.virtual_display: virtual_display = Display(visible=False, size=(1400, 900)) virtual_display.start() sampler = TPESampler(**TPESampler.hyperopt_parameters()) pruner = HyperbandPruner() if study_args.load_study: assert study_args.study_name assert study_args.storage_path study = optuna.load_study( study_name=study_args.study_name, storage=study_args.storage_path, sampler=sampler, pruner=pruner, ) else: study = optuna.create_study( study_name=study_args.study_name, storage=study_args.storage_path, sampler=sampler, pruner=pruner, direction="maximize", ) try: study.optimize( objective_fn(train_args, study_args), n_trials=study_args.n_trials, n_jobs=study_args.n_jobs, timeout=study_args.timeout, ) except KeyboardInterrupt: pass best = study.best_trial print(f"Best Trial Value: {best.value}") print("Attributes:") for key, value in list(best.params.items()) + list(best.user_attrs.items()): print(f" {key}: {value}") df = study.trials_dataframe() df = df[df.state == "COMPLETE"].sort_values(by=["value"], ascending=False) print(df.to_markdown(index=False)) fig1 = plot_optimization_history(study) fig1.write_image("opt_history.png") fig2 = plot_param_importances(study) fig2.write_image("param_importances.png") if __name__ == "__main__": optimize()