araffin commited on
Commit
1a05ad3
1 Parent(s): 5992afa
Files changed (1) hide show
  1. README.md +52 -3
README.md CHANGED
@@ -20,9 +20,58 @@ model-index:
20
  type: LunarLander-v2
21
  ---
22
 
23
- # **PPO** Agent playing **LunarLander-v2**
24
  This is a trained model of a **PPO** agent playing **LunarLander-v2** using the [stable-baselines3 library](https://github.com/DLR-RM/stable-baselines3).
25
 
26
- ## Usage (with Stable-baselines3)
27
- TODO: Add your code
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
 
20
  type: LunarLander-v2
21
  ---
22
 
23
+ # **PPO** Agent playing **LunarLander-v2**
24
  This is a trained model of a **PPO** agent playing **LunarLander-v2** using the [stable-baselines3 library](https://github.com/DLR-RM/stable-baselines3).
25
 
26
+ ## Usage (with Stable-baselines3)
27
+
28
+ ```python
29
+ from stable_baselines3 import PPO
30
+ from stable_baselines3.common.env_util import make_vec_env
31
+ from stable_baselines3.common.callbacks import EvalCallback
32
+
33
+ # Create the environment
34
+ env_id = "LunarLander-v2"
35
+ n_envs = 16
36
+ env = make_vec_env(env_id, n_envs=n_envs)
37
+
38
+ # Create the evaluation envs
39
+ eval_envs = make_vec_env(env_id, n_envs=5)
40
+
41
+ # Adjust evaluation interval depending on the number of envs
42
+ eval_freq = int(1e5)
43
+ eval_freq = max(eval_freq // n_envs, 1)
44
+
45
+ # Create evaluation callback to save best model
46
+ # and monitor agent performance
47
+ eval_callback = EvalCallback(
48
+ eval_envs,
49
+ best_model_save_path="./logs/",
50
+ eval_freq=eval_freq,
51
+ n_eval_episodes=10,
52
+ )
53
+
54
+ # Instantiate the agent
55
+ # Hyperparameters from https://github.com/DLR-RM/rl-baselines3-zoo
56
+ model = PPO(
57
+ "MlpPolicy",
58
+ env,
59
+ n_steps=1024,
60
+ batch_size=64,
61
+ gae_lambda=0.98,
62
+ gamma=0.999,
63
+ n_epochs=4,
64
+ ent_coef=0.01,
65
+ verbose=1,
66
+ )
67
+
68
+ # Train the agent (you can kill it before using ctrl+c)
69
+ try:
70
+ model.learn(total_timesteps=int(5e6), callback=eval_callback)
71
+ except KeyboardInterrupt:
72
+ pass
73
+
74
+ # Load best model
75
+ model = PPO.load("logs/best_model.zip")
76
+ ```
77