PPO playing QbertNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/2067e21d62fff5db60168687e7d9e89019a8bfc0
5c87b65
| import optuna | |
| from gym.spaces import Box | |
| from typing import Any, Dict | |
| from rl_algo_impls.wrappers.vectorable_wrapper import ( | |
| VecEnv, | |
| single_action_space, | |
| ) | |
| def sample_on_policy_hyperparams( | |
| trial: optuna.Trial, policy_hparams: Dict[str, Any], env: VecEnv | |
| ) -> Dict[str, Any]: | |
| act_space = single_action_space(env) | |
| policy_hparams["init_layers_orthogonal"] = trial.suggest_categorical( | |
| "init_layers_orthogonal", [True, False] | |
| ) | |
| policy_hparams["activation_fn"] = trial.suggest_categorical( | |
| "activation_fn", ["tanh", "relu"] | |
| ) | |
| if isinstance(act_space, Box): | |
| policy_hparams["log_std_init"] = trial.suggest_float("log_std_init", -5, 0.5) | |
| policy_hparams["use_sde"] = trial.suggest_categorical("use_sde", [False, True]) | |
| if policy_hparams.get("use_sde", False): | |
| policy_hparams["squash_output"] = trial.suggest_categorical( | |
| "squash_output", [False, True] | |
| ) | |
| elif "squash_output" in policy_hparams: | |
| del policy_hparams["squash_output"] | |
| return policy_hparams | |