DeathReaper0965 commited on
Commit
2c43073
1 Parent(s): 563fb61

Add model usage code

Browse files
Files changed (1) hide show
  1. README.md +35 -6
README.md CHANGED
@@ -22,16 +22,45 @@ model-index:
22
  ---
23
 
24
  # **PPO** Agent playing **LunarLander-v2**
25
- This is a trained model of a **PPO** agent playing **LunarLander-v2**
26
- using the [stable-baselines3 library](https://github.com/DLR-RM/stable-baselines3).
27
 
28
  ## Usage (with Stable-baselines3)
29
- TODO: Add your code
30
-
31
 
32
  ```python
33
- from stable_baselines3 import ...
 
 
 
34
  from huggingface_sb3 import load_from_hub
35
 
36
- ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  ```
 
22
  ---
23
 
24
  # **PPO** Agent playing **LunarLander-v2**
25
+ A trained model of a **PPO** agent playing **LunarLander-v2** using the [stable-baselines3 library](https://github.com/DLR-RM/stable-baselines3).
 
26
 
27
  ## Usage (with Stable-baselines3)
 
 
28
 
29
  ```python
30
+ from stable_baselines3 import PPO
31
+ from stable_baselines3.common.env_util import make_vec_env
32
+ from stable_baselines3.common.evaluation import evaluate_policy
33
+
34
  from huggingface_sb3 import load_from_hub
35
 
36
+
37
+ # Download the model checkpoint
38
+ model_checkpoint = load_from_hub("deathReaper0965/ppo-mlp-LunarLander-v2", "ppo-mlp-LunarLander-v2.zip")
39
+ # Create a vectorized environment
40
+ env = make_vec_env("LunarLander-v2", n_envs=1)
41
+
42
+ # Load the model
43
+ model = PPO.load(model_checkpoint, env=env)
44
+
45
+ # Evaluate
46
+ print("Evaluating model")
47
+ mean_reward, std_reward = evaluate_policy(
48
+ model,
49
+ env,
50
+ n_eval_episodes=30,
51
+ deterministic=True,
52
+ )
53
+ print(f"Mean reward = {mean_reward:.2f} +/- {std_reward}")
54
+
55
+ # Start a new episode
56
+ obs = env.reset()
57
+
58
+ try:
59
+ while True:
60
+ action, state = model.predict(obs, deterministic=True)
61
+ obs, reward, done, info = env.step(action)
62
+ env.render()
63
+
64
+ except KeyboardInterrupt:
65
+ pass
66
  ```