File size: 10,353 Bytes
be5548b
 
 
 
 
 
 
 
 
 
f397ead
be5548b
b41e61f
 
be5548b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e46ecc1
 
be5548b
 
 
 
f397ead
 
be5548b
 
 
 
 
9cf63fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be5548b
 
 
 
 
 
b41e61f
 
 
be5548b
 
 
077ff6c
 
 
be5548b
 
077ff6c
 
b41e61f
 
f397ead
 
 
 
 
 
077ff6c
 
 
 
b41e61f
077ff6c
 
 
b41e61f
 
 
 
be5548b
 
f397ead
be5548b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f397ead
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be5548b
 
 
 
 
 
 
 
 
 
077ff6c
 
be5548b
 
 
 
 
 
 
077ff6c
 
 
 
be5548b
077ff6c
be5548b
077ff6c
 
 
 
be5548b
077ff6c
be5548b
077ff6c
 
 
 
 
 
be5548b
 
077ff6c
 
be5548b
 
b41e61f
ece8924
be5548b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
077ff6c
be5548b
ece8924
 
 
 
 
 
be5548b
 
b41e61f
be5548b
 
 
 
 
b41e61f
077ff6c
be5548b
 
 
 
 
 
 
 
 
 
 
 
 
 
f397ead
 
be5548b
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
from flask import Flask, render_template, request, session, redirect, url_for, send_from_directory, jsonify
from PIL import Image
import io
import base64
import time

import gym
import gym_minigrid
import numpy as np
from gym_minigrid.window import Window
from gym_minigrid.curriculums import SelectedParametersOrRandomCurriculum

from textworld_utils.utils import generate_text_obs

import os

app = Flask(__name__)

env_types = ["Information_seeking", "Collaboration", "AppleStealing"]

env_label_to_env_name = {
    "Full SocialAI environment": "SocialAI-SocialAIParamEnv-v1",  # all
    "Pointing (Train)": "SocialAI-EPointingHeldoutDoorsTrainInformationSeekingParamEnv-v1",  # Pointing Train
    "Pointing (Test)": "SocialAI-EPointingBoxesTestInformationSeekingParamEnv-v1",  # Pointing Test
    "Role Reversal Single Role B (Pretrain - experimental)": "SocialAI-MarblePassBCollaborationParamEnv-v1",
    "Role Reversal Single Asocial (Pretrain - control)": "SocialAI-AsocialMarbleCollaborationParamEnv-v1",
    "Role Reversal Group Role B (Pretrain - experimental)": "SocialAI-RoleReversalGroupExperimentalCollaborationParamEnv-v1",
    "Role Reversal Group Asocial (Pretrain - control)": "SocialAI-RoleReversalGroupControlCollaborationParamEnv-v1",
    "Role Reversal Role A (Finetune - test)": "SocialAI-MarblePassACollaborationParamEnv-v1",
    "Imitation (Train)": "SocialAI-EEmulationNoDistrInformationSeekingParamEnv-v1",
    "Imitation (Test)": "SocialAI-EEmulationNoDistrDoorsInformationSeekingParamEnv-v1",
    "AsocialBox (textworld)": "SocialAI-AsocialBoxInformationSeekingParamEnv-v1",
    "ColorBoxes (textworld)": "SocialAI-ColorBoxesLLMCSParamEnv-v1",
    "Language Color (Train)": "SocialAI-ELangColorHeldoutDoorsTrainInformationSeekingParamEnv-v1",
    "Language Color (Test)": "SocialAI-ELangColorDoorsTestInformationSeekingParamEnv-v1",
    "Language Feedback (Train)": "SocialAI-ELangFeedbackHeldoutDoorsTrainInformationSeekingParamEnv-v1",
    "Language Feedback (Test)": "SocialAI-ELangFeedbackDoorsTestInformationSeekingParamEnv-v1",
    "Joint Attention Language Color (Train)": "SocialAI-JAELangColorHeldoutDoorsTrainInformationSeekingParamEnv-v1",
    "Joint Attention Language Color (Test)": "SocialAI-JAELangColorDoorsTestInformationSeekingParamEnv-v1",
    "Apple stealing": "SocialAI-AppleStealingObst_NoParamEnv-v1",
    "Apple stealing (Occlusions)": "SocialAI-AppleStealingObst_MediumParamEnv-v1",
    "Scaffolding (train - scaf_8: Phase 1)": "SocialAI-AELangFeedbackTrainScaffoldingCSParamEnv-v1",
    "Scaffolding/Formats (test)":"SocialAI-AELangFeedbackTrainFormatsCSParamEnv-v1",
}
available_env_labels = [
    "Full SocialAI environment",
    "---- Pointing ----",
    "Pointing (Train)",
    "Pointing (Test)",
    "---- Role Reversal ----",
    "Role Reversal Single Role B (Pretrain - experimental)",
    "Role Reversal Single Asocial (Pretrain - control)",
    "Role Reversal Group Role B (Pretrain - experimental)",
    "Role Reversal Group Asocial (Pretrain - control)",
    "Role Reversal Role A (Finetune - test)",
    "---- Imitation ----",
    "Imitation (Train)",
    "Imitation (Test)",
    "---- TextWorld (LLM experiments)  ----",
    "AsocialBox (textworld)",
    "ColorBoxes (textworld)",
    "---- Language Color ----",
    "Language Color (Train)",
    "Language Color (Test)",
    "---- Language Feedback ----",
    "Language Feedback (Train)",
    "Language Feedback (Test)",
    "---- Joint Attention Language Color ----",
    "Joint Attention Language Color (Train)",
    "Joint Attention Language Color (Test)",
    "---- Apple Stealing ----",
    "Apple stealing",
    "Apple stealing (Occlusions)",
    "---- Scaffolding/Formats ----",
    "Scaffolding (train - scaf_8: Phase 1)",
    "Scaffolding/Formats (test)"
]
assert all([l in available_env_labels for l in env_label_to_env_name.keys()])

global env_name
global env_label
env_label = list(env_label_to_env_name.keys())[0]
env_name = env_label_to_env_name[env_label]


textworld_envs = ["SocialAI-AsocialBoxInformationSeekingParamEnv-v1", "SocialAI-ColorBoxesLLMCSParamEnv-v1"]

global mask_unobserved
mask_unobserved = False

global textual_observations
textual_observations = False

env = gym.make(env_name)

global obs, info
obs, info = env.reset(with_info=True)





def get_parameter_options(env):
    return env.get_potential_params()

def create_bubble_text(obs, info, full_conversation, textual_observations):
    if textual_observations:
        bubble_text = "Textual observation\n\n"+ \
                      generate_text_obs(obs, info)
    else:
        bubble_text = full_conversation

    bubble_text = format_bubble_text(bubble_text)

    return bubble_text


def update_tree():
    selected_parameters = env.current_env.parameters
    print("sel param:", selected_parameters)
    selected_env_type = selected_parameters["Env_type"]

    assert selected_env_type in env_types, f"Env_type {selected_env_type} not in {env_types}"

    folded_nodes = [e for e in env_types if e  != selected_env_type]

    env.parameter_tree.draw_tree(
        filename="./web_demo/static/current_tree",
        ignore_labels=["Num_of_colors"],
        selected_parameters=selected_parameters,
        folded_nodes=folded_nodes

    )

update_tree()


def np_img_to_base64(np_image):
    image = Image.fromarray(np_image)
    img_io = io.BytesIO()
    image.save(img_io, 'JPEG', quality=70)
    img_io.seek(0)
    return base64.b64encode(img_io.getvalue()).decode('utf-8')


def format_bubble_text(text):
    lines = text.split("\n")

    if len(lines) > 10:
        # Keep the first line, add "....", and then append the last 8 lines
        lines = [lines[0], "...."] + lines[-8:]

    return "\n".join(lines)





@app.route('/set_env_params', methods=['POST'])
def set_env_params():
    global env
    selected_params_ids = request.get_json()

    selected_parameters = {
        env.parameter_tree.get_node_for_id(k): env.parameter_tree.get_node_for_id(v) for k,v in selected_params_ids.items()
    }
    global obs, info

    selected_parameters_curriuclum = SelectedParametersOrRandomCurriculum(selected_parameters)

    obs, info = env.reset(with_info=True, ACL=selected_parameters_curriuclum)
    update_tree()  # Update the tree for the new environment
    return jsonify({"success": True}), 200
    # return redirect(url_for('index'))  # Redirect back to the main page


@app.route('/set_env', methods=['POST'])
def set_env():
    global env_name  # Declare the variable as global to modify it
    global env_label  # Declare the variable as global to modify it
    env_label = request.form.get('env_label')  # Get the selected env_name from the form

    env_name = env_label_to_env_name[env_label]

    global env  # Declare the env variable as global to modify it
    env = gym.make(env_name)  # Initialize the environment with the new name
    global obs, info
    obs, info = env.reset(with_info=True)
    update_tree()  # Update the tree for the new environment
    return redirect(url_for('index'))  # Redirect back to the main page


@app.route('/set_mask_unobserved', methods=['POST'])
def set_mask_unobserved():
    global mask_unobserved
    mask_unobserved = request.form.get('mask_unobserved') == 'true'

    image = env.render('rgb_array', tile_size=32, mask_unobserved=mask_unobserved)
    image_data = np_img_to_base64(image)

    return jsonify({'image_data': image_data})

@app.route('/set_textual_observations', methods=['POST'])
def set_textual_observations():
    global textual_observations
    textual_observations = request.form.get('textual_observations') == 'true'

    bubble_text = create_bubble_text(obs, info, env.current_env.full_conversation, textual_observations)

    return jsonify({"bubble_text": bubble_text})



@app.route('/perform_action', methods=['POST'])
def perform_action():
    action_name = request.form.get('action')

    global obs, info

    if action_name == 'done':
        # reset the env and update the tree image
        obs, info = env.reset(with_info=True)
        done = False
        update_tree()

    else:
        if action_name == "speak":
            action_template = request.form.get('template')
            action_word = request.form.get('word')

            temp_ind, word_ind = env.grammar.get_action(action_template, action_word)
            action = [np.nan, temp_ind, word_ind]

        elif action_name == 'left':
            action = [int(env.actions.left), np.nan, np.nan]
        elif action_name == 'right':
            action = [int(env.actions.right), np.nan, np.nan]
        elif action_name == 'forward':
            action = [int(env.actions.forward), np.nan, np.nan]
        elif action_name == 'toggle':
            action = [int(env.actions.toggle), np.nan, np.nan]
        elif action_name == 'noop':
            action = [np.nan, np.nan, np.nan]
        else:
            action = [np.nan, np.nan, np.nan]

        obs, reward, done, info = env.step(action)

    image = env.render('rgb_array', tile_size=32, mask_unobserved=mask_unobserved)
    image_data = np_img_to_base64(image)

    bubble_text = create_bubble_text(obs, info, env.current_env.full_conversation, textual_observations)

    return jsonify({
        'image_data': image_data,
        'success': info["success"],
        'done': done,
        'bubble_text': bubble_text
    })



@app.route('/', methods=['GET', 'POST'])
def index():
    image = env.render('rgb_array', tile_size=32, mask_unobserved=mask_unobserved)
    image_data = np_img_to_base64(image)

    # bubble_text = format_bubble_text(env.current_env.full_conversation)
    bubble_text = create_bubble_text(obs, info, env.current_env.full_conversation, textual_observations)

    grammar_templates = env.grammar.templates
    grammar_words = env.grammar.things

    return render_template(
        'index.html',
        image_data=image_data,
        bubble_text=bubble_text,
        mask_unobserved=mask_unobserved,
        timestamp=time.time(),
        available_env_labels=available_env_labels,
        current_env_label=env_label,
        grammar_templates=grammar_templates,
        grammar_words=grammar_words,
        parameter_options=get_parameter_options(env),
        current_parameters=env.current_params
    )


if __name__ == '__main__':
    app.run(host='0.0.0.0', port=7860, debug=True)