Spaces:
Running
Running
Parameter selection added
Browse files- gym-minigrid/gym_minigrid/curriculums/expertcurriculumsocialaiparamenv.py +21 -0
- gym-minigrid/gym_minigrid/parametric_env.py +25 -0
- gym-minigrid/gym_minigrid/social_ai_envs/case_studies_envs/LLMcasestudyenvs.py +11 -11
- gym-minigrid/gym_minigrid/social_ai_envs/informationseekingenv.py +4 -2
- gym-minigrid/gym_minigrid/social_ai_envs/socialaiparamenv.py +19 -5
- web_demo/app.py +33 -3
- web_demo/templates/index.html +86 -6
gym-minigrid/gym_minigrid/curriculums/expertcurriculumsocialaiparamenv.py
CHANGED
@@ -3,6 +3,27 @@ import warnings
|
|
3 |
import numpy as np
|
4 |
import random
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
class ScaffoldingExpertCurriculum:
|
7 |
|
8 |
def __init__(self, type, minimum_episodes=1000, average_interval=500, phase_thresholds=(0.75, 0.75)):
|
|
|
3 |
import numpy as np
|
4 |
import random
|
5 |
|
6 |
+
class SelectedParametersOrRandomCurriculum():
|
7 |
+
def __init__(self, selected_parameters):
|
8 |
+
|
9 |
+
self.selected_parameters = selected_parameters
|
10 |
+
|
11 |
+
def choose(self, node, chosen_parameters):
|
12 |
+
# if in selected_parameters choose the selected one
|
13 |
+
# else choose a random child
|
14 |
+
|
15 |
+
assert node.type == 'param'
|
16 |
+
|
17 |
+
if node in self.selected_parameters:
|
18 |
+
chosen = self.selected_parameters[node]
|
19 |
+
assert chosen in node.children
|
20 |
+
return chosen
|
21 |
+
|
22 |
+
else:
|
23 |
+
return random.choice(node.children)
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
class ScaffoldingExpertCurriculum:
|
28 |
|
29 |
def __init__(self, type, minimum_episodes=1000, average_interval=500, phase_thresholds=(0.75, 0.75)):
|
gym-minigrid/gym_minigrid/parametric_env.py
CHANGED
@@ -3,6 +3,7 @@ from graphviz import Digraph
|
|
3 |
import re
|
4 |
import random
|
5 |
from termcolor import cprint
|
|
|
6 |
|
7 |
|
8 |
class Node:
|
@@ -18,6 +19,13 @@ class Node:
|
|
18 |
self.children = []
|
19 |
self.type = type
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
def __repr__(self):
|
22 |
return f"{self.id}({self.type})-'{self.label}'"
|
23 |
|
@@ -37,6 +45,9 @@ class ParameterTree(ABC):
|
|
37 |
self.tree = Digraph("unix", format='svg')
|
38 |
self.tree.attr(size='30,100')
|
39 |
|
|
|
|
|
|
|
40 |
def add_node(self, label, parent=None, type="param"):
|
41 |
"""
|
42 |
All children of this node must be set
|
@@ -125,6 +136,20 @@ class ParameterTree(ABC):
|
|
125 |
|
126 |
nodes.extend(node.children)
|
127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
def draw_tree(self, filename, selected_parameters={}, ignore_labels=[], folded_nodes=[], label_parser={}, save=True):
|
129 |
|
130 |
self.create_digraph()
|
|
|
3 |
import re
|
4 |
import random
|
5 |
from termcolor import cprint
|
6 |
+
from collections import defaultdict
|
7 |
|
8 |
|
9 |
class Node:
|
|
|
19 |
self.children = []
|
20 |
self.type = type
|
21 |
|
22 |
+
# calculate node's level
|
23 |
+
parent_ = self.parent
|
24 |
+
self.level = 1
|
25 |
+
while parent_ is not None:
|
26 |
+
self.level += 1
|
27 |
+
parent_ = parent_.parent
|
28 |
+
|
29 |
def __repr__(self):
|
30 |
return f"{self.id}({self.type})-'{self.label}'"
|
31 |
|
|
|
45 |
self.tree = Digraph("unix", format='svg')
|
46 |
self.tree.attr(size='30,100')
|
47 |
|
48 |
+
def get_node_for_id(self, id):
|
49 |
+
return self.nodes[id]
|
50 |
+
|
51 |
def add_node(self, label, parent=None, type="param"):
|
52 |
"""
|
53 |
All children of this node must be set
|
|
|
136 |
|
137 |
nodes.extend(node.children)
|
138 |
|
139 |
+
def get_all_params(self):
|
140 |
+
all_params = defaultdict(list)
|
141 |
+
|
142 |
+
nodes = [self.root]
|
143 |
+
while nodes:
|
144 |
+
node = nodes.pop(0)
|
145 |
+
|
146 |
+
if node.type == "value":
|
147 |
+
all_params[node.parent].append(node)
|
148 |
+
|
149 |
+
nodes.extend(node.children)
|
150 |
+
|
151 |
+
return all_params
|
152 |
+
|
153 |
def draw_tree(self, filename, selected_parameters={}, ignore_labels=[], folded_nodes=[], label_parser={}, save=True):
|
154 |
|
155 |
self.create_digraph()
|
gym-minigrid/gym_minigrid/social_ai_envs/case_studies_envs/LLMcasestudyenvs.py
CHANGED
@@ -33,17 +33,17 @@ class AsocialBoxInformationSeekingParamEnv(SocialAIParamEnv):
|
|
33 |
# Information seeking
|
34 |
inf_seeking_nd = tree.add_node("Information_seeking", parent=env_type_nd, type="value")
|
35 |
|
36 |
-
prag_fr_compl_nd = tree.add_node("Pragmatic_frame_complexity", parent=inf_seeking_nd, type="param")
|
37 |
-
tree.add_node("No", parent=prag_fr_compl_nd, type="value")
|
38 |
-
|
39 |
-
# scaffolding
|
40 |
-
scaffolding_nd = tree.add_node("Scaffolding", parent=inf_seeking_nd, type="param")
|
41 |
-
scaffolding_N_nd = tree.add_node("N", parent=scaffolding_nd, type="value")
|
42 |
-
|
43 |
-
cue_type_nd = tree.add_node("Cue_type", parent=scaffolding_N_nd, type="param")
|
44 |
-
tree.add_node("Language_Color", parent=cue_type_nd, type="value")
|
45 |
-
tree.add_node("Language_Feedback", parent=cue_type_nd, type="value")
|
46 |
-
tree.add_node("Pointing", parent=cue_type_nd, type="value")
|
47 |
|
48 |
problem_nd = tree.add_node("Problem", parent=inf_seeking_nd, type="param")
|
49 |
|
|
|
33 |
# Information seeking
|
34 |
inf_seeking_nd = tree.add_node("Information_seeking", parent=env_type_nd, type="value")
|
35 |
|
36 |
+
# prag_fr_compl_nd = tree.add_node("Pragmatic_frame_complexity", parent=inf_seeking_nd, type="param")
|
37 |
+
# tree.add_node("No", parent=prag_fr_compl_nd, type="value")
|
38 |
+
#
|
39 |
+
# # scaffolding
|
40 |
+
# scaffolding_nd = tree.add_node("Scaffolding", parent=inf_seeking_nd, type="param")
|
41 |
+
# scaffolding_N_nd = tree.add_node("N", parent=scaffolding_nd, type="value")
|
42 |
+
#
|
43 |
+
# cue_type_nd = tree.add_node("Cue_type", parent=scaffolding_N_nd, type="param")
|
44 |
+
# tree.add_node("Language_Color", parent=cue_type_nd, type="value")
|
45 |
+
# tree.add_node("Language_Feedback", parent=cue_type_nd, type="value")
|
46 |
+
# tree.add_node("Pointing", parent=cue_type_nd, type="value")
|
47 |
|
48 |
problem_nd = tree.add_node("Problem", parent=inf_seeking_nd, type="param")
|
49 |
|
gym-minigrid/gym_minigrid/social_ai_envs/informationseekingenv.py
CHANGED
@@ -78,6 +78,7 @@ class Caretaker(NPC):
|
|
78 |
self.distractor_obj = self.env.distractor_generator
|
79 |
|
80 |
if self.env.ja_recursive:
|
|
|
81 |
if int(self.env.parameters["N"]) == 1:
|
82 |
self.ja_decoy = self.env._rand_elem([self.target_obj])
|
83 |
else:
|
@@ -101,6 +102,7 @@ class Caretaker(NPC):
|
|
101 |
if self.env.hidden_npc:
|
102 |
return reply, info
|
103 |
|
|
|
104 |
scaffolding = self.env.parameters.get("Scaffolding", "N") == "Y"
|
105 |
language_color = False
|
106 |
language_feedback = False
|
@@ -135,7 +137,7 @@ class Caretaker(NPC):
|
|
135 |
assert action is None
|
136 |
|
137 |
if self.env.ja_recursive:
|
138 |
-
# look at the center of the room (this makes the cue giving
|
139 |
action = self.look_at_action([self.env.current_width // 2, self.env.current_height // 2])
|
140 |
else:
|
141 |
# look at the agent
|
@@ -550,7 +552,7 @@ class InformationSeekingEnv(MultiModalMiniGridEnv):
|
|
550 |
num_of_colors = self.n_colors
|
551 |
|
552 |
# additional test for recursivness of joint attention -> cues are given outside of JA
|
553 |
-
self.ja_recursive = self.parameters.get("JA_recursive", False) if self.parameters else False
|
554 |
|
555 |
self.add_obstacles()
|
556 |
if self.obstacles != "No":
|
|
|
78 |
self.distractor_obj = self.env.distractor_generator
|
79 |
|
80 |
if self.env.ja_recursive:
|
81 |
+
# how many objects
|
82 |
if int(self.env.parameters["N"]) == 1:
|
83 |
self.ja_decoy = self.env._rand_elem([self.target_obj])
|
84 |
else:
|
|
|
102 |
if self.env.hidden_npc:
|
103 |
return reply, info
|
104 |
|
105 |
+
|
106 |
scaffolding = self.env.parameters.get("Scaffolding", "N") == "Y"
|
107 |
language_color = False
|
108 |
language_feedback = False
|
|
|
137 |
assert action is None
|
138 |
|
139 |
if self.env.ja_recursive:
|
140 |
+
# look at the center of the room (this makes the cue giving inside and outisde JA different)
|
141 |
action = self.look_at_action([self.env.current_width // 2, self.env.current_height // 2])
|
142 |
else:
|
143 |
# look at the agent
|
|
|
552 |
num_of_colors = self.n_colors
|
553 |
|
554 |
# additional test for recursivness of joint attention -> cues are given outside of JA
|
555 |
+
self.ja_recursive = self.parameters.get("JA_recursive", False) == "Y" if self.parameters else False
|
556 |
|
557 |
self.add_obstacles()
|
558 |
if self.obstacles != "No":
|
gym-minigrid/gym_minigrid/social_ai_envs/socialaiparamenv.py
CHANGED
@@ -297,6 +297,16 @@ class SocialAIParamEnv(gym.Env):
|
|
297 |
|
298 |
return tree
|
299 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
300 |
def construct_env_from_params(self, params):
|
301 |
params_labels = {k.label: v.label for k, v in params.items()}
|
302 |
if params_labels['Env_type'] == "Collaboration":
|
@@ -329,14 +339,19 @@ class SocialAIParamEnv(gym.Env):
|
|
329 |
|
330 |
return env, reset_kwargs
|
331 |
|
332 |
-
def reset(self, with_info=False):
|
333 |
# select a new social environment at random, for each new episode
|
334 |
|
335 |
old_window = None
|
336 |
if self.current_env: # a previous env exists, save old window
|
337 |
old_window = self.current_env.window
|
338 |
|
339 |
-
|
|
|
|
|
|
|
|
|
|
|
340 |
|
341 |
self.current_env, reset_kwargs = self.construct_env_from_params(self.current_params)
|
342 |
assert reset_kwargs is not {}
|
@@ -360,9 +375,8 @@ class SocialAIParamEnv(gym.Env):
|
|
360 |
else:
|
361 |
return obs
|
362 |
|
363 |
-
def reset_with_info(self):
|
364 |
-
return self.reset(with_info=True)
|
365 |
-
|
366 |
|
367 |
def seed(self, seed=1337):
|
368 |
# Seed the random number generator
|
|
|
297 |
|
298 |
return tree
|
299 |
|
300 |
+
def get_potential_params(self):
|
301 |
+
all_params = self.parameter_tree.get_all_params()
|
302 |
+
potential_params = {
|
303 |
+
k: v for k, v in all_params.items() if k in self.current_params.keys()
|
304 |
+
}
|
305 |
+
return potential_params
|
306 |
+
|
307 |
+
def get_all_params(self):
|
308 |
+
return self.parameter_tree.get_all_params()
|
309 |
+
|
310 |
def construct_env_from_params(self, params):
|
311 |
params_labels = {k.label: v.label for k, v in params.items()}
|
312 |
if params_labels['Env_type'] == "Collaboration":
|
|
|
339 |
|
340 |
return env, reset_kwargs
|
341 |
|
342 |
+
def reset(self, with_info=False, selected_params=None, ACL=None):
|
343 |
# select a new social environment at random, for each new episode
|
344 |
|
345 |
old_window = None
|
346 |
if self.current_env: # a previous env exists, save old window
|
347 |
old_window = self.current_env.window
|
348 |
|
349 |
+
if selected_params is not None:
|
350 |
+
self.current_params = selected_params
|
351 |
+
elif ACL is not None:
|
352 |
+
self.current_params = self.parameter_tree.sample_env_params(ACL=ACL)
|
353 |
+
else:
|
354 |
+
self.current_params = self.parameter_tree.sample_env_params(ACL=self.curriculum)
|
355 |
|
356 |
self.current_env, reset_kwargs = self.construct_env_from_params(self.current_params)
|
357 |
assert reset_kwargs is not {}
|
|
|
375 |
else:
|
376 |
return obs
|
377 |
|
378 |
+
def reset_with_info(self, selected_params=None, ACL=None):
|
379 |
+
return self.reset(with_info=True, selected_params=selected_params, ACL=ACL)
|
|
|
380 |
|
381 |
def seed(self, seed=1337):
|
382 |
# Seed the random number generator
|
web_demo/app.py
CHANGED
@@ -8,6 +8,7 @@ import gym
|
|
8 |
import gym_minigrid
|
9 |
import numpy as np
|
10 |
from gym_minigrid.window import Window
|
|
|
11 |
|
12 |
from textworld_utils.utils import generate_text_obs
|
13 |
|
@@ -34,8 +35,8 @@ env_label_to_env_name = {
|
|
34 |
"Language Color (Test)": "SocialAI-ELangColorDoorsTestInformationSeekingParamEnv-v1",
|
35 |
"Language Feedback (Train)": "SocialAI-ELangFeedbackHeldoutDoorsTrainInformationSeekingParamEnv-v1",
|
36 |
"Language Feedback (Test)": "SocialAI-ELangFeedbackDoorsTestInformationSeekingParamEnv-v1",
|
37 |
-
"Joint Attention Language Color (Train)": "SocialAI-
|
38 |
-
"Joint Attention Language Color (Test)": "SocialAI-
|
39 |
"Apple stealing": "SocialAI-AppleStealingObst_NoParamEnv-v1",
|
40 |
"Apple stealing (Occlusions)": "SocialAI-AppleStealingObst_MediumParamEnv-v1",
|
41 |
"Scaffolding (train - scaf_8: Phase 1)": "SocialAI-AELangFeedbackTrainScaffoldingCSParamEnv-v1",
|
@@ -96,6 +97,12 @@ global obs, info
|
|
96 |
obs, info = env.reset(with_info=True)
|
97 |
|
98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
def create_bubble_text(obs, info, full_conversation, textual_observations):
|
100 |
if textual_observations:
|
101 |
bubble_text = "Textual observation\n\n"+ \
|
@@ -110,6 +117,7 @@ def create_bubble_text(obs, info, full_conversation, textual_observations):
|
|
110 |
|
111 |
def update_tree():
|
112 |
selected_parameters = env.current_env.parameters
|
|
|
113 |
selected_env_type = selected_parameters["Env_type"]
|
114 |
|
115 |
assert selected_env_type in env_types, f"Env_type {selected_env_type} not in {env_types}"
|
@@ -145,6 +153,27 @@ def format_bubble_text(text):
|
|
145 |
return "\n".join(lines)
|
146 |
|
147 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
@app.route('/set_env', methods=['POST'])
|
149 |
def set_env():
|
150 |
global env_name # Declare the variable as global to modify it
|
@@ -217,7 +246,6 @@ def perform_action():
|
|
217 |
|
218 |
obs, reward, done, info = env.step(action)
|
219 |
|
220 |
-
|
221 |
image = env.render('rgb_array', tile_size=32, mask_unobserved=mask_unobserved)
|
222 |
image_data = np_img_to_base64(image)
|
223 |
|
@@ -253,6 +281,8 @@ def index():
|
|
253 |
current_env_label=env_label,
|
254 |
grammar_templates=grammar_templates,
|
255 |
grammar_words=grammar_words,
|
|
|
|
|
256 |
)
|
257 |
|
258 |
|
|
|
8 |
import gym_minigrid
|
9 |
import numpy as np
|
10 |
from gym_minigrid.window import Window
|
11 |
+
from gym_minigrid.curriculums import SelectedParametersOrRandomCurriculum
|
12 |
|
13 |
from textworld_utils.utils import generate_text_obs
|
14 |
|
|
|
35 |
"Language Color (Test)": "SocialAI-ELangColorDoorsTestInformationSeekingParamEnv-v1",
|
36 |
"Language Feedback (Train)": "SocialAI-ELangFeedbackHeldoutDoorsTrainInformationSeekingParamEnv-v1",
|
37 |
"Language Feedback (Test)": "SocialAI-ELangFeedbackDoorsTestInformationSeekingParamEnv-v1",
|
38 |
+
"Joint Attention Language Color (Train)": "SocialAI-JAELangColorHeldoutDoorsTrainInformationSeekingParamEnv-v1",
|
39 |
+
"Joint Attention Language Color (Test)": "SocialAI-JAELangColorDoorsTestInformationSeekingParamEnv-v1",
|
40 |
"Apple stealing": "SocialAI-AppleStealingObst_NoParamEnv-v1",
|
41 |
"Apple stealing (Occlusions)": "SocialAI-AppleStealingObst_MediumParamEnv-v1",
|
42 |
"Scaffolding (train - scaf_8: Phase 1)": "SocialAI-AELangFeedbackTrainScaffoldingCSParamEnv-v1",
|
|
|
97 |
obs, info = env.reset(with_info=True)
|
98 |
|
99 |
|
100 |
+
|
101 |
+
|
102 |
+
|
103 |
+
def get_parameter_options(env):
|
104 |
+
return env.get_potential_params()
|
105 |
+
|
106 |
def create_bubble_text(obs, info, full_conversation, textual_observations):
|
107 |
if textual_observations:
|
108 |
bubble_text = "Textual observation\n\n"+ \
|
|
|
117 |
|
118 |
def update_tree():
|
119 |
selected_parameters = env.current_env.parameters
|
120 |
+
print("sel param:", selected_parameters)
|
121 |
selected_env_type = selected_parameters["Env_type"]
|
122 |
|
123 |
assert selected_env_type in env_types, f"Env_type {selected_env_type} not in {env_types}"
|
|
|
153 |
return "\n".join(lines)
|
154 |
|
155 |
|
156 |
+
|
157 |
+
|
158 |
+
|
159 |
+
@app.route('/set_env_params', methods=['POST'])
|
160 |
+
def set_env_params():
|
161 |
+
global env
|
162 |
+
selected_params_ids = request.get_json()
|
163 |
+
|
164 |
+
selected_parameters = {
|
165 |
+
env.parameter_tree.get_node_for_id(k): env.parameter_tree.get_node_for_id(v) for k,v in selected_params_ids.items()
|
166 |
+
}
|
167 |
+
global obs, info
|
168 |
+
|
169 |
+
selected_parameters_curriuclum = SelectedParametersOrRandomCurriculum(selected_parameters)
|
170 |
+
|
171 |
+
obs, info = env.reset(with_info=True, ACL=selected_parameters_curriuclum)
|
172 |
+
update_tree() # Update the tree for the new environment
|
173 |
+
return jsonify({"success": True}), 200
|
174 |
+
# return redirect(url_for('index')) # Redirect back to the main page
|
175 |
+
|
176 |
+
|
177 |
@app.route('/set_env', methods=['POST'])
|
178 |
def set_env():
|
179 |
global env_name # Declare the variable as global to modify it
|
|
|
246 |
|
247 |
obs, reward, done, info = env.step(action)
|
248 |
|
|
|
249 |
image = env.render('rgb_array', tile_size=32, mask_unobserved=mask_unobserved)
|
250 |
image_data = np_img_to_base64(image)
|
251 |
|
|
|
281 |
current_env_label=env_label,
|
282 |
grammar_templates=grammar_templates,
|
283 |
grammar_words=grammar_words,
|
284 |
+
parameter_options=get_parameter_options(env),
|
285 |
+
current_parameters=env.current_params
|
286 |
)
|
287 |
|
288 |
|
web_demo/templates/index.html
CHANGED
@@ -54,10 +54,10 @@
|
|
54 |
margin-bottom: 10px; /* Adds a space between the two rows */
|
55 |
}
|
56 |
|
57 |
-
.
|
58 |
width: 60%;
|
59 |
min-width: 1000px;
|
60 |
-
max-height:
|
61 |
margin-left: 5px;
|
62 |
margin-right: 5px;
|
63 |
margin-bottom: 5px;
|
@@ -204,8 +204,24 @@
|
|
204 |
.then(data => {
|
205 |
document.getElementById('envImage').src = `data:image/jpeg;base64,${data.image_data}`;
|
206 |
if (actionName === "done") {
|
207 |
-
|
208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
}
|
210 |
|
211 |
// Add this to handle the caretaker's utterance
|
@@ -236,6 +252,7 @@
|
|
236 |
.catch(error => {
|
237 |
console.error('Error:', error);
|
238 |
});
|
|
|
239 |
}
|
240 |
|
241 |
document.addEventListener("DOMContentLoaded", function() {
|
@@ -327,7 +344,49 @@
|
|
327 |
});
|
328 |
}
|
329 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
330 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
331 |
</script>
|
332 |
</head>
|
333 |
<body>
|
@@ -360,8 +419,29 @@
|
|
360 |
</div>
|
361 |
|
362 |
<p>This is the sampling tree. The current sampled parameters are highlighted blue</p>
|
363 |
-
<div class="
|
364 |
-
<
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
365 |
</div>
|
366 |
|
367 |
<p>This is the environment.</p>
|
|
|
54 |
margin-bottom: 10px; /* Adds a space between the two rows */
|
55 |
}
|
56 |
|
57 |
+
.tree-container {
|
58 |
width: 60%;
|
59 |
min-width: 1000px;
|
60 |
+
max-height: 600px;
|
61 |
margin-left: 5px;
|
62 |
margin-right: 5px;
|
63 |
margin-bottom: 5px;
|
|
|
204 |
.then(data => {
|
205 |
document.getElementById('envImage').src = `data:image/jpeg;base64,${data.image_data}`;
|
206 |
if (actionName === "done") {
|
207 |
+
|
208 |
+
// Save the scroll position in sessionStorage before reloading
|
209 |
+
sessionStorage.setItem('scrollPosition', window.scrollY);
|
210 |
+
console.log(window.scrollY);
|
211 |
+
|
212 |
+
// Function to reload the page with updated URL parameter
|
213 |
+
function reloadPage() {
|
214 |
+
let currentUrl = new URL(window.location.href);
|
215 |
+
currentUrl.searchParams.set('_', new Date().getTime());
|
216 |
+
window.location.href = currentUrl.href;
|
217 |
+
}
|
218 |
+
|
219 |
+
// Call the function to reload the page
|
220 |
+
reloadPage();
|
221 |
+
|
222 |
+
<!-- const svgTree = document.getElementById('svgTree');-->
|
223 |
+
<!-- svgTree.data = `./static/current_tree.svg?_=${data.timestamp}`;-->
|
224 |
+
<!-- // todo: update the dropdown lists for parameters-->
|
225 |
}
|
226 |
|
227 |
// Add this to handle the caretaker's utterance
|
|
|
252 |
.catch(error => {
|
253 |
console.error('Error:', error);
|
254 |
});
|
255 |
+
|
256 |
}
|
257 |
|
258 |
document.addEventListener("DOMContentLoaded", function() {
|
|
|
344 |
});
|
345 |
}
|
346 |
|
347 |
+
function setEnvParams() {
|
348 |
+
// Collect data from dropdowns
|
349 |
+
var data = {};
|
350 |
+
|
351 |
+
{% for key in parameter_options.keys() %}
|
352 |
+
data['{{ key.id }}'] = document.getElementById('{{ key.id }}').value;
|
353 |
+
{% endfor %}
|
354 |
|
355 |
+
// Send data to Flask backend
|
356 |
+
fetch('/set_env_params', {
|
357 |
+
method: 'POST',
|
358 |
+
headers: {
|
359 |
+
'Content-Type': 'application/json',
|
360 |
+
},
|
361 |
+
body: JSON.stringify(data),
|
362 |
+
})
|
363 |
+
.then(response => {
|
364 |
+
if(response.ok) {
|
365 |
+
// If the request was successful, reload the page without cache
|
366 |
+
let currentUrl = new URL(window.location.href);
|
367 |
+
currentUrl.searchParams.set('_', new Date().getTime());
|
368 |
+
window.location.href = currentUrl.href;
|
369 |
+
} else {
|
370 |
+
// Handle errors
|
371 |
+
console.error('Error in setEnvParams:', response);
|
372 |
+
}
|
373 |
+
})
|
374 |
+
.catch((error) => {
|
375 |
+
console.error('Error:', error);
|
376 |
+
});
|
377 |
+
|
378 |
+
}
|
379 |
+
|
380 |
+
function restoreScrollPosition() {
|
381 |
+
let savedScrollPosition = parseInt(sessionStorage.getItem('scrollPosition'));
|
382 |
+
if (!isNaN(savedScrollPosition)) {
|
383 |
+
window.scrollTo(0, savedScrollPosition);
|
384 |
+
sessionStorage.removeItem('scrollPosition');
|
385 |
+
}
|
386 |
+
}
|
387 |
+
|
388 |
+
// Attach the restoreScrollPosition function to the appropriate event
|
389 |
+
window.addEventListener('load', restoreScrollPosition);
|
390 |
</script>
|
391 |
</head>
|
392 |
<body>
|
|
|
419 |
</div>
|
420 |
|
421 |
<p>This is the sampling tree. The current sampled parameters are highlighted blue</p>
|
422 |
+
<div class="tree-container" >
|
423 |
+
<div class="form-container">
|
424 |
+
<span class="form-label">Select Parameters:</span>
|
425 |
+
<form method="post" id="paramsEnvForm">
|
426 |
+
{% set ns = namespace(lvl=1) %}
|
427 |
+
{% for key, options in parameter_options.items() %}
|
428 |
+
{% if key.level != ns.lvl %}
|
429 |
+
{% if ns.lvl >= 0 %}<hr>{% endif %}
|
430 |
+
{% set ns.lvl = key.level %}
|
431 |
+
{% endif %}
|
432 |
+
|
433 |
+
<label>{{ key.label }}</label>
|
434 |
+
<select name="{{ key }}" id="{{ key.id }}" onchange="setEnvParams();">
|
435 |
+
{% for option in options %}
|
436 |
+
<option value="{{ option.id }}" {% if current_parameters[key].id == option.id %}selected{% endif %} >
|
437 |
+
{{ option.label }}
|
438 |
+
</option>
|
439 |
+
{% endfor %}
|
440 |
+
</select>
|
441 |
+
{% endfor %}
|
442 |
+
</form>
|
443 |
+
</div>
|
444 |
+
<object id="svgTree" type="image/svg+xml" data="{{ url_for('static', filename='current_tree.svg', _=timestamp) }}" width="95%" height="70%">Your browser does not support SVG</object>
|
445 |
</div>
|
446 |
|
447 |
<p>This is the environment.</p>
|