Edit model card

A2C Agent playing Pendulum-v1

This is a trained model of a A2C agent playing Pendulum-v1 using the stable-baselines3 library.

Usage (with Stable-baselines3)

from huggingface_sb3 import load_from_hub
from stable_baselines3 import A2C
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import VecNormalize

# Download checkpoint and stats
env_id = "Pendulum-v1"
checkpoint = load_from_hub(f"araffin/a2c-{env_id}", f"a2c-{env_id}.zip")
vec_normalize_stats = load_from_hub(f"araffin/a2c-{env_id}", f"vec_normalize.pkl")

# Load the model
model = A2C.load(checkpoint)

env = make_vec_env(env_id, n_envs=1)
env = VecNormalize.load(vec_normalize_stats, env)
#  do not update them at test time
env.training = False
# reward normalization is not needed at test time
env.norm_reward = False

# Evaluate
print("Evaluating model")
mean_reward, std_reward = evaluate_policy(
    model,
    env,
    n_eval_episodes=20,
    deterministic=True,
)
print(f"Mean reward = {mean_reward:.2f} +/- {std_reward:.2f}")

# Start a new episode
obs = env.reset()

try:
    while True:
        action, _states = model.predict(obs, deterministic=True)
        obs, rewards, dones, info = env.step(action)
        env.render()
except KeyboardInterrupt:
    pass

Training Code

from huggingface_sb3 import package_to_hub
from stable_baselines3 import A2C
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecNormalize, sync_envs_normalization

# Create the environment
env_id = "Pendulum-v1"
env = make_vec_env(env_id, n_envs=8)
# Normalize
env = VecNormalize(env, gamma=0.9)

# Create the evaluation env (could be used in `EvalCallback`)
eval_env = make_vec_env(env_id, n_envs=1)
eval_env = VecNormalize(eval_env, gamma=0.9, training=False, norm_reward=False)

# Instantiate the agent
model = A2C(
    "MlpPolicy",
    env,
    n_steps=8,
    gamma=0.9,
    gae_lambda=0.9,
    use_sde=True,
    policy_kwargs=dict(log_std_init=-2),
    verbose=1,
)

# Train the agent
try:
    model.learn(total_timesteps=int(1e6))
except KeyboardInterrupt:
    pass


# Synchronize stats (done automatically in `EvalCallback`)
sync_envs_normalization(env, eval_env)

package_to_hub(
    model=model,
    model_name=f"a2c-{env_id}",
    model_architecture="A2C",
    env_id=env_id,
    eval_env=eval_env,
    repo_id=f"araffin/a2c-{env_id}",
    commit_message="Initial commit",
)
Downloads last month
4
Video Preview
loading

Evaluation results