sw_test / app.py
Edward Beeching
fsa
bee7740
raw history blame
No virus
1.25 kB
import numpy as np
import time
import streamlit as st
from scienceworld import ScienceWorldEnv
st.title("ScienceWorld interactive demo")
hash_env = lambda _: None
import os
stream = os.popen('java -version')
output = stream.read()
st.write('output')
@st.cache(allow_output_mutation=True)
def load_env():
simplification_str = 'easy'
task_idx = None
print('Loading envs')
step_limit = 100
env = ScienceWorldEnv("", None, step_limit, 0)
if task_idx is None:
task_idx = 13
if isinstance(task_idx, int):
task_names = env.getTaskNames()
task_name = task_names[task_idx]
else:
task_name = task_idx
# Just reset to variation 0, as another call (e.g. reset_with_variation...) will setup
# an appropriate variation (train/dev/test)
env.load(task_name, 0, simplification_str)
obs, info = env.resetWithVariation(0, simplification_str)
return env, obs, info
class RandomAgent():
def act(self, info):
return np.random.choice(info['valid'])
num_episodes = 10
env, initial_obs, initial_info = load_env()
act = st.text_input('action to perform')
st.write(f'Action: {act}')
obs, reward, done, info = env.step(act)
st.write(f'Observation: {obs.strip()}')