import cv2 import streamlit as st import time from huggingface_sb3 import load_from_hub from stable_baselines3 import PPO from stable_baselines3.common.env_util import make_atari_env from stable_baselines3.common.vec_env import VecFrameStack from stable_baselines3.common.env_util import make_atari_env st.title("Atari Environments Live Model") # @st.cache This is not cachable :( def load_env(env_name): env = make_atari_env(env_name, n_envs=1) env = VecFrameStack(env, n_stack=4) return env # @st.cache This is not cachable :( def load_model(env_name): custom_objects = { "learning_rate": 0.0, "lr_schedule": lambda _: 0.0, "clip_range": lambda _: 0.0, } checkpoint = load_from_hub( f"ThomasSimonini/ppo-{env_name}", f"ppo-{env_name}.zip", ) model = PPO.load(checkpoint, custom_objects=custom_objects) return model st.subheader("In game theory and optimization Nash Equilibrium loss minimization starts playing randomly but then by understanding ratios of action success to action-reward with an action (observe, decide/predict, act and then observe outcome the Deep RL agents go from 50% efficiency to 98-99% efficiency based on quality of decision without making mistakes.") st.subheader("list of agent environments https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/benchmark.md") st.subheader("Deep RL models: https://huggingface.co/sb3") env_name = st.selectbox( "Select environment", ( "SeaquestNoFrameskip-v4", "QbertNoFrameskip-v4", "SpaceInvadersNoFrameskip-v4", "PongNoFrameskip-v4", #"AsteroidsNoFrameskip-v4", #"BeamRiderNoFrameskip-v4", #"BreakoutNoFrameskip-v4 ", #"EnduroNoFrameskip-v4", #"MsPacmanNoFrameskip-v4", #"RoadRunnerNoFrameskip-v4", #"Swimmer-v3", #"Walker2d-v3", ), ) num_episodes = st.slider("Number of Episodes", 1, 20, 5) env = load_env(env_name) model = load_model(env_name) obs = env.reset() with st.empty(): for i in range(num_episodes): obs = env.reset() done = False while not done: frame = env.render(mode="rgb_array") im = st.image(frame, width=400) action, _states = model.predict(obs) obs, reward, done, info = env.step([action]) time.sleep(0.1)