A2C_Cart_Pole / README.md
mRoszak's picture
Update README.md
f74111e
metadata
library_name: stable-baselines3
tags:
  - CartPole-v1
  - deep-reinforcement-learning
  - reinforcement-learning
  - stable-baselines3
model-index:
  - name: A2C
    results:
      - task:
          type: reinforcement-learning
          name: reinforcement-learning
        dataset:
          name: CartPole-v1
          type: CartPole-v1
        metrics:
          - type: mean_reward
            value: 9.80 +/- 0.60
            name: mean_reward
            verified: false

A2C Agent playing CartPole-v1

This is a trained model of a A2C agent playing CartPole-v1 using the stable-baselines3 library.

Usage (with Stable-baselines3)

import gym

from stable_baselines3 import A2C
from stable_baselines3.common.env_util import make_vec_env
from huggingface_sb3 import package_to_hub
import wandb
from wandb.integration.sb3 import WandbCallback

# Parallel environments
env = gym.make("CartPole-v1")
eval_env = gym.make("CartPole-v1")
config = {
    "policy_type": "MlpPolicy",
    "total_timesteps": 25000,
    "env_id": "CartPole-v1",
    }

run = wandb.init(
    project="cart_pole",
    config=config,
    sync_tensorboard=True,  # auto-upload sb3's tensorboard metrics
    # monitor_gym=True,  # auto-upload the videos of agents playing the game
    # save_code=True,  # optional
)


model = A2C("MlpPolicy", env, verbose=1, tensorboard_log=f"runs/{run.id}")
model.learn(
    total_timesteps=config["total_timesteps"],
    callback=WandbCallback(
        model_save_path=f"models/{run.id}",
        verbose=2,
    ),
)
run.finish()



model.save("a2c_Cart_Pole")


package_to_hub(model=model,
               model_name="a2c_Cart_Pole",
               model_architecture="A2C",
               env_id="CartPole-v1",
               eval_env=eval_env,
               repo_id="mRoszak/A2C_Cart_Pole",
               commit_message="Test commit")

...