araffin commited on
Commit
4d63d30
1 Parent(s): 4e5e4dd
Files changed (1) hide show
  1. README.md +55 -1
README.md CHANGED
@@ -24,5 +24,59 @@ model-index:
24
  This is a trained model of a **DQN** agent playing **LunarLander-v2** using the [stable-baselines3 library](https://github.com/DLR-RM/stable-baselines3).
25
 
26
  ## Usage (with Stable-baselines3)
27
- TODO: Add your code
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
 
24
  This is a trained model of a **DQN** agent playing **LunarLander-v2** using the [stable-baselines3 library](https://github.com/DLR-RM/stable-baselines3).
25
 
26
  ## Usage (with Stable-baselines3)
27
+
28
+ ```python
29
+ from stable_baselines3 import DQN
30
+ from stable_baselines3.common.env_util import make_vec_env
31
+ from stable_baselines3.common.callbacks import EvalCallback
32
+
33
+ # Create the environment
34
+ env_id = "LunarLander-v2"
35
+ n_envs = 1
36
+ env = make_vec_env(env_id, n_envs=n_envs)
37
+
38
+ # Create the evaluation envs
39
+ eval_envs = make_vec_env(env_id, n_envs=5)
40
+
41
+ # Adjust evaluation interval depending on the number of envs
42
+ eval_freq = int(1e5)
43
+ eval_freq = max(eval_freq // n_envs, 1)
44
+
45
+ # Create evaluation callback to save best model
46
+ # and monitor agent performance
47
+ eval_callback = EvalCallback(
48
+ eval_envs,
49
+ best_model_save_path="./logs/",
50
+ eval_freq=eval_freq,
51
+ n_eval_episodes=10,
52
+ )
53
+
54
+ # Instantiate the agent
55
+ # Hyperparameters from https://github.com/DLR-RM/rl-baselines3-zoo
56
+ model = DQN(
57
+ "MlpPolicy",
58
+ env,
59
+ learning_starts=0,
60
+ batch_size=128,
61
+ buffer_size=50000,
62
+ learning_rate=1e-3,
63
+ target_update_interval=250,
64
+ train_freq=4,
65
+ gradient_steps=-1,
66
+ # Explore for 20_000 timesteps
67
+ exploration_fraction=0.04,
68
+ exploration_final_eps=0.1,
69
+ policy_kwargs=dict(net_arch=[256, 256]),
70
+ verbose=1,
71
+ )
72
+
73
+ # Train the agent (you can kill it before using ctrl+c)
74
+ try:
75
+ model.learn(total_timesteps=int(5e5), callback=eval_callback)
76
+ except KeyboardInterrupt:
77
+ pass
78
+
79
+ # Load best model
80
+ model = DQN.load("logs/best_model.zip")
81
+ ```
82