ThomasSimonini HF staff commited on
Commit
7b0b070
1 Parent(s): 1212adb

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +8 -34
README.md CHANGED
@@ -5,27 +5,19 @@ tags:
5
  - stable-baselines3
6
  ---
7
  # PPO Agent playing PongNoFrameskip-v4
8
- This is a trained model of a PPO agent playing PongNoFrameskip-v4 using the stable-baselines3 library (our agent is the 🟢 one).
9
 
10
  <video src="https://huggingface.co/ThomasSimonini/ppo-PongNoFrameskip-v4/resolve/main/output.mp4" controls autoplay loop></video>
11
 
12
  ## Evaluation Results
13
- Mean_reward = 21.00 +/- 0.0
14
 
15
  # Usage (with Stable-baselines3)
16
- ## Watch your agent interacts (in Google Colab)
17
  - You need to use `gym==0.19` since it **includes Atari Roms**.
18
  - The Actor Space is 6 since we use only **legit actions**.
19
 
20
- ```python
21
- # Install these libraries in one cell (don't forget to restart the runtime after installing the librairies)
22
- !pip install stable-baselines3[extra]
23
- !pip install huggingface_sb3
24
- !pip install huggingface_hub
25
- !pip install pickle5
26
- ```
27
 
28
- Don't forget to restart the runtime before running the code below:
29
  ```python
30
  # Import the libraries
31
  import os
@@ -37,16 +29,8 @@ from stable_baselines3.common.vec_env import VecNormalize
37
 
38
  from stable_baselines3.common.env_util import make_atari_env
39
  from stable_baselines3.common.vec_env import VecFrameStack
40
- from stable_baselines3 import PPO
41
- from stable_baselines3.common.callbacks import CheckpointCallback
42
-
43
 
44
  from huggingface_sb3 import load_from_hub, push_to_hub
45
- import gym
46
- from stable_baselines3.common.vec_env import VecVideoRecorder, DummyVecEnv
47
-
48
-
49
- from stable_baselines3.common.evaluation import evaluate_policy
50
 
51
  # Load the model
52
  checkpoint = load_from_hub("ThomasSimonini/ppo-PongNoFrameskip-v4", "ppo-PongNoFrameskip-v4.zip")
@@ -60,24 +44,14 @@ custom_objects = {
60
 
61
  model= PPO.load(checkpoint, custom_objects=custom_objects)
62
 
63
- ## Evaluate the agent
64
  env = make_atari_env('PongNoFrameskip-v4', n_envs=1)
65
  env = VecFrameStack(env, n_stack=4)
66
 
67
- mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=10)
68
- print(f"mean_reward={mean_reward:.2f} +/- {std_reward}")
69
-
70
- ## Generate a video of your agent performing with Colab
71
- !pip install gym pyvirtualdisplay > /dev/null 2>&1
72
- !apt-get install -y xvfb python-opengl ffmpeg > /dev/null 2>&1
73
- !pip install colabgymrender==1.0.2
74
-
75
- observation = env.reset()
76
- terminal = False
77
- while not terminal:
78
- action, _state = model.predict(observation)
79
- observation, reward, terminal, info = env.step(action)
80
- env.play()
81
  ```
82
 
83
 
 
5
  - stable-baselines3
6
  ---
7
  # PPO Agent playing PongNoFrameskip-v4
8
+ This is a trained model of a **PPO agent playing PongNoFrameskip-v4 using the stable-baselines3 library** (our agent is the 🟢 one).
9
 
10
  <video src="https://huggingface.co/ThomasSimonini/ppo-PongNoFrameskip-v4/resolve/main/output.mp4" controls autoplay loop></video>
11
 
12
  ## Evaluation Results
13
+ Mean_reward: `21.00 +/- 0.0`
14
 
15
  # Usage (with Stable-baselines3)
 
16
  - You need to use `gym==0.19` since it **includes Atari Roms**.
17
  - The Actor Space is 6 since we use only **legit actions**.
18
 
19
+ Watch your agent interacts :
 
 
 
 
 
 
20
 
 
21
  ```python
22
  # Import the libraries
23
  import os
 
29
 
30
  from stable_baselines3.common.env_util import make_atari_env
31
  from stable_baselines3.common.vec_env import VecFrameStack
 
 
 
32
 
33
  from huggingface_sb3 import load_from_hub, push_to_hub
 
 
 
 
 
34
 
35
  # Load the model
36
  checkpoint = load_from_hub("ThomasSimonini/ppo-PongNoFrameskip-v4", "ppo-PongNoFrameskip-v4.zip")
 
44
 
45
  model= PPO.load(checkpoint, custom_objects=custom_objects)
46
 
 
47
  env = make_atari_env('PongNoFrameskip-v4', n_envs=1)
48
  env = VecFrameStack(env, n_stack=4)
49
 
50
+ obs = env.reset()
51
+ while True:
52
+ action, _states = model.predict(obs)
53
+ obs, rewards, dones, info = env.step(action)
54
+ env.render()
 
 
 
 
 
 
 
 
 
55
  ```
56
 
57