import cv2 import gradio as gr import time from huggingface_sb3 import load_from_hub from stable_baselines3 import PPO from stable_baselines3.common.env_util import make_atari_env from stable_baselines3.common.vec_env import VecFrameStack from stable_baselines3.common.env_util import make_atari_env max_steps = 5000 # Let's try with 5000 steps. # Loading functions were taken from Edward Beeching code def load_env(env_name): env = make_atari_env(env_name, n_envs=1) env = VecFrameStack(env, n_stack=4) return env def load_model(env_name): custom_objects = { "learning_rate": 0.0, "lr_schedule": lambda _: 0.0, "clip_range": lambda _: 0.0, } checkpoint = load_from_hub( f"ThomasSimonini/ppo-{env_name}", f"ppo-{env_name}.zip", ) model = PPO.load(checkpoint, custom_objects=custom_objects) return model def replay(env_name, time_sleep): max_steps = 500 env = load_env(env_name) model = load_model(env_name) #for i in range(num_episodes): obs = env.reset() done = False i = 0 while not done: i+= 1 if i < max_steps: frame = env.render(mode="rgb_array") action, _states = model.predict(obs) obs, reward, done, info = env.step([action]) time.sleep(time_sleep) yield frame else: break demo = gr.Interface( replay, [gr.Dropdown(["SpaceInvadersNoFrameskip-v4", "PongNoFrameskip-v4", "SeaquestNoFrameskip-v4", "QbertNoFrameskip-v4", ]), #gr.Slider(100, 10000, value=500), gr.Slider(0.01, 1, value=0.05), #gr.Slider(1, 20, value=5) ], gr.Image(), title="Watch Agents playing Atari games 🤖", description="Select an environment to watch a Hugging Face's trained deep reinforcement learning agent.", article = "time_sleep is the time delay between each frame (0.05 by default)." ).launch().queue()