File size: 2,948 Bytes
8bf4dee
 
b638440
 
 
8bf4dee
b638440
8bf4dee
 
 
 
 
 
 
 
 
 
 
b638440
 
 
 
 
8bf4dee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from copy import deepcopy

import optuna

from rl_algo_impls.runner.config import Config, EnvHyperparams, Hyperparams
from rl_algo_impls.shared.policy.optimize_on_policy import sample_on_policy_hyperparams
from rl_algo_impls.shared.vec_env import make_eval_env
from rl_algo_impls.tuning.optimize_env import sample_env_hyperparams


def sample_params(
    trial: optuna.Trial,
    base_hyperparams: Hyperparams,
    base_config: Config,
) -> Hyperparams:
    hyperparams = deepcopy(base_hyperparams)

    base_env_hyperparams = EnvHyperparams(**hyperparams.env_hyperparams)
    env = make_eval_env(
        base_config,
        base_env_hyperparams,
        override_hparams={"n_envs": 1},
    )

    # env_hyperparams
    env_hyperparams = sample_env_hyperparams(trial, hyperparams.env_hyperparams, env)

    # policy_hyperparams
    policy_hyperparams = sample_on_policy_hyperparams(
        trial, hyperparams.policy_hyperparams, env
    )

    # algo_hyperparams
    algo_hyperparams = hyperparams.algo_hyperparams

    learning_rate = trial.suggest_float("learning_rate", 1e-5, 2e-3, log=True)
    learning_rate_decay = trial.suggest_categorical(
        "learning_rate_decay", ["none", "linear"]
    )
    n_steps_exp = trial.suggest_int("n_steps_exp", 1, 10)
    n_steps = 2**n_steps_exp
    trial.set_user_attr("n_steps", n_steps)
    gamma = 1.0 - trial.suggest_float("gamma_om", 1e-4, 1e-1, log=True)
    trial.set_user_attr("gamma", gamma)
    gae_lambda = 1 - trial.suggest_float("gae_lambda_om", 1e-4, 1e-1)
    trial.set_user_attr("gae_lambda", gae_lambda)
    ent_coef = trial.suggest_float("ent_coef", 1e-8, 2.5e-2, log=True)
    ent_coef_decay = trial.suggest_categorical("ent_coef_decay", ["none", "linear"])
    vf_coef = trial.suggest_float("vf_coef", 0.1, 0.7)
    max_grad_norm = trial.suggest_float("max_grad_norm", 1e-1, 1e1, log=True)
    use_rms_prop = trial.suggest_categorical("use_rms_prop", [True, False])
    normalize_advantage = trial.suggest_categorical(
        "normalize_advantage", [True, False]
    )

    algo_hyperparams.update(
        {
            "learning_rate": learning_rate,
            "learning_rate_decay": learning_rate_decay,
            "n_steps": n_steps,
            "gamma": gamma,
            "gae_lambda": gae_lambda,
            "ent_coef": ent_coef,
            "ent_coef_decay": ent_coef_decay,
            "vf_coef": vf_coef,
            "max_grad_norm": max_grad_norm,
            "use_rms_prop": use_rms_prop,
            "normalize_advantage": normalize_advantage,
        }
    )

    if policy_hyperparams.get("use_sde", False):
        sde_sample_freq = 2 ** trial.suggest_int("sde_sample_freq_exp", 0, n_steps_exp)
        trial.set_user_attr("sde_sample_freq", sde_sample_freq)
        algo_hyperparams["sde_sample_freq"] = sde_sample_freq
    elif "sde_sample_freq" in algo_hyperparams:
        del algo_hyperparams["sde_sample_freq"]

    env.close()

    return hyperparams