grg commited on
Commit
f397ead
·
1 Parent(s): 9cf63fa

Parameter selection added

Browse files
gym-minigrid/gym_minigrid/curriculums/expertcurriculumsocialaiparamenv.py CHANGED
@@ -3,6 +3,27 @@ import warnings
3
  import numpy as np
4
  import random
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  class ScaffoldingExpertCurriculum:
7
 
8
  def __init__(self, type, minimum_episodes=1000, average_interval=500, phase_thresholds=(0.75, 0.75)):
 
3
  import numpy as np
4
  import random
5
 
6
+ class SelectedParametersOrRandomCurriculum():
7
+ def __init__(self, selected_parameters):
8
+
9
+ self.selected_parameters = selected_parameters
10
+
11
+ def choose(self, node, chosen_parameters):
12
+ # if in selected_parameters choose the selected one
13
+ # else choose a random child
14
+
15
+ assert node.type == 'param'
16
+
17
+ if node in self.selected_parameters:
18
+ chosen = self.selected_parameters[node]
19
+ assert chosen in node.children
20
+ return chosen
21
+
22
+ else:
23
+ return random.choice(node.children)
24
+
25
+
26
+
27
  class ScaffoldingExpertCurriculum:
28
 
29
  def __init__(self, type, minimum_episodes=1000, average_interval=500, phase_thresholds=(0.75, 0.75)):
gym-minigrid/gym_minigrid/parametric_env.py CHANGED
@@ -3,6 +3,7 @@ from graphviz import Digraph
3
  import re
4
  import random
5
  from termcolor import cprint
 
6
 
7
 
8
  class Node:
@@ -18,6 +19,13 @@ class Node:
18
  self.children = []
19
  self.type = type
20
 
 
 
 
 
 
 
 
21
  def __repr__(self):
22
  return f"{self.id}({self.type})-'{self.label}'"
23
 
@@ -37,6 +45,9 @@ class ParameterTree(ABC):
37
  self.tree = Digraph("unix", format='svg')
38
  self.tree.attr(size='30,100')
39
 
 
 
 
40
  def add_node(self, label, parent=None, type="param"):
41
  """
42
  All children of this node must be set
@@ -125,6 +136,20 @@ class ParameterTree(ABC):
125
 
126
  nodes.extend(node.children)
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  def draw_tree(self, filename, selected_parameters={}, ignore_labels=[], folded_nodes=[], label_parser={}, save=True):
129
 
130
  self.create_digraph()
 
3
  import re
4
  import random
5
  from termcolor import cprint
6
+ from collections import defaultdict
7
 
8
 
9
  class Node:
 
19
  self.children = []
20
  self.type = type
21
 
22
+ # calculate node's level
23
+ parent_ = self.parent
24
+ self.level = 1
25
+ while parent_ is not None:
26
+ self.level += 1
27
+ parent_ = parent_.parent
28
+
29
  def __repr__(self):
30
  return f"{self.id}({self.type})-'{self.label}'"
31
 
 
45
  self.tree = Digraph("unix", format='svg')
46
  self.tree.attr(size='30,100')
47
 
48
+ def get_node_for_id(self, id):
49
+ return self.nodes[id]
50
+
51
  def add_node(self, label, parent=None, type="param"):
52
  """
53
  All children of this node must be set
 
136
 
137
  nodes.extend(node.children)
138
 
139
+ def get_all_params(self):
140
+ all_params = defaultdict(list)
141
+
142
+ nodes = [self.root]
143
+ while nodes:
144
+ node = nodes.pop(0)
145
+
146
+ if node.type == "value":
147
+ all_params[node.parent].append(node)
148
+
149
+ nodes.extend(node.children)
150
+
151
+ return all_params
152
+
153
  def draw_tree(self, filename, selected_parameters={}, ignore_labels=[], folded_nodes=[], label_parser={}, save=True):
154
 
155
  self.create_digraph()
gym-minigrid/gym_minigrid/social_ai_envs/case_studies_envs/LLMcasestudyenvs.py CHANGED
@@ -33,17 +33,17 @@ class AsocialBoxInformationSeekingParamEnv(SocialAIParamEnv):
33
  # Information seeking
34
  inf_seeking_nd = tree.add_node("Information_seeking", parent=env_type_nd, type="value")
35
 
36
- prag_fr_compl_nd = tree.add_node("Pragmatic_frame_complexity", parent=inf_seeking_nd, type="param")
37
- tree.add_node("No", parent=prag_fr_compl_nd, type="value")
38
-
39
- # scaffolding
40
- scaffolding_nd = tree.add_node("Scaffolding", parent=inf_seeking_nd, type="param")
41
- scaffolding_N_nd = tree.add_node("N", parent=scaffolding_nd, type="value")
42
-
43
- cue_type_nd = tree.add_node("Cue_type", parent=scaffolding_N_nd, type="param")
44
- tree.add_node("Language_Color", parent=cue_type_nd, type="value")
45
- tree.add_node("Language_Feedback", parent=cue_type_nd, type="value")
46
- tree.add_node("Pointing", parent=cue_type_nd, type="value")
47
 
48
  problem_nd = tree.add_node("Problem", parent=inf_seeking_nd, type="param")
49
 
 
33
  # Information seeking
34
  inf_seeking_nd = tree.add_node("Information_seeking", parent=env_type_nd, type="value")
35
 
36
+ # prag_fr_compl_nd = tree.add_node("Pragmatic_frame_complexity", parent=inf_seeking_nd, type="param")
37
+ # tree.add_node("No", parent=prag_fr_compl_nd, type="value")
38
+ #
39
+ # # scaffolding
40
+ # scaffolding_nd = tree.add_node("Scaffolding", parent=inf_seeking_nd, type="param")
41
+ # scaffolding_N_nd = tree.add_node("N", parent=scaffolding_nd, type="value")
42
+ #
43
+ # cue_type_nd = tree.add_node("Cue_type", parent=scaffolding_N_nd, type="param")
44
+ # tree.add_node("Language_Color", parent=cue_type_nd, type="value")
45
+ # tree.add_node("Language_Feedback", parent=cue_type_nd, type="value")
46
+ # tree.add_node("Pointing", parent=cue_type_nd, type="value")
47
 
48
  problem_nd = tree.add_node("Problem", parent=inf_seeking_nd, type="param")
49
 
gym-minigrid/gym_minigrid/social_ai_envs/informationseekingenv.py CHANGED
@@ -78,6 +78,7 @@ class Caretaker(NPC):
78
  self.distractor_obj = self.env.distractor_generator
79
 
80
  if self.env.ja_recursive:
 
81
  if int(self.env.parameters["N"]) == 1:
82
  self.ja_decoy = self.env._rand_elem([self.target_obj])
83
  else:
@@ -101,6 +102,7 @@ class Caretaker(NPC):
101
  if self.env.hidden_npc:
102
  return reply, info
103
 
 
104
  scaffolding = self.env.parameters.get("Scaffolding", "N") == "Y"
105
  language_color = False
106
  language_feedback = False
@@ -135,7 +137,7 @@ class Caretaker(NPC):
135
  assert action is None
136
 
137
  if self.env.ja_recursive:
138
- # look at the center of the room (this makes the cue giving in side and outisde JA different)
139
  action = self.look_at_action([self.env.current_width // 2, self.env.current_height // 2])
140
  else:
141
  # look at the agent
@@ -550,7 +552,7 @@ class InformationSeekingEnv(MultiModalMiniGridEnv):
550
  num_of_colors = self.n_colors
551
 
552
  # additional test for recursivness of joint attention -> cues are given outside of JA
553
- self.ja_recursive = self.parameters.get("JA_recursive", False) if self.parameters else False
554
 
555
  self.add_obstacles()
556
  if self.obstacles != "No":
 
78
  self.distractor_obj = self.env.distractor_generator
79
 
80
  if self.env.ja_recursive:
81
+ # how many objects
82
  if int(self.env.parameters["N"]) == 1:
83
  self.ja_decoy = self.env._rand_elem([self.target_obj])
84
  else:
 
102
  if self.env.hidden_npc:
103
  return reply, info
104
 
105
+
106
  scaffolding = self.env.parameters.get("Scaffolding", "N") == "Y"
107
  language_color = False
108
  language_feedback = False
 
137
  assert action is None
138
 
139
  if self.env.ja_recursive:
140
+ # look at the center of the room (this makes the cue giving inside and outisde JA different)
141
  action = self.look_at_action([self.env.current_width // 2, self.env.current_height // 2])
142
  else:
143
  # look at the agent
 
552
  num_of_colors = self.n_colors
553
 
554
  # additional test for recursivness of joint attention -> cues are given outside of JA
555
+ self.ja_recursive = self.parameters.get("JA_recursive", False) == "Y" if self.parameters else False
556
 
557
  self.add_obstacles()
558
  if self.obstacles != "No":
gym-minigrid/gym_minigrid/social_ai_envs/socialaiparamenv.py CHANGED
@@ -297,6 +297,16 @@ class SocialAIParamEnv(gym.Env):
297
 
298
  return tree
299
 
 
 
 
 
 
 
 
 
 
 
300
  def construct_env_from_params(self, params):
301
  params_labels = {k.label: v.label for k, v in params.items()}
302
  if params_labels['Env_type'] == "Collaboration":
@@ -329,14 +339,19 @@ class SocialAIParamEnv(gym.Env):
329
 
330
  return env, reset_kwargs
331
 
332
- def reset(self, with_info=False):
333
  # select a new social environment at random, for each new episode
334
 
335
  old_window = None
336
  if self.current_env: # a previous env exists, save old window
337
  old_window = self.current_env.window
338
 
339
- self.current_params = self.parameter_tree.sample_env_params(ACL=self.curriculum)
 
 
 
 
 
340
 
341
  self.current_env, reset_kwargs = self.construct_env_from_params(self.current_params)
342
  assert reset_kwargs is not {}
@@ -360,9 +375,8 @@ class SocialAIParamEnv(gym.Env):
360
  else:
361
  return obs
362
 
363
- def reset_with_info(self):
364
- return self.reset(with_info=True)
365
-
366
 
367
  def seed(self, seed=1337):
368
  # Seed the random number generator
 
297
 
298
  return tree
299
 
300
+ def get_potential_params(self):
301
+ all_params = self.parameter_tree.get_all_params()
302
+ potential_params = {
303
+ k: v for k, v in all_params.items() if k in self.current_params.keys()
304
+ }
305
+ return potential_params
306
+
307
+ def get_all_params(self):
308
+ return self.parameter_tree.get_all_params()
309
+
310
  def construct_env_from_params(self, params):
311
  params_labels = {k.label: v.label for k, v in params.items()}
312
  if params_labels['Env_type'] == "Collaboration":
 
339
 
340
  return env, reset_kwargs
341
 
342
+ def reset(self, with_info=False, selected_params=None, ACL=None):
343
  # select a new social environment at random, for each new episode
344
 
345
  old_window = None
346
  if self.current_env: # a previous env exists, save old window
347
  old_window = self.current_env.window
348
 
349
+ if selected_params is not None:
350
+ self.current_params = selected_params
351
+ elif ACL is not None:
352
+ self.current_params = self.parameter_tree.sample_env_params(ACL=ACL)
353
+ else:
354
+ self.current_params = self.parameter_tree.sample_env_params(ACL=self.curriculum)
355
 
356
  self.current_env, reset_kwargs = self.construct_env_from_params(self.current_params)
357
  assert reset_kwargs is not {}
 
375
  else:
376
  return obs
377
 
378
+ def reset_with_info(self, selected_params=None, ACL=None):
379
+ return self.reset(with_info=True, selected_params=selected_params, ACL=ACL)
 
380
 
381
  def seed(self, seed=1337):
382
  # Seed the random number generator
web_demo/app.py CHANGED
@@ -8,6 +8,7 @@ import gym
8
  import gym_minigrid
9
  import numpy as np
10
  from gym_minigrid.window import Window
 
11
 
12
  from textworld_utils.utils import generate_text_obs
13
 
@@ -34,8 +35,8 @@ env_label_to_env_name = {
34
  "Language Color (Test)": "SocialAI-ELangColorDoorsTestInformationSeekingParamEnv-v1",
35
  "Language Feedback (Train)": "SocialAI-ELangFeedbackHeldoutDoorsTrainInformationSeekingParamEnv-v1",
36
  "Language Feedback (Test)": "SocialAI-ELangFeedbackDoorsTestInformationSeekingParamEnv-v1",
37
- "Joint Attention Language Color (Train)": "SocialAI-ELangColorHeldoutDoorsTrainInformationSeekingParamEnv-v1",
38
- "Joint Attention Language Color (Test)": "SocialAI-ELangColorDoorsTestInformationSeekingParamEnv-v1",
39
  "Apple stealing": "SocialAI-AppleStealingObst_NoParamEnv-v1",
40
  "Apple stealing (Occlusions)": "SocialAI-AppleStealingObst_MediumParamEnv-v1",
41
  "Scaffolding (train - scaf_8: Phase 1)": "SocialAI-AELangFeedbackTrainScaffoldingCSParamEnv-v1",
@@ -96,6 +97,12 @@ global obs, info
96
  obs, info = env.reset(with_info=True)
97
 
98
 
 
 
 
 
 
 
99
  def create_bubble_text(obs, info, full_conversation, textual_observations):
100
  if textual_observations:
101
  bubble_text = "Textual observation\n\n"+ \
@@ -110,6 +117,7 @@ def create_bubble_text(obs, info, full_conversation, textual_observations):
110
 
111
  def update_tree():
112
  selected_parameters = env.current_env.parameters
 
113
  selected_env_type = selected_parameters["Env_type"]
114
 
115
  assert selected_env_type in env_types, f"Env_type {selected_env_type} not in {env_types}"
@@ -145,6 +153,27 @@ def format_bubble_text(text):
145
  return "\n".join(lines)
146
 
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  @app.route('/set_env', methods=['POST'])
149
  def set_env():
150
  global env_name # Declare the variable as global to modify it
@@ -217,7 +246,6 @@ def perform_action():
217
 
218
  obs, reward, done, info = env.step(action)
219
 
220
-
221
  image = env.render('rgb_array', tile_size=32, mask_unobserved=mask_unobserved)
222
  image_data = np_img_to_base64(image)
223
 
@@ -253,6 +281,8 @@ def index():
253
  current_env_label=env_label,
254
  grammar_templates=grammar_templates,
255
  grammar_words=grammar_words,
 
 
256
  )
257
 
258
 
 
8
  import gym_minigrid
9
  import numpy as np
10
  from gym_minigrid.window import Window
11
+ from gym_minigrid.curriculums import SelectedParametersOrRandomCurriculum
12
 
13
  from textworld_utils.utils import generate_text_obs
14
 
 
35
  "Language Color (Test)": "SocialAI-ELangColorDoorsTestInformationSeekingParamEnv-v1",
36
  "Language Feedback (Train)": "SocialAI-ELangFeedbackHeldoutDoorsTrainInformationSeekingParamEnv-v1",
37
  "Language Feedback (Test)": "SocialAI-ELangFeedbackDoorsTestInformationSeekingParamEnv-v1",
38
+ "Joint Attention Language Color (Train)": "SocialAI-JAELangColorHeldoutDoorsTrainInformationSeekingParamEnv-v1",
39
+ "Joint Attention Language Color (Test)": "SocialAI-JAELangColorDoorsTestInformationSeekingParamEnv-v1",
40
  "Apple stealing": "SocialAI-AppleStealingObst_NoParamEnv-v1",
41
  "Apple stealing (Occlusions)": "SocialAI-AppleStealingObst_MediumParamEnv-v1",
42
  "Scaffolding (train - scaf_8: Phase 1)": "SocialAI-AELangFeedbackTrainScaffoldingCSParamEnv-v1",
 
97
  obs, info = env.reset(with_info=True)
98
 
99
 
100
+
101
+
102
+
103
+ def get_parameter_options(env):
104
+ return env.get_potential_params()
105
+
106
  def create_bubble_text(obs, info, full_conversation, textual_observations):
107
  if textual_observations:
108
  bubble_text = "Textual observation\n\n"+ \
 
117
 
118
  def update_tree():
119
  selected_parameters = env.current_env.parameters
120
+ print("sel param:", selected_parameters)
121
  selected_env_type = selected_parameters["Env_type"]
122
 
123
  assert selected_env_type in env_types, f"Env_type {selected_env_type} not in {env_types}"
 
153
  return "\n".join(lines)
154
 
155
 
156
+
157
+
158
+
159
+ @app.route('/set_env_params', methods=['POST'])
160
+ def set_env_params():
161
+ global env
162
+ selected_params_ids = request.get_json()
163
+
164
+ selected_parameters = {
165
+ env.parameter_tree.get_node_for_id(k): env.parameter_tree.get_node_for_id(v) for k,v in selected_params_ids.items()
166
+ }
167
+ global obs, info
168
+
169
+ selected_parameters_curriuclum = SelectedParametersOrRandomCurriculum(selected_parameters)
170
+
171
+ obs, info = env.reset(with_info=True, ACL=selected_parameters_curriuclum)
172
+ update_tree() # Update the tree for the new environment
173
+ return jsonify({"success": True}), 200
174
+ # return redirect(url_for('index')) # Redirect back to the main page
175
+
176
+
177
  @app.route('/set_env', methods=['POST'])
178
  def set_env():
179
  global env_name # Declare the variable as global to modify it
 
246
 
247
  obs, reward, done, info = env.step(action)
248
 
 
249
  image = env.render('rgb_array', tile_size=32, mask_unobserved=mask_unobserved)
250
  image_data = np_img_to_base64(image)
251
 
 
281
  current_env_label=env_label,
282
  grammar_templates=grammar_templates,
283
  grammar_words=grammar_words,
284
+ parameter_options=get_parameter_options(env),
285
+ current_parameters=env.current_params
286
  )
287
 
288
 
web_demo/templates/index.html CHANGED
@@ -54,10 +54,10 @@
54
  margin-bottom: 10px; /* Adds a space between the two rows */
55
  }
56
 
57
- .svg-container {
58
  width: 60%;
59
  min-width: 1000px;
60
- max-height: 300px;
61
  margin-left: 5px;
62
  margin-right: 5px;
63
  margin-bottom: 5px;
@@ -204,8 +204,24 @@
204
  .then(data => {
205
  document.getElementById('envImage').src = `data:image/jpeg;base64,${data.image_data}`;
206
  if (actionName === "done") {
207
- const svgTree = document.getElementById('svgTree');
208
- svgTree.data = `./static/current_tree.svg?_=${data.timestamp}`;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  }
210
 
211
  // Add this to handle the caretaker's utterance
@@ -236,6 +252,7 @@
236
  .catch(error => {
237
  console.error('Error:', error);
238
  });
 
239
  }
240
 
241
  document.addEventListener("DOMContentLoaded", function() {
@@ -327,7 +344,49 @@
327
  });
328
  }
329
 
 
 
 
 
 
 
 
330
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
  </script>
332
  </head>
333
  <body>
@@ -360,8 +419,29 @@
360
  </div>
361
 
362
  <p>This is the sampling tree. The current sampled parameters are highlighted blue</p>
363
- <div class="svg-container">
364
- <object id="svgTree" type="image/svg+xml" data="{{ url_for('static', filename='current_tree.svg', _=timestamp) }}" width="95%" height="95%">Your browser does not support SVG</object>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
365
  </div>
366
 
367
  <p>This is the environment.</p>
 
54
  margin-bottom: 10px; /* Adds a space between the two rows */
55
  }
56
 
57
+ .tree-container {
58
  width: 60%;
59
  min-width: 1000px;
60
+ max-height: 600px;
61
  margin-left: 5px;
62
  margin-right: 5px;
63
  margin-bottom: 5px;
 
204
  .then(data => {
205
  document.getElementById('envImage').src = `data:image/jpeg;base64,${data.image_data}`;
206
  if (actionName === "done") {
207
+
208
+ // Save the scroll position in sessionStorage before reloading
209
+ sessionStorage.setItem('scrollPosition', window.scrollY);
210
+ console.log(window.scrollY);
211
+
212
+ // Function to reload the page with updated URL parameter
213
+ function reloadPage() {
214
+ let currentUrl = new URL(window.location.href);
215
+ currentUrl.searchParams.set('_', new Date().getTime());
216
+ window.location.href = currentUrl.href;
217
+ }
218
+
219
+ // Call the function to reload the page
220
+ reloadPage();
221
+
222
+ <!-- const svgTree = document.getElementById('svgTree');-->
223
+ <!-- svgTree.data = `./static/current_tree.svg?_=${data.timestamp}`;-->
224
+ <!-- // todo: update the dropdown lists for parameters-->
225
  }
226
 
227
  // Add this to handle the caretaker's utterance
 
252
  .catch(error => {
253
  console.error('Error:', error);
254
  });
255
+
256
  }
257
 
258
  document.addEventListener("DOMContentLoaded", function() {
 
344
  });
345
  }
346
 
347
+ function setEnvParams() {
348
+ // Collect data from dropdowns
349
+ var data = {};
350
+
351
+ {% for key in parameter_options.keys() %}
352
+ data['{{ key.id }}'] = document.getElementById('{{ key.id }}').value;
353
+ {% endfor %}
354
 
355
+ // Send data to Flask backend
356
+ fetch('/set_env_params', {
357
+ method: 'POST',
358
+ headers: {
359
+ 'Content-Type': 'application/json',
360
+ },
361
+ body: JSON.stringify(data),
362
+ })
363
+ .then(response => {
364
+ if(response.ok) {
365
+ // If the request was successful, reload the page without cache
366
+ let currentUrl = new URL(window.location.href);
367
+ currentUrl.searchParams.set('_', new Date().getTime());
368
+ window.location.href = currentUrl.href;
369
+ } else {
370
+ // Handle errors
371
+ console.error('Error in setEnvParams:', response);
372
+ }
373
+ })
374
+ .catch((error) => {
375
+ console.error('Error:', error);
376
+ });
377
+
378
+ }
379
+
380
+ function restoreScrollPosition() {
381
+ let savedScrollPosition = parseInt(sessionStorage.getItem('scrollPosition'));
382
+ if (!isNaN(savedScrollPosition)) {
383
+ window.scrollTo(0, savedScrollPosition);
384
+ sessionStorage.removeItem('scrollPosition');
385
+ }
386
+ }
387
+
388
+ // Attach the restoreScrollPosition function to the appropriate event
389
+ window.addEventListener('load', restoreScrollPosition);
390
  </script>
391
  </head>
392
  <body>
 
419
  </div>
420
 
421
  <p>This is the sampling tree. The current sampled parameters are highlighted blue</p>
422
+ <div class="tree-container" >
423
+ <div class="form-container">
424
+ <span class="form-label">Select Parameters:</span>
425
+ <form method="post" id="paramsEnvForm">
426
+ {% set ns = namespace(lvl=1) %}
427
+ {% for key, options in parameter_options.items() %}
428
+ {% if key.level != ns.lvl %}
429
+ {% if ns.lvl >= 0 %}<hr>{% endif %}
430
+ {% set ns.lvl = key.level %}
431
+ {% endif %}
432
+
433
+ <label>{{ key.label }}</label>
434
+ <select name="{{ key }}" id="{{ key.id }}" onchange="setEnvParams();">
435
+ {% for option in options %}
436
+ <option value="{{ option.id }}" {% if current_parameters[key].id == option.id %}selected{% endif %} >
437
+ {{ option.label }}
438
+ </option>
439
+ {% endfor %}
440
+ </select>
441
+ {% endfor %}
442
+ </form>
443
+ </div>
444
+ <object id="svgTree" type="image/svg+xml" data="{{ url_for('static', filename='current_tree.svg', _=timestamp) }}" width="95%" height="70%">Your browser does not support SVG</object>
445
  </div>
446
 
447
  <p>This is the environment.</p>