SocialAISchool / scripts /generate_hp_tuning_script.py
grg's picture
Cleaned old git history
be5548b
import sys
import itertools
if __name__ == '__main__':
'''
Generate scripts to perform grid search on agents's hyperparameters.
Defines the values to test for each hyperparameter.
'''
tuning_dict = {
# "PPO": {
# "frames-per-proc": [20, 40, 80],
# "lr": [1e-4, 1e-3],
# "entropy-coef": [0, 0.01, 0.05],
# "recurrence": [5, 10],
# "epochs": [4, 8, 12],
# "batch-size": [640, 1280, 2560],
# "env": ["MiniGrid-CoinThief-8x8-v0 --env_args few_actions True", "MiniGrid-TalkItOutPolite-8x8-v0"]
# },
"PPO-RND": {
# rnd and ride
"optim-eps": [1e-5, 1e-7],
"entropy-coef": [0.01, 0.0001, 0.0005],
"intrinsic-reward-learning-rate": [0.0001, 0.0004, 0.001],
# "intrinsic-reward-coef": [0.1],
# "intrinisc-reward-momentum": [0],
"intrinsic-reward-epsilon": [0.01, 0.001, 0.0001],
# "intrinsic-reward-alpha": [0.99],
"intrinsic-reward-max-grad-norm": [1000, 40, 20, 1],
# rnd
# "intrinsic-reward-loss-coef": [0.1],
"env": ["MiniGrid-CoinThief-8x8-v0 --env_args few_actions True", "MiniGrid-TalkItOutPolite-8x8-v0"]
}
}
with open("hp_tuning_agent.txt", 'w') as f:
for agent in tuning_dict:
f.write('## {}\n'.format(agent))
current_agent_parameters = list(tuning_dict[agent].keys())
current_agent_hyperparams = tuning_dict[agent].values()
for point in itertools.product(*current_agent_hyperparams):
current_arguments = ''
for i in range(len(current_agent_parameters)):
current_arguments += ' --*' + current_agent_parameters[i]
current_arguments += ' ' + str(point[i]) if point[i] is not None else ''
if agent == "PPO-RND":
f.write(
'--slurm_conf jz_short_2gpus_32g --nb_seeds 8 --model PPO_RND_tuning --algo ppo -cs --frames 10000000 '
'--save-interval 100 --log-interval 100 '
'--dialogue --multi-modal-babyai11-agent '
'--exploration-bonus --exploration-bonus-type rnd --clipped-rewards '
'--arch original_endpool_res {}\n'.format(current_arguments))
else:
f.write(
'--slurm_conf jz_short_2gpus_32g --nb_seeds 8 --model PPO_RND_tuning --algo ppo -cs --frames 10000000 '
'--save-interval 100 --log-interval 100 '
'--dialogue --multi-modal-babyai11-agent '
'--arch original_endpool_res {}\n'.format(current_arguments))