araffin commited on
Commit
ff432d8
1 Parent(s): 1a05ad3
Files changed (1) hide show
  1. README.md +37 -0
README.md CHANGED
@@ -25,6 +25,43 @@ model-index:
25
 
26
  ## Usage (with Stable-baselines3)
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  ```python
29
  from stable_baselines3 import PPO
30
  from stable_baselines3.common.env_util import make_vec_env
 
25
 
26
  ## Usage (with Stable-baselines3)
27
 
28
+ ```python
29
+ from huggingface_sb3 import load_from_hub
30
+ from stable_baselines3 import PPO
31
+ from stable_baselines3.common.env_util import make_vec_env
32
+ from stable_baselines3.common.evaluation import evaluate_policy
33
+
34
+ # Download checkpoint
35
+ checkpoint = load_from_hub("araffin/ppo-LunarLander-v2", "ppo-LunarLander-v2.zip")
36
+ # Load the model
37
+ model = PPO.load(checkpoint)
38
+
39
+ env = make_vec_env("LunarLander-v2", n_envs=1)
40
+
41
+ # Evaluate
42
+ print("Evaluating model")
43
+ mean_reward, std_reward = evaluate_policy(
44
+ model,
45
+ env,
46
+ n_eval_episodes=20,
47
+ deterministic=True,
48
+ )
49
+ print(f"Mean reward = {mean_reward:.2f} +/- {std_reward:.2f}")
50
+
51
+ # Start a new episode
52
+ obs = env.reset()
53
+
54
+ try:
55
+ while True:
56
+ action, _states = model.predict(obs, deterministic=True)
57
+ obs, rewards, dones, info = env.step(action)
58
+ env.render()
59
+ except KeyboardInterrupt:
60
+ pass
61
+ ```
62
+
63
+ ## Training code (with SB3)
64
+
65
  ```python
66
  from stable_baselines3 import PPO
67
  from stable_baselines3.common.env_util import make_vec_env