Spaces:
Sleeping
Sleeping
Adding textual observations slider, decorative, changes fix for bubble newlines, mask unobserved slider only updates image
Browse files- web_demo/app.py +34 -14
- web_demo/templates/index.html +66 -10
web_demo/app.py
CHANGED
@@ -42,7 +42,6 @@ env_label_to_env_name = {
|
|
42 |
"Scaffolding/Formats (test)":"SocialAI-AELangFeedbackTrainFormatsCSParamEnv-v1",
|
43 |
}
|
44 |
|
45 |
-
# env = gym.make(args.env, **env_args_str_to_dict(args.env_args))
|
46 |
global env_name
|
47 |
global env_label
|
48 |
env_label = list(env_label_to_env_name.keys())[0]
|
@@ -54,16 +53,23 @@ textworld_envs = ["SocialAI-AsocialBoxInformationSeekingParamEnv-v1", "SocialAI-
|
|
54 |
global mask_unobserved
|
55 |
mask_unobserved = False
|
56 |
|
|
|
|
|
|
|
57 |
env = gym.make(env_name)
|
58 |
|
|
|
|
|
59 |
|
60 |
-
def create_bubble_text(env_name, obs, info, full_conversation, textworld_envs):
|
61 |
-
if env_name in textworld_envs:
|
62 |
-
text_obs = generate_text_obs(obs, info)
|
63 |
-
bubble_text = text_obs
|
64 |
|
|
|
|
|
|
|
|
|
65 |
else:
|
66 |
-
bubble_text =
|
|
|
|
|
67 |
|
68 |
return bubble_text
|
69 |
|
@@ -115,6 +121,8 @@ def set_env():
|
|
115 |
|
116 |
global env # Declare the env variable as global to modify it
|
117 |
env = gym.make(env_name) # Initialize the environment with the new name
|
|
|
|
|
118 |
update_tree() # Update the tree for the new environment
|
119 |
return redirect(url_for('index')) # Redirect back to the main page
|
120 |
|
@@ -122,17 +130,30 @@ def set_env():
|
|
122 |
@app.route('/set_mask_unobserved', methods=['POST'])
|
123 |
def set_mask_unobserved():
|
124 |
global mask_unobserved
|
125 |
-
|
126 |
-
|
|
|
|
|
127 |
|
128 |
-
return
|
129 |
|
|
|
|
|
|
|
|
|
130 |
|
|
|
131 |
|
132 |
-
|
133 |
-
|
|
|
|
|
|
|
|
|
134 |
action_name = request.form.get('action')
|
135 |
|
|
|
|
|
136 |
if action_name == 'done':
|
137 |
# reset the env and update the tree image
|
138 |
obs, info = env.reset(with_info=True)
|
@@ -165,7 +186,7 @@ def update_image():
|
|
165 |
image = env.render('rgb_array', tile_size=32, mask_unobserved=mask_unobserved)
|
166 |
image_data = np_img_to_base64(image)
|
167 |
|
168 |
-
bubble_text = create_bubble_text(
|
169 |
|
170 |
return jsonify({'image_data': image_data, "bubble_text": bubble_text})
|
171 |
|
@@ -173,12 +194,11 @@ def update_image():
|
|
173 |
|
174 |
@app.route('/', methods=['GET', 'POST'])
|
175 |
def index():
|
176 |
-
obs, info = env.reset(with_info=True)
|
177 |
image = env.render('rgb_array', tile_size=32, mask_unobserved=mask_unobserved)
|
178 |
image_data = np_img_to_base64(image)
|
179 |
|
180 |
# bubble_text = format_bubble_text(env.current_env.full_conversation)
|
181 |
-
bubble_text = create_bubble_text(
|
182 |
|
183 |
available_env_labels = env_label_to_env_name.keys()
|
184 |
|
|
|
42 |
"Scaffolding/Formats (test)":"SocialAI-AELangFeedbackTrainFormatsCSParamEnv-v1",
|
43 |
}
|
44 |
|
|
|
45 |
global env_name
|
46 |
global env_label
|
47 |
env_label = list(env_label_to_env_name.keys())[0]
|
|
|
53 |
global mask_unobserved
|
54 |
mask_unobserved = False
|
55 |
|
56 |
+
global textual_observations
|
57 |
+
textual_observations = False
|
58 |
+
|
59 |
env = gym.make(env_name)
|
60 |
|
61 |
+
global obs, info
|
62 |
+
obs, info = env.reset(with_info=True)
|
63 |
|
|
|
|
|
|
|
|
|
64 |
|
65 |
+
def create_bubble_text(obs, info, full_conversation, textual_observations):
|
66 |
+
if textual_observations:
|
67 |
+
bubble_text = "Textual observation\n\n"+ \
|
68 |
+
generate_text_obs(obs, info)
|
69 |
else:
|
70 |
+
bubble_text = full_conversation
|
71 |
+
|
72 |
+
bubble_text = format_bubble_text(bubble_text)
|
73 |
|
74 |
return bubble_text
|
75 |
|
|
|
121 |
|
122 |
global env # Declare the env variable as global to modify it
|
123 |
env = gym.make(env_name) # Initialize the environment with the new name
|
124 |
+
global obs, info
|
125 |
+
obs, info = env.reset(with_info=True)
|
126 |
update_tree() # Update the tree for the new environment
|
127 |
return redirect(url_for('index')) # Redirect back to the main page
|
128 |
|
|
|
130 |
@app.route('/set_mask_unobserved', methods=['POST'])
|
131 |
def set_mask_unobserved():
|
132 |
global mask_unobserved
|
133 |
+
mask_unobserved = request.form.get('mask_unobserved') == 'true'
|
134 |
+
|
135 |
+
image = env.render('rgb_array', tile_size=32, mask_unobserved=mask_unobserved)
|
136 |
+
image_data = np_img_to_base64(image)
|
137 |
|
138 |
+
return jsonify({'image_data': image_data})
|
139 |
|
140 |
+
@app.route('/set_textual_observations', methods=['POST'])
|
141 |
+
def set_textual_observations():
|
142 |
+
global textual_observations
|
143 |
+
textual_observations = request.form.get('textual_observations') == 'true'
|
144 |
|
145 |
+
bubble_text = create_bubble_text(obs, info, env.current_env.full_conversation, textual_observations)
|
146 |
|
147 |
+
return jsonify({"bubble_text": bubble_text})
|
148 |
+
|
149 |
+
|
150 |
+
|
151 |
+
@app.route('/perform_action', methods=['POST'])
|
152 |
+
def perform_action():
|
153 |
action_name = request.form.get('action')
|
154 |
|
155 |
+
global obs, info
|
156 |
+
|
157 |
if action_name == 'done':
|
158 |
# reset the env and update the tree image
|
159 |
obs, info = env.reset(with_info=True)
|
|
|
186 |
image = env.render('rgb_array', tile_size=32, mask_unobserved=mask_unobserved)
|
187 |
image_data = np_img_to_base64(image)
|
188 |
|
189 |
+
bubble_text = create_bubble_text(obs, info, env.current_env.full_conversation, textual_observations)
|
190 |
|
191 |
return jsonify({'image_data': image_data, "bubble_text": bubble_text})
|
192 |
|
|
|
194 |
|
195 |
@app.route('/', methods=['GET', 'POST'])
|
196 |
def index():
|
|
|
197 |
image = env.render('rgb_array', tile_size=32, mask_unobserved=mask_unobserved)
|
198 |
image_data = np_img_to_base64(image)
|
199 |
|
200 |
# bubble_text = format_bubble_text(env.current_env.full_conversation)
|
201 |
+
bubble_text = create_bubble_text(obs, info, env.current_env.full_conversation, textual_observations)
|
202 |
|
203 |
available_env_labels = env_label_to_env_name.keys()
|
204 |
|
web_demo/templates/index.html
CHANGED
@@ -186,7 +186,7 @@
|
|
186 |
bodyData += `&template=${template}&word=${word}`;
|
187 |
}
|
188 |
|
189 |
-
fetch('/
|
190 |
method: 'POST',
|
191 |
headers: {
|
192 |
'Content-Type': 'application/x-www-form-urlencoded',
|
@@ -204,7 +204,8 @@
|
|
204 |
// Add this to handle the caretaker's utterance
|
205 |
let bubble = document.getElementById('caretakerBubble');
|
206 |
if(data.bubble_text) {
|
207 |
-
|
|
|
208 |
bubble.style.display = 'block';
|
209 |
} else {
|
210 |
bubble.style.display = 'none';
|
@@ -235,6 +236,56 @@
|
|
235 |
});
|
236 |
});
|
237 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
</script>
|
239 |
</head>
|
240 |
<body>
|
@@ -252,12 +303,17 @@
|
|
252 |
|
253 |
<div class="form-container">
|
254 |
<span class="form-label">Mask unobserved cells:</span>
|
255 |
-
<
|
256 |
-
<
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
|
|
|
|
|
|
|
|
|
|
261 |
</div>
|
262 |
</div>
|
263 |
|
@@ -284,12 +340,12 @@
|
|
284 |
Speaking actions:
|
285 |
<select id="actionTemplate">
|
286 |
{% for templ in grammar_templates %}
|
287 |
-
|
288 |
{% endfor %}
|
289 |
</select>
|
290 |
<select id="actionWord">
|
291 |
{% for word in grammar_words %}
|
292 |
-
|
293 |
{% endfor %}
|
294 |
</select>
|
295 |
<button onclick="performAction('speak')">Speak [s]</button>
|
|
|
186 |
bodyData += `&template=${template}&word=${word}`;
|
187 |
}
|
188 |
|
189 |
+
fetch('/perform_action', {
|
190 |
method: 'POST',
|
191 |
headers: {
|
192 |
'Content-Type': 'application/x-www-form-urlencoded',
|
|
|
204 |
// Add this to handle the caretaker's utterance
|
205 |
let bubble = document.getElementById('caretakerBubble');
|
206 |
if(data.bubble_text) {
|
207 |
+
let formattedText = data.bubble_text.replace(/\n/g, '<br>');
|
208 |
+
bubble.innerHTML = formattedText;
|
209 |
bubble.style.display = 'block';
|
210 |
} else {
|
211 |
bubble.style.display = 'none';
|
|
|
236 |
});
|
237 |
});
|
238 |
|
239 |
+
|
240 |
+
// handle the change event of the mask unobserved slider change
|
241 |
+
function updateMaskUnobserved() {
|
242 |
+
const maskUnobservedValue = document.querySelector('input[name="mask_unobserved"]').checked;
|
243 |
+
|
244 |
+
fetch('/set_mask_unobserved', {
|
245 |
+
method: 'POST',
|
246 |
+
headers: {
|
247 |
+
'Content-Type': 'application/x-www-form-urlencoded',
|
248 |
+
},
|
249 |
+
body: `mask_unobserved=${maskUnobservedValue}`
|
250 |
+
})
|
251 |
+
.then(response => response.json())
|
252 |
+
.then(data => {
|
253 |
+
// Update the image src with the new image data
|
254 |
+
document.getElementById('envImage').src = `data:image/jpeg;base64,${data.image_data}`;
|
255 |
+
})
|
256 |
+
.catch(error => {
|
257 |
+
console.error('Error:', error);
|
258 |
+
});
|
259 |
+
}
|
260 |
+
|
261 |
+
// handle the change event of the textual observations slider change
|
262 |
+
function updateTextualObservations() {
|
263 |
+
const textualObservationsValue = document.querySelector('input[name="textual_observations"]').checked;
|
264 |
+
|
265 |
+
fetch('/set_textual_observations', {
|
266 |
+
method: 'POST',
|
267 |
+
headers: {
|
268 |
+
'Content-Type': 'application/x-www-form-urlencoded',
|
269 |
+
},
|
270 |
+
body: `textual_observations=${textualObservationsValue}`
|
271 |
+
})
|
272 |
+
.then(response => response.json())
|
273 |
+
.then(data => {
|
274 |
+
let bubble = document.getElementById('caretakerBubble');
|
275 |
+
if(data.bubble_text) {
|
276 |
+
let formattedText = data.bubble_text.replace(/\n/g, '<br>');
|
277 |
+
bubble.innerHTML = formattedText;
|
278 |
+
bubble.style.display = 'block';
|
279 |
+
} else {
|
280 |
+
bubble.style.display = 'none';
|
281 |
+
}
|
282 |
+
})
|
283 |
+
.catch(error => {
|
284 |
+
console.error('Error:', error);
|
285 |
+
});
|
286 |
+
}
|
287 |
+
|
288 |
+
|
289 |
</script>
|
290 |
</head>
|
291 |
<body>
|
|
|
303 |
|
304 |
<div class="form-container">
|
305 |
<span class="form-label">Mask unobserved cells:</span>
|
306 |
+
<label class="switch">
|
307 |
+
<input type="checkbox" name="mask_unobserved" value="true" onchange="updateMaskUnobserved()" {% if mask_unobserved %}checked{% endif %}>
|
308 |
+
<span class="slider round"></span>
|
309 |
+
</label>
|
310 |
+
</div>
|
311 |
+
<div class="form-container">
|
312 |
+
<span class="form-label">Textual observation:</span>
|
313 |
+
<label class="switch">
|
314 |
+
<input type="checkbox" name="textual_observations" value="true" onchange="updateTextualObservations()" {% if textual_observations %}checked{% endif %}>
|
315 |
+
<span class="slider round"></span>
|
316 |
+
</label>
|
317 |
</div>
|
318 |
</div>
|
319 |
|
|
|
340 |
Speaking actions:
|
341 |
<select id="actionTemplate">
|
342 |
{% for templ in grammar_templates %}
|
343 |
+
<option value="{{ templ }}" {{ 'selected' if templ == "Help" else '' }}>{{ templ }}</option>
|
344 |
{% endfor %}
|
345 |
</select>
|
346 |
<select id="actionWord">
|
347 |
{% for word in grammar_words %}
|
348 |
+
<option value="{{ word }}" {{ 'selected' if word == "please" else '' }}>{{ word }}</option>
|
349 |
{% endfor %}
|
350 |
</select>
|
351 |
<button onclick="performAction('speak')">Speak [s]</button>
|