araffin commited on
Commit
a956d3b
1 Parent(s): 878e2e4

Add usage code

Browse files
Files changed (1) hide show
  1. README.md +39 -0
README.md CHANGED
@@ -23,6 +23,45 @@ model-index:
23
  # **DQN** Agent playing **LunarLander-v2**
24
  This is a trained model of a **DQN** agent playing **LunarLander-v2** using the [stable-baselines3 library](https://github.com/DLR-RM/stable-baselines3).
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  ## Training Code (with Stable-baselines3)
27
 
28
  ```python
 
23
  # **DQN** Agent playing **LunarLander-v2**
24
  This is a trained model of a **DQN** agent playing **LunarLander-v2** using the [stable-baselines3 library](https://github.com/DLR-RM/stable-baselines3).
25
 
26
+ ## Usage (with Stable-Baselines3)
27
+
28
+ ```python
29
+ from huggingface_sb3 import load_from_hub
30
+ from stable_baselines3 import DQN
31
+ from stable_baselines3.common.env_util import make_vec_env
32
+ from stable_baselines3.common.evaluation import evaluate_policy
33
+
34
+ # Download checkpoint
35
+ checkpoint = load_from_hub("araffin/dqn-LunarLander-v2", "dqn-LunarLander-v2.zip")
36
+ # Remove warning
37
+ kwargs = dict(target_update_interval=30)
38
+ # Load the model
39
+ model = DQN.load(checkpoint, **kwargs)
40
+
41
+ env = make_vec_env("LunarLander-v2", n_envs=1)
42
+
43
+ # Evaluate
44
+ print("Evaluating model")
45
+ mean_reward, std_reward = evaluate_policy(
46
+ model,
47
+ env,
48
+ n_eval_episodes=20,
49
+ deterministic=True,
50
+ )
51
+ print(f"Mean reward = {mean_reward:.2f} +/- {std_reward:.2f}")
52
+
53
+ # Start a new episode
54
+ obs = env.reset()
55
+
56
+ try:
57
+ while True:
58
+ action, _states = model.predict(obs, deterministic=True)
59
+ obs, rewards, dones, info = env.step(action)
60
+ env.render()
61
+ except KeyboardInterrupt:
62
+ pass
63
+ ```
64
+
65
  ## Training Code (with Stable-baselines3)
66
 
67
  ```python