grg commited on
Commit
b41e61f
β€’
1 Parent(s): 11bd154

Adding textual observations to demo

Browse files
Files changed (1) hide show
  1. web_demo/app.py +25 -5
web_demo/app.py CHANGED
@@ -9,6 +9,8 @@ import gym_minigrid
9
  import numpy as np
10
  from gym_minigrid.window import Window
11
 
 
 
12
  import os
13
 
14
  app = Flask(__name__)
@@ -46,11 +48,27 @@ global env_label
46
  env_label = list(env_label_to_env_name.keys())[0]
47
  env_name = env_label_to_env_name[env_label]
48
 
 
 
 
49
  global mask_unobserved
50
  mask_unobserved = False
51
 
52
  env = gym.make(env_name)
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  def update_tree():
55
  selected_parameters = env.current_env.parameters
56
  selected_env_type = selected_parameters["Env_type"]
@@ -116,10 +134,9 @@ def set_mask_unobserved():
116
  def update_image():
117
  action_name = request.form.get('action')
118
 
119
-
120
  if action_name == 'done':
121
  # reset the env and update the tree image
122
- obs = env.reset()
123
  update_tree()
124
 
125
  else:
@@ -145,21 +162,24 @@ def update_image():
145
 
146
  obs, reward, done, info = env.step(action)
147
 
 
148
  image = env.render('rgb_array', tile_size=32, mask_unobserved=mask_unobserved)
149
  image_data = np_img_to_base64(image)
150
 
151
-
152
- bubble_text = format_bubble_text(env.current_env.full_conversation)
153
 
154
  return jsonify({'image_data': image_data, "bubble_text": bubble_text})
155
 
156
 
 
157
  @app.route('/', methods=['GET', 'POST'])
158
  def index():
 
159
  image = env.render('rgb_array', tile_size=32, mask_unobserved=mask_unobserved)
160
  image_data = np_img_to_base64(image)
161
 
162
- bubble_text = format_bubble_text(env.current_env.full_conversation)
 
163
 
164
  available_env_labels = env_label_to_env_name.keys()
165
 
 
9
  import numpy as np
10
  from gym_minigrid.window import Window
11
 
12
+ from textworld_utils.utils import generate_text_obs
13
+
14
  import os
15
 
16
  app = Flask(__name__)
 
48
  env_label = list(env_label_to_env_name.keys())[0]
49
  env_name = env_label_to_env_name[env_label]
50
 
51
+
52
+ textworld_envs = ["SocialAI-AsocialBoxInformationSeekingParamEnv-v1", "SocialAI-ColorBoxesLLMCSParamEnv-v1"]
53
+
54
  global mask_unobserved
55
  mask_unobserved = False
56
 
57
  env = gym.make(env_name)
58
 
59
+
60
+ def create_bubble_text(env_name, obs, info, full_conversation, textworld_envs):
61
+ if env_name in textworld_envs:
62
+ text_obs = generate_text_obs(obs, info)
63
+ # bubble_text = "Textworld state:\n" + text_obs
64
+ bubble_text = text_obs
65
+
66
+ else:
67
+ bubble_text = format_bubble_text(full_conversation)
68
+
69
+ return bubble_text
70
+
71
+
72
  def update_tree():
73
  selected_parameters = env.current_env.parameters
74
  selected_env_type = selected_parameters["Env_type"]
 
134
  def update_image():
135
  action_name = request.form.get('action')
136
 
 
137
  if action_name == 'done':
138
  # reset the env and update the tree image
139
+ obs, info = env.reset(with_info=True)
140
  update_tree()
141
 
142
  else:
 
162
 
163
  obs, reward, done, info = env.step(action)
164
 
165
+
166
  image = env.render('rgb_array', tile_size=32, mask_unobserved=mask_unobserved)
167
  image_data = np_img_to_base64(image)
168
 
169
+ bubble_text = create_bubble_text(env_name, obs, info, env.current_env.full_conversation, textworld_envs)
 
170
 
171
  return jsonify({'image_data': image_data, "bubble_text": bubble_text})
172
 
173
 
174
+
175
  @app.route('/', methods=['GET', 'POST'])
176
  def index():
177
+ obs, info = env.reset(with_info=True)
178
  image = env.render('rgb_array', tile_size=32, mask_unobserved=mask_unobserved)
179
  image_data = np_img_to_base64(image)
180
 
181
+ # bubble_text = format_bubble_text(env.current_env.full_conversation)
182
+ bubble_text = create_bubble_text(env_name, obs, info, env.current_env.full_conversation, textworld_envs)
183
 
184
  available_env_labels = env_label_to_env_name.keys()
185