araffin commited on
Commit
52818c4
1 Parent(s): 504c2ab

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +51 -1
README.md CHANGED
@@ -1,3 +1,53 @@
1
  ---
2
- {}
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ tags:
3
+ - deep-reinforcement-learning
4
+ - reinforcement-learning
5
+ - stable-baselines3
6
  ---
7
+
8
+ This is a pre-trained model of a PPO agent playing CartPole-v1 using the [stable-baselines3](https://github.com/DLR-RM/stable-baselines3) library.
9
+
10
+ ### Usage (with Stable-baselines3)
11
+ Using this model becomes easy when you have stable-baselines3 and huggingface_sb3 installed:
12
+
13
+ ```
14
+ pip install stable-baselines3
15
+ pip install huggingface_sb3
16
+ ```
17
+
18
+ Then, you can use the model like this:
19
+
20
+ ```python
21
+ import os
22
+
23
+ import gymnasium as gym
24
+
25
+ from huggingface_sb3 import load_from_hub
26
+ from stable_baselines3 import PPO
27
+ from stable_baselines3.common.evaluation import evaluate_policy
28
+
29
+ # Allow the use of `pickle.load()` when downloading model from the hub
30
+ # Please make sure that the organization from which you download can be trusted
31
+ os.environ["TRUST_REMOTE_CODE"] = "True"
32
+
33
+ # Retrieve the model from the hub
34
+ ## repo_id = id of the model repository from the Hugging Face Hub (repo_id = {organization}/{repo_name})
35
+ ## filename = name of the model zip file from the repository
36
+ checkpoint = load_from_hub(
37
+ repo_id="sb3/demo-hf-CartPole-v1",
38
+ filename="ppo-CartPole-v1",
39
+ )
40
+ model = PPO.load(checkpoint)
41
+
42
+ # Evaluate the agent and watch it
43
+ eval_env = gym.make("CartPole-v1")
44
+ mean_reward, std_reward = evaluate_policy(
45
+ model, eval_env, render=True, n_eval_episodes=5, deterministic=True, warn=False
46
+ )
47
+ print(f"mean_reward={mean_reward:.2f} +/- {std_reward}")
48
+ ```
49
+
50
+ ### Evaluation Results
51
+ Mean_reward: 500.0
52
+
53
+