Spaces:
Running
Running
from flask import Flask, render_template, request, session, redirect, url_for, send_from_directory, jsonify | |
from PIL import Image | |
import io | |
import base64 | |
import time | |
import gym | |
import gym_minigrid | |
import numpy as np | |
from gym_minigrid.window import Window | |
from textworld_utils.utils import generate_text_obs | |
import os | |
app = Flask(__name__) | |
env_types = ["Information_seeking", "Collaboration", "AppleStealing"] | |
env_label_to_env_name = { | |
"Full SocialAI environment": "SocialAI-SocialAIParamEnv-v1", # all | |
"Pointing (Train)": "SocialAI-EPointingHeldoutDoorsTrainInformationSeekingParamEnv-v1", # Pointing Train | |
"Pointing (Test)": "SocialAI-EPointingBoxesTestInformationSeekingParamEnv-v1", # Pointing Test | |
"Role Reversal Single Role B (Pretrain - experimental)": "SocialAI-MarblePassBCollaborationParamEnv-v1", | |
"Role Reversal Single Asocial (Pretrain - control)": "SocialAI-AsocialMarbleCollaborationParamEnv-v1", | |
"Role Reversal Group Role B (Pretrain - experimental)": "SocialAI-RoleReversalGroupExperimentalCollaborationParamEnv-v1", | |
"Role Reversal Group Asocial (Pretrain - control)": "SocialAI-RoleReversalGroupControlCollaborationParamEnv-v1", | |
"Role Reversal Role A (Finetune - test)": "SocialAI-MarblePassACollaborationParamEnv-v1", | |
"Imitation (Train)": "SocialAI-EEmulationNoDistrInformationSeekingParamEnv-v1", | |
"Imitation (Test)": "SocialAI-EEmulationNoDistrDoorsInformationSeekingParamEnv-v1", | |
"AsocialBox (textworld)": "SocialAI-AsocialBoxInformationSeekingParamEnv-v1", | |
"ColorBoxes (textworld)": "SocialAI-ColorBoxesLLMCSParamEnv-v1", | |
"Language Color (Train)": "SocialAI-ELangColorHeldoutDoorsTrainInformationSeekingParamEnv-v1", | |
"Language Color (Test)": "SocialAI-ELangColorDoorsTestInformationSeekingParamEnv-v1", | |
"Language Feedback (Train)": "SocialAI-ELangFeedbackHeldoutDoorsTrainInformationSeekingParamEnv-v1", | |
"Language Feedback (Test)": "SocialAI-ELangFeedbackDoorsTestInformationSeekingParamEnv-v1", | |
"Joint Attention Language Color (Train)": "SocialAI-ELangColorHeldoutDoorsTrainInformationSeekingParamEnv-v1", | |
"Joint Attention Language Color (Test)": "SocialAI-ELangColorDoorsTestInformationSeekingParamEnv-v1", | |
"Apple stealing": "SocialAI-AppleStealingObst_NoParamEnv-v1", | |
"Apple stealing (Occlusions)": "SocialAI-AppleStealingObst_MediumParamEnv-v1", | |
"Scaffolding (train - scaf_8: Phase 1)": "SocialAI-AELangFeedbackTrainScaffoldingCSParamEnv-v1", | |
"Scaffolding/Formats (test)":"SocialAI-AELangFeedbackTrainFormatsCSParamEnv-v1", | |
} | |
global env_name | |
global env_label | |
env_label = list(env_label_to_env_name.keys())[0] | |
env_name = env_label_to_env_name[env_label] | |
textworld_envs = ["SocialAI-AsocialBoxInformationSeekingParamEnv-v1", "SocialAI-ColorBoxesLLMCSParamEnv-v1"] | |
global mask_unobserved | |
mask_unobserved = False | |
global textual_observations | |
textual_observations = False | |
env = gym.make(env_name) | |
global obs, info | |
obs, info = env.reset(with_info=True) | |
def create_bubble_text(obs, info, full_conversation, textual_observations): | |
if textual_observations: | |
bubble_text = "Textual observation\n\n"+ \ | |
generate_text_obs(obs, info) | |
else: | |
bubble_text = full_conversation | |
bubble_text = format_bubble_text(bubble_text) | |
return bubble_text | |
def update_tree(): | |
selected_parameters = env.current_env.parameters | |
selected_env_type = selected_parameters["Env_type"] | |
assert selected_env_type in env_types, f"Env_type {selected_env_type} not in {env_types}" | |
folded_nodes = [e for e in env_types if e != selected_env_type] | |
env.parameter_tree.draw_tree( | |
filename="./web_demo/static/current_tree", | |
ignore_labels=["Num_of_colors"], | |
selected_parameters=selected_parameters, | |
folded_nodes=folded_nodes | |
) | |
update_tree() | |
def np_img_to_base64(np_image): | |
image = Image.fromarray(np_image) | |
img_io = io.BytesIO() | |
image.save(img_io, 'JPEG', quality=70) | |
img_io.seek(0) | |
return base64.b64encode(img_io.getvalue()).decode('utf-8') | |
def format_bubble_text(text): | |
lines = text.split("\n") | |
if len(lines) > 10: | |
# Keep the first line, add "....", and then append the last 8 lines | |
lines = [lines[0], "...."] + lines[-8:] | |
return "\n".join(lines) | |
def set_env(): | |
global env_name # Declare the variable as global to modify it | |
global env_label # Declare the variable as global to modify it | |
env_label = request.form.get('env_label') # Get the selected env_name from the form | |
env_name = env_label_to_env_name[env_label] | |
global env # Declare the env variable as global to modify it | |
env = gym.make(env_name) # Initialize the environment with the new name | |
global obs, info | |
obs, info = env.reset(with_info=True) | |
update_tree() # Update the tree for the new environment | |
return redirect(url_for('index')) # Redirect back to the main page | |
def set_mask_unobserved(): | |
global mask_unobserved | |
mask_unobserved = request.form.get('mask_unobserved') == 'true' | |
image = env.render('rgb_array', tile_size=32, mask_unobserved=mask_unobserved) | |
image_data = np_img_to_base64(image) | |
return jsonify({'image_data': image_data}) | |
def set_textual_observations(): | |
global textual_observations | |
textual_observations = request.form.get('textual_observations') == 'true' | |
bubble_text = create_bubble_text(obs, info, env.current_env.full_conversation, textual_observations) | |
return jsonify({"bubble_text": bubble_text}) | |
def perform_action(): | |
action_name = request.form.get('action') | |
global obs, info | |
if action_name == 'done': | |
# reset the env and update the tree image | |
obs, info = env.reset(with_info=True) | |
update_tree() | |
else: | |
if action_name == "speak": | |
action_template = request.form.get('template') | |
action_word = request.form.get('word') | |
temp_ind, word_ind = env.grammar.get_action(action_template, action_word) | |
action = [np.nan, temp_ind, word_ind] | |
elif action_name == 'left': | |
action = [int(env.actions.left), np.nan, np.nan] | |
elif action_name == 'right': | |
action = [int(env.actions.right), np.nan, np.nan] | |
elif action_name == 'forward': | |
action = [int(env.actions.forward), np.nan, np.nan] | |
elif action_name == 'toggle': | |
action = [int(env.actions.toggle), np.nan, np.nan] | |
elif action_name == 'noop': | |
action = [np.nan, np.nan, np.nan] | |
else: | |
action = [np.nan, np.nan, np.nan] | |
obs, reward, done, info = env.step(action) | |
image = env.render('rgb_array', tile_size=32, mask_unobserved=mask_unobserved) | |
image_data = np_img_to_base64(image) | |
bubble_text = create_bubble_text(obs, info, env.current_env.full_conversation, textual_observations) | |
return jsonify({'image_data': image_data, "bubble_text": bubble_text}) | |
def index(): | |
image = env.render('rgb_array', tile_size=32, mask_unobserved=mask_unobserved) | |
image_data = np_img_to_base64(image) | |
# bubble_text = format_bubble_text(env.current_env.full_conversation) | |
bubble_text = create_bubble_text(obs, info, env.current_env.full_conversation, textual_observations) | |
available_env_labels = env_label_to_env_name.keys() | |
grammar_templates = env.grammar.templates | |
grammar_words = env.grammar.things | |
return render_template( | |
'index.html', | |
image_data=image_data, | |
bubble_text=bubble_text, | |
mask_unobserved=mask_unobserved, | |
timestamp=time.time(), | |
available_env_labels=available_env_labels, | |
current_env_label=env_label, | |
grammar_templates=grammar_templates, | |
grammar_words=grammar_words, | |
) | |
if __name__ == '__main__': | |
app.run(host='0.0.0.0', port=7860, debug=True) | |