File size: 1,351 Bytes
2a30b4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import optuna

from typing import Any, Dict

from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv, single_observation_space


def sample_env_hyperparams(
    trial: optuna.Trial, env_hparams: Dict[str, Any], env: VecEnv
) -> Dict[str, Any]:
    obs_space = single_observation_space(env)

    n_envs = 2 ** trial.suggest_int("n_envs_exp", 1, 5)
    trial.set_user_attr("n_envs", n_envs)
    env_hparams["n_envs"] = n_envs

    normalize = trial.suggest_categorical("normalize", [False, True])
    env_hparams["normalize"] = normalize
    if normalize:
        normalize_kwargs = env_hparams.get("normalize_kwargs", {})
        if len(obs_space.shape) == 3:
            normalize_kwargs.update(
                {
                    "norm_obs": False,
                    "norm_reward": True,
                }
            )
        else:
            norm_obs = trial.suggest_categorical("norm_obs", [True, False])
            norm_reward = trial.suggest_categorical("norm_reward", [True, False])
            normalize_kwargs.update(
                {
                    "norm_obs": norm_obs,
                    "norm_reward": norm_reward,
                }
            )
        env_hparams["normalize_kwargs"] = normalize_kwargs
    elif "normalize_kwargs" in env_hparams:
        del env_hparams["normalize_kwargs"]

    return env_hparams