ThomasSimonini HF staff commited on
Commit
179a163
1 Parent(s): 93c7cc6

Add Usage part

Browse files
Files changed (1) hide show
  1. README.md +44 -0
README.md CHANGED
@@ -10,5 +10,49 @@ This is a pre-trained model of a PPO agent playing Walker2DBulletEnv-v0 using th
10
 
11
  <video src="https://huggingface.co/ThomasSimonini/ppo-Walker2DBulletEnv-v0/resolve/main/output.mp4" controls autoplay loop></video>
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  ### Evaluation Results
14
  Mean_reward: 2371.90 +/- 16.50
 
10
 
11
  <video src="https://huggingface.co/ThomasSimonini/ppo-Walker2DBulletEnv-v0/resolve/main/output.mp4" controls autoplay loop></video>
12
 
13
+ ### Usage (with Stable-baselines3)
14
+ Using this model becomes easy when you have stable-baselines3 and huggingface_sb3 installed:
15
+ ```
16
+ pip install stable-baselines3
17
+ pip install huggingface_sb3
18
+ ```
19
+
20
+ Then, you can use the model like this:
21
+
22
+ ```python
23
+
24
+ import gym
25
+ import pybullet_envs
26
+
27
+ from huggingface_sb3 import load_from_hub
28
+
29
+ from stable_baselines3 import PPO
30
+ from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
31
+ from stable_baselines3.common.evaluation import evaluate_policy
32
+
33
+ # Retrieve the model from the hub
34
+ ## repo_id = id of the model repository from the Hugging Face Hub (repo_id = {organization}/{repo_name})
35
+ ## filename = name of the model zip file from the repository
36
+ repo_id = "ThomasSimonini/ppo-Walker2DBulletEnv-v0"
37
+ checkpoint = load_from_hub(repo_id = repo_id, filename="ppo-Walker2DBulletEnv-v0.zip")
38
+ model = PPO.load(checkpoint)
39
+
40
+ # Load the saved statistics
41
+ stats_path = load_from_hub(repo_id = repo_id, filename="vec_normalize.pkl")
42
+
43
+ eval_env = DummyVecEnv([lambda: gym.make("Walker2DBulletEnv-v0")])
44
+ eval_env = VecNormalize.load(stats_path, eval_env)
45
+ # do not update them at test time
46
+ eval_env.training = False
47
+ # reward normalization is not needed at test time
48
+ eval_env.norm_reward = False
49
+
50
+ from stable_baselines3.common.evaluation import evaluate_policy
51
+
52
+ mean_reward, std_reward = evaluate_policy(model, eval_env)
53
+ print(f"Mean reward = {mean_reward:.2f} +/- {std_reward:.2f}")
54
+
55
+ ```
56
+
57
  ### Evaluation Results
58
  Mean_reward: 2371.90 +/- 16.50