import cv2 import streamlit as st import time from huggingface_sb3 import load_from_hub from stable_baselines3 import PPO from stable_baselines3.common.env_util import make_atari_env from stable_baselines3.common.vec_env import VecFrameStack from stable_baselines3.common.env_util import make_atari_env st.title("Atari Environments Live Model") # @st.cache This is not cachable :( def load_env(env_name): env = make_atari_env(env_name, n_envs=1) env = VecFrameStack(env, n_stack=4) return env # @st.cache This is not cachable :( def load_model(env_name): custom_objects = { "learning_rate": 0.0, "lr_schedule": lambda _: 0.0, "clip_range": lambda _: 0.0, } checkpoint = load_from_hub( f"ThomasSimonini/ppo-{env_name}", f"ppo-{env_name}.zip", ) model = PPO.load(checkpoint, custom_objects=custom_objects) return model env_name = st.selectbox( "Select environment", ( "SpaceInvadersNoFrameskip-v4", "PongNoFrameskip-v4", "SeaquestNoFrameskip-v4", "QbertNoFrameskip-v4", ), ) num_episodes = st.slider("Number of Episodes", 1, 20, 5) env = load_env(env_name) model = load_model(env_name) obs = env.reset() with st.empty(): for i in range(num_episodes): obs = env.reset() done = False while not done: frame = env.render(mode="rgb_array") im = st.image(frame, width=400) action, _states = model.predict(obs) obs, reward, done, info = env.step([action]) time.sleep(0.1)