Jacampo commited on
Commit
df9abba
1 Parent(s): 83df564

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +46 -0
README.md CHANGED
@@ -5,3 +5,49 @@ tags:
5
  - stable-baselines3
6
  ---
7
  # TODO: Fill this model card
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  - stable-baselines3
6
  ---
7
  # TODO: Fill this model card
8
+ This is a pre-trained model of agent playing Asteroids-v0 using the [stable-baselines3](https://github.com/DLR-RM/stable-baselines3) library.
9
+
10
+ ### Usage (with Stable-baselines3)
11
+ Using this model becomes easy when you have stable-baselines3 and huggingface_sb3 installed:
12
+
13
+ ```
14
+ pip install stable-baselines3
15
+ pip install huggingface_sb3
16
+ ```
17
+
18
+ Then, you can use the model like this:
19
+
20
+ ```python
21
+ import gym
22
+
23
+ from huggingface_sb3 import load_from_hub
24
+ from stable_baselines3 import PPO
25
+ from stable_baselines3.common.evaluation import evaluate_policy
26
+
27
+ # Retrieve the model from the hub
28
+ ## repo_id = id of the model repository from the Hugging Face Hub (repo_id = {organization}/{repo_name})
29
+ ## filename = name of the model zip file from the repository
30
+ checkpoint = load_from_hub(repo_id="TrabajoAprendizajeProfundo/Trabajo", filename="Asteroids-v0.zip")
31
+ model = PPO.load(checkpoint)
32
+
33
+ # Evaluate the agent
34
+ eval_env = gym.make('Asteroids-v0')
35
+ mean_reward, std_reward = evaluate_policy(model, eval_env, n_eval_episodes=10, deterministic=True)
36
+ print(f"mean_reward={mean_reward:.2f} +/- {std_reward}")
37
+
38
+ # Watch the agent play
39
+ directory = './video'
40
+ env = Recorder(env, directory)
41
+
42
+ obs = env.reset()
43
+ done = False
44
+ while not done:
45
+ action, _state = model2.predict(obs)
46
+ obs, reward, done, info = env.step(action)
47
+
48
+ env.play()
49
+ ```
50
+
51
+ ### Evaluation Results
52
+ mean_reward, std_reward = evaluate_policy(model2, eval_env, n_eval_episodes=10)
53
+ print(f"mean_reward={mean_reward:.2f} +/- {std_reward}")