bobobert4 commited on
Commit
dc2feff
1 Parent(s): 4924db4

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +57 -2
README.md CHANGED
@@ -30,8 +30,63 @@ TODO: Add your code
30
 
31
 
32
  ```python
33
- from stable_baselines3 import ...
34
- from huggingface_sb3 import load_from_hub
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  ...
37
  ```
 
30
 
31
 
32
  ```python
33
+ import panda_gym
34
+ import gym
35
 
36
+ from huggingface_sb3 import package_to_hub
37
+
38
+ from stable_baselines3 import A2C
39
+ from stable_baselines3.common.env_util import make_vec_env
40
+ from stable_baselines3.common.vec_env import SubprocVecEnv
41
+ from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
42
+ from stable_baselines3.common.evaluation import evaluate_policy
43
+
44
+ env_id = "PandaReachDense-v2"
45
+ model_name = "PandaReachDenseA2C-n8"
46
+ env_name = f"{env_id}_vec_normalize.pkl"
47
+
48
+ if __name__=="__main__":
49
+ env = make_vec_env(env_id, n_envs=6, vec_env_cls=SubprocVecEnv)
50
+ # 3
51
+ env = VecNormalize(env, norm_obs=True, norm_reward=False, clip_obs=10.)
52
+ def linear_scheduler(progress_remaining: float):
53
+ # from https://github.com/DLR-RM/rl-baselines3-zoo/blob/33eba22eb36128412a5b22b57a7a10bfe71e6278/rl_zoo3/utils.py
54
+ return progress_remaining * 0.0009
55
+ # 4
56
+ model = A2C(policy = "MultiInputPolicy",
57
+ env = env,
58
+ verbose=1,
59
+ device='cpu',
60
+ learning_rate=linear_scheduler,
61
+ use_rms_prop=True,
62
+ gae_lambda=0.9,
63
+ use_sde=True,
64
+ n_steps=8,
65
+ )
66
+ # 5
67
+ model.learn(1_500_000)
68
+
69
+ model.save(model_name)
70
+ env.save(env_name)
71
+ del env
72
+
73
+ eval_env = DummyVecEnv([lambda: gym.make("PandaReachDense-v2")])
74
+ eval_env = VecNormalize.load(env_name, eval_env)
75
+
76
+ eval_env.training = False
77
+ eval_env.norm_reward = False
78
+
79
+ mean_reward, std_reward = evaluate_policy(model, eval_env)
80
+ print(f"Mean reward = {mean_reward:.2f} +/- {std_reward:.2f}")
81
+
82
+ package_to_hub(
83
+ model=model,
84
+ model_name=model_name,
85
+ model_architecture="A2C",
86
+ env_id=env_id,
87
+ eval_env=eval_env,
88
+ repo_id=f"bobobert4/a2c-{env_id}",
89
+ commit_message="Another commit",
90
+ )
91
  ...
92
  ```