awacke1's picture
Update app.py
709da65
import cv2
import streamlit as st
import time
from huggingface_sb3 import load_from_hub
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack
from stable_baselines3.common.env_util import make_atari_env
st.subheader("Atari 2600 Deep RL Environments Live AI")
# @st.cache This is not cachable :(
def load_env(env_name):
env = make_atari_env(env_name, n_envs=1)
env = VecFrameStack(env, n_stack=4)
return env
# @st.cache This is not cachable :(
def load_model(env_name):
custom_objects = {
"learning_rate": 0.0,
"lr_schedule": lambda _: 0.0,
"clip_range": lambda _: 0.0,
}
checkpoint = load_from_hub(
f"ThomasSimonini/ppo-{env_name}",
f"ppo-{env_name}.zip",
)
model = PPO.load(checkpoint, custom_objects=custom_objects)
return model
st.write("In game theory and optimization Nash Equilibrium loss minimization starts playing randomly but then by understanding ratios of action success to action-reward with an action (observe, decide/predict, act and then observe outcome the Deep RL agents go from 50% efficiency to 98-99% efficiency based on quality of decision without making mistakes. A good reference to environments is here https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/benchmark.md")
#st.write("Deep RL models: https://huggingface.co/sb3")
env_name = st.selectbox(
"Select environment",
(
"SeaquestNoFrameskip-v4",
"QbertNoFrameskip-v4",
"SpaceInvadersNoFrameskip-v4",
"PongNoFrameskip-v4",
#"AsteroidsNoFrameskip-v4",
#"BeamRiderNoFrameskip-v4",
#"BreakoutNoFrameskip-v4 ",
#"EnduroNoFrameskip-v4",
#"MsPacmanNoFrameskip-v4",
#"RoadRunnerNoFrameskip-v4",
#"Swimmer-v3",
#"Walker2d-v3",
),
)
num_episodes = st.slider("Number of Episodes", 1, 20, 5)
env = load_env(env_name)
model = load_model(env_name)
obs = env.reset()
with st.empty():
for i in range(num_episodes):
obs = env.reset()
done = False
while not done:
frame = env.render(mode="rgb_array")
im = st.image(frame, width=400)
action, _states = model.predict(obs)
obs, reward, done, info = env.step([action])
time.sleep(0.1)