A2C Agent playing PandaReachJointsDense-v3

This is a trained model of a A2C agent playing PandaReachJointsDense-v3 using the stable-baselines3 library.

Usage (with Stable-baselines3)

import gymnasium as gym
import panda_gym
from stable_baselines3 import A2C
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.monitor import Monitor
from huggingface_sb3 import package_to_hub
import wandb
from wandb.integration.sb3 import WandbCallback

def main():
    # CONSTANTS
    TOTAL_TIMESTEPS = 500_000
    ENV_ID = "PandaReachJointsDense-v3"

    # WANDB
    run_config = {
        "policy_type": "MultiInputPolicy",
        "total_timesteps": TOTAL_TIMESTEPS,
        "env_id": ENV_ID,
    }

    run = wandb.init(
        project="mso_3_4-be1",
        name="panda-reach-a2c",
        config=run_config,
        sync_tensorboard=True,
        monitor_gym=True,
        save_code=True,
    )

    # 1. Env setup
    env = gym.make(ENV_ID)
    env = Monitor(env)

    # 2. Model setup
    model = A2C("MultiInputPolicy", env, verbose=1, tensorboard_log=f"runs/{run.id}")

    print("Début de l'entraînement sur PandaReach...")
    
    # 3. Training
    model.learn(
        total_timesteps=TOTAL_TIMESTEPS,
        callback=WandbCallback(gradient_save_freq=100, model_save_path=f"models/{run.id}", verbose=2),
    )

    # 4. Save model
    model.save("a2c_sb3_panda_reach")
    print("Modèle sauvegardé sous a2c_sb3_panda_reach.zip !")

    # 5. Evaluation
    print("Évaluation du modèle entraîné...")
    mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=50, render=False)
    print(f"RÉSULTATS DE L'ÉVALUATION: Récompense moyenne = {mean_reward} +/- {std_reward}")

    env.close()
    run.finish()

    # 6. Upload du modèle sur le Hub Hugging Face
    print("Préparation du modèle pour le Hub Hugging Face...")
    
    model_name = "a2c-PandaReachJointsDense-v3" 
    repo_id = f"Flavio0834/{model_name}" 
    
    commit_message = "Upload du modèle A2C pour PandaReach"

    package_to_hub(
        model=model,                  
        model_name=model_name,        
        model_architecture="A2C",     
        env_id=ENV_ID,                
        eval_env=gym.make(ENV_ID, render_mode="rgb_array"),
        repo_id=repo_id,              
        commit_message=commit_message 
    )
    
    print(f"Modèle envoyé avec succès sur : https://huggingface.co/{repo_id}")

if __name__ == '__main__':
    main()
Downloads last month
-
Video Preview
loading

Evaluation results

  • mean_reward on PandaReachJointsDense-v3
    self-reported
    -0.25 +/- 0.19