ThomasSimonini's picture
Removed video in model card
3e3884e
metadata
tags:
  - deep-reinforcement-learning
  - reinforcement-learning
  - stable-baselines3
  - atari
model-index:
  - name: PPO Agent
    results:
      - task:
          type: reinforcement-learning
        dataset:
          type: SeaquestNoFrameskip-v4
          name: SeaquestNoFrameskip-v4
        metrics:
          - type: mean_reward
            value: 1820.00 +/- 20.0

PPO Agent playing SeaquestNoFrameskip-v4

This is a trained model of a PPO agent playing SeaquestNoFrameskip-v4 using the stable-baselines3 library.

The training report: https://wandb.ai/simoninithomas/HFxSB3/reports/Atari-HFxSB3-Benchmark--VmlldzoxNjI3NTIy

Evaluation Results

Mean_reward: 1820.00 +/- 20.0

Usage (with Stable-baselines3)

  • You need to use gym==0.19 since it includes Atari Roms.
  • The Action Space is 6 since we use only possible actions in this game.

Watch your agent interacts :

# Import the libraries
import os 

import gym

from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecNormalize

from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack

from huggingface_sb3 import load_from_hub, push_to_hub

# Load the model
checkpoint = load_from_hub("ThomasSimonini/ppo-SeaquestNoFrameskip-v4", "ppo-SeaquestNoFrameskip-v4.zip")

# Because we using 3.7 on Colab and this agent was trained with 3.8 to avoid Pickle errors:
custom_objects = {
            "learning_rate": 0.0,
            "lr_schedule": lambda _: 0.0,
            "clip_range": lambda _: 0.0,
        }

model= PPO.load(checkpoint, custom_objects=custom_objects)

env = make_atari_env('SeaquestNoFrameskip-v4', n_envs=1)
env = VecFrameStack(env, n_stack=4)

obs = env.reset()
while True:
    action, _states = model.predict(obs)
    obs, rewards, dones, info = env.step(action)
    env.render()

Training Code

import wandb
import gym

from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack, VecVideoRecorder
from stable_baselines3.common.callbacks import CheckpointCallback

from wandb.integration.sb3 import WandbCallback

from huggingface_sb3 import load_from_hub, push_to_hub

config = {
    "env_name": "SeaquestNoFrameskip-v4",
    "num_envs": 8,
    "total_timesteps": int(10e6),
    "seed": 2862830927,    
}

run = wandb.init(
    project="HFxSB3",
    config = config,
    sync_tensorboard = True,  # Auto-upload sb3's tensorboard metrics
    monitor_gym = True, # Auto-upload the videos of agents playing the game
    save_code = True, # Save the code to W&B
    )

# There already exists an environment generator
# that will make and wrap atari environments correctly.
# Here we are also multi-worker training (n_envs=8 => 8 environments)
env = make_atari_env(config["env_name"], n_envs=config["num_envs"], seed=config["seed"]) #SeaquestNoFrameskip-v4

print("ENV ACTION SPACE: ", env.action_space.n)

# Frame-stacking with 4 frames
env = VecFrameStack(env, n_stack=4)
# Video recorder
env = VecVideoRecorder(env, "videos", record_video_trigger=lambda x: x % 100000 == 0, video_length=2000)

model = PPO(policy = "CnnPolicy",
            env = env,
            batch_size = 256,
            clip_range = 0.1,
            ent_coef = 0.01,
            gae_lambda = 0.9,
            gamma = 0.99,
            learning_rate = 2.5e-4,
            max_grad_norm = 0.5,
            n_epochs = 4,
            n_steps = 128,
            vf_coef = 0.5,
            tensorboard_log = f"runs",
            verbose=1,
            )
    
model.learn(
    total_timesteps = config["total_timesteps"],
    callback = [
        WandbCallback(
        gradient_save_freq = 1000,
        model_save_path = f"models/{run.id}",
        ), 
        CheckpointCallback(save_freq=10000, save_path='./seaquest',
                                         name_prefix=config["env_name"]),
        ]
)

model.save("ppo-SeaquestNoFrameskip-v4.zip")
push_to_hub(repo_id="ThomasSimonini/ppo-SeaquestNoFrameskip-v4", 
    filename="ppo-SeaquestNoFrameskip-v4.zip",
    commit_message="Added Seaquest trained agent")