araffin commited on
Commit
a9d03a3
1 Parent(s): f98e0ce
Files changed (1) hide show
  1. README.md +38 -0
README.md CHANGED
@@ -23,6 +23,44 @@ model-index:
23
  # **A2C** Agent playing **LunarLander-v2**
24
  This is a trained model of a **A2C** agent playing **LunarLander-v2** using the [stable-baselines3 library](https://github.com/DLR-RM/stable-baselines3).
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  ## Training code (with Stable-baselines3)
27
  ```python
28
  from stable_baselines3 import A2C
 
23
  # **A2C** Agent playing **LunarLander-v2**
24
  This is a trained model of a **A2C** agent playing **LunarLander-v2** using the [stable-baselines3 library](https://github.com/DLR-RM/stable-baselines3).
25
 
26
+ ## Usage (with Stable-Baselines3)
27
+
28
+ ```python
29
+ from huggingface_sb3 import load_from_hub
30
+ from stable_baselines3 import A2C
31
+ from stable_baselines3.common.env_util import make_vec_env
32
+ from stable_baselines3.common.evaluation import evaluate_policy
33
+
34
+ # Download checkpoint
35
+ checkpoint = load_from_hub("araffin/a2c-LunarLander-v2", "a2c-LunarLander-v2.zip")
36
+ # Load the model
37
+ model = A2C.load(checkpoint)
38
+
39
+ env = make_vec_env("LunarLander-v2", n_envs=1)
40
+
41
+ # Evaluate
42
+ print("Evaluating model")
43
+ mean_reward, std_reward = evaluate_policy(
44
+ model,
45
+ env,
46
+ n_eval_episodes=20,
47
+ deterministic=True,
48
+ )
49
+ print(f"Mean reward = {mean_reward:.2f} +/- {std_reward:.2f}")
50
+
51
+ # Start a new episode
52
+ obs = env.reset()
53
+
54
+ try:
55
+ while True:
56
+ action, _states = model.predict(obs, deterministic=True)
57
+ obs, rewards, dones, info = env.step(action)
58
+ env.render()
59
+ except KeyboardInterrupt:
60
+ pass
61
+
62
+ ```
63
+
64
  ## Training code (with Stable-baselines3)
65
  ```python
66
  from stable_baselines3 import A2C