araffin commited on
Commit
b80445f
1 Parent(s): 68ab42e
Files changed (1) hide show
  1. README.md +49 -1
README.md CHANGED
@@ -24,5 +24,53 @@ model-index:
24
  This is a trained model of a **A2C** agent playing **LunarLander-v2** using the [stable-baselines3 library](https://github.com/DLR-RM/stable-baselines3).
25
 
26
  ## Usage (with Stable-baselines3)
27
- TODO: Add your code
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
 
24
  This is a trained model of a **A2C** agent playing **LunarLander-v2** using the [stable-baselines3 library](https://github.com/DLR-RM/stable-baselines3).
25
 
26
  ## Usage (with Stable-baselines3)
27
+ ```python
28
+ from stable_baselines3 import A2C
29
+ from stable_baselines3.common.env_util import make_vec_env
30
+ from stable_baselines3.common.callbacks import EvalCallback
31
+
32
+ # Create the environment
33
+ env_id = "LunarLander-v2"
34
+ n_envs = 8
35
+ env = make_vec_env(env_id, n_envs=n_envs)
36
+
37
+ # Create the evaluation envs
38
+ eval_envs = make_vec_env(env_id, n_envs=5)
39
+
40
+ # Adjust evaluation interval depending on the number of envs
41
+ eval_freq = int(1e5)
42
+ eval_freq = max(eval_freq // n_envs, 1)
43
+
44
+ # Create evaluation callback to save best model
45
+ # and monitor agent performance
46
+ eval_callback = EvalCallback(
47
+ eval_envs,
48
+ best_model_save_path="./logs/",
49
+ eval_freq=eval_freq,
50
+ n_eval_episodes=10,
51
+ )
52
+
53
+
54
+ # Instantiate the agent
55
+ # Hyperparameters from https://github.com/DLR-RM/rl-baselines3-zoo
56
+ linear_schedule = lambda progress_remaining: progress_remaining * 0.00083
57
+ model = A2C(
58
+ "MlpPolicy",
59
+ env,
60
+ n_steps=5,
61
+ gamma=0.995,
62
+ learning_rate=linear_schedule,
63
+ ent_coef=0.00001,
64
+ verbose=1,
65
+ )
66
+
67
+ # Train the agent (you can kill it before using ctrl+c)
68
+ try:
69
+ model.learn(total_timesteps=int(5e5), callback=eval_callback)
70
+ except KeyboardInterrupt:
71
+ pass
72
+
73
+ # Load best model
74
+ model = A2C.load("logs/best_model.zip")
75
+ ```
76