DeathReaper0965's picture
verified scores
385dc0c
|
raw
history blame
1.86 kB
metadata
library_name: stable-baselines3
tags:
  - LunarLander-v2
  - deep-reinforcement-learning
  - reinforcement-learning
  - stable-baselines3
model-index:
  - name: PPO
    results:
      - task:
          type: reinforcement-learning
          name: reinforcement-learning
        dataset:
          name: LunarLander-v2
          type: LunarLander-v2
        metrics:
          - type: mean_reward
            value: 330.99 +/- 20.23
            name: mean_reward
            verified: true

PPO Agent playing LunarLander-v2

A trained model of a PPO agent playing LunarLander-v2 using the stable-baselines3 library.

Usage (with Stable-baselines3)

from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.evaluation import evaluate_policy

from huggingface_sb3 import load_from_hub


# Download the model checkpoint
model_checkpoint = load_from_hub("DeathReaper0965/ppo-mlp-LunarLander-v2", "ppo-mlp-LunarLander-v2.zip")
# Create a vectorized environment
env = make_vec_env("LunarLander-v2", n_envs=1)

# Load the model
model = PPO.load(model_checkpoint, env=env)

# Evaluate
print("Evaluating model")
mean_reward, std_reward = evaluate_policy(
    model,
    env,
    n_eval_episodes=30,
    deterministic=True,
)
print(f"Mean reward = {mean_reward:.2f} +/- {std_reward}")

# Start a new episode
obs = env.reset()

try:
    while True:
        action, state = model.predict(obs, deterministic=True)
        obs, reward, done, info = env.step(action)
        env.render()

except KeyboardInterrupt:
    pass

Conclusion

The above steps ensure that the Agent Model gets downloaded successfully.
You may need to download the required libraries and packages specific to your operating system to resume training from the supplied checkpoint and run it for more steps.