metadata
library_name: stable-baselines3
tags:
- CartPole-v1
- deep-reinforcement-learning
- reinforcement-learning
- stable-baselines3
model-index:
- name: A2C
results:
- task:
type: reinforcement-learning
name: reinforcement-learning
dataset:
name: CartPole-v1
type: CartPole-v1
metrics:
- type: mean_reward
value: 9.80 +/- 0.60
name: mean_reward
verified: false
A2C Agent playing CartPole-v1
This is a trained model of a A2C agent playing CartPole-v1 using the stable-baselines3 library.
Usage (with Stable-baselines3)
import gym
from stable_baselines3 import A2C
from stable_baselines3.common.env_util import make_vec_env
from huggingface_sb3 import package_to_hub
import wandb
from wandb.integration.sb3 import WandbCallback
# Parallel environments
env = gym.make("CartPole-v1")
eval_env = gym.make("CartPole-v1")
config = {
"policy_type": "MlpPolicy",
"total_timesteps": 25000,
"env_id": "CartPole-v1",
}
run = wandb.init(
project="cart_pole",
config=config,
sync_tensorboard=True, # auto-upload sb3's tensorboard metrics
# monitor_gym=True, # auto-upload the videos of agents playing the game
# save_code=True, # optional
)
model = A2C("MlpPolicy", env, verbose=1, tensorboard_log=f"runs/{run.id}")
model.learn(
total_timesteps=config["total_timesteps"],
callback=WandbCallback(
model_save_path=f"models/{run.id}",
verbose=2,
),
)
run.finish()
model.save("a2c_Cart_Pole")
package_to_hub(model=model,
model_name="a2c_Cart_Pole",
model_architecture="A2C",
env_id="CartPole-v1",
eval_env=eval_env,
repo_id="mRoszak/A2C_Cart_Pole",
commit_message="Test commit")
...