kinkpunk commited on
Commit
51e738b
1 Parent(s): c27712b

Update README.md

Browse files

Add training code

Files changed (1) hide show
  1. README.md +32 -0
README.md CHANGED
@@ -56,3 +56,35 @@ mean_reward, std_reward = evaluate_policy(model, env,
56
  # Print the results
57
  print('mean_reward={:.2f} +/- {:.2f}'.format(mean_reward, std_reward))
58
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  # Print the results
57
  print('mean_reward={:.2f} +/- {:.2f}'.format(mean_reward, std_reward))
58
  ```
59
+
60
+ ## Training (with Stable-baselines3)
61
+ ```python
62
+ from huggingface_sb3 import load_from_hub
63
+
64
+ from stable_baselines3 import PPO
65
+ from stable_baselines3.common.evaluation import evaluate_policy
66
+ from stable_baselines3.common.env_util import make_vec_env
67
+
68
+ # Create the evaluation envs
69
+ env = make_vec_env('LunarLander-v2', n_envs=16)
70
+ env = gym.make('LunarLander-v2')
71
+
72
+ # Instantiate the agent
73
+ model = PPO(
74
+ policy = 'MlpPolicy',
75
+ env = env,
76
+ n_steps = 1024,
77
+ batch_size = 32,
78
+ n_epochs = 8,
79
+ gamma = 0.99,
80
+ gae_lambda = 0.95,
81
+ ent_coef = 0.01,
82
+ verbose=1,
83
+ seed=2022)
84
+
85
+ # Train
86
+ model.learn(total_timesteps=1500000)
87
+ # Save model
88
+ model_name = "Any-Name"
89
+ model.save(model_name)
90
+ ```