Update README.md
Browse filesAdd training code
README.md
CHANGED
@@ -56,3 +56,35 @@ mean_reward, std_reward = evaluate_policy(model, env,
|
|
56 |
# Print the results
|
57 |
print('mean_reward={:.2f} +/- {:.2f}'.format(mean_reward, std_reward))
|
58 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
# Print the results
|
57 |
print('mean_reward={:.2f} +/- {:.2f}'.format(mean_reward, std_reward))
|
58 |
```
|
59 |
+
|
60 |
+
## Training (with Stable-baselines3)
|
61 |
+
```python
|
62 |
+
from huggingface_sb3 import load_from_hub
|
63 |
+
|
64 |
+
from stable_baselines3 import PPO
|
65 |
+
from stable_baselines3.common.evaluation import evaluate_policy
|
66 |
+
from stable_baselines3.common.env_util import make_vec_env
|
67 |
+
|
68 |
+
# Create the evaluation envs
|
69 |
+
env = make_vec_env('LunarLander-v2', n_envs=16)
|
70 |
+
env = gym.make('LunarLander-v2')
|
71 |
+
|
72 |
+
# Instantiate the agent
|
73 |
+
model = PPO(
|
74 |
+
policy = 'MlpPolicy',
|
75 |
+
env = env,
|
76 |
+
n_steps = 1024,
|
77 |
+
batch_size = 32,
|
78 |
+
n_epochs = 8,
|
79 |
+
gamma = 0.99,
|
80 |
+
gae_lambda = 0.95,
|
81 |
+
ent_coef = 0.01,
|
82 |
+
verbose=1,
|
83 |
+
seed=2022)
|
84 |
+
|
85 |
+
# Train
|
86 |
+
model.learn(total_timesteps=1500000)
|
87 |
+
# Save model
|
88 |
+
model_name = "Any-Name"
|
89 |
+
model.save(model_name)
|
90 |
+
```
|