Weiyu Liu commited on
Commit
f392320
1 Parent(s): 38a6100

add natural language model and app

Browse files
Files changed (42) hide show
  1. __pycache__/app.cpython-38.pyc +0 -0
  2. app.py +197 -45
  3. app_v0.py +282 -0
  4. app_v1.py +217 -0
  5. configs/conditional_pose_diffusion_language.yaml +92 -0
  6. data/template_sentence_data.pkl +3 -0
  7. requirements.txt +2 -1
  8. src/StructDiffusion/data/__pycache__/semantic_arrangement_demo.cpython-38.pyc +0 -0
  9. src/StructDiffusion/data/__pycache__/semantic_arrangement_language.cpython-38.pyc +0 -0
  10. src/StructDiffusion/data/__pycache__/semantic_arrangement_language_demo.cpython-38.pyc +0 -0
  11. src/StructDiffusion/data/pairwise_collision.py +19 -63
  12. src/StructDiffusion/data/semantic_arrangement.py +1 -44
  13. src/StructDiffusion/data/semantic_arrangement_language.py +633 -0
  14. src/StructDiffusion/data/semantic_arrangement_language_demo.py +693 -0
  15. src/StructDiffusion/diffusion/__pycache__/pose_conversion.cpython-38.pyc +0 -0
  16. src/StructDiffusion/diffusion/__pycache__/sampler.cpython-38.pyc +0 -0
  17. src/StructDiffusion/diffusion/sampler.py +243 -235
  18. src/StructDiffusion/language/__pycache__/sentence_encoder.cpython-38.pyc +0 -0
  19. src/StructDiffusion/language/__pycache__/tokenizer.cpython-38.pyc +0 -0
  20. src/StructDiffusion/language/convert_to_natural_language.ipynb +773 -0
  21. src/StructDiffusion/language/sentence_encoder.py +23 -0
  22. src/StructDiffusion/language/test_parrot_paraphrase.py +38 -0
  23. src/StructDiffusion/language/tokenizer.py +1 -22
  24. src/StructDiffusion/models/__pycache__/models.cpython-38.pyc +0 -0
  25. src/StructDiffusion/models/models.py +17 -3
  26. src/StructDiffusion/utils/__pycache__/batch_inference.cpython-38.pyc +0 -0
  27. src/StructDiffusion/utils/__pycache__/files.cpython-38.pyc +0 -0
  28. src/StructDiffusion/utils/__pycache__/rearrangement.cpython-38.pyc +0 -0
  29. src/StructDiffusion/utils/__pycache__/rotation_continuity.cpython-38.pyc +0 -0
  30. src/StructDiffusion/utils/__pycache__/tra3d.cpython-38.pyc +0 -0
  31. src/StructDiffusion/utils/batch_inference.py +11 -247
  32. src/StructDiffusion/utils/files.py +9 -1
  33. src/StructDiffusion/utils/np_speed_test.py +41 -0
  34. src/StructDiffusion/utils/rearrangement.py +29 -4
  35. src/StructDiffusion/utils/tra3d.py +148 -0
  36. tmp_data/input_scene.glb +0 -0
  37. tmp_data/input_scene_102.glb +0 -0
  38. tmp_data/input_scene_None.glb +0 -0
  39. tmp_data/output_scene.glb +0 -0
  40. tmp_data/output_scene_102.glb +0 -0
  41. wandb_logs/StructDiffusion/CollisionDiscriminator/checkpoints/epoch=199-step=653400.ckpt +3 -0
  42. wandb_logs/StructDiffusion/ConditionalPoseDiffusionLanguage/checkpoints/epoch=199-step=100000.ckpt +3 -0
__pycache__/app.cpython-38.pyc CHANGED
Binary files a/__pycache__/app.cpython-38.pyc and b/__pycache__/app.cpython-38.pyc differ
 
app.py CHANGED
@@ -10,13 +10,15 @@ from omegaconf import OmegaConf
10
  import sys
11
  sys.path.append('./src')
12
 
13
- from StructDiffusion.data.semantic_arrangement_demo import SemanticArrangementDataset
14
  from StructDiffusion.language.tokenizer import Tokenizer
15
- from StructDiffusion.models.pl_models import ConditionalPoseDiffusionModel
16
- from StructDiffusion.diffusion.sampler import Sampler
17
  from StructDiffusion.diffusion.pose_conversion import get_struct_objs_poses
18
  from StructDiffusion.utils.files import get_checkpoint_path_from_dir
19
- from StructDiffusion.utils.rearrangement import show_pcs_with_trimesh
 
 
20
  import StructDiffusion.utils.transformations as tra
21
 
22
 
@@ -65,23 +67,31 @@ class Infer_Wrapper:
65
 
66
  def __init__(self, args, cfg):
67
 
 
 
68
  # load
69
  pl.seed_everything(args.eval_random_seed)
70
  self.device = (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
71
 
72
- checkpoint_dir = os.path.join(cfg.WANDB.save_dir, cfg.WANDB.project, args.checkpoint_id, "checkpoints")
73
- checkpoint_path = get_checkpoint_path_from_dir(checkpoint_dir)
74
 
75
  self.tokenizer = Tokenizer(cfg.DATASET.vocab_dir)
76
  # override ignore_rgb for visualization
77
  cfg.DATASET.ignore_rgb = False
78
  self.dataset = SemanticArrangementDataset(tokenizer=self.tokenizer, **cfg.DATASET)
79
 
80
- self.sampler = Sampler(ConditionalPoseDiffusionModel, checkpoint_path, self.device)
 
 
 
 
 
81
 
82
  def visualize_scene(self, di, session_id):
83
- raw_datum = self.dataset.get_raw_data(di)
84
- language_command = self.tokenizer.convert_structure_params_to_natural_language(raw_datum["sentence"])
 
85
 
86
  obj_xyz = raw_datum["pcs"]
87
  scene = show_pcs_with_trimesh([xyz[:, :3] for xyz in obj_xyz], [xyz[:, 3:] for xyz in obj_xyz], return_scene=True)
@@ -93,20 +103,77 @@ class Infer_Wrapper:
93
 
94
  return language_command, scene_filename
95
 
96
- def infer(self, di, session_id, progress=gr.Progress()):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
- # di = np.random.choice(len(self.dataset))
99
 
100
- raw_datum = self.dataset.get_raw_data(di)
101
- print(self.tokenizer.convert_structure_params_to_natural_language(raw_datum["sentence"]))
102
- datum = self.dataset.convert_to_tensors(raw_datum, self.tokenizer)
103
  batch = self.dataset.single_datum_to_batch(datum, args.num_samples, self.device, inference_mode=True)
104
 
105
- num_poses = datum["goal_poses"].shape[0]
106
- xs = self.sampler.sample(batch, num_poses, progress)
107
 
108
- struct_pose, pc_poses_in_struct = get_struct_objs_poses(xs[0])
109
- new_obj_xyzs = move_pc_and_create_scene_simple(batch["pcs"], struct_pose, pc_poses_in_struct)
110
 
111
  # vis
112
  vis_obj_xyzs = new_obj_xyzs[:3]
@@ -115,18 +182,11 @@ class Infer_Wrapper:
115
  vis_obj_xyzs = vis_obj_xyzs.detach().cpu()
116
  vis_obj_xyzs = vis_obj_xyzs.numpy()
117
 
118
- # for bi, vis_obj_xyz in enumerate(vis_obj_xyzs):
119
- # if verbose:
120
- # print("example {}".format(bi))
121
- # print(vis_obj_xyz.shape)
122
- #
123
- # if trimesh:
124
- # show_pcs_with_trimesh([xyz[:, :3] for xyz in vis_obj_xyz], [xyz[:, 3:] for xyz in vis_obj_xyz])
125
  vis_obj_xyz = vis_obj_xyzs[0]
126
- scene = show_pcs_with_trimesh([xyz[:, :3] for xyz in vis_obj_xyz], [xyz[:, 3:] for xyz in vis_obj_xyz], return_scene=True)
127
-
 
128
  scene.apply_transform(tra.euler_matrix(np.pi, 0, np.pi/2))
129
-
130
  scene_filename = "./tmp_data/output_scene_{}.glb".format(session_id)
131
  scene.export(scene_filename)
132
 
@@ -167,10 +227,13 @@ class Infer_Wrapper:
167
 
168
  args = OmegaConf.create()
169
  args.base_config_file = "./configs/base.yaml"
170
- args.config_file = "./configs/conditional_pose_diffusion.yaml"
171
- args.checkpoint_id = "ConditionalPoseDiffusion"
 
172
  args.eval_random_seed = 42
173
- args.num_samples = 1
 
 
174
 
175
  base_cfg = OmegaConf.load(args.base_config_file)
176
  cfg = OmegaConf.load(args.config_file)
@@ -178,34 +241,123 @@ cfg = OmegaConf.merge(base_cfg, cfg)
178
 
179
  infer_wrapper = Infer_Wrapper(args, cfg)
180
 
181
- # version 0
182
- # demo = gr.Interface(
183
- # fn=infer_wrapper.run,
184
- # inputs=gr.Slider(0, len(infer_wrapper.dataset)),
185
- # # clear color range [0-1.0]
186
- # outputs=gr.Model3D(clear_color=[0, 0, 0, 0], label="3D Model")
187
- # )
188
  #
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  # demo.launch()
190
 
191
  # version 1
192
- demo = gr.Blocks(theme=gr.themes.Soft())
 
193
  with demo:
194
  gr.Markdown("<p style='text-align:center;font-size:18px'><b>StructDiffusion Demo</b></p>")
195
  # font-size:18px
196
  gr.Markdown("<p style='text-align:center'>StructDiffusion combines a diffusion model and an object-centric transformer to construct structures given partial-view point clouds and high-level language goals.<br><a href='https://structdiffusion.github.io/'>Website</a> | <a href='https://github.com/StructDiffusion/StructDiffusion'>Code</a></p>")
197
 
198
  session_id = gr.State(value=np.random.randint(0, 1000))
199
- data_selection = gr.Number(label="Example No.", minimum=0, maximum=len(infer_wrapper.dataset) - 1, precision=0)
200
- input_scene = gr.Model3D(clear_color=[0, 0, 0, 0], label="Input 3D Scene")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  language_command = gr.Textbox(label="Input Language Command")
202
- output_scene = gr.Model3D(clear_color=[0, 0, 0, 0], label="Generated 3D Structure")
203
 
204
- b1 = gr.Button("Show Input Language and Scene")
205
  b2 = gr.Button("Generate 3D Structure")
206
 
207
- b1.click(infer_wrapper.visualize_scene, inputs=[data_selection, session_id], outputs=[language_command, input_scene])
208
- b2.click(infer_wrapper.infer, inputs=[data_selection, session_id], outputs=output_scene)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
  demo.queue(concurrency_count=10)
211
  demo.launch()
 
10
  import sys
11
  sys.path.append('./src')
12
 
13
+ from StructDiffusion.data.semantic_arrangement_language_demo import SemanticArrangementDataset
14
  from StructDiffusion.language.tokenizer import Tokenizer
15
+ from StructDiffusion.models.pl_models import ConditionalPoseDiffusionModel, PairwiseCollisionModel
16
+ from StructDiffusion.diffusion.sampler import Sampler, SamplerV2
17
  from StructDiffusion.diffusion.pose_conversion import get_struct_objs_poses
18
  from StructDiffusion.utils.files import get_checkpoint_path_from_dir
19
+ from StructDiffusion.utils.rearrangement import show_pcs_with_trimesh, get_trimesh_scene_with_table
20
+ import StructDiffusion.utils.transformations as tra
21
+ from StructDiffusion.language.sentence_encoder import SentenceBertEncoder
22
  import StructDiffusion.utils.transformations as tra
23
 
24
 
 
67
 
68
  def __init__(self, args, cfg):
69
 
70
+ self.num_pts = cfg.DATASET.num_pts
71
+
72
  # load
73
  pl.seed_everything(args.eval_random_seed)
74
  self.device = (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
75
 
76
+ diffusion_checkpoint_path = get_checkpoint_path_from_dir(os.path.join(cfg.WANDB.save_dir, cfg.WANDB.project, args.diffusion_checkpoint_id, "checkpoints"))
77
+ collision_checkpoint_path = get_checkpoint_path_from_dir(os.path.join(cfg.WANDB.save_dir, cfg.WANDB.project, args.collision_checkpoint_id, "checkpoints"))
78
 
79
  self.tokenizer = Tokenizer(cfg.DATASET.vocab_dir)
80
  # override ignore_rgb for visualization
81
  cfg.DATASET.ignore_rgb = False
82
  self.dataset = SemanticArrangementDataset(tokenizer=self.tokenizer, **cfg.DATASET)
83
 
84
+ self.sampler = SamplerV2(ConditionalPoseDiffusionModel, diffusion_checkpoint_path,
85
+ PairwiseCollisionModel, collision_checkpoint_path, self.device)
86
+
87
+ self.sentence_encoder = SentenceBertEncoder()
88
+
89
+ self.session_id_to_obj_xyzs = {}
90
 
91
  def visualize_scene(self, di, session_id):
92
+
93
+ raw_datum = self.dataset.get_raw_data(di, inference_mode=True, shuffle_object_index=True)
94
+ language_command = raw_datum["template_sentence"]
95
 
96
  obj_xyz = raw_datum["pcs"]
97
  scene = show_pcs_with_trimesh([xyz[:, :3] for xyz in obj_xyz], [xyz[:, 3:] for xyz in obj_xyz], return_scene=True)
 
103
 
104
  return language_command, scene_filename
105
 
106
+ def build_scene(self, mesh_filename_1, x_1, y_1, z_1, ai_1, aj_1, ak_1, scale_1,
107
+ mesh_filename_2, x_2, y_2, z_2, ai_2, aj_2, ak_2, scale_2,
108
+ mesh_filename_3, x_3, y_3, z_3, ai_3, aj_3, ak_3, scale_3,
109
+ mesh_filename_4, x_4, y_4, z_4, ai_4, aj_4, ak_4, scale_4,
110
+ mesh_filename_5, x_5, y_5, z_5, ai_5, aj_5, ak_5, scale_5, session_id):
111
+
112
+ object_list = [(mesh_filename_1, x_1, y_1, z_1, ai_1, aj_1, ak_1, scale_1),
113
+ (mesh_filename_2, x_2, y_2, z_2, ai_2, aj_2, ak_2, scale_2),
114
+ (mesh_filename_3, x_3, y_3, z_3, ai_3, aj_3, ak_3, scale_3),
115
+ (mesh_filename_4, x_4, y_4, z_4, ai_4, aj_4, ak_4, scale_4),
116
+ (mesh_filename_5, x_5, y_5, z_5, ai_5, aj_5, ak_5, scale_5)]
117
+
118
+ scene = get_trimesh_scene_with_table()
119
+
120
+ obj_xyzs = []
121
+ for mesh_filename, x, y, z, ai, aj, ak, scale in object_list:
122
+ if mesh_filename is None:
123
+ continue
124
+ obj_mesh = trimesh.load(mesh_filename)
125
+ obj_mesh.apply_scale(scale)
126
+ z_min = obj_mesh.bounds[0, 2]
127
+ tform = tra.euler_matrix(ai, aj, ak)
128
+ tform[:3, 3] = [x, y, z - z_min]
129
+ obj_mesh.apply_transform(tform)
130
+ obj_xyz = obj_mesh.sample(self.num_pts)
131
+ obj = trimesh.PointCloud(obj_xyz)
132
+ scene.add_geometry(obj)
133
+
134
+ obj_xyzs.append(obj_xyz)
135
+
136
+ self.session_id_to_obj_xyzs[session_id] = obj_xyzs
137
+
138
+ # scene.show()
139
+
140
+ # obj_file = "/home/weiyu/data_drive/StructDiffusion/housekeep_custom_handpicked_small/visual/book_Eat_to_Live_The_Amazing_NutrientRich_Program_for_Fast_and_Sustained_Weight_Loss_Revised_Edition_Book_L/model.obj"
141
+ # obj = trimesh.load(obj_file)
142
+ #
143
+ # scene = get_trimesh_scene_with_table()
144
+ # scene.add_geometry(obj)
145
+ #
146
+ # scene.show()
147
+
148
+ # raw_datum = self.dataset.get_raw_data(di, inference_mode=True, shuffle_object_index=True)
149
+ # language_command = raw_datum["template_sentence"]
150
+ #
151
+ # obj_xyz = raw_datum["pcs"]
152
+ # scene = show_pcs_with_trimesh([xyz[:, :3] for xyz in obj_xyz], [xyz[:, 3:] for xyz in obj_xyz],
153
+ # return_scene=True)
154
+
155
+ scene.apply_transform(tra.euler_matrix(np.pi, 0, np.pi / 2))
156
+ scene_filename = "./tmp_data/input_scene_{}.glb".format(session_id)
157
+ scene.export(scene_filename)
158
+
159
+ return scene_filename
160
+
161
+ # return language_command, scene_filename
162
+
163
+ def infer(self, language_command, session_id, progress=gr.Progress()):
164
+
165
+ obj_xyzs = self.session_id_to_obj_xyzs[session_id]
166
 
167
+ sentence_embedding = self.sentence_encoder.encode([language_command]).flatten()
168
 
169
+ raw_datum = self.dataset.build_data_from_xyzs(obj_xyzs, sentence_embedding)
170
+ datum = self.dataset.convert_to_tensors(raw_datum, self.tokenizer, use_sentence_embedding=True)
 
171
  batch = self.dataset.single_datum_to_batch(datum, args.num_samples, self.device, inference_mode=True)
172
 
173
+ num_poses = raw_datum["num_goal_poses"]
174
+ struct_pose, pc_poses_in_struct = self.sampler.sample(batch, num_poses, args.num_elites, args.discriminator_batch_size)
175
 
176
+ new_obj_xyzs = move_pc_and_create_scene_simple(batch["pcs"][:args.num_elites], struct_pose, pc_poses_in_struct)
 
177
 
178
  # vis
179
  vis_obj_xyzs = new_obj_xyzs[:3]
 
182
  vis_obj_xyzs = vis_obj_xyzs.detach().cpu()
183
  vis_obj_xyzs = vis_obj_xyzs.numpy()
184
 
 
 
 
 
 
 
 
185
  vis_obj_xyz = vis_obj_xyzs[0]
186
+ # scene = show_pcs_with_trimesh([xyz[:, :3] for xyz in vis_obj_xyz], [xyz[:, 3:] for xyz in vis_obj_xyz], return_scene=True)
187
+ scene = show_pcs_with_trimesh([xyz[:, :3] for xyz in vis_obj_xyz], obj_rgbs=None, return_scene=True)
188
+ scene.show()
189
  scene.apply_transform(tra.euler_matrix(np.pi, 0, np.pi/2))
 
190
  scene_filename = "./tmp_data/output_scene_{}.glb".format(session_id)
191
  scene.export(scene_filename)
192
 
 
227
 
228
  args = OmegaConf.create()
229
  args.base_config_file = "./configs/base.yaml"
230
+ args.config_file = "./configs/conditional_pose_diffusion_language.yaml"
231
+ args.diffusion_checkpoint_id = "ConditionalPoseDiffusionLanguage"
232
+ args.collision_checkpoint_id = "CollisionDiscriminator"
233
  args.eval_random_seed = 42
234
+ args.num_samples = 50
235
+ args.num_elites = 3
236
+ args.discriminator_batch_size = 10
237
 
238
  base_cfg = OmegaConf.load(args.base_config_file)
239
  cfg = OmegaConf.load(args.config_file)
 
241
 
242
  infer_wrapper = Infer_Wrapper(args, cfg)
243
 
244
+ # # version 1
245
+ # demo = gr.Blocks(theme=gr.themes.Soft())
246
+ # with demo:
247
+ # gr.Markdown("<p style='text-align:center;font-size:18px'><b>StructDiffusion Demo</b></p>")
248
+ # # font-size:18px
249
+ # gr.Markdown("<p style='text-align:center'>StructDiffusion combines a diffusion model and an object-centric transformer to construct structures given partial-view point clouds and high-level language goals.<br><a href='https://structdiffusion.github.io/'>Website</a> | <a href='https://github.com/StructDiffusion/StructDiffusion'>Code</a></p>")
 
250
  #
251
+ # session_id = gr.State(value=np.random.randint(0, 1000))
252
+ # data_selection = gr.Number(label="Example No.", minimum=0, maximum=len(infer_wrapper.dataset) - 1, precision=0)
253
+ # input_scene = gr.Model3D(clear_color=[0, 0, 0, 0], label="Input 3D Scene")
254
+ # language_command = gr.Textbox(label="Input Language Command")
255
+ # output_scene = gr.Model3D(clear_color=[0, 0, 0, 0], label="Generated 3D Structure")
256
+ #
257
+ # b1 = gr.Button("Show Input Language and Scene")
258
+ # b2 = gr.Button("Generate 3D Structure")
259
+ #
260
+ # b1.click(infer_wrapper.visualize_scene, inputs=[data_selection, session_id], outputs=[language_command, input_scene])
261
+ # b2.click(infer_wrapper.infer, inputs=[data_selection, session_id], outputs=output_scene)
262
+ #
263
+ # demo.queue(concurrency_count=10)
264
  # demo.launch()
265
 
266
  # version 1
267
+ # demo = gr.Blocks(theme=gr.themes.Soft())
268
+ demo = gr.Blocks()
269
  with demo:
270
  gr.Markdown("<p style='text-align:center;font-size:18px'><b>StructDiffusion Demo</b></p>")
271
  # font-size:18px
272
  gr.Markdown("<p style='text-align:center'>StructDiffusion combines a diffusion model and an object-centric transformer to construct structures given partial-view point clouds and high-level language goals.<br><a href='https://structdiffusion.github.io/'>Website</a> | <a href='https://github.com/StructDiffusion/StructDiffusion'>Code</a></p>")
273
 
274
  session_id = gr.State(value=np.random.randint(0, 1000))
275
+ with gr.Tab("Object 1"):
276
+ with gr.Column(scale=1, min_width=600):
277
+ mesh_filename_1 = gr.Model3D(clear_color=[0, 0, 0, 0], label="Load 3D Object")
278
+ with gr.Row():
279
+ x_1 = gr.Slider(0, 1, label="x")
280
+ y_1 = gr.Slider(-0.5, 0.5, label="y")
281
+ z_1 = gr.Slider(0, 0.5, label="z")
282
+ with gr.Row():
283
+ ai_1 = gr.Slider(0, np.pi * 2, label="roll")
284
+ aj_1 = gr.Slider(0, np.pi * 2, label="pitch")
285
+ ak_1 = gr.Slider(0, np.pi * 2, label="yaw")
286
+ scale_1 = gr.Slider(0, 1)
287
+ with gr.Tab("Object 2"):
288
+ with gr.Column(scale=1, min_width=600):
289
+ mesh_filename_2 = gr.Model3D(clear_color=[0, 0, 0, 0], label="Load 3D Object")
290
+ with gr.Row():
291
+ x_2 = gr.Slider(0, 1, label="x")
292
+ y_2 = gr.Slider(-0.5, 0.5, label="y")
293
+ z_2 = gr.Slider(0, 0.5, label="z")
294
+ with gr.Row():
295
+ ai_2 = gr.Slider(0, np.pi * 2, label="roll")
296
+ aj_2 = gr.Slider(0, np.pi * 2, label="pitch")
297
+ ak_2 = gr.Slider(0, np.pi * 2, label="yaw")
298
+ scale_2 = gr.Slider(0, 1)
299
+ with gr.Tab("Object 3"):
300
+ with gr.Column(scale=1, min_width=600):
301
+ mesh_filename_3 = gr.Model3D(clear_color=[0, 0, 0, 0], label="Load 3D Object")
302
+ with gr.Row():
303
+ x_3 = gr.Slider(0, 1, label="x")
304
+ y_3 = gr.Slider(-0.5, 0.5, label="y")
305
+ z_3 = gr.Slider(0, 0.5, label="z")
306
+ with gr.Row():
307
+ ai_3 = gr.Slider(0, np.pi * 2, label="roll")
308
+ aj_3 = gr.Slider(0, np.pi * 2, label="pitch")
309
+ ak_3 = gr.Slider(0, np.pi * 2, label="yaw")
310
+ scale_3 = gr.Slider(0, 1)
311
+ with gr.Tab("Object 4"):
312
+ with gr.Column(scale=1, min_width=600):
313
+ mesh_filename_4 = gr.Model3D(clear_color=[0, 0, 0, 0], label="Load 3D Object")
314
+ with gr.Row():
315
+ x_4 = gr.Slider(0, 1, label="x")
316
+ y_4 = gr.Slider(-0.5, 0.5, label="y")
317
+ z_4 = gr.Slider(0, 0.5, label="z")
318
+ with gr.Row():
319
+ ai_4 = gr.Slider(0, np.pi * 2, label="roll")
320
+ aj_4 = gr.Slider(0, np.pi * 2, label="pitch")
321
+ ak_4 = gr.Slider(0, np.pi * 2, label="yaw")
322
+ scale_4 = gr.Slider(0, 1)
323
+ with gr.Tab("Object 5"):
324
+ with gr.Column(scale=1, min_width=600):
325
+ mesh_filename_5 = gr.Model3D(clear_color=[0, 0, 0, 0], label="Load 3D Object")
326
+ with gr.Row():
327
+ x_5 = gr.Slider(0, 1, label="x")
328
+ y_5 = gr.Slider(-0.5, 0.5, label="y")
329
+ z_5 = gr.Slider(0, 0.5, label="z")
330
+ with gr.Row():
331
+ ai_5 = gr.Slider(0, np.pi * 2, label="roll")
332
+ aj_5 = gr.Slider(0, np.pi * 2, label="pitch")
333
+ ak_5 = gr.Slider(0, np.pi * 2, label="yaw")
334
+ scale_5 = gr.Slider(0, 1)
335
+
336
+ b1 = gr.Button("Build Initial Scene")
337
+
338
+ initial_scene = gr.Model3D(clear_color=[0, 0, 0, 0], label="Initial 3D Scene")
339
  language_command = gr.Textbox(label="Input Language Command")
 
340
 
 
341
  b2 = gr.Button("Generate 3D Structure")
342
 
343
+ output_scene = gr.Model3D(clear_color=[0, 0, 0, 0], label="Generated 3D Structure")
344
+
345
+ # data_selection = gr.Number(label="Example No.", minimum=0, maximum=len(infer_wrapper.dataset) - 1, precision=0)
346
+ # input_scene = gr.Model3D(clear_color=[0, 0, 0, 0], label="Input 3D Scene")
347
+ # language_command = gr.Textbox(label="Input Language Command")
348
+ # output_scene = gr.Model3D(clear_color=[0, 0, 0, 0], label="Generated 3D Structure")
349
+ #
350
+ # b1 = gr.Button("Show Input Language and Scene")
351
+ # b2 = gr.Button("Generate 3D Structure")
352
+
353
+ b1.click(infer_wrapper.build_scene, inputs=[mesh_filename_1, x_1, y_1, z_1, ai_1, aj_1, ak_1, scale_1,
354
+ mesh_filename_2, x_2, y_2, z_2, ai_2, aj_2, ak_2, scale_2,
355
+ mesh_filename_3, x_3, y_3, z_3, ai_3, aj_3, ak_3, scale_3,
356
+ mesh_filename_4, x_4, y_4, z_4, ai_4, aj_4, ak_4, scale_4,
357
+ mesh_filename_5, x_5, y_5, z_5, ai_5, aj_5, ak_5, scale_5,
358
+ session_id], outputs=[initial_scene])
359
+
360
+ b2.click(infer_wrapper.infer, inputs=[language_command, session_id], outputs=output_scene)
361
 
362
  demo.queue(concurrency_count=10)
363
  demo.launch()
app_v0.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import torch
4
+ import trimesh
5
+ import numpy as np
6
+ import pytorch_lightning as pl
7
+ import gradio as gr
8
+ from omegaconf import OmegaConf
9
+
10
+ import sys
11
+ sys.path.append('./src')
12
+
13
+ from StructDiffusion.data.semantic_arrangement_demo import SemanticArrangementDataset
14
+ from StructDiffusion.language.tokenizer import Tokenizer
15
+ from StructDiffusion.models.pl_models import ConditionalPoseDiffusionModel
16
+ from StructDiffusion.diffusion.sampler import Sampler
17
+ from StructDiffusion.diffusion.pose_conversion import get_struct_objs_poses
18
+ from StructDiffusion.utils.files import get_checkpoint_path_from_dir
19
+ from StructDiffusion.utils.rearrangement import show_pcs_with_trimesh
20
+ import StructDiffusion.utils.transformations as tra
21
+
22
+
23
+ def move_pc_and_create_scene_simple(obj_xyzs, struct_pose, pc_poses_in_struct):
24
+
25
+ device = obj_xyzs.device
26
+
27
+ # obj_xyzs: B, N, P, 3 or 6
28
+ # struct_pose: B, 1, 4, 4
29
+ # pc_poses_in_struct: B, N, 4, 4
30
+
31
+ B, N, _, _ = pc_poses_in_struct.shape
32
+ _, _, P, _ = obj_xyzs.shape
33
+
34
+ current_pc_poses = torch.eye(4).repeat(B, N, 1, 1).to(device) # B, N, 4, 4
35
+ # print(torch.mean(obj_xyzs, dim=2).shape)
36
+ current_pc_poses[:, :, :3, 3] = torch.mean(obj_xyzs[:, :, :, :3], dim=2) # B, N, 4, 4
37
+ current_pc_poses = current_pc_poses.reshape(B * N, 4, 4) # B x N, 4, 4
38
+
39
+ struct_pose = struct_pose.repeat(1, N, 1, 1) # B, N, 4, 4
40
+ struct_pose = struct_pose.reshape(B * N, 4, 4) # B x 1, 4, 4
41
+ pc_poses_in_struct = pc_poses_in_struct.reshape(B * N, 4, 4) # B x N, 4, 4
42
+
43
+ goal_pc_pose = struct_pose @ pc_poses_in_struct # B x N, 4, 4
44
+ # print("goal pc poses")
45
+ # print(goal_pc_pose)
46
+ goal_pc_transform = goal_pc_pose @ torch.inverse(current_pc_poses) # B x N, 4, 4
47
+
48
+ # # important: pytorch3d uses row-major ordering, need to transpose each transformation matrix
49
+ # transpose = tra3d.Transform3d(matrix=goal_pc_transform.transpose(1, 2))
50
+ # new_obj_xyzs = obj_xyzs.reshape(B * N, P, -1) # B x N, P, 3
51
+ # new_obj_xyzs[:, :, :3] = transpose.transform_points(new_obj_xyzs[:, :, :3])
52
+
53
+ # a verision that does not rely on pytorch3d
54
+ new_obj_xyzs = obj_xyzs.reshape(B * N, P, -1)[:, :, :3] # B x N, P, 3
55
+ new_obj_xyzs = torch.concat([new_obj_xyzs, torch.ones(B * N, P, 1).to(device)], dim=-1) # B x N, P, 4
56
+ new_obj_xyzs = torch.einsum('bij,bkj->bki', goal_pc_transform, new_obj_xyzs)[:, :, :3] # # B x N, P, 3
57
+
58
+ # put it back to B, N, P, 3
59
+ obj_xyzs[:, :, :, :3] = new_obj_xyzs.reshape(B, N, P, -1)
60
+
61
+ return obj_xyzs
62
+
63
+
64
+ class Infer_Wrapper:
65
+
66
+ def __init__(self, args, cfg):
67
+
68
+ # load
69
+ pl.seed_everything(args.eval_random_seed)
70
+ self.device = (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
71
+
72
+ checkpoint_dir = os.path.join(cfg.WANDB.save_dir, cfg.WANDB.project, args.checkpoint_id, "checkpoints")
73
+ checkpoint_path = get_checkpoint_path_from_dir(checkpoint_dir)
74
+
75
+ self.tokenizer = Tokenizer(cfg.DATASET.vocab_dir)
76
+ # override ignore_rgb for visualization
77
+ cfg.DATASET.ignore_rgb = False
78
+ self.dataset = SemanticArrangementDataset(tokenizer=self.tokenizer, **cfg.DATASET)
79
+
80
+ self.sampler = Sampler(ConditionalPoseDiffusionModel, checkpoint_path, self.device)
81
+
82
+ def visualize_scene(self, di, session_id):
83
+ raw_datum = self.dataset.get_raw_data(di)
84
+ language_command = self.tokenizer.convert_structure_params_to_natural_language(raw_datum["sentence"])
85
+
86
+ obj_xyz = raw_datum["pcs"]
87
+ scene = show_pcs_with_trimesh([xyz[:, :3] for xyz in obj_xyz], [xyz[:, 3:] for xyz in obj_xyz], return_scene=True)
88
+
89
+ scene.apply_transform(tra.euler_matrix(np.pi, 0, np.pi/2))
90
+
91
+ scene_filename = "./tmp_data/input_scene_{}.glb".format(session_id)
92
+ scene.export(scene_filename)
93
+
94
+ return language_command, scene_filename
95
+
96
+ def infer(self, di, session_id, progress=gr.Progress()):
97
+
98
+ # di = np.random.choice(len(self.dataset))
99
+
100
+ raw_datum = self.dataset.get_raw_data(di)
101
+ print(self.tokenizer.convert_structure_params_to_natural_language(raw_datum["sentence"]))
102
+ datum = self.dataset.convert_to_tensors(raw_datum, self.tokenizer)
103
+ batch = self.dataset.single_datum_to_batch(datum, args.num_samples, self.device, inference_mode=True)
104
+
105
+ num_poses = datum["goal_poses"].shape[0]
106
+ xs = self.sampler.sample(batch, num_poses, progress)
107
+
108
+ struct_pose, pc_poses_in_struct = get_struct_objs_poses(xs[0])
109
+ new_obj_xyzs = move_pc_and_create_scene_simple(batch["pcs"], struct_pose, pc_poses_in_struct)
110
+
111
+ # vis
112
+ vis_obj_xyzs = new_obj_xyzs[:3]
113
+ if torch.is_tensor(vis_obj_xyzs):
114
+ if vis_obj_xyzs.is_cuda:
115
+ vis_obj_xyzs = vis_obj_xyzs.detach().cpu()
116
+ vis_obj_xyzs = vis_obj_xyzs.numpy()
117
+
118
+ # for bi, vis_obj_xyz in enumerate(vis_obj_xyzs):
119
+ # if verbose:
120
+ # print("example {}".format(bi))
121
+ # print(vis_obj_xyz.shape)
122
+ #
123
+ # if trimesh:
124
+ # show_pcs_with_trimesh([xyz[:, :3] for xyz in vis_obj_xyz], [xyz[:, 3:] for xyz in vis_obj_xyz])
125
+ vis_obj_xyz = vis_obj_xyzs[0]
126
+ scene = show_pcs_with_trimesh([xyz[:, :3] for xyz in vis_obj_xyz], [xyz[:, 3:] for xyz in vis_obj_xyz], return_scene=True)
127
+
128
+ scene.apply_transform(tra.euler_matrix(np.pi, 0, np.pi/2))
129
+
130
+ scene_filename = "./tmp_data/output_scene_{}.glb".format(session_id)
131
+ scene.export(scene_filename)
132
+
133
+ # pc_filename = "/home/weiyu/Research/StructDiffusion/StructDiffusion/interactive_demo/tmp_data/pc.glb"
134
+ # scene_filename = "/home/weiyu/Research/StructDiffusion/StructDiffusion/interactive_demo/tmp_data/scene.glb"
135
+ #
136
+ # vis_obj_xyz = vis_obj_xyz.reshape(-1, 6)
137
+ # vis_pc = trimesh.PointCloud(vis_obj_xyz[:, :3], colors=np.concatenate([vis_obj_xyz[:, 3:] * 255, np.ones([vis_obj_xyz.shape[0], 1]) * 255], axis=-1))
138
+ # vis_pc.export(pc_filename)
139
+ #
140
+ # scene = trimesh.Scene()
141
+ # # add the coordinate frame first
142
+ # # geom = trimesh.creation.axis(0.01)
143
+ # # scene.add_geometry(geom)
144
+ # table = trimesh.creation.box(extents=[1.0, 1.0, 0.02])
145
+ # table.apply_translation([0.5, 0, -0.01])
146
+ # table.visual.vertex_colors = [150, 111, 87, 125]
147
+ # scene.add_geometry(table)
148
+ # # bounds = trimesh.creation.box(extents=[4.0, 4.0, 4.0])
149
+ # # bounds = trimesh.creation.icosphere(subdivisions=3, radius=3.1)
150
+ # # bounds.apply_translation([0, 0, 0])
151
+ # # bounds.visual.vertex_colors = [30, 30, 30, 30]
152
+ # # scene.add_geometry(bounds)
153
+ # # RT_4x4 = np.array([[-0.39560353822208355, -0.9183993826406329, 0.006357240869497738, 0.2651463080169481],
154
+ # # [-0.797630370081598, 0.3401340617616391, -0.4980909683511864, 0.2225696480721997],
155
+ # # [0.45528412367406523, -0.2021172778236285, -0.8671014777611122, 0.9449050652025951],
156
+ # # [0.0, 0.0, 0.0, 1.0]])
157
+ # # RT_4x4 = np.linalg.inv(RT_4x4)
158
+ # # RT_4x4 = RT_4x4 @ np.diag([1, -1, -1, 1])
159
+ # # scene.camera_transform = RT_4x4
160
+ #
161
+ # mesh_list = trimesh.util.concatenate(scene.dump())
162
+ # print(mesh_list)
163
+ # trimesh.io.export.export_mesh(mesh_list, scene_filename, file_type='obj')
164
+
165
+ return scene_filename
166
+
167
+ def infer_new(self, di, session_id, progress=gr.Progress()):
168
+
169
+ # di = np.random.choice(len(self.dataset))
170
+
171
+ raw_datum = self.dataset.get_raw_data(di)
172
+ print(self.tokenizer.convert_structure_params_to_natural_language(raw_datum["sentence"]))
173
+ datum = self.dataset.convert_to_tensors(raw_datum, self.tokenizer)
174
+ batch = self.dataset.single_datum_to_batch(datum, args.num_samples, self.device, inference_mode=True)
175
+
176
+ num_poses = datum["goal_poses"].shape[0]
177
+ xs = self.sampler.sample(batch, num_poses, progress)
178
+
179
+ struct_pose, pc_poses_in_struct = get_struct_objs_poses(xs[0])
180
+ new_obj_xyzs = move_pc_and_create_scene_simple(batch["pcs"], struct_pose, pc_poses_in_struct)
181
+
182
+ # vis
183
+ vis_obj_xyzs = new_obj_xyzs[:3]
184
+ if torch.is_tensor(vis_obj_xyzs):
185
+ if vis_obj_xyzs.is_cuda:
186
+ vis_obj_xyzs = vis_obj_xyzs.detach().cpu()
187
+ vis_obj_xyzs = vis_obj_xyzs.numpy()
188
+
189
+ # for bi, vis_obj_xyz in enumerate(vis_obj_xyzs):
190
+ # if verbose:
191
+ # print("example {}".format(bi))
192
+ # print(vis_obj_xyz.shape)
193
+ #
194
+ # if trimesh:
195
+ # show_pcs_with_trimesh([xyz[:, :3] for xyz in vis_obj_xyz], [xyz[:, 3:] for xyz in vis_obj_xyz])
196
+ vis_obj_xyz = vis_obj_xyzs[0]
197
+ scene = show_pcs_with_trimesh([xyz[:, :3] for xyz in vis_obj_xyz], [xyz[:, 3:] for xyz in vis_obj_xyz], return_scene=True)
198
+
199
+ scene.apply_transform(tra.euler_matrix(np.pi, 0, np.pi/2))
200
+
201
+ scene_filename = "./tmp_data/output_scene_{}.glb".format(session_id)
202
+ scene.export(scene_filename)
203
+
204
+ # pc_filename = "/home/weiyu/Research/StructDiffusion/StructDiffusion/interactive_demo/tmp_data/pc.glb"
205
+ # scene_filename = "/home/weiyu/Research/StructDiffusion/StructDiffusion/interactive_demo/tmp_data/scene.glb"
206
+ #
207
+ # vis_obj_xyz = vis_obj_xyz.reshape(-1, 6)
208
+ # vis_pc = trimesh.PointCloud(vis_obj_xyz[:, :3], colors=np.concatenate([vis_obj_xyz[:, 3:] * 255, np.ones([vis_obj_xyz.shape[0], 1]) * 255], axis=-1))
209
+ # vis_pc.export(pc_filename)
210
+ #
211
+ # scene = trimesh.Scene()
212
+ # # add the coordinate frame first
213
+ # # geom = trimesh.creation.axis(0.01)
214
+ # # scene.add_geometry(geom)
215
+ # table = trimesh.creation.box(extents=[1.0, 1.0, 0.02])
216
+ # table.apply_translation([0.5, 0, -0.01])
217
+ # table.visual.vertex_colors = [150, 111, 87, 125]
218
+ # scene.add_geometry(table)
219
+ # # bounds = trimesh.creation.box(extents=[4.0, 4.0, 4.0])
220
+ # # bounds = trimesh.creation.icosphere(subdivisions=3, radius=3.1)
221
+ # # bounds.apply_translation([0, 0, 0])
222
+ # # bounds.visual.vertex_colors = [30, 30, 30, 30]
223
+ # # scene.add_geometry(bounds)
224
+ # # RT_4x4 = np.array([[-0.39560353822208355, -0.9183993826406329, 0.006357240869497738, 0.2651463080169481],
225
+ # # [-0.797630370081598, 0.3401340617616391, -0.4980909683511864, 0.2225696480721997],
226
+ # # [0.45528412367406523, -0.2021172778236285, -0.8671014777611122, 0.9449050652025951],
227
+ # # [0.0, 0.0, 0.0, 1.0]])
228
+ # # RT_4x4 = np.linalg.inv(RT_4x4)
229
+ # # RT_4x4 = RT_4x4 @ np.diag([1, -1, -1, 1])
230
+ # # scene.camera_transform = RT_4x4
231
+ #
232
+ # mesh_list = trimesh.util.concatenate(scene.dump())
233
+ # print(mesh_list)
234
+ # trimesh.io.export.export_mesh(mesh_list, scene_filename, file_type='obj')
235
+
236
+ return scene_filename
237
+
238
+
239
+ args = OmegaConf.create()
240
+ args.base_config_file = "./configs/base.yaml"
241
+ args.config_file = "./configs/conditional_pose_diffusion.yaml"
242
+ args.checkpoint_id = "ConditionalPoseDiffusion"
243
+ args.eval_random_seed = 42
244
+ args.num_samples = 1
245
+
246
+ base_cfg = OmegaConf.load(args.base_config_file)
247
+ cfg = OmegaConf.load(args.config_file)
248
+ cfg = OmegaConf.merge(base_cfg, cfg)
249
+
250
+ infer_wrapper = Infer_Wrapper(args, cfg)
251
+
252
+ # version 0
253
+ # demo = gr.Interface(
254
+ # fn=infer_wrapper.run,
255
+ # inputs=gr.Slider(0, len(infer_wrapper.dataset)),
256
+ # # clear color range [0-1.0]
257
+ # outputs=gr.Model3D(clear_color=[0, 0, 0, 0], label="3D Model")
258
+ # )
259
+ #
260
+ # demo.launch()
261
+
262
+ # version 1
263
+ demo = gr.Blocks(theme=gr.themes.Soft())
264
+ with demo:
265
+ gr.Markdown("<p style='text-align:center;font-size:18px'><b>StructDiffusion Demo</b></p>")
266
+ # font-size:18px
267
+ gr.Markdown("<p style='text-align:center'>StructDiffusion combines a diffusion model and an object-centric transformer to construct structures given partial-view point clouds and high-level language goals.<br><a href='https://structdiffusion.github.io/'>Website</a> | <a href='https://github.com/StructDiffusion/StructDiffusion'>Code</a></p>")
268
+
269
+ session_id = gr.State(value=np.random.randint(0, 1000))
270
+ data_selection = gr.Number(label="Example No.", minimum=0, maximum=len(infer_wrapper.dataset) - 1, precision=0)
271
+ input_scene = gr.Model3D(clear_color=[0, 0, 0, 0], label="Input 3D Scene")
272
+ language_command = gr.Textbox(label="Input Language Command")
273
+ output_scene = gr.Model3D(clear_color=[0, 0, 0, 0], label="Generated 3D Structure")
274
+
275
+ b1 = gr.Button("Show Input Language and Scene")
276
+ b2 = gr.Button("Generate 3D Structure")
277
+
278
+ b1.click(infer_wrapper.visualize_scene, inputs=[data_selection, session_id], outputs=[language_command, input_scene])
279
+ b2.click(infer_wrapper.infer, inputs=[data_selection, session_id], outputs=output_scene)
280
+
281
+ demo.queue(concurrency_count=10)
282
+ demo.launch()
app_v1.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import torch
4
+ import trimesh
5
+ import numpy as np
6
+ import pytorch_lightning as pl
7
+ import gradio as gr
8
+ from omegaconf import OmegaConf
9
+
10
+ import sys
11
+ sys.path.append('./src')
12
+
13
+ from StructDiffusion.data.semantic_arrangement_language_demo import SemanticArrangementDataset
14
+ from StructDiffusion.language.tokenizer import Tokenizer
15
+ from StructDiffusion.models.pl_models import ConditionalPoseDiffusionModel, PairwiseCollisionModel
16
+ from StructDiffusion.diffusion.sampler import Sampler, SamplerV2
17
+ from StructDiffusion.diffusion.pose_conversion import get_struct_objs_poses
18
+ from StructDiffusion.utils.files import get_checkpoint_path_from_dir
19
+ from StructDiffusion.utils.rearrangement import show_pcs_with_trimesh, get_trimesh_scene_with_table
20
+ import StructDiffusion.utils.transformations as tra
21
+ from StructDiffusion.language.sentence_encoder import SentenceBertEncoder
22
+
23
+
24
+ def move_pc_and_create_scene_simple(obj_xyzs, struct_pose, pc_poses_in_struct):
25
+
26
+ device = obj_xyzs.device
27
+
28
+ # obj_xyzs: B, N, P, 3 or 6
29
+ # struct_pose: B, 1, 4, 4
30
+ # pc_poses_in_struct: B, N, 4, 4
31
+
32
+ B, N, _, _ = pc_poses_in_struct.shape
33
+ _, _, P, _ = obj_xyzs.shape
34
+
35
+ current_pc_poses = torch.eye(4).repeat(B, N, 1, 1).to(device) # B, N, 4, 4
36
+ # print(torch.mean(obj_xyzs, dim=2).shape)
37
+ current_pc_poses[:, :, :3, 3] = torch.mean(obj_xyzs[:, :, :, :3], dim=2) # B, N, 4, 4
38
+ current_pc_poses = current_pc_poses.reshape(B * N, 4, 4) # B x N, 4, 4
39
+
40
+ struct_pose = struct_pose.repeat(1, N, 1, 1) # B, N, 4, 4
41
+ struct_pose = struct_pose.reshape(B * N, 4, 4) # B x 1, 4, 4
42
+ pc_poses_in_struct = pc_poses_in_struct.reshape(B * N, 4, 4) # B x N, 4, 4
43
+
44
+ goal_pc_pose = struct_pose @ pc_poses_in_struct # B x N, 4, 4
45
+ # print("goal pc poses")
46
+ # print(goal_pc_pose)
47
+ goal_pc_transform = goal_pc_pose @ torch.inverse(current_pc_poses) # B x N, 4, 4
48
+
49
+ # # important: pytorch3d uses row-major ordering, need to transpose each transformation matrix
50
+ # transpose = tra3d.Transform3d(matrix=goal_pc_transform.transpose(1, 2))
51
+ # new_obj_xyzs = obj_xyzs.reshape(B * N, P, -1) # B x N, P, 3
52
+ # new_obj_xyzs[:, :, :3] = transpose.transform_points(new_obj_xyzs[:, :, :3])
53
+
54
+ # a verision that does not rely on pytorch3d
55
+ new_obj_xyzs = obj_xyzs.reshape(B * N, P, -1)[:, :, :3] # B x N, P, 3
56
+ new_obj_xyzs = torch.concat([new_obj_xyzs, torch.ones(B * N, P, 1).to(device)], dim=-1) # B x N, P, 4
57
+ new_obj_xyzs = torch.einsum('bij,bkj->bki', goal_pc_transform, new_obj_xyzs)[:, :, :3] # # B x N, P, 3
58
+
59
+ # put it back to B, N, P, 3
60
+ obj_xyzs[:, :, :, :3] = new_obj_xyzs.reshape(B, N, P, -1)
61
+
62
+ return obj_xyzs
63
+
64
+
65
+ class Infer_Wrapper:
66
+
67
+ def __init__(self, args, cfg):
68
+
69
+ # load
70
+ pl.seed_everything(args.eval_random_seed)
71
+ self.device = (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
72
+
73
+ diffusion_checkpoint_path = get_checkpoint_path_from_dir(os.path.join(cfg.WANDB.save_dir, cfg.WANDB.project, args.diffusion_checkpoint_id, "checkpoints"))
74
+ collision_checkpoint_path = get_checkpoint_path_from_dir(os.path.join(cfg.WANDB.save_dir, cfg.WANDB.project, args.collision_checkpoint_id, "checkpoints"))
75
+
76
+ self.tokenizer = Tokenizer(cfg.DATASET.vocab_dir)
77
+ # override ignore_rgb for visualization
78
+ cfg.DATASET.ignore_rgb = False
79
+ self.dataset = SemanticArrangementDataset(tokenizer=self.tokenizer, **cfg.DATASET)
80
+
81
+ self.sampler = SamplerV2(ConditionalPoseDiffusionModel, diffusion_checkpoint_path,
82
+ PairwiseCollisionModel, collision_checkpoint_path, self.device)
83
+
84
+ def visualize_scene(self, di, session_id):
85
+
86
+ raw_datum = self.dataset.get_raw_data(di, inference_mode=True, shuffle_object_index=True)
87
+ language_command = raw_datum["template_sentence"]
88
+
89
+ obj_xyz = raw_datum["pcs"]
90
+ scene = show_pcs_with_trimesh([xyz[:, :3] for xyz in obj_xyz], [xyz[:, 3:] for xyz in obj_xyz], return_scene=True)
91
+
92
+ scene.apply_transform(tra.euler_matrix(np.pi, 0, np.pi/2))
93
+
94
+ scene_filename = "./tmp_data/input_scene_{}.glb".format(session_id)
95
+ scene.export(scene_filename)
96
+
97
+ return language_command, scene_filename
98
+
99
+
100
+ def infer(self, di, session_id, progress=gr.Progress()):
101
+
102
+ # di = np.random.choice(len(self.dataset))
103
+
104
+ raw_datum = self.dataset.get_raw_data(di, inference_mode=True, shuffle_object_index=True)
105
+ print(raw_datum["template_sentence"])
106
+ datum = self.dataset.convert_to_tensors(raw_datum, self.tokenizer, use_sentence_embedding=self.dataset.use_sentence_embedding)
107
+ batch = self.dataset.single_datum_to_batch(datum, args.num_samples, self.device, inference_mode=True)
108
+
109
+ num_poses = datum["goal_poses"].shape[0]
110
+ struct_pose, pc_poses_in_struct = self.sampler.sample(batch, num_poses, args.num_elites, args.discriminator_batch_size)
111
+
112
+ new_obj_xyzs = move_pc_and_create_scene_simple(batch["pcs"][:args.num_elites], struct_pose, pc_poses_in_struct)
113
+
114
+ # vis
115
+ vis_obj_xyzs = new_obj_xyzs[:3]
116
+ if torch.is_tensor(vis_obj_xyzs):
117
+ if vis_obj_xyzs.is_cuda:
118
+ vis_obj_xyzs = vis_obj_xyzs.detach().cpu()
119
+ vis_obj_xyzs = vis_obj_xyzs.numpy()
120
+
121
+ # for bi, vis_obj_xyz in enumerate(vis_obj_xyzs):
122
+ # if verbose:
123
+ # print("example {}".format(bi))
124
+ # print(vis_obj_xyz.shape)
125
+ #
126
+ # if trimesh:
127
+ # show_pcs_with_trimesh([xyz[:, :3] for xyz in vis_obj_xyz], [xyz[:, 3:] for xyz in vis_obj_xyz])
128
+ vis_obj_xyz = vis_obj_xyzs[0]
129
+ scene = show_pcs_with_trimesh([xyz[:, :3] for xyz in vis_obj_xyz], [xyz[:, 3:] for xyz in vis_obj_xyz], return_scene=True)
130
+
131
+ scene.apply_transform(tra.euler_matrix(np.pi, 0, np.pi/2))
132
+
133
+ scene_filename = "./tmp_data/output_scene_{}.glb".format(session_id)
134
+ scene.export(scene_filename)
135
+
136
+ # pc_filename = "/home/weiyu/Research/StructDiffusion/StructDiffusion/interactive_demo/tmp_data/pc.glb"
137
+ # scene_filename = "/home/weiyu/Research/StructDiffusion/StructDiffusion/interactive_demo/tmp_data/scene.glb"
138
+ #
139
+ # vis_obj_xyz = vis_obj_xyz.reshape(-1, 6)
140
+ # vis_pc = trimesh.PointCloud(vis_obj_xyz[:, :3], colors=np.concatenate([vis_obj_xyz[:, 3:] * 255, np.ones([vis_obj_xyz.shape[0], 1]) * 255], axis=-1))
141
+ # vis_pc.export(pc_filename)
142
+ #
143
+ # scene = trimesh.Scene()
144
+ # # add the coordinate frame first
145
+ # # geom = trimesh.creation.axis(0.01)
146
+ # # scene.add_geometry(geom)
147
+ # table = trimesh.creation.box(extents=[1.0, 1.0, 0.02])
148
+ # table.apply_translation([0.5, 0, -0.01])
149
+ # table.visual.vertex_colors = [150, 111, 87, 125]
150
+ # scene.add_geometry(table)
151
+ # # bounds = trimesh.creation.box(extents=[4.0, 4.0, 4.0])
152
+ # # bounds = trimesh.creation.icosphere(subdivisions=3, radius=3.1)
153
+ # # bounds.apply_translation([0, 0, 0])
154
+ # # bounds.visual.vertex_colors = [30, 30, 30, 30]
155
+ # # scene.add_geometry(bounds)
156
+ # # RT_4x4 = np.array([[-0.39560353822208355, -0.9183993826406329, 0.006357240869497738, 0.2651463080169481],
157
+ # # [-0.797630370081598, 0.3401340617616391, -0.4980909683511864, 0.2225696480721997],
158
+ # # [0.45528412367406523, -0.2021172778236285, -0.8671014777611122, 0.9449050652025951],
159
+ # # [0.0, 0.0, 0.0, 1.0]])
160
+ # # RT_4x4 = np.linalg.inv(RT_4x4)
161
+ # # RT_4x4 = RT_4x4 @ np.diag([1, -1, -1, 1])
162
+ # # scene.camera_transform = RT_4x4
163
+ #
164
+ # mesh_list = trimesh.util.concatenate(scene.dump())
165
+ # print(mesh_list)
166
+ # trimesh.io.export.export_mesh(mesh_list, scene_filename, file_type='obj')
167
+
168
+ return scene_filename
169
+
170
+
171
+ args = OmegaConf.create()
172
+ args.base_config_file = "./configs/base.yaml"
173
+ args.config_file = "./configs/conditional_pose_diffusion_language.yaml"
174
+ args.diffusion_checkpoint_id = "ConditionalPoseDiffusionLanguage"
175
+ args.collision_checkpoint_id = "CollisionDiscriminator"
176
+ args.eval_random_seed = 42
177
+ args.num_samples = 50
178
+ args.num_elites = 3
179
+ args.discriminator_batch_size = 10
180
+
181
+ base_cfg = OmegaConf.load(args.base_config_file)
182
+ cfg = OmegaConf.load(args.config_file)
183
+ cfg = OmegaConf.merge(base_cfg, cfg)
184
+
185
+ infer_wrapper = Infer_Wrapper(args, cfg)
186
+
187
+ # version 0
188
+ # demo = gr.Interface(
189
+ # fn=infer_wrapper.run,
190
+ # inputs=gr.Slider(0, len(infer_wrapper.dataset)),
191
+ # # clear color range [0-1.0]
192
+ # outputs=gr.Model3D(clear_color=[0, 0, 0, 0], label="3D Model")
193
+ # )
194
+ #
195
+ # demo.launch()
196
+
197
+ # version 1
198
+ demo = gr.Blocks(theme=gr.themes.Soft())
199
+ with demo:
200
+ gr.Markdown("<p style='text-align:center;font-size:18px'><b>StructDiffusion Demo</b></p>")
201
+ # font-size:18px
202
+ gr.Markdown("<p style='text-align:center'>StructDiffusion combines a diffusion model and an object-centric transformer to construct structures given partial-view point clouds and high-level language goals.<br><a href='https://structdiffusion.github.io/'>Website</a> | <a href='https://github.com/StructDiffusion/StructDiffusion'>Code</a></p>")
203
+
204
+ session_id = gr.State(value=np.random.randint(0, 1000))
205
+ data_selection = gr.Number(label="Example No.", minimum=0, maximum=len(infer_wrapper.dataset) - 1, precision=0)
206
+ input_scene = gr.Model3D(clear_color=[0, 0, 0, 0], label="Input 3D Scene")
207
+ language_command = gr.Textbox(label="Input Language Command")
208
+ output_scene = gr.Model3D(clear_color=[0, 0, 0, 0], label="Generated 3D Structure")
209
+
210
+ b1 = gr.Button("Show Input Language and Scene")
211
+ b2 = gr.Button("Generate 3D Structure")
212
+
213
+ b1.click(infer_wrapper.visualize_scene, inputs=[data_selection, session_id], outputs=[language_command, input_scene])
214
+ b2.click(infer_wrapper.infer, inputs=[data_selection, session_id], outputs=output_scene)
215
+
216
+ demo.queue(concurrency_count=10)
217
+ demo.launch()
configs/conditional_pose_diffusion_language.yaml ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ random_seed: 1
2
+
3
+ WANDB:
4
+ project: StructDiffusion
5
+ save_dir: ${base_dirs.wandb_dir}
6
+ name: conditional_pose_diffusion_language_shuffle
7
+
8
+ DATASET:
9
+ data_root: ${base_dirs.data}
10
+ vocab_dir: ${base_dirs.data}/type_vocabs_coarse.json
11
+
12
+ # important
13
+ use_virtual_structure_frame: True
14
+ ignore_distractor_objects: True
15
+ ignore_rgb: True
16
+
17
+ # the following are determined by the dataset
18
+ max_num_target_objects: 7
19
+ max_num_distractor_objects: 5
20
+ # set to 1 because we use sentence embedding, which only takes one spot in the input seq to transformer diffusion
21
+ max_num_shape_parameters: 1
22
+ # set to zeros because they are not used for now
23
+ max_num_rearrange_features: 0
24
+ max_num_anchor_features: 0
25
+
26
+ # language
27
+ sentence_embedding_file: ${base_dirs.data}/template_sentence_data.pkl
28
+ use_incomplete_sentence: True
29
+
30
+ # shuffle
31
+ shuffle_object_index: True
32
+
33
+ num_pts: 1024
34
+ filter_num_moved_objects_range:
35
+ data_augmentation: False
36
+
37
+ DATALOADER:
38
+ batch_size: 64
39
+ num_workers: 8
40
+ pin_memory: True
41
+
42
+ MODEL:
43
+ # transformer encoder
44
+ encoder_input_dim: 256
45
+ num_attention_heads: 8
46
+ encoder_hidden_dim: 512
47
+ encoder_dropout: 0.0
48
+ encoder_activation: relu
49
+ encoder_num_layers: 8
50
+ # output head
51
+ structure_dropout: 0
52
+ object_dropout: 0
53
+ # pc encoder
54
+ ignore_rgb: ${DATASET.ignore_rgb}
55
+ pc_emb_dim: 256
56
+ posed_pc_emb_dim: 80
57
+ # pose encoder
58
+ pose_emb_dim: 80
59
+ # language
60
+ word_emb_dim: 160
61
+ # diffusion step
62
+ time_emb_dim: 80
63
+ # sequence embeddings
64
+ # max_num_target_objects (+ max_num_distractor_objects if not ignore_distractor_objects)
65
+ max_seq_size: 7
66
+ max_token_type_size: 4
67
+ seq_pos_emb_dim: 8
68
+ seq_type_emb_dim: 8
69
+ # virtual frame
70
+ use_virtual_structure_frame: ${DATASET.use_virtual_structure_frame}
71
+ # language
72
+ use_sentence_embedding: True
73
+ sentence_embedding_dim: 384
74
+
75
+ NOISE_SCHEDULE:
76
+ timesteps: 200
77
+
78
+ LOSS:
79
+ type: huber
80
+
81
+ OPTIMIZER:
82
+ lr: 0.0001
83
+ weight_decay: 0 #0.0001
84
+ # lr_restart: 3000
85
+ # warmup: 10
86
+
87
+ TRAINER:
88
+ max_epochs: 200
89
+ gradient_clip_val: 1.0
90
+ gpus: 1
91
+ deterministic: False
92
+ # enable_progress_bar: False
data/template_sentence_data.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bbb41fd847ca6ee48d484c346fd8b0bbf478dfcc4559bf9646e3b7e7bf9fe83b
3
+ size 5680159
requirements.txt CHANGED
@@ -7,4 +7,5 @@ pyglet==1.5.0
7
  openpyxl
8
  pytorch_lightning==1.6.1
9
  wandb===0.13.10
10
- omegaconf==2.2.2
 
 
7
  openpyxl
8
  pytorch_lightning==1.6.1
9
  wandb===0.13.10
10
+ omegaconf==2.2.2
11
+ sentence-transformers
src/StructDiffusion/data/__pycache__/semantic_arrangement_demo.cpython-38.pyc CHANGED
Binary files a/src/StructDiffusion/data/__pycache__/semantic_arrangement_demo.cpython-38.pyc and b/src/StructDiffusion/data/__pycache__/semantic_arrangement_demo.cpython-38.pyc differ
 
src/StructDiffusion/data/__pycache__/semantic_arrangement_language.cpython-38.pyc ADDED
Binary file (18.5 kB). View file
 
src/StructDiffusion/data/__pycache__/semantic_arrangement_language_demo.cpython-38.pyc ADDED
Binary file (19.3 kB). View file
 
src/StructDiffusion/data/pairwise_collision.py CHANGED
@@ -32,11 +32,27 @@ def load_pairwise_collision_data(h5_filename):
32
  return data_dict
33
 
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  class PairwiseCollisionDataset(torch.utils.data.Dataset):
36
 
37
  def __init__(self, urdf_pc_idx_file, collision_data_dir, random_rotation=True,
38
  num_pts=1024, normalize_pc=True, num_scene_pts=2048, data_augmentation=False,
39
- debug=False):
40
 
41
  # load dictionary mapping from urdf to list of pc data, each sample is
42
  # {"step_t": step_t, "obj": obj, "filename": filename}
@@ -49,6 +65,8 @@ class PairwiseCollisionDataset(torch.utils.data.Dataset):
49
  filename = pd["filename"]
50
  if "data00026058" in filename or "data00011415" in filename or "data00026061" in filename or "data00700565" in filename or "data00505290" in filename:
51
  continue
 
 
52
  valid_pc_data.append(pd)
53
  if valid_pc_data:
54
  self.urdf_to_pc_data[urdf] = valid_pc_data
@@ -297,65 +315,3 @@ class PairwiseCollisionDataset(torch.utils.data.Dataset):
297
  "label": torch.FloatTensor([label]),
298
  }
299
  return datum
300
-
301
- # @staticmethod
302
- # def collate_fn(data):
303
- # """
304
- # :param data:
305
- # :return:
306
- # """
307
- #
308
- # batched_data_dict = {}
309
- # for key in ["is_circle"]:
310
- # batched_data_dict[key] = torch.cat([dict[key] for dict in data], dim=0)
311
- # for key in ["scene_xyz"]:
312
- # batched_data_dict[key] = torch.stack([dict[key] for dict in data], dim=0)
313
- #
314
- # return batched_data_dict
315
- #
316
- # # def create_pair_xyzs_from_obj_xyzs(self, new_obj_xyzs, debug=False):
317
- # #
318
- # # new_obj_xyzs = [xyz.cpu().numpy() for xyz in new_obj_xyzs]
319
- # #
320
- # # # compute pairwise collision
321
- # # scene_xyzs = []
322
- # # obj_xyz_pair_idxs = list(itertools.combinations(range(len(new_obj_xyzs)), 2))
323
- # #
324
- # # for obj_xyz_pair_idx in obj_xyz_pair_idxs:
325
- # # obj_xyz_pair = [new_obj_xyzs[obj_xyz_pair_idx[0]], new_obj_xyzs[obj_xyz_pair_idx[1]]]
326
- # # num_indicator = 2
327
- # # obj_xyz_pair_ind = []
328
- # # for oi, obj_xyz in enumerate(obj_xyz_pair):
329
- # # obj_xyz = np.concatenate([obj_xyz, np.tile(np.eye(num_indicator)[oi], (obj_xyz.shape[0], 1))], axis=1)
330
- # # obj_xyz_pair_ind.append(obj_xyz)
331
- # # pair_scene_xyz = np.concatenate(obj_xyz_pair_ind, axis=0)
332
- # #
333
- # # # subsampling and normalizing pc
334
- # # rand_idx = np.random.randint(0, pair_scene_xyz.shape[0], self.num_scene_pts)
335
- # # pair_scene_xyz = pair_scene_xyz[rand_idx]
336
- # # if self.normalize_pc:
337
- # # pair_scene_xyz[:, 0:3] = pc_normalize(pair_scene_xyz[:, 0:3])
338
- # #
339
- # # scene_xyzs.append(array_to_tensor(pair_scene_xyz))
340
- # #
341
- # # if debug:
342
- # # for scene_xyz in scene_xyzs:
343
- # # show_pcs([scene_xyz[:, 0:3]], [np.tile(np.array([0, 1, 0], dtype=np.float), (scene_xyz.shape[0], 1))],
344
- # # add_coordinate_frame=True)
345
- # #
346
- # # return scene_xyzs
347
-
348
-
349
- if __name__ == "__main__":
350
- dataset = PairwiseCollisionDataset(urdf_pc_idx_file="/home/weiyu/data_drive/StructDiffusion/pairwise_collision_data/urdf_pc_idx.pkl",
351
- collision_data_dir="/home/weiyu/data_drive/StructDiffusion/pairwise_collision_data",
352
- debug=False)
353
-
354
- for i in tqdm.tqdm(np.random.permutation(len(dataset))):
355
- # print(i)
356
- d = dataset[i]
357
- # print(d["label"])
358
-
359
- # dl = torch.utils.data.DataLoader(dataset, batch_size=32, num_workers=8)
360
- # for b in tqdm.tqdm(dl):
361
- # pass
 
32
  return data_dict
33
 
34
 
35
+ def replace_root_directory(original_filename: str, new_root: str) -> str:
36
+ # Split the original filename into a list by directory
37
+ original_parts = original_filename.split('/')
38
+
39
+ # Find the index of the "data_new_objects" part
40
+ data_index = original_parts.index('data_new_objects')
41
+
42
+ # Split the new root into a list by directory
43
+ new_root_parts = new_root.split('/')
44
+
45
+ # Combine the new root with the rest of the original filename
46
+ updated_filename = '/'.join(new_root_parts + original_parts[data_index + 1:])
47
+
48
+ return updated_filename
49
+
50
+
51
  class PairwiseCollisionDataset(torch.utils.data.Dataset):
52
 
53
  def __init__(self, urdf_pc_idx_file, collision_data_dir, random_rotation=True,
54
  num_pts=1024, normalize_pc=True, num_scene_pts=2048, data_augmentation=False,
55
+ debug=False, new_data_root=None):
56
 
57
  # load dictionary mapping from urdf to list of pc data, each sample is
58
  # {"step_t": step_t, "obj": obj, "filename": filename}
 
65
  filename = pd["filename"]
66
  if "data00026058" in filename or "data00011415" in filename or "data00026061" in filename or "data00700565" in filename or "data00505290" in filename:
67
  continue
68
+ if new_data_root:
69
+ pd["filename"] = replace_root_directory(pd["filename"], new_data_root)
70
  valid_pc_data.append(pd)
71
  if valid_pc_data:
72
  self.urdf_to_pc_data[urdf] = valid_pc_data
 
315
  "label": torch.FloatTensor([label]),
316
  }
317
  return datum
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/StructDiffusion/data/semantic_arrangement.py CHANGED
@@ -533,47 +533,4 @@ def compute_min_max(dataloader):
533
  current_min, _ = torch.min(goal_poses, dim=0)
534
  max_value[max_value < current_max] = current_max[max_value < current_max]
535
  max_value[max_value > current_min] = current_min[max_value > current_min]
536
- print(f"{min_value} - {max_value}")
537
-
538
-
539
- if __name__ == "__main__":
540
-
541
- tokenizer = Tokenizer("/home/weiyu/data_drive/data_new_objects/type_vocabs_coarse.json")
542
-
543
- data_roots = []
544
- index_roots = []
545
- for shape, index in [("circle", "index_10k"), ("line", "index_10k"), ("stacking", "index_10k"), ("dinner", "index_10k")]:
546
- data_roots.append("/home/weiyu/data_drive/data_new_objects/examples_{}_new_objects/result".format(shape))
547
- index_roots.append(index)
548
-
549
- dataset = SemanticArrangementDataset(data_roots=data_roots,
550
- index_roots=index_roots,
551
- split="valid", tokenizer=tokenizer,
552
- max_num_target_objects=7,
553
- max_num_distractor_objects=5,
554
- max_num_shape_parameters=5,
555
- max_num_rearrange_features=0,
556
- max_num_anchor_features=0,
557
- num_pts=1024,
558
- use_virtual_structure_frame=True,
559
- ignore_distractor_objects=True,
560
- ignore_rgb=True,
561
- filter_num_moved_objects_range=None, # [5, 5]
562
- data_augmentation=False,
563
- shuffle_object_index=False,
564
- debug=False)
565
-
566
- # print(len(dataset))
567
- # for d in dataset:
568
- # print("\n\n" + "="*100)
569
-
570
- dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=8)
571
- for i, d in enumerate(tqdm(dataloader)):
572
- pass
573
- # for k in d:
574
- # if isinstance(d[k], torch.Tensor):
575
- # print("--size", k, d[k].shape)
576
- # for k in d:
577
- # print(k, d[k])
578
- #
579
- # input("next?")
 
533
  current_min, _ = torch.min(goal_poses, dim=0)
534
  max_value[max_value < current_max] = current_max[max_value < current_max]
535
  max_value[max_value > current_min] = current_min[max_value > current_min]
536
+ print(f"{min_value} - {max_value}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/StructDiffusion/data/semantic_arrangement_language.py ADDED
@@ -0,0 +1,633 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import cv2
3
+ import h5py
4
+ import numpy as np
5
+ import os
6
+ import trimesh
7
+ import torch
8
+ from tqdm import tqdm
9
+ import json
10
+ import random
11
+ import pickle
12
+
13
+ from torch.utils.data import DataLoader
14
+
15
+ # Local imports
16
+ from StructDiffusion.utils.rearrangement import show_pcs, get_pts, combine_and_sample_xyzs
17
+ from StructDiffusion.language.tokenizer import Tokenizer
18
+
19
+ import StructDiffusion.utils.brain2.camera as cam
20
+ import StructDiffusion.utils.brain2.image as img
21
+ import StructDiffusion.utils.transformations as tra
22
+
23
+
24
+ class SemanticArrangementDataset(torch.utils.data.Dataset):
25
+
26
+ def __init__(self, data_roots, index_roots, split, tokenizer,
27
+ max_num_target_objects=11, max_num_distractor_objects=5,
28
+ max_num_shape_parameters=7, max_num_rearrange_features=1, max_num_anchor_features=3,
29
+ num_pts=1024,
30
+ use_virtual_structure_frame=True, ignore_distractor_objects=True, ignore_rgb=True,
31
+ filter_num_moved_objects_range=None, shuffle_object_index=False,
32
+ sentence_embedding_file=None, use_incomplete_sentence=False,
33
+ data_augmentation=True, debug=False, **kwargs):
34
+ """
35
+
36
+ Note: setting filter_num_moved_objects_range=[k, k] and max_num_objects=k will create no padding for target objs
37
+
38
+ :param data_root:
39
+ :param split: train, valid, or test
40
+ :param shuffle_object_index: whether to shuffle the positions of target objects and other objects in the sequence
41
+ :param debug:
42
+ :param max_num_shape_parameters:
43
+ :param max_num_objects:
44
+ :param max_num_rearrange_features:
45
+ :param max_num_anchor_features:
46
+ :param num_pts:
47
+ :param use_stored_arrangement_indices:
48
+ :param kwargs:
49
+ """
50
+
51
+ self.use_virtual_structure_frame = use_virtual_structure_frame
52
+ self.ignore_distractor_objects = ignore_distractor_objects
53
+ self.ignore_rgb = ignore_rgb and not debug
54
+
55
+ self.num_pts = num_pts
56
+ self.debug = debug
57
+
58
+ self.max_num_objects = max_num_target_objects
59
+ self.max_num_other_objects = max_num_distractor_objects
60
+ self.max_num_shape_parameters = max_num_shape_parameters
61
+ self.max_num_rearrange_features = max_num_rearrange_features
62
+ self.max_num_anchor_features = max_num_anchor_features
63
+ self.shuffle_object_index = shuffle_object_index
64
+
65
+ # used to tokenize the language part
66
+ self.tokenizer = tokenizer
67
+
68
+ # retrieve data
69
+ self.data_roots = data_roots
70
+ self.arrangement_data = []
71
+ arrangement_steps = []
72
+ for ddx in range(len(data_roots)):
73
+ data_root = data_roots[ddx]
74
+ index_root = index_roots[ddx]
75
+ arrangement_indices_file = os.path.join(data_root, index_root, "{}_arrangement_indices_file_all.txt".format(split))
76
+ if os.path.exists(arrangement_indices_file):
77
+ with open(arrangement_indices_file, "r") as fh:
78
+ arrangement_steps.extend([(os.path.join(data_root, f[0]), f[1]) for f in eval(fh.readline().strip())])
79
+ else:
80
+ print("{} does not exist".format(arrangement_indices_file))
81
+ # only keep the goal, ignore the intermediate steps
82
+ for filename, step_t in arrangement_steps:
83
+ if step_t == 0:
84
+ if "data00026058" in filename or "data00011415" in filename or "data00026061" in filename or "data00700565" in filename:
85
+ continue
86
+ self.arrangement_data.append((filename, step_t))
87
+ # if specified, filter data
88
+ if filter_num_moved_objects_range is not None:
89
+ self.arrangement_data = self.filter_based_on_number_of_moved_objects(filter_num_moved_objects_range)
90
+ print("{} valid sequences".format(len(self.arrangement_data)))
91
+
92
+ # language
93
+ if sentence_embedding_file:
94
+ assert max_num_shape_parameters == 1
95
+ # since we do not use them right now, ignore them
96
+ # assert max_num_rearrange_features == 0
97
+ # assert max_num_anchor_features == 0
98
+ with open(sentence_embedding_file, "rb") as fh:
99
+ template_sentence_data = pickle.load(fh)
100
+ self.use_sentence_embedding = True
101
+ self.type_value_tuple_to_template_sentences = template_sentence_data["type_value_tuple_to_template_sentences"]
102
+ self.template_sentence_to_embedding = template_sentence_data["template_sentence_to_embedding"]
103
+ self.use_incomplete_sentence = use_incomplete_sentence
104
+ print("use sentence embedding")
105
+ print(len(self.type_value_tuple_to_template_sentences))
106
+ print(len(self.template_sentence_to_embedding))
107
+ else:
108
+ self.use_sentence_embedding = False
109
+
110
+ # Data Aug
111
+ self.data_augmentation = data_augmentation
112
+ # additive noise
113
+ self.gp_rescale_factor_range = [12, 20]
114
+ self.gaussian_scale_range = [0., 0.003]
115
+ # multiplicative noise
116
+ self.gamma_shape = 1000.
117
+ self.gamma_scale = 0.001
118
+
119
+ def filter_based_on_number_of_moved_objects(self, filter_num_moved_objects_range):
120
+ assert len(list(filter_num_moved_objects_range)) == 2
121
+ min_num, max_num = filter_num_moved_objects_range
122
+ print("Remove scenes that have less than {} or more than {} objects being moved".format(min_num, max_num))
123
+ ok_data = []
124
+ for filename, step_t in self.arrangement_data:
125
+ h5 = h5py.File(filename, 'r')
126
+ moved_objs = h5['moved_objs'][()].split(',')
127
+ if min_num <= len(moved_objs) <= max_num:
128
+ ok_data.append((filename, step_t))
129
+ print("{} valid sequences left".format(len(ok_data)))
130
+ return ok_data
131
+
132
+ def get_data_idx(self, idx):
133
+ # Create the datum to return
134
+ file_idx = np.argmax(idx < self.file_to_count)
135
+ data = h5py.File(self.data_files[file_idx], 'r')
136
+ if file_idx > 0:
137
+ # for lang2sym, idx is always 0
138
+ idx = idx - self.file_to_count[file_idx - 1]
139
+ return data, idx, file_idx
140
+
141
+ def add_noise_to_depth(self, depth_img):
142
+ """ add depth noise """
143
+ multiplicative_noise = np.random.gamma(self.gamma_shape, self.gamma_scale)
144
+ depth_img = multiplicative_noise * depth_img
145
+ return depth_img
146
+
147
+ def add_noise_to_xyz(self, xyz_img, depth_img):
148
+ """ TODO: remove this code or at least celean it up"""
149
+ xyz_img = xyz_img.copy()
150
+ H, W, C = xyz_img.shape
151
+ gp_rescale_factor = np.random.randint(self.gp_rescale_factor_range[0],
152
+ self.gp_rescale_factor_range[1])
153
+ gp_scale = np.random.uniform(self.gaussian_scale_range[0],
154
+ self.gaussian_scale_range[1])
155
+ small_H, small_W = (np.array([H, W]) / gp_rescale_factor).astype(int)
156
+ additive_noise = np.random.normal(loc=0.0, scale=gp_scale, size=(small_H, small_W, C))
157
+ additive_noise = cv2.resize(additive_noise, (W, H), interpolation=cv2.INTER_CUBIC)
158
+ xyz_img[depth_img > 0, :] += additive_noise[depth_img > 0, :]
159
+ return xyz_img
160
+
161
+ def random_index(self):
162
+ return self[np.random.randint(len(self))]
163
+
164
+ def _get_rgb(self, h5, idx, ee=True):
165
+ RGB = "ee_rgb" if ee else "rgb"
166
+ rgb1 = img.PNGToNumpy(h5[RGB][idx])[:, :, :3] / 255. # remove alpha
167
+ return rgb1
168
+
169
+ def _get_depth(self, h5, idx, ee=True):
170
+ DEPTH = "ee_depth" if ee else "depth"
171
+
172
+ def _get_images(self, h5, idx, ee=True):
173
+ if ee:
174
+ RGB, DEPTH, SEG = "ee_rgb", "ee_depth", "ee_seg"
175
+ DMIN, DMAX = "ee_depth_min", "ee_depth_max"
176
+ else:
177
+ RGB, DEPTH, SEG = "rgb", "depth", "seg"
178
+ DMIN, DMAX = "depth_min", "depth_max"
179
+ dmin = h5[DMIN][idx]
180
+ dmax = h5[DMAX][idx]
181
+ rgb1 = img.PNGToNumpy(h5[RGB][idx])[:, :, :3] / 255. # remove alpha
182
+ depth1 = h5[DEPTH][idx] / 20000. * (dmax - dmin) + dmin
183
+ seg1 = img.PNGToNumpy(h5[SEG][idx])
184
+
185
+ valid1 = np.logical_and(depth1 > 0.1, depth1 < 2.)
186
+
187
+ # proj_matrix = h5['proj_matrix'][()]
188
+ camera = cam.get_camera_from_h5(h5)
189
+ if self.data_augmentation:
190
+ depth1 = self.add_noise_to_depth(depth1)
191
+
192
+ xyz1 = cam.compute_xyz(depth1, camera)
193
+ if self.data_augmentation:
194
+ xyz1 = self.add_noise_to_xyz(xyz1, depth1)
195
+
196
+ # Transform the point cloud
197
+ # Here it is...
198
+ # CAM_POSE = "ee_cam_pose" if ee else "cam_pose"
199
+ CAM_POSE = "ee_camera_view" if ee else "camera_view"
200
+ cam_pose = h5[CAM_POSE][idx]
201
+ if ee:
202
+ # ee_camera_view has 0s for x, y, z
203
+ cam_pos = h5["ee_cam_pose"][:][:3, 3]
204
+ cam_pose[:3, 3] = cam_pos
205
+
206
+ # Get transformed point cloud
207
+ h, w, d = xyz1.shape
208
+ xyz1 = xyz1.reshape(h * w, -1)
209
+ xyz1 = trimesh.transform_points(xyz1, cam_pose)
210
+ xyz1 = xyz1.reshape(h, w, -1)
211
+
212
+ scene1 = rgb1, depth1, seg1, valid1, xyz1
213
+
214
+ return scene1
215
+
216
+ def __len__(self):
217
+ return len(self.arrangement_data)
218
+
219
+ def _get_ids(self, h5):
220
+ """
221
+ get object ids
222
+
223
+ @param h5:
224
+ @return:
225
+ """
226
+ ids = {}
227
+ for k in h5.keys():
228
+ if k.startswith("id_"):
229
+ ids[k[3:]] = h5[k][()]
230
+ return ids
231
+
232
+ def get_positive_ratio(self):
233
+ num_pos = 0
234
+ for d in self.arrangement_data:
235
+ filename, step_t = d
236
+ if step_t == 0:
237
+ num_pos += 1
238
+ return (len(self.arrangement_data) - num_pos) * 1.0 / num_pos
239
+
240
+ def get_object_position_vocab_sizes(self):
241
+ return self.tokenizer.get_object_position_vocab_sizes()
242
+
243
+ def get_vocab_size(self):
244
+ return self.tokenizer.get_vocab_size()
245
+
246
+ def get_data_index(self, idx):
247
+ filename = self.arrangement_data[idx]
248
+ return filename
249
+
250
+ def get_raw_data(self, idx, inference_mode=False, shuffle_object_index=False):
251
+ """
252
+
253
+ :param idx:
254
+ :param inference_mode:
255
+ :param shuffle_object_index: used to test different orders of objects
256
+ :return:
257
+ """
258
+
259
+ filename, _ = self.arrangement_data[idx]
260
+
261
+ h5 = h5py.File(filename, 'r')
262
+ ids = self._get_ids(h5)
263
+ all_objs = sorted([o for o in ids.keys() if "object_" in o])
264
+ goal_specification = json.loads(str(np.array(h5["goal_specification"])))
265
+ num_rearrange_objs = len(goal_specification["rearrange"]["objects"])
266
+ num_other_objs = len(goal_specification["anchor"]["objects"] + goal_specification["distract"]["objects"])
267
+ assert len(all_objs) == num_rearrange_objs + num_other_objs, "{}, {}".format(len(all_objs), num_rearrange_objs + num_other_objs)
268
+ assert num_rearrange_objs <= self.max_num_objects
269
+ assert num_other_objs <= self.max_num_other_objects
270
+
271
+ # important: only using the last step
272
+ step_t = num_rearrange_objs
273
+
274
+ target_objs = all_objs[:num_rearrange_objs]
275
+ other_objs = all_objs[num_rearrange_objs:]
276
+
277
+ structure_parameters = goal_specification["shape"]
278
+
279
+ # Important: ensure the order is correct
280
+ if structure_parameters["type"] == "circle" or structure_parameters["type"] == "line":
281
+ target_objs = target_objs[::-1]
282
+ elif structure_parameters["type"] == "tower" or structure_parameters["type"] == "dinner":
283
+ target_objs = target_objs
284
+ else:
285
+ raise KeyError("{} structure is not recognized".format(structure_parameters["type"]))
286
+ all_objs = target_objs + other_objs
287
+
288
+ ###################################
289
+ # getting scene images and point clouds
290
+ scene = self._get_images(h5, step_t, ee=True)
291
+ rgb, depth, seg, valid, xyz = scene
292
+ if inference_mode:
293
+ initial_scene = scene
294
+
295
+ # getting object point clouds
296
+ obj_pcs = []
297
+ obj_pad_mask = []
298
+ current_pc_poses = []
299
+ other_obj_pcs = []
300
+ other_obj_pad_mask = []
301
+ for obj in all_objs:
302
+ obj_mask = np.logical_and(seg == ids[obj], valid)
303
+ if np.sum(obj_mask) <= 0:
304
+ raise Exception
305
+ ok, obj_xyz, obj_rgb, _ = get_pts(xyz, rgb, obj_mask, num_pts=self.num_pts)
306
+ if not ok:
307
+ raise Exception
308
+
309
+ if obj in target_objs:
310
+ if self.ignore_rgb:
311
+ obj_pcs.append(obj_xyz)
312
+ else:
313
+ obj_pcs.append(torch.concat([obj_xyz, obj_rgb], dim=-1))
314
+ obj_pad_mask.append(0)
315
+ pc_pose = np.eye(4)
316
+ pc_pose[:3, 3] = torch.mean(obj_xyz, dim=0).numpy()
317
+ current_pc_poses.append(pc_pose)
318
+ elif obj in other_objs:
319
+ if self.ignore_rgb:
320
+ other_obj_pcs.append(obj_xyz)
321
+ else:
322
+ other_obj_pcs.append(torch.concat([obj_xyz, obj_rgb], dim=-1))
323
+ other_obj_pad_mask.append(0)
324
+ else:
325
+ raise Exception
326
+
327
+ ###################################
328
+ # computes goal positions for objects
329
+ # Important: because of the noises we added to point clouds, the rearranged point clouds will not be perfect
330
+ if self.use_virtual_structure_frame:
331
+ goal_structure_pose = tra.euler_matrix(structure_parameters["rotation"][0], structure_parameters["rotation"][1],
332
+ structure_parameters["rotation"][2])
333
+ goal_structure_pose[:3, 3] = [structure_parameters["position"][0], structure_parameters["position"][1],
334
+ structure_parameters["position"][2]]
335
+ goal_structure_pose_inv = np.linalg.inv(goal_structure_pose)
336
+
337
+ goal_obj_poses = []
338
+ current_obj_poses = []
339
+ goal_pc_poses = []
340
+ for obj, current_pc_pose in zip(target_objs, current_pc_poses):
341
+ goal_pose = h5[obj][0]
342
+ current_pose = h5[obj][step_t]
343
+ if inference_mode:
344
+ goal_obj_poses.append(goal_pose)
345
+ current_obj_poses.append(current_pose)
346
+
347
+ goal_pc_pose = goal_pose @ np.linalg.inv(current_pose) @ current_pc_pose
348
+ if self.use_virtual_structure_frame:
349
+ goal_pc_pose = goal_structure_pose_inv @ goal_pc_pose
350
+ goal_pc_poses.append(goal_pc_pose)
351
+
352
+ # transform current object point cloud to the goal point cloud in the world frame
353
+ if self.debug:
354
+ new_obj_pcs = [copy.deepcopy(pc.numpy()) for pc in obj_pcs]
355
+ for i, obj_pc in enumerate(new_obj_pcs):
356
+
357
+ current_pc_pose = current_pc_poses[i]
358
+ goal_pc_pose = goal_pc_poses[i]
359
+ if self.use_virtual_structure_frame:
360
+ goal_pc_pose = goal_structure_pose @ goal_pc_pose
361
+ print("current pc pose", current_pc_pose)
362
+ print("goal pc pose", goal_pc_pose)
363
+
364
+ goal_pc_transform = goal_pc_pose @ np.linalg.inv(current_pc_pose)
365
+ print("transform", goal_pc_transform)
366
+ new_obj_pc = copy.deepcopy(obj_pc)
367
+ new_obj_pc[:, :3] = trimesh.transform_points(obj_pc[:, :3], goal_pc_transform)
368
+ print(new_obj_pc.shape)
369
+
370
+ # visualize rearrangement sequence (new_obj_xyzs), the current object before moving (obj_xyz), and other objects
371
+ new_obj_pcs[i] = new_obj_pc
372
+ new_obj_pcs[i][:, 3:] = np.tile(np.array([1, 0, 0], dtype=np.float), (new_obj_pc.shape[0], 1))
373
+ new_obj_rgb_current = np.tile(np.array([0, 1, 0], dtype=np.float), (new_obj_pc.shape[0], 1))
374
+ show_pcs([pc[:, :3] for pc in new_obj_pcs] + [pc[:, :3] for pc in other_obj_pcs] + [obj_pc[:, :3]],
375
+ [pc[:, 3:] for pc in new_obj_pcs] + [pc[:, 3:] for pc in other_obj_pcs] + [new_obj_rgb_current],
376
+ add_coordinate_frame=True)
377
+ show_pcs([pc[:, :3] for pc in new_obj_pcs], [pc[:, 3:] for pc in new_obj_pcs], add_coordinate_frame=True)
378
+
379
+ # pad data
380
+ for i in range(self.max_num_objects - len(target_objs)):
381
+ obj_pcs.append(torch.zeros_like(obj_pcs[0], dtype=torch.float32))
382
+ obj_pad_mask.append(1)
383
+ for i in range(self.max_num_other_objects - len(other_objs)):
384
+ other_obj_pcs.append(torch.zeros_like(obj_pcs[0], dtype=torch.float32))
385
+ other_obj_pad_mask.append(1)
386
+
387
+ ###################################
388
+ # preparing sentence
389
+ sentence = []
390
+ sentence_pad_mask = []
391
+
392
+ # structure parameters
393
+ # 5 parameters
394
+ structure_parameters = goal_specification["shape"]
395
+ if structure_parameters["type"] == "circle" or structure_parameters["type"] == "line":
396
+ sentence.append((structure_parameters["type"], "shape"))
397
+ sentence.append((structure_parameters["rotation"][2], "rotation"))
398
+ sentence.append((structure_parameters["position"][0], "position_x"))
399
+ sentence.append((structure_parameters["position"][1], "position_y"))
400
+ if structure_parameters["type"] == "circle":
401
+ sentence.append((structure_parameters["radius"], "radius"))
402
+ elif structure_parameters["type"] == "line":
403
+ sentence.append((structure_parameters["length"] / 2.0, "radius"))
404
+ if not self.use_sentence_embedding:
405
+ for _ in range(5):
406
+ sentence_pad_mask.append(0)
407
+ else:
408
+ sentence.append((structure_parameters["type"], "shape"))
409
+ sentence.append((structure_parameters["rotation"][2], "rotation"))
410
+ sentence.append((structure_parameters["position"][0], "position_x"))
411
+ sentence.append((structure_parameters["position"][1], "position_y"))
412
+ if not self.use_sentence_embedding:
413
+ for _ in range(4):
414
+ sentence_pad_mask.append(0)
415
+ sentence.append(("PAD", None))
416
+ sentence_pad_mask.append(1)
417
+
418
+ if self.use_sentence_embedding:
419
+
420
+ if self.use_incomplete_sentence:
421
+ token_idxs = np.random.permutation(len(sentence))
422
+ token_idxs = token_idxs[:np.random.randint(1, len(sentence) + 1)]
423
+ token_idxs = sorted(token_idxs)
424
+ incomplete_sentence = [sentence[ti] for ti in token_idxs]
425
+ else:
426
+ incomplete_sentence = sentence
427
+
428
+ type_value_tuple = self.tokenizer.convert_structure_params_to_type_value_tuple(incomplete_sentence)
429
+ template_sentence = np.random.choice(self.type_value_tuple_to_template_sentences[type_value_tuple])
430
+ sentence_embedding = self.template_sentence_to_embedding[template_sentence]
431
+ sentence_pad_mask = [0]
432
+
433
+ ###################################
434
+ # paddings
435
+ for i in range(self.max_num_objects - len(target_objs)):
436
+ goal_pc_poses.append(np.eye(4))
437
+
438
+ ###################################
439
+ if self.debug:
440
+ print("---")
441
+ print("all objects:", all_objs)
442
+ print("target objects:", target_objs)
443
+ print("other objects:", other_objs)
444
+ print("goal specification:", goal_specification)
445
+ print("sentence:", sentence)
446
+ if self.use_sentence_embedding:
447
+ print("use sentence embedding")
448
+ if self.use_incomplete_sentence:
449
+ print("incomplete_sentence:", incomplete_sentence)
450
+ print("template sentence:", template_sentence)
451
+ show_pcs([pc[:, :3] for pc in obj_pcs + other_obj_pcs], [pc[:, 3:] for pc in obj_pcs + other_obj_pcs], add_coordinate_frame=True)
452
+
453
+ assert len(obj_pcs) == len(goal_pc_poses)
454
+ ###################################
455
+
456
+ # shuffle the position of objects
457
+ # important: only shuffle for dinner
458
+ if shuffle_object_index and structure_parameters["type"] == "dinner":
459
+ num_target_objs = len(target_objs)
460
+ shuffle_target_object_indices = list(range(num_target_objs))
461
+ random.shuffle(shuffle_target_object_indices)
462
+ shuffle_object_indices = shuffle_target_object_indices + list(range(num_target_objs, self.max_num_objects))
463
+ obj_pcs = [obj_pcs[i] for i in shuffle_object_indices]
464
+ goal_pc_poses = [goal_pc_poses[i] for i in shuffle_object_indices]
465
+ if inference_mode:
466
+ goal_obj_poses = [goal_obj_poses[i] for i in shuffle_object_indices[:num_target_objs]]
467
+ current_obj_poses = [current_obj_poses[i] for i in shuffle_object_indices[:num_target_objs]]
468
+ target_objs = [target_objs[i] for i in shuffle_target_object_indices[:num_target_objs]]
469
+ current_pc_poses = [current_pc_poses[i] for i in shuffle_object_indices[:num_target_objs]]
470
+
471
+ ###################################
472
+ if self.use_virtual_structure_frame:
473
+ if self.ignore_distractor_objects:
474
+ # language, structure virtual frame, target objects
475
+ pcs = obj_pcs
476
+ type_index = [0] * self.max_num_shape_parameters + [2] + [3] * self.max_num_objects
477
+ position_index = list(range(self.max_num_shape_parameters)) + [0] + list(range(self.max_num_objects))
478
+ pad_mask = sentence_pad_mask + [0] + obj_pad_mask
479
+ else:
480
+ # language, distractor objects, structure virtual frame, target objects
481
+ pcs = other_obj_pcs + obj_pcs
482
+ type_index = [0] * self.max_num_shape_parameters + [1] * self.max_num_other_objects + [2] + [3] * self.max_num_objects
483
+ position_index = list(range(self.max_num_shape_parameters)) + list(range(self.max_num_other_objects)) + [0] + list(range(self.max_num_objects))
484
+ pad_mask = sentence_pad_mask + other_obj_pad_mask + [0] + obj_pad_mask
485
+ goal_poses = [goal_structure_pose] + goal_pc_poses
486
+ else:
487
+ if self.ignore_distractor_objects:
488
+ # language, target objects
489
+ pcs = obj_pcs
490
+ type_index = [0] * self.max_num_shape_parameters + [3] * self.max_num_objects
491
+ position_index = list(range(self.max_num_shape_parameters)) + list(range(self.max_num_objects))
492
+ pad_mask = sentence_pad_mask + obj_pad_mask
493
+ else:
494
+ # language, distractor objects, target objects
495
+ pcs = other_obj_pcs + obj_pcs
496
+ type_index = [0] * self.max_num_shape_parameters + [1] * self.max_num_other_objects + [3] * self.max_num_objects
497
+ position_index = list(range(self.max_num_shape_parameters)) + list(range(self.max_num_other_objects)) + list(range(self.max_num_objects))
498
+ pad_mask = sentence_pad_mask + other_obj_pad_mask + obj_pad_mask
499
+ goal_poses = goal_pc_poses
500
+
501
+ datum = {
502
+ "pcs": pcs,
503
+ "goal_poses": goal_poses,
504
+ "type_index": type_index,
505
+ "position_index": position_index,
506
+ "pad_mask": pad_mask,
507
+ "t": step_t,
508
+ "filename": filename
509
+ }
510
+ if self.use_sentence_embedding:
511
+ datum["sentence"] = sentence_embedding
512
+ else:
513
+ datum["sentence"] = sentence
514
+
515
+ if inference_mode:
516
+ datum["rgb"] = rgb
517
+ datum["goal_obj_poses"] = goal_obj_poses
518
+ datum["current_obj_poses"] = current_obj_poses
519
+ datum["target_objs"] = target_objs
520
+ datum["initial_scene"] = initial_scene
521
+ datum["ids"] = ids
522
+ datum["goal_specification"] = goal_specification
523
+ datum["current_pc_poses"] = current_pc_poses
524
+ if self.use_sentence_embedding:
525
+ datum["template_sentence"] = template_sentence
526
+
527
+ return datum
528
+
529
+ @staticmethod
530
+ def convert_to_tensors(datum, tokenizer, use_sentence_embedding=False):
531
+ tensors = {
532
+ "pcs": torch.stack(datum["pcs"], dim=0),
533
+ "goal_poses": torch.FloatTensor(np.array(datum["goal_poses"])),
534
+ "type_index": torch.LongTensor(np.array(datum["type_index"])),
535
+ "position_index": torch.LongTensor(np.array(datum["position_index"])),
536
+ "pad_mask": torch.LongTensor(np.array(datum["pad_mask"])),
537
+ "t": datum["t"],
538
+ "filename": datum["filename"]
539
+ }
540
+ if use_sentence_embedding:
541
+ tensors["sentence"] = torch.FloatTensor(datum["sentence"]) # after batching, B x sentence embed dim
542
+ else:
543
+ tensors["sentence"] = torch.LongTensor(np.array([tokenizer.tokenize(*i) for i in datum["sentence"]]))
544
+ return tensors
545
+
546
+ def __getitem__(self, idx):
547
+
548
+ datum = self.convert_to_tensors(self.get_raw_data(idx, shuffle_object_index=self.shuffle_object_index),
549
+ self.tokenizer,
550
+ self.use_sentence_embedding)
551
+
552
+ return datum
553
+
554
+ def single_datum_to_batch(self, x, num_samples, device, inference_mode=True):
555
+ tensor_x = {}
556
+
557
+ tensor_x["pcs"] = x["pcs"].to(device)[None, :, :, :].repeat(num_samples, 1, 1, 1)
558
+ tensor_x["sentence"] = x["sentence"].to(device)[None, :].repeat(num_samples, 1)
559
+ if not inference_mode:
560
+ tensor_x["goal_poses"] = x["goal_poses"].to(device)[None, :, :, :].repeat(num_samples, 1, 1, 1)
561
+
562
+ tensor_x["type_index"] = x["type_index"].to(device)[None, :].repeat(num_samples, 1)
563
+ tensor_x["position_index"] = x["position_index"].to(device)[None, :].repeat(num_samples, 1)
564
+ tensor_x["pad_mask"] = x["pad_mask"].to(device)[None, :].repeat(num_samples, 1)
565
+
566
+ return tensor_x
567
+
568
+
569
+ def compute_min_max(dataloader):
570
+
571
+ # tensor([-0.3557, -0.3847, 0.0000, -1.0000, -1.0000, -0.4759, -1.0000, -1.0000,
572
+ # -0.9079, -0.8668, -0.9105, -0.4186])
573
+ # tensor([0.3915, 0.3494, 0.3267, 1.0000, 1.0000, 0.8961, 1.0000, 1.0000, 0.8194,
574
+ # 0.4787, 0.6421, 1.0000])
575
+ # tensor([0.0918, -0.3758, 0.0000, -1.0000, -1.0000, 0.0000, -1.0000, -1.0000,
576
+ # -0.0000, 0.0000, 0.0000, 1.0000])
577
+ # tensor([0.9199, 0.3710, 0.0000, 1.0000, 1.0000, 0.0000, 1.0000, 1.0000, -0.0000,
578
+ # 0.0000, 0.0000, 1.0000])
579
+
580
+ min_value = torch.ones(16) * 10000
581
+ max_value = torch.ones(16) * -10000
582
+ for d in tqdm(dataloader):
583
+ goal_poses = d["goal_poses"]
584
+ goal_poses = goal_poses.reshape(-1, 16)
585
+ current_max, _ = torch.max(goal_poses, dim=0)
586
+ current_min, _ = torch.min(goal_poses, dim=0)
587
+ max_value[max_value < current_max] = current_max[max_value < current_max]
588
+ max_value[max_value > current_min] = current_min[max_value > current_min]
589
+ print(f"{min_value} - {max_value}")
590
+
591
+
592
+ if __name__ == "__main__":
593
+
594
+ tokenizer = Tokenizer("/home/weiyu/data_drive/data_new_objects/type_vocabs_coarse.json")
595
+
596
+ data_roots = []
597
+ index_roots = []
598
+ for shape, index in [("circle", "index_10k"), ("line", "index_10k"), ("stacking", "index_10k"), ("dinner", "index_10k")]:
599
+ data_roots.append("/home/weiyu/data_drive/data_new_objects/examples_{}_new_objects/result".format(shape))
600
+ index_roots.append(index)
601
+
602
+ dataset = SemanticArrangementDataset(data_roots=data_roots,
603
+ index_roots=index_roots,
604
+ split="valid", tokenizer=tokenizer,
605
+ max_num_target_objects=7,
606
+ max_num_distractor_objects=5,
607
+ max_num_shape_parameters=1,
608
+ max_num_rearrange_features=0,
609
+ max_num_anchor_features=0,
610
+ num_pts=1024,
611
+ use_virtual_structure_frame=True,
612
+ ignore_distractor_objects=True,
613
+ ignore_rgb=True,
614
+ filter_num_moved_objects_range=None, # [5, 5]
615
+ data_augmentation=False,
616
+ shuffle_object_index=True,
617
+ sentence_embedding_file="/home/weiyu/Research/StructDiffusion/old/StructDiffusion/src/StructDiffusion/language/template_sentence_data.pkl",
618
+ use_incomplete_sentence=True,
619
+ debug=False)
620
+
621
+ # print(len(dataset))
622
+ # for d in dataset:
623
+ # print("\n\n" + "="*100)
624
+
625
+ dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=8)
626
+ for i, d in enumerate(tqdm(dataloader)):
627
+ for k in d:
628
+ if isinstance(d[k], torch.Tensor):
629
+ print("--size", k, d[k].shape)
630
+ for k in d:
631
+ print(k, d[k])
632
+
633
+ input("next?")
src/StructDiffusion/data/semantic_arrangement_language_demo.py ADDED
@@ -0,0 +1,693 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import cv2
3
+ import h5py
4
+ import numpy as np
5
+ import os
6
+ import trimesh
7
+ import torch
8
+ from tqdm import tqdm
9
+ import json
10
+ import random
11
+ import pickle
12
+
13
+ from torch.utils.data import DataLoader
14
+
15
+ # Local imports
16
+ from StructDiffusion.utils.rearrangement import show_pcs, get_pts, combine_and_sample_xyzs
17
+ from StructDiffusion.language.tokenizer import Tokenizer
18
+
19
+ import StructDiffusion.utils.brain2.camera as cam
20
+ import StructDiffusion.utils.brain2.image as img
21
+ import StructDiffusion.utils.transformations as tra
22
+
23
+
24
+ class SemanticArrangementDataset(torch.utils.data.Dataset):
25
+
26
+ def __init__(self, data_root, tokenizer,
27
+ max_num_target_objects=11, max_num_distractor_objects=5,
28
+ max_num_shape_parameters=7, max_num_rearrange_features=1, max_num_anchor_features=3,
29
+ num_pts=1024,
30
+ use_virtual_structure_frame=True, ignore_distractor_objects=True, ignore_rgb=True,
31
+ filter_num_moved_objects_range=None, shuffle_object_index=False,
32
+ sentence_embedding_file=None, use_incomplete_sentence=False,
33
+ data_augmentation=True, debug=False, **kwargs):
34
+ """
35
+
36
+ Note: setting filter_num_moved_objects_range=[k, k] and max_num_objects=k will create no padding for target objs
37
+
38
+ :param data_root:
39
+ :param split: train, valid, or test
40
+ :param shuffle_object_index: whether to shuffle the positions of target objects and other objects in the sequence
41
+ :param debug:
42
+ :param max_num_shape_parameters:
43
+ :param max_num_objects:
44
+ :param max_num_rearrange_features:
45
+ :param max_num_anchor_features:
46
+ :param num_pts:
47
+ :param use_stored_arrangement_indices:
48
+ :param kwargs:
49
+ """
50
+
51
+ self.use_virtual_structure_frame = use_virtual_structure_frame
52
+ self.ignore_distractor_objects = ignore_distractor_objects
53
+ self.ignore_rgb = ignore_rgb and not debug
54
+
55
+ self.num_pts = num_pts
56
+ self.debug = debug
57
+
58
+ self.max_num_objects = max_num_target_objects
59
+ self.max_num_other_objects = max_num_distractor_objects
60
+ self.max_num_shape_parameters = max_num_shape_parameters
61
+ self.max_num_rearrange_features = max_num_rearrange_features
62
+ self.max_num_anchor_features = max_num_anchor_features
63
+ self.shuffle_object_index = shuffle_object_index
64
+
65
+ # used to tokenize the language part
66
+ self.tokenizer = tokenizer
67
+
68
+ # retrieve data
69
+ self.data_root = data_root
70
+ self.arrangement_data = []
71
+ for filename in os.listdir(data_root):
72
+ if ".h5" in filename:
73
+ self.arrangement_data.append((os.path.join(data_root, filename), 0))
74
+ print("{} valid sequences".format(len(self.arrangement_data)))
75
+
76
+ # language
77
+ if sentence_embedding_file:
78
+ assert max_num_shape_parameters == 1
79
+ # since we do not use them right now, ignore them
80
+ # assert max_num_rearrange_features == 0
81
+ # assert max_num_anchor_features == 0
82
+ with open(sentence_embedding_file, "rb") as fh:
83
+ template_sentence_data = pickle.load(fh)
84
+ self.use_sentence_embedding = True
85
+ self.type_value_tuple_to_template_sentences = template_sentence_data["type_value_tuple_to_template_sentences"]
86
+ self.template_sentence_to_embedding = template_sentence_data["template_sentence_to_embedding"]
87
+ self.use_incomplete_sentence = use_incomplete_sentence
88
+ print("use sentence embedding")
89
+ print(len(self.type_value_tuple_to_template_sentences))
90
+ print(len(self.template_sentence_to_embedding))
91
+ else:
92
+ self.use_sentence_embedding = False
93
+
94
+ # Data Aug
95
+ self.data_augmentation = data_augmentation
96
+ # additive noise
97
+ self.gp_rescale_factor_range = [12, 20]
98
+ self.gaussian_scale_range = [0., 0.003]
99
+ # multiplicative noise
100
+ self.gamma_shape = 1000.
101
+ self.gamma_scale = 0.001
102
+
103
+ def filter_based_on_number_of_moved_objects(self, filter_num_moved_objects_range):
104
+ assert len(list(filter_num_moved_objects_range)) == 2
105
+ min_num, max_num = filter_num_moved_objects_range
106
+ print("Remove scenes that have less than {} or more than {} objects being moved".format(min_num, max_num))
107
+ ok_data = []
108
+ for filename, step_t in self.arrangement_data:
109
+ h5 = h5py.File(filename, 'r')
110
+ moved_objs = h5['moved_objs'][()].split(',')
111
+ if min_num <= len(moved_objs) <= max_num:
112
+ ok_data.append((filename, step_t))
113
+ print("{} valid sequences left".format(len(ok_data)))
114
+ return ok_data
115
+
116
+ def get_data_idx(self, idx):
117
+ # Create the datum to return
118
+ file_idx = np.argmax(idx < self.file_to_count)
119
+ data = h5py.File(self.data_files[file_idx], 'r')
120
+ if file_idx > 0:
121
+ # for lang2sym, idx is always 0
122
+ idx = idx - self.file_to_count[file_idx - 1]
123
+ return data, idx, file_idx
124
+
125
+ def add_noise_to_depth(self, depth_img):
126
+ """ add depth noise """
127
+ multiplicative_noise = np.random.gamma(self.gamma_shape, self.gamma_scale)
128
+ depth_img = multiplicative_noise * depth_img
129
+ return depth_img
130
+
131
+ def add_noise_to_xyz(self, xyz_img, depth_img):
132
+ """ TODO: remove this code or at least celean it up"""
133
+ xyz_img = xyz_img.copy()
134
+ H, W, C = xyz_img.shape
135
+ gp_rescale_factor = np.random.randint(self.gp_rescale_factor_range[0],
136
+ self.gp_rescale_factor_range[1])
137
+ gp_scale = np.random.uniform(self.gaussian_scale_range[0],
138
+ self.gaussian_scale_range[1])
139
+ small_H, small_W = (np.array([H, W]) / gp_rescale_factor).astype(int)
140
+ additive_noise = np.random.normal(loc=0.0, scale=gp_scale, size=(small_H, small_W, C))
141
+ additive_noise = cv2.resize(additive_noise, (W, H), interpolation=cv2.INTER_CUBIC)
142
+ xyz_img[depth_img > 0, :] += additive_noise[depth_img > 0, :]
143
+ return xyz_img
144
+
145
+ def random_index(self):
146
+ return self[np.random.randint(len(self))]
147
+
148
+ def _get_rgb(self, h5, idx, ee=True):
149
+ RGB = "ee_rgb" if ee else "rgb"
150
+ rgb1 = img.PNGToNumpy(h5[RGB][idx])[:, :, :3] / 255. # remove alpha
151
+ return rgb1
152
+
153
+ def _get_depth(self, h5, idx, ee=True):
154
+ DEPTH = "ee_depth" if ee else "depth"
155
+
156
+ def _get_images(self, h5, idx, ee=True):
157
+ if ee:
158
+ RGB, DEPTH, SEG = "ee_rgb", "ee_depth", "ee_seg"
159
+ DMIN, DMAX = "ee_depth_min", "ee_depth_max"
160
+ else:
161
+ RGB, DEPTH, SEG = "rgb", "depth", "seg"
162
+ DMIN, DMAX = "depth_min", "depth_max"
163
+ dmin = h5[DMIN][idx]
164
+ dmax = h5[DMAX][idx]
165
+ rgb1 = img.PNGToNumpy(h5[RGB][idx])[:, :, :3] / 255. # remove alpha
166
+ depth1 = h5[DEPTH][idx] / 20000. * (dmax - dmin) + dmin
167
+ seg1 = img.PNGToNumpy(h5[SEG][idx])
168
+
169
+ valid1 = np.logical_and(depth1 > 0.1, depth1 < 2.)
170
+
171
+ # proj_matrix = h5['proj_matrix'][()]
172
+ camera = cam.get_camera_from_h5(h5)
173
+ if self.data_augmentation:
174
+ depth1 = self.add_noise_to_depth(depth1)
175
+
176
+ xyz1 = cam.compute_xyz(depth1, camera)
177
+ if self.data_augmentation:
178
+ xyz1 = self.add_noise_to_xyz(xyz1, depth1)
179
+
180
+ # Transform the point cloud
181
+ # Here it is...
182
+ # CAM_POSE = "ee_cam_pose" if ee else "cam_pose"
183
+ CAM_POSE = "ee_camera_view" if ee else "camera_view"
184
+ cam_pose = h5[CAM_POSE][idx]
185
+ if ee:
186
+ # ee_camera_view has 0s for x, y, z
187
+ cam_pos = h5["ee_cam_pose"][:][:3, 3]
188
+ cam_pose[:3, 3] = cam_pos
189
+
190
+ # Get transformed point cloud
191
+ h, w, d = xyz1.shape
192
+ xyz1 = xyz1.reshape(h * w, -1)
193
+ xyz1 = trimesh.transform_points(xyz1, cam_pose)
194
+ xyz1 = xyz1.reshape(h, w, -1)
195
+
196
+ scene1 = rgb1, depth1, seg1, valid1, xyz1
197
+
198
+ return scene1
199
+
200
+ def __len__(self):
201
+ return len(self.arrangement_data)
202
+
203
+ def _get_ids(self, h5):
204
+ """
205
+ get object ids
206
+
207
+ @param h5:
208
+ @return:
209
+ """
210
+ ids = {}
211
+ for k in h5.keys():
212
+ if k.startswith("id_"):
213
+ ids[k[3:]] = h5[k][()]
214
+ return ids
215
+
216
+ def get_positive_ratio(self):
217
+ num_pos = 0
218
+ for d in self.arrangement_data:
219
+ filename, step_t = d
220
+ if step_t == 0:
221
+ num_pos += 1
222
+ return (len(self.arrangement_data) - num_pos) * 1.0 / num_pos
223
+
224
+ def get_object_position_vocab_sizes(self):
225
+ return self.tokenizer.get_object_position_vocab_sizes()
226
+
227
+ def get_vocab_size(self):
228
+ return self.tokenizer.get_vocab_size()
229
+
230
+ def get_data_index(self, idx):
231
+ filename = self.arrangement_data[idx]
232
+ return filename
233
+
234
+ def get_raw_data(self, idx, inference_mode=False, shuffle_object_index=False):
235
+ """
236
+
237
+ :param idx:
238
+ :param inference_mode:
239
+ :param shuffle_object_index: used to test different orders of objects
240
+ :return:
241
+ """
242
+
243
+ filename, _ = self.arrangement_data[idx]
244
+
245
+ h5 = h5py.File(filename, 'r')
246
+ ids = self._get_ids(h5)
247
+ all_objs = sorted([o for o in ids.keys() if "object_" in o])
248
+ goal_specification = json.loads(str(np.array(h5["goal_specification"])))
249
+ num_rearrange_objs = len(goal_specification["rearrange"]["objects"])
250
+ num_other_objs = len(goal_specification["anchor"]["objects"] + goal_specification["distract"]["objects"])
251
+ assert len(all_objs) == num_rearrange_objs + num_other_objs, "{}, {}".format(len(all_objs), num_rearrange_objs + num_other_objs)
252
+ assert num_rearrange_objs <= self.max_num_objects
253
+ assert num_other_objs <= self.max_num_other_objects
254
+
255
+ # important: only using the last step
256
+ step_t = num_rearrange_objs
257
+
258
+ target_objs = all_objs[:num_rearrange_objs]
259
+ other_objs = all_objs[num_rearrange_objs:]
260
+
261
+ structure_parameters = goal_specification["shape"]
262
+
263
+ # Important: ensure the order is correct
264
+ if structure_parameters["type"] == "circle" or structure_parameters["type"] == "line":
265
+ target_objs = target_objs[::-1]
266
+ elif structure_parameters["type"] == "tower" or structure_parameters["type"] == "dinner":
267
+ target_objs = target_objs
268
+ else:
269
+ raise KeyError("{} structure is not recognized".format(structure_parameters["type"]))
270
+ all_objs = target_objs + other_objs
271
+
272
+ ###################################
273
+ # getting scene images and point clouds
274
+ scene = self._get_images(h5, step_t, ee=True)
275
+ rgb, depth, seg, valid, xyz = scene
276
+ if inference_mode:
277
+ initial_scene = scene
278
+
279
+ # getting object point clouds
280
+ obj_pcs = []
281
+ obj_pad_mask = []
282
+ current_pc_poses = []
283
+ other_obj_pcs = []
284
+ other_obj_pad_mask = []
285
+ for obj in all_objs:
286
+ obj_mask = np.logical_and(seg == ids[obj], valid)
287
+ if np.sum(obj_mask) <= 0:
288
+ raise Exception
289
+ ok, obj_xyz, obj_rgb, _ = get_pts(xyz, rgb, obj_mask, num_pts=self.num_pts)
290
+ if not ok:
291
+ raise Exception
292
+
293
+ if obj in target_objs:
294
+ if self.ignore_rgb:
295
+ obj_pcs.append(obj_xyz)
296
+ else:
297
+ obj_pcs.append(torch.concat([obj_xyz, obj_rgb], dim=-1))
298
+ obj_pad_mask.append(0)
299
+ pc_pose = np.eye(4)
300
+ pc_pose[:3, 3] = torch.mean(obj_xyz, dim=0).numpy()
301
+ current_pc_poses.append(pc_pose)
302
+ elif obj in other_objs:
303
+ if self.ignore_rgb:
304
+ other_obj_pcs.append(obj_xyz)
305
+ else:
306
+ other_obj_pcs.append(torch.concat([obj_xyz, obj_rgb], dim=-1))
307
+ other_obj_pad_mask.append(0)
308
+ else:
309
+ raise Exception
310
+
311
+ ###################################
312
+ # computes goal positions for objects
313
+ # Important: because of the noises we added to point clouds, the rearranged point clouds will not be perfect
314
+ if self.use_virtual_structure_frame:
315
+ goal_structure_pose = tra.euler_matrix(structure_parameters["rotation"][0], structure_parameters["rotation"][1],
316
+ structure_parameters["rotation"][2])
317
+ goal_structure_pose[:3, 3] = [structure_parameters["position"][0], structure_parameters["position"][1],
318
+ structure_parameters["position"][2]]
319
+ goal_structure_pose_inv = np.linalg.inv(goal_structure_pose)
320
+
321
+ goal_obj_poses = []
322
+ current_obj_poses = []
323
+ goal_pc_poses = []
324
+ for obj, current_pc_pose in zip(target_objs, current_pc_poses):
325
+ goal_pose = h5[obj][0]
326
+ current_pose = h5[obj][step_t]
327
+ if inference_mode:
328
+ goal_obj_poses.append(goal_pose)
329
+ current_obj_poses.append(current_pose)
330
+
331
+ goal_pc_pose = goal_pose @ np.linalg.inv(current_pose) @ current_pc_pose
332
+ if self.use_virtual_structure_frame:
333
+ goal_pc_pose = goal_structure_pose_inv @ goal_pc_pose
334
+ goal_pc_poses.append(goal_pc_pose)
335
+
336
+ # transform current object point cloud to the goal point cloud in the world frame
337
+ if self.debug:
338
+ new_obj_pcs = [copy.deepcopy(pc.numpy()) for pc in obj_pcs]
339
+ for i, obj_pc in enumerate(new_obj_pcs):
340
+
341
+ current_pc_pose = current_pc_poses[i]
342
+ goal_pc_pose = goal_pc_poses[i]
343
+ if self.use_virtual_structure_frame:
344
+ goal_pc_pose = goal_structure_pose @ goal_pc_pose
345
+ print("current pc pose", current_pc_pose)
346
+ print("goal pc pose", goal_pc_pose)
347
+
348
+ goal_pc_transform = goal_pc_pose @ np.linalg.inv(current_pc_pose)
349
+ print("transform", goal_pc_transform)
350
+ new_obj_pc = copy.deepcopy(obj_pc)
351
+ new_obj_pc[:, :3] = trimesh.transform_points(obj_pc[:, :3], goal_pc_transform)
352
+ print(new_obj_pc.shape)
353
+
354
+ # visualize rearrangement sequence (new_obj_xyzs), the current object before moving (obj_xyz), and other objects
355
+ new_obj_pcs[i] = new_obj_pc
356
+ new_obj_pcs[i][:, 3:] = np.tile(np.array([1, 0, 0], dtype=np.float), (new_obj_pc.shape[0], 1))
357
+ new_obj_rgb_current = np.tile(np.array([0, 1, 0], dtype=np.float), (new_obj_pc.shape[0], 1))
358
+ show_pcs([pc[:, :3] for pc in new_obj_pcs] + [pc[:, :3] for pc in other_obj_pcs] + [obj_pc[:, :3]],
359
+ [pc[:, 3:] for pc in new_obj_pcs] + [pc[:, 3:] for pc in other_obj_pcs] + [new_obj_rgb_current],
360
+ add_coordinate_frame=True)
361
+ show_pcs([pc[:, :3] for pc in new_obj_pcs], [pc[:, 3:] for pc in new_obj_pcs], add_coordinate_frame=True)
362
+
363
+ # pad data
364
+ for i in range(self.max_num_objects - len(target_objs)):
365
+ obj_pcs.append(torch.zeros_like(obj_pcs[0], dtype=torch.float32))
366
+ obj_pad_mask.append(1)
367
+ for i in range(self.max_num_other_objects - len(other_objs)):
368
+ other_obj_pcs.append(torch.zeros_like(obj_pcs[0], dtype=torch.float32))
369
+ other_obj_pad_mask.append(1)
370
+
371
+ ###################################
372
+ # preparing sentence
373
+ sentence = []
374
+ sentence_pad_mask = []
375
+
376
+ # structure parameters
377
+ # 5 parameters
378
+ structure_parameters = goal_specification["shape"]
379
+ if structure_parameters["type"] == "circle" or structure_parameters["type"] == "line":
380
+ sentence.append((structure_parameters["type"], "shape"))
381
+ sentence.append((structure_parameters["rotation"][2], "rotation"))
382
+ sentence.append((structure_parameters["position"][0], "position_x"))
383
+ sentence.append((structure_parameters["position"][1], "position_y"))
384
+ if structure_parameters["type"] == "circle":
385
+ sentence.append((structure_parameters["radius"], "radius"))
386
+ elif structure_parameters["type"] == "line":
387
+ sentence.append((structure_parameters["length"] / 2.0, "radius"))
388
+ if not self.use_sentence_embedding:
389
+ for _ in range(5):
390
+ sentence_pad_mask.append(0)
391
+ else:
392
+ sentence.append((structure_parameters["type"], "shape"))
393
+ sentence.append((structure_parameters["rotation"][2], "rotation"))
394
+ sentence.append((structure_parameters["position"][0], "position_x"))
395
+ sentence.append((structure_parameters["position"][1], "position_y"))
396
+ if not self.use_sentence_embedding:
397
+ for _ in range(4):
398
+ sentence_pad_mask.append(0)
399
+ sentence.append(("PAD", None))
400
+ sentence_pad_mask.append(1)
401
+
402
+ if self.use_sentence_embedding:
403
+
404
+ if self.use_incomplete_sentence:
405
+ token_idxs = np.random.permutation(len(sentence))
406
+ token_idxs = token_idxs[:np.random.randint(1, len(sentence) + 1)]
407
+ token_idxs = sorted(token_idxs)
408
+ incomplete_sentence = [sentence[ti] for ti in token_idxs]
409
+ else:
410
+ incomplete_sentence = sentence
411
+
412
+ type_value_tuple = self.tokenizer.convert_structure_params_to_type_value_tuple(incomplete_sentence)
413
+ template_sentence = np.random.choice(self.type_value_tuple_to_template_sentences[type_value_tuple])
414
+ sentence_embedding = self.template_sentence_to_embedding[template_sentence]
415
+ sentence_pad_mask = [0]
416
+
417
+ ###################################
418
+ # paddings
419
+ for i in range(self.max_num_objects - len(target_objs)):
420
+ goal_pc_poses.append(np.eye(4))
421
+
422
+ ###################################
423
+ if self.debug:
424
+ print("---")
425
+ print("all objects:", all_objs)
426
+ print("target objects:", target_objs)
427
+ print("other objects:", other_objs)
428
+ print("goal specification:", goal_specification)
429
+ print("sentence:", sentence)
430
+ if self.use_sentence_embedding:
431
+ print("use sentence embedding")
432
+ if self.use_incomplete_sentence:
433
+ print("incomplete_sentence:", incomplete_sentence)
434
+ print("template sentence:", template_sentence)
435
+ show_pcs([pc[:, :3] for pc in obj_pcs + other_obj_pcs], [pc[:, 3:] for pc in obj_pcs + other_obj_pcs], add_coordinate_frame=True)
436
+
437
+ assert len(obj_pcs) == len(goal_pc_poses)
438
+ ###################################
439
+
440
+ # shuffle the position of objects
441
+ # important: only shuffle for dinner
442
+ if shuffle_object_index and structure_parameters["type"] == "dinner":
443
+ num_target_objs = len(target_objs)
444
+ shuffle_target_object_indices = list(range(num_target_objs))
445
+ random.shuffle(shuffle_target_object_indices)
446
+ shuffle_object_indices = shuffle_target_object_indices + list(range(num_target_objs, self.max_num_objects))
447
+ obj_pcs = [obj_pcs[i] for i in shuffle_object_indices]
448
+ goal_pc_poses = [goal_pc_poses[i] for i in shuffle_object_indices]
449
+ if inference_mode:
450
+ goal_obj_poses = [goal_obj_poses[i] for i in shuffle_object_indices[:num_target_objs]]
451
+ current_obj_poses = [current_obj_poses[i] for i in shuffle_object_indices[:num_target_objs]]
452
+ target_objs = [target_objs[i] for i in shuffle_target_object_indices[:num_target_objs]]
453
+ current_pc_poses = [current_pc_poses[i] for i in shuffle_object_indices[:num_target_objs]]
454
+
455
+ ###################################
456
+ if self.use_virtual_structure_frame:
457
+ if self.ignore_distractor_objects:
458
+ # language, structure virtual frame, target objects
459
+ pcs = obj_pcs
460
+ type_index = [0] * self.max_num_shape_parameters + [2] + [3] * self.max_num_objects
461
+ position_index = list(range(self.max_num_shape_parameters)) + [0] + list(range(self.max_num_objects))
462
+ pad_mask = sentence_pad_mask + [0] + obj_pad_mask
463
+ else:
464
+ # language, distractor objects, structure virtual frame, target objects
465
+ pcs = other_obj_pcs + obj_pcs
466
+ type_index = [0] * self.max_num_shape_parameters + [1] * self.max_num_other_objects + [2] + [3] * self.max_num_objects
467
+ position_index = list(range(self.max_num_shape_parameters)) + list(range(self.max_num_other_objects)) + [0] + list(range(self.max_num_objects))
468
+ pad_mask = sentence_pad_mask + other_obj_pad_mask + [0] + obj_pad_mask
469
+ goal_poses = [goal_structure_pose] + goal_pc_poses
470
+ else:
471
+ if self.ignore_distractor_objects:
472
+ # language, target objects
473
+ pcs = obj_pcs
474
+ type_index = [0] * self.max_num_shape_parameters + [3] * self.max_num_objects
475
+ position_index = list(range(self.max_num_shape_parameters)) + list(range(self.max_num_objects))
476
+ pad_mask = sentence_pad_mask + obj_pad_mask
477
+ else:
478
+ # language, distractor objects, target objects
479
+ pcs = other_obj_pcs + obj_pcs
480
+ type_index = [0] * self.max_num_shape_parameters + [1] * self.max_num_other_objects + [3] * self.max_num_objects
481
+ position_index = list(range(self.max_num_shape_parameters)) + list(range(self.max_num_other_objects)) + list(range(self.max_num_objects))
482
+ pad_mask = sentence_pad_mask + other_obj_pad_mask + obj_pad_mask
483
+ goal_poses = goal_pc_poses
484
+
485
+ datum = {
486
+ "pcs": pcs,
487
+ "goal_poses": goal_poses,
488
+ "type_index": type_index,
489
+ "position_index": position_index,
490
+ "pad_mask": pad_mask,
491
+ "t": step_t,
492
+ "filename": filename
493
+ }
494
+ if self.use_sentence_embedding:
495
+ datum["sentence"] = sentence_embedding
496
+ else:
497
+ datum["sentence"] = sentence
498
+
499
+ if inference_mode:
500
+ datum["rgb"] = rgb
501
+ datum["goal_obj_poses"] = goal_obj_poses
502
+ datum["current_obj_poses"] = current_obj_poses
503
+ datum["target_objs"] = target_objs
504
+ datum["initial_scene"] = initial_scene
505
+ datum["ids"] = ids
506
+ datum["goal_specification"] = goal_specification
507
+ datum["current_pc_poses"] = current_pc_poses
508
+ if self.use_sentence_embedding:
509
+ datum["template_sentence"] = template_sentence
510
+
511
+ return datum
512
+
513
+ def build_data_from_xyzs(self, obj_xyzs, sentence_embedding, shuffle_object_index=True):
514
+
515
+ ## objects
516
+ obj_pcs = []
517
+ obj_pad_mask = []
518
+ current_pc_poses = []
519
+ other_obj_pcs = []
520
+ other_obj_pad_mask = []
521
+ for obj_xyz in obj_xyzs:
522
+ obj_pcs.append(torch.from_numpy(obj_xyz.astype(np.float32)))
523
+ obj_pad_mask.append(0)
524
+
525
+ # pad data
526
+ num_target_objs = len(obj_pcs)
527
+ for i in range(self.max_num_objects - num_target_objs):
528
+ obj_pcs.append(torch.zeros_like(obj_pcs[0], dtype=torch.float32))
529
+ obj_pad_mask.append(1)
530
+ for i in range(self.max_num_other_objects):
531
+ other_obj_pcs.append(torch.zeros_like(obj_pcs[0], dtype=torch.float32))
532
+ other_obj_pad_mask.append(1)
533
+
534
+ ## sentence
535
+ sentence_pad_mask = [0]
536
+
537
+ if shuffle_object_index:
538
+ num_target_objs = num_target_objs
539
+ shuffle_target_object_indices = list(range(num_target_objs))
540
+ random.shuffle(shuffle_target_object_indices)
541
+ shuffle_object_indices = shuffle_target_object_indices + list(range(num_target_objs, self.max_num_objects))
542
+ obj_pcs = [obj_pcs[i] for i in shuffle_object_indices]
543
+
544
+ ###################################
545
+ if self.use_virtual_structure_frame:
546
+ if self.ignore_distractor_objects:
547
+ # language, structure virtual frame, target objects
548
+ pcs = obj_pcs
549
+ type_index = [0] * self.max_num_shape_parameters + [2] + [3] * self.max_num_objects
550
+ position_index = list(range(self.max_num_shape_parameters)) + [0] + list(range(self.max_num_objects))
551
+ pad_mask = sentence_pad_mask + [0] + obj_pad_mask
552
+ else:
553
+ # language, distractor objects, structure virtual frame, target objects
554
+ pcs = other_obj_pcs + obj_pcs
555
+ type_index = [0] * self.max_num_shape_parameters + [1] * self.max_num_other_objects + [2] + [3] * self.max_num_objects
556
+ position_index = list(range(self.max_num_shape_parameters)) + list(range(self.max_num_other_objects)) + [0] + list(range(self.max_num_objects))
557
+ pad_mask = sentence_pad_mask + other_obj_pad_mask + [0] + obj_pad_mask
558
+ num_goal_poses = self.max_num_objects + 1
559
+ else:
560
+ if self.ignore_distractor_objects:
561
+ # language, target objects
562
+ pcs = obj_pcs
563
+ type_index = [0] * self.max_num_shape_parameters + [3] * self.max_num_objects
564
+ position_index = list(range(self.max_num_shape_parameters)) + list(range(self.max_num_objects))
565
+ pad_mask = sentence_pad_mask + obj_pad_mask
566
+ else:
567
+ # language, distractor objects, target objects
568
+ pcs = other_obj_pcs + obj_pcs
569
+ type_index = [0] * self.max_num_shape_parameters + [1] * self.max_num_other_objects + [3] * self.max_num_objects
570
+ position_index = list(range(self.max_num_shape_parameters)) + list(range(self.max_num_other_objects)) + list(range(self.max_num_objects))
571
+ pad_mask = sentence_pad_mask + other_obj_pad_mask + obj_pad_mask
572
+ num_goal_poses = self.max_num_objects
573
+
574
+ datum = {
575
+ "pcs": pcs,
576
+ "type_index": type_index,
577
+ "position_index": position_index,
578
+ "pad_mask": pad_mask,
579
+ "sentence": sentence_embedding,
580
+ "num_goal_poses": num_goal_poses,
581
+ "t": 0,
582
+ "filename": "inference"
583
+ }
584
+
585
+ return datum
586
+
587
+ @staticmethod
588
+ def convert_to_tensors(datum, tokenizer, use_sentence_embedding=False):
589
+ tensors = {
590
+ "pcs": torch.stack(datum["pcs"], dim=0),
591
+ "type_index": torch.LongTensor(np.array(datum["type_index"])),
592
+ "position_index": torch.LongTensor(np.array(datum["position_index"])),
593
+ "pad_mask": torch.LongTensor(np.array(datum["pad_mask"])),
594
+ "t": datum["t"],
595
+ "filename": datum["filename"]
596
+ }
597
+ if "goal_poses" in datum:
598
+ tensors["goal_poses"] = torch.FloatTensor(np.array(datum["goal_poses"])),
599
+
600
+ if use_sentence_embedding:
601
+ tensors["sentence"] = torch.FloatTensor(datum["sentence"]) # after batching, B x sentence embed dim
602
+ else:
603
+ tensors["sentence"] = torch.LongTensor(np.array([tokenizer.tokenize(*i) for i in datum["sentence"]]))
604
+ return tensors
605
+
606
+ def __getitem__(self, idx):
607
+
608
+ datum = self.convert_to_tensors(self.get_raw_data(idx, shuffle_object_index=self.shuffle_object_index),
609
+ self.tokenizer,
610
+ self.use_sentence_embedding)
611
+
612
+ return datum
613
+
614
+ def single_datum_to_batch(self, x, num_samples, device, inference_mode=True):
615
+ tensor_x = {}
616
+
617
+ tensor_x["pcs"] = x["pcs"].to(device)[None, :, :, :].repeat(num_samples, 1, 1, 1)
618
+ tensor_x["sentence"] = x["sentence"].to(device)[None, :].repeat(num_samples, 1)
619
+ if not inference_mode:
620
+ tensor_x["goal_poses"] = x["goal_poses"].to(device)[None, :, :, :].repeat(num_samples, 1, 1, 1)
621
+
622
+ tensor_x["type_index"] = x["type_index"].to(device)[None, :].repeat(num_samples, 1)
623
+ tensor_x["position_index"] = x["position_index"].to(device)[None, :].repeat(num_samples, 1)
624
+ tensor_x["pad_mask"] = x["pad_mask"].to(device)[None, :].repeat(num_samples, 1)
625
+
626
+ return tensor_x
627
+
628
+
629
+ def compute_min_max(dataloader):
630
+
631
+ # tensor([-0.3557, -0.3847, 0.0000, -1.0000, -1.0000, -0.4759, -1.0000, -1.0000,
632
+ # -0.9079, -0.8668, -0.9105, -0.4186])
633
+ # tensor([0.3915, 0.3494, 0.3267, 1.0000, 1.0000, 0.8961, 1.0000, 1.0000, 0.8194,
634
+ # 0.4787, 0.6421, 1.0000])
635
+ # tensor([0.0918, -0.3758, 0.0000, -1.0000, -1.0000, 0.0000, -1.0000, -1.0000,
636
+ # -0.0000, 0.0000, 0.0000, 1.0000])
637
+ # tensor([0.9199, 0.3710, 0.0000, 1.0000, 1.0000, 0.0000, 1.0000, 1.0000, -0.0000,
638
+ # 0.0000, 0.0000, 1.0000])
639
+
640
+ min_value = torch.ones(16) * 10000
641
+ max_value = torch.ones(16) * -10000
642
+ for d in tqdm(dataloader):
643
+ goal_poses = d["goal_poses"]
644
+ goal_poses = goal_poses.reshape(-1, 16)
645
+ current_max, _ = torch.max(goal_poses, dim=0)
646
+ current_min, _ = torch.min(goal_poses, dim=0)
647
+ max_value[max_value < current_max] = current_max[max_value < current_max]
648
+ max_value[max_value > current_min] = current_min[max_value > current_min]
649
+ print(f"{min_value} - {max_value}")
650
+
651
+
652
+ if __name__ == "__main__":
653
+
654
+ tokenizer = Tokenizer("/home/weiyu/data_drive/data_new_objects/type_vocabs_coarse.json")
655
+
656
+ data_roots = []
657
+ index_roots = []
658
+ for shape, index in [("circle", "index_10k"), ("line", "index_10k"), ("stacking", "index_10k"), ("dinner", "index_10k")]:
659
+ data_roots.append("/home/weiyu/data_drive/data_new_objects/examples_{}_new_objects/result".format(shape))
660
+ index_roots.append(index)
661
+
662
+ dataset = SemanticArrangementDataset(data_roots=data_roots,
663
+ index_roots=index_roots,
664
+ split="valid", tokenizer=tokenizer,
665
+ max_num_target_objects=7,
666
+ max_num_distractor_objects=5,
667
+ max_num_shape_parameters=1,
668
+ max_num_rearrange_features=0,
669
+ max_num_anchor_features=0,
670
+ num_pts=1024,
671
+ use_virtual_structure_frame=True,
672
+ ignore_distractor_objects=True,
673
+ ignore_rgb=True,
674
+ filter_num_moved_objects_range=None, # [5, 5]
675
+ data_augmentation=False,
676
+ shuffle_object_index=True,
677
+ sentence_embedding_file="/home/weiyu/Research/StructDiffusion/old/StructDiffusion/src/StructDiffusion/language/template_sentence_data.pkl",
678
+ use_incomplete_sentence=True,
679
+ debug=False)
680
+
681
+ # print(len(dataset))
682
+ # for d in dataset:
683
+ # print("\n\n" + "="*100)
684
+
685
+ dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=8)
686
+ for i, d in enumerate(tqdm(dataloader)):
687
+ for k in d:
688
+ if isinstance(d[k], torch.Tensor):
689
+ print("--size", k, d[k].shape)
690
+ for k in d:
691
+ print(k, d[k])
692
+
693
+ input("next?")
src/StructDiffusion/diffusion/__pycache__/pose_conversion.cpython-38.pyc CHANGED
Binary files a/src/StructDiffusion/diffusion/__pycache__/pose_conversion.cpython-38.pyc and b/src/StructDiffusion/diffusion/__pycache__/pose_conversion.cpython-38.pyc differ
 
src/StructDiffusion/diffusion/__pycache__/sampler.cpython-38.pyc CHANGED
Binary files a/src/StructDiffusion/diffusion/__pycache__/sampler.cpython-38.pyc and b/src/StructDiffusion/diffusion/__pycache__/sampler.cpython-38.pyc differ
 
src/StructDiffusion/diffusion/sampler.py CHANGED
@@ -1,6 +1,11 @@
1
  import torch
2
  from tqdm import tqdm
 
3
  from StructDiffusion.diffusion.noise_schedule import extract
 
 
 
 
4
 
5
  class Sampler:
6
 
@@ -14,7 +19,7 @@ class Sampler:
14
  self.backbone.to(device)
15
  self.backbone.eval()
16
 
17
- def sample(self, batch, num_poses, progress):
18
 
19
  noise_schedule = self.model.noise_schedule
20
 
@@ -23,7 +28,7 @@ class Sampler:
23
  x_noisy = torch.randn((B, num_poses, 9), device=self.device)
24
 
25
  xs = []
26
- for t_index in progress.tqdm(reversed(range(0, noise_schedule.timesteps)),
27
  desc='sampling loop time step', total=noise_schedule.timesteps):
28
 
29
  t = torch.full((B,), t_index, device=self.device, dtype=torch.long)
@@ -57,236 +62,239 @@ class Sampler:
57
  xs = list(reversed(xs))
58
  return xs
59
 
60
- # class SamplerV2:
61
- #
62
- # def __init__(self, diffusion_model_class, diffusion_checkpoint_path,
63
- # collision_model_class, collision_checkpoint_path,
64
- # device, debug=False):
65
- #
66
- # self.debug = debug
67
- # self.device = device
68
- #
69
- # self.diffusion_model = diffusion_model_class.load_from_checkpoint(diffusion_checkpoint_path)
70
- # self.diffusion_backbone = self.diffusion_model.model
71
- # self.diffusion_backbone.to(device)
72
- # self.diffusion_backbone.eval()
73
- #
74
- # self.collision_model = collision_model_class.load_from_checkpoint(collision_checkpoint_path)
75
- # self.collision_backbone = self.collision_model.model
76
- # self.collision_backbone.to(device)
77
- # self.collision_backbone.eval()
78
- #
79
- # def sample(self, batch, num_poses):
80
- #
81
- # noise_schedule = self.diffusion_model.noise_schedule
82
- #
83
- # B = batch["pcs"].shape[0]
84
- #
85
- # x_noisy = torch.randn((B, num_poses, 9), device=self.device)
86
- #
87
- # xs = []
88
- # for t_index in tqdm(reversed(range(0, noise_schedule.timesteps)),
89
- # desc='sampling loop time step', total=noise_schedule.timesteps):
90
- #
91
- # t = torch.full((B,), t_index, device=self.device, dtype=torch.long)
92
- #
93
- # # noise schedule
94
- # betas_t = extract(noise_schedule.betas, t, x_noisy.shape)
95
- # sqrt_one_minus_alphas_cumprod_t = extract(noise_schedule.sqrt_one_minus_alphas_cumprod, t, x_noisy.shape)
96
- # sqrt_recip_alphas_t = extract(noise_schedule.sqrt_recip_alphas, t, x_noisy.shape)
97
- #
98
- # # predict noise
99
- # pcs = batch["pcs"]
100
- # sentence = batch["sentence"]
101
- # type_index = batch["type_index"]
102
- # position_index = batch["position_index"]
103
- # pad_mask = batch["pad_mask"]
104
- # # calling the backbone instead of the pytorch-lightning model
105
- # with torch.no_grad():
106
- # predicted_noise = self.diffusion_backbone.forward(t, pcs, sentence, x_noisy, type_index, position_index, pad_mask)
107
- #
108
- # # compute noisy x at t
109
- # model_mean = sqrt_recip_alphas_t * (x_noisy - betas_t * predicted_noise / sqrt_one_minus_alphas_cumprod_t)
110
- # if t_index == 0:
111
- # x_noisy = model_mean
112
- # else:
113
- # posterior_variance_t = extract(noise_schedule.posterior_variance, t, x_noisy.shape)
114
- # noise = torch.randn_like(x_noisy)
115
- # x_noisy = model_mean + torch.sqrt(posterior_variance_t) * noise
116
- #
117
- # xs.append(x_noisy)
118
- #
119
- # xs = list(reversed(xs))
120
- #
121
- # visualize = True
122
- #
123
- # struct_pose, pc_poses_in_struct = get_struct_objs_poses(xs[0])
124
- # # struct_pose: B, 1, 4, 4
125
- # # pc_poses_in_struct: B, N, 4, 4
126
- #
127
- # S = B
128
- # num_elite = 10
129
- # ####################################################
130
- # # only keep one copy
131
- #
132
- # # N, P, 3
133
- # obj_xyzs = batch["pcs"][0][:, :, :3]
134
- # print("obj_xyzs shape", obj_xyzs.shape)
135
- #
136
- # # 1, N
137
- # # object_pad_mask: padding location has 1
138
- # num_target_objs = num_poses
139
- # if self.diffusion_backbone.use_virtual_structure_frame:
140
- # num_target_objs -= 1
141
- # object_pad_mask = batch["pad_mask"][0][-num_target_objs:].unsqueeze(0)
142
- # target_object_inds = 1 - object_pad_mask
143
- # print("target_object_inds shape", target_object_inds.shape)
144
- # print("target_object_inds", target_object_inds)
145
- #
146
- # N, P, _ = obj_xyzs.shape
147
- # print("S, N, P: {}, {}, {}".format(S, N, P))
148
- #
149
- # ####################################################
150
- # # S, N, ...
151
- #
152
- # struct_pose = struct_pose.repeat(1, N, 1, 1) # S, N, 4, 4
153
- # struct_pose = struct_pose.reshape(S * N, 4, 4) # S x N, 4, 4
154
- #
155
- # new_obj_xyzs = obj_xyzs.repeat(S, 1, 1, 1) # S, N, P, 3
156
- # current_pc_pose = torch.eye(4).repeat(S, N, 1, 1).to(self.device) # S, N, 4, 4
157
- # current_pc_pose[:, :, :3, 3] = torch.mean(new_obj_xyzs, dim=2) # S, N, 4, 4
158
- # current_pc_pose = current_pc_pose.reshape(S * N, 4, 4) # S x N, 4, 4
159
- #
160
- # # optimize xyzrpy
161
- # obj_params = torch.zeros((S, N, 6)).to(self.device)
162
- # obj_params[:, :, :3] = pc_poses_in_struct[:, :, :3, 3]
163
- # obj_params[:, :, 3:] = tra3d.matrix_to_euler_angles(pc_poses_in_struct[:, :, :3, :3], "XYZ") # S, N, 6
164
- # #
165
- # # new_obj_xyzs_before_cem, goal_pc_pose_before_cem = move_pc(obj_xyzs, obj_params, struct_pose, current_pc_pose, device)
166
- # #
167
- # # if visualize:
168
- # # print("visualizing rearrangements predicted by the generator")
169
- # # visualize_batch_pcs(new_obj_xyzs_before_cem, S, N, P, limit_B=5)
170
- #
171
- # ####################################################
172
- # # rank
173
- #
174
- # # evaluate in batches
175
- # scores = torch.zeros(S).to(self.device)
176
- # no_intersection_scores = torch.zeros(S).to(self.device) # the higher the better
177
- # num_batches = int(S / B)
178
- # if S % B != 0:
179
- # num_batches += 1
180
- # for b in range(num_batches):
181
- # if b + 1 == num_batches:
182
- # cur_batch_idxs_start = b * B
183
- # cur_batch_idxs_end = S
184
- # else:
185
- # cur_batch_idxs_start = b * B
186
- # cur_batch_idxs_end = (b + 1) * B
187
- # cur_batch_size = cur_batch_idxs_end - cur_batch_idxs_start
188
- #
189
- # # print("current batch idxs start", cur_batch_idxs_start)
190
- # # print("current batch idxs end", cur_batch_idxs_end)
191
- # # print("size of the current batch", cur_batch_size)
192
- #
193
- # batch_obj_params = obj_params[cur_batch_idxs_start: cur_batch_idxs_end]
194
- # batch_struct_pose = struct_pose[cur_batch_idxs_start * N: cur_batch_idxs_end * N]
195
- # batch_current_pc_pose = current_pc_pose[cur_batch_idxs_start * N:cur_batch_idxs_end * N]
196
- #
197
- # new_obj_xyzs, _, subsampled_scene_xyz, _, obj_pair_xyzs = \
198
- # move_pc_and_create_scene_new(obj_xyzs, batch_obj_params, batch_struct_pose, batch_current_pc_pose,
199
- # target_object_inds, self.device,
200
- # return_scene_pts=False,
201
- # return_scene_pts_and_pc_idxs=False,
202
- # num_scene_pts=False,
203
- # normalize_pc=False,
204
- # return_pair_pc=True,
205
- # num_pair_pc_pts=self.collision_model.data_cfg.num_scene_pts,
206
- # normalize_pair_pc=self.collision_model.data_cfg.normalize_pc)
207
- #
208
- # #######################################
209
- # # predict whether there are pairwise collisions
210
- # # if collision_score_weight > 0:
211
- # with torch.no_grad():
212
- # _, num_comb, num_pair_pc_pts, _ = obj_pair_xyzs.shape
213
- # # obj_pair_xyzs = obj_pair_xyzs.reshape(cur_batch_size * num_comb, num_pair_pc_pts, -1)
214
- # collision_logits = self.collision_backbone.forward(obj_pair_xyzs.reshape(cur_batch_size * num_comb, num_pair_pc_pts, -1))
215
- # collision_scores = self.collision_backbone.convert_logits(collision_logits).reshape(cur_batch_size, num_comb) # cur_batch_size, num_comb
216
- #
217
- # # debug
218
- # # for bi, this_obj_pair_xyzs in enumerate(obj_pair_xyzs):
219
- # # print("batch id", bi)
220
- # # for pi, obj_pair_xyz in enumerate(this_obj_pair_xyzs):
221
- # # print("pair", pi)
222
- # # # obj_pair_xyzs: 2 * P, 5
223
- # # print("collision score", collision_scores[bi, pi])
224
- # # trimesh.PointCloud(obj_pair_xyz[:, :3].cpu()).show()
225
- #
226
- # # 1 - mean() since the collision model predicts 1 if there is a collision
227
- # no_intersection_scores[cur_batch_idxs_start:cur_batch_idxs_end] = 1 - torch.mean(collision_scores, dim=1)
228
- # if visualize:
229
- # print("no intersection scores", no_intersection_scores)
230
- # # #######################################
231
- # # if discriminator_score_weight > 0:
232
- # # # # debug:
233
- # # # print(subsampled_scene_xyz.shape)
234
- # # # print(subsampled_scene_xyz[0])
235
- # # # trimesh.PointCloud(subsampled_scene_xyz[0, :, :3].cpu().numpy()).show()
236
- # # #
237
- # # with torch.no_grad():
238
- # #
239
- # # # Important: since this discriminator only uses local structure param, takes sentence from the first and last position
240
- # # # local_sentence = sentence[:, [0, 4]]
241
- # # # local_sentence_pad_mask = sentence_pad_mask[:, [0, 4]]
242
- # # # sentence_disc, sentence_pad_mask_disc, position_index_dic = discriminator_inference.dataset.tensorfy_sentence(raw_sentence_discriminator, raw_sentence_pad_mask_discriminator, raw_position_index_discriminator)
243
- # #
244
- # # sentence_disc = torch.LongTensor(
245
- # # [discriminator_tokenizer.tokenize(*i) for i in raw_sentence_discriminator])
246
- # # sentence_pad_mask_disc = torch.LongTensor(raw_sentence_pad_mask_discriminator)
247
- # # position_index_dic = torch.LongTensor(raw_position_index_discriminator)
248
- # #
249
- # # preds = discriminator_model.forward(subsampled_scene_xyz,
250
- # # sentence_disc.unsqueeze(0).repeat(cur_batch_size, 1).to(device),
251
- # # sentence_pad_mask_disc.unsqueeze(0).repeat(cur_batch_size,
252
- # # 1).to(device),
253
- # # position_index_dic.unsqueeze(0).repeat(cur_batch_size, 1).to(
254
- # # device))
255
- # # # preds = discriminator_model.forward(subsampled_scene_xyz)
256
- # # preds = discriminator_model.convert_logits(preds)
257
- # # preds = preds["is_circle"] # cur_batch_size,
258
- # # scores[cur_batch_idxs_start:cur_batch_idxs_end] = preds
259
- # # if visualize:
260
- # # print("discriminator scores", scores)
261
- #
262
- # # scores = scores * discriminator_score_weight + no_intersection_scores * collision_score_weight
263
- # scores = no_intersection_scores
264
- # sort_idx = torch.argsort(scores).flip(dims=[0])[:num_elite]
265
- # elite_obj_params = obj_params[sort_idx] # num_elite, N, 6
266
- # elite_struct_poses = struct_pose.reshape(S, N, 4, 4)[sort_idx] # num_elite, N, 4, 4
267
- # elite_struct_poses = elite_struct_poses.reshape(num_elite * N, 4, 4) # num_elite x N, 4, 4
268
- # elite_scores = scores[sort_idx]
269
- # print("elite scores:", elite_scores)
270
- #
271
- # ####################################################
272
- # # # visualize best samples
273
- # # num_scene_pts = 4096 # if discriminator_num_scene_pts is None else discriminator_num_scene_pts
274
- # # batch_current_pc_pose = current_pc_pose[0: num_elite * N]
275
- # # best_new_obj_xyzs, best_goal_pc_pose, best_subsampled_scene_xyz, _, _ = \
276
- # # move_pc_and_create_scene_new(obj_xyzs, elite_obj_params, elite_struct_poses, batch_current_pc_pose,
277
- # # target_object_inds, self.device,
278
- # # return_scene_pts=True, num_scene_pts=num_scene_pts, normalize_pc=True)
279
- # # if visualize:
280
- # # print("visualizing elite rearrangements ranked by collision model/discriminator")
281
- # # visualize_batch_pcs(best_new_obj_xyzs, num_elite, limit_B=num_elite)
282
- #
283
- # # num_elite, N, 6
284
- # elite_obj_params = elite_obj_params.reshape(num_elite * N, -1)
285
- # pc_poses_in_struct = torch.eye(4).repeat(num_elite * N, 1, 1).to(self.device)
286
- # pc_poses_in_struct[:, :3, :3] = tra3d.euler_angles_to_matrix(elite_obj_params[:, 3:], "XYZ")
287
- # pc_poses_in_struct[:, :3, 3] = elite_obj_params[:, :3]
288
- # pc_poses_in_struct = pc_poses_in_struct.reshape(num_elite, N, 4, 4) # num_elite, N, 4, 4
289
- #
290
- # struct_pose = elite_struct_poses.reshape(num_elite, N, 4, 4)[:, 0,].unsqueeze(1) # num_elite, 1, 4, 4
291
- #
292
- # return struct_pose, pc_poses_in_struct
 
 
 
 
1
  import torch
2
  from tqdm import tqdm
3
+
4
  from StructDiffusion.diffusion.noise_schedule import extract
5
+ from StructDiffusion.diffusion.pose_conversion import get_struct_objs_poses
6
+ from StructDiffusion.utils.batch_inference import move_pc_and_create_scene_new
7
+ import StructDiffusion.utils.tra3d as tra3d
8
+
9
 
10
  class Sampler:
11
 
 
19
  self.backbone.to(device)
20
  self.backbone.eval()
21
 
22
+ def sample(self, batch, num_poses):
23
 
24
  noise_schedule = self.model.noise_schedule
25
 
 
28
  x_noisy = torch.randn((B, num_poses, 9), device=self.device)
29
 
30
  xs = []
31
+ for t_index in tqdm(reversed(range(0, noise_schedule.timesteps)),
32
  desc='sampling loop time step', total=noise_schedule.timesteps):
33
 
34
  t = torch.full((B,), t_index, device=self.device, dtype=torch.long)
 
62
  xs = list(reversed(xs))
63
  return xs
64
 
65
+ class SamplerV2:
66
+
67
+ def __init__(self, diffusion_model_class, diffusion_checkpoint_path,
68
+ collision_model_class, collision_checkpoint_path,
69
+ device, debug=False):
70
+
71
+ self.debug = debug
72
+ self.device = device
73
+
74
+ self.diffusion_model = diffusion_model_class.load_from_checkpoint(diffusion_checkpoint_path)
75
+ self.diffusion_backbone = self.diffusion_model.model
76
+ self.diffusion_backbone.to(device)
77
+ self.diffusion_backbone.eval()
78
+
79
+ self.collision_model = collision_model_class.load_from_checkpoint(collision_checkpoint_path)
80
+ self.collision_backbone = self.collision_model.model
81
+ self.collision_backbone.to(device)
82
+ self.collision_backbone.eval()
83
+
84
+ def sample(self, batch, num_poses, num_elite, discriminator_batch_size):
85
+
86
+ noise_schedule = self.diffusion_model.noise_schedule
87
+
88
+ B = batch["pcs"].shape[0]
89
+
90
+ x_noisy = torch.randn((B, num_poses, 9), device=self.device)
91
+
92
+ xs = []
93
+ for t_index in tqdm(reversed(range(0, noise_schedule.timesteps)),
94
+ desc='sampling loop time step', total=noise_schedule.timesteps):
95
+
96
+ t = torch.full((B,), t_index, device=self.device, dtype=torch.long)
97
+
98
+ # noise schedule
99
+ betas_t = extract(noise_schedule.betas, t, x_noisy.shape)
100
+ sqrt_one_minus_alphas_cumprod_t = extract(noise_schedule.sqrt_one_minus_alphas_cumprod, t, x_noisy.shape)
101
+ sqrt_recip_alphas_t = extract(noise_schedule.sqrt_recip_alphas, t, x_noisy.shape)
102
+
103
+ # predict noise
104
+ pcs = batch["pcs"]
105
+ sentence = batch["sentence"]
106
+ type_index = batch["type_index"]
107
+ position_index = batch["position_index"]
108
+ pad_mask = batch["pad_mask"]
109
+ # calling the backbone instead of the pytorch-lightning model
110
+ with torch.no_grad():
111
+ predicted_noise = self.diffusion_backbone.forward(t, pcs, sentence, x_noisy, type_index, position_index, pad_mask)
112
+
113
+ # compute noisy x at t
114
+ model_mean = sqrt_recip_alphas_t * (x_noisy - betas_t * predicted_noise / sqrt_one_minus_alphas_cumprod_t)
115
+ if t_index == 0:
116
+ x_noisy = model_mean
117
+ else:
118
+ posterior_variance_t = extract(noise_schedule.posterior_variance, t, x_noisy.shape)
119
+ noise = torch.randn_like(x_noisy)
120
+ x_noisy = model_mean + torch.sqrt(posterior_variance_t) * noise
121
+
122
+ xs.append(x_noisy)
123
+
124
+ xs = list(reversed(xs))
125
+
126
+ visualize = True
127
+
128
+ struct_pose, pc_poses_in_struct = get_struct_objs_poses(xs[0])
129
+ # struct_pose: B, 1, 4, 4
130
+ # pc_poses_in_struct: B, N, 4, 4
131
+
132
+ S = B
133
+ B_discriminator = discriminator_batch_size
134
+ ####################################################
135
+ # only keep one copy
136
+
137
+ # N, P, 3
138
+ obj_xyzs = batch["pcs"][0][:, :, :3]
139
+ print("obj_xyzs shape", obj_xyzs.shape)
140
+
141
+ # 1, N
142
+ # object_pad_mask: padding location has 1
143
+ num_target_objs = num_poses
144
+ if self.diffusion_backbone.use_virtual_structure_frame:
145
+ num_target_objs -= 1
146
+ object_pad_mask = batch["pad_mask"][0][-num_target_objs:].unsqueeze(0)
147
+ target_object_inds = 1 - object_pad_mask
148
+ print("target_object_inds shape", target_object_inds.shape)
149
+ print("target_object_inds", target_object_inds)
150
+
151
+ N, P, _ = obj_xyzs.shape
152
+ print("S, N, P: {}, {}, {}".format(S, N, P))
153
+
154
+ ####################################################
155
+ # S, N, ...
156
+
157
+ struct_pose = struct_pose.repeat(1, N, 1, 1) # S, N, 4, 4
158
+ struct_pose = struct_pose.reshape(S * N, 4, 4) # S x N, 4, 4
159
+
160
+ new_obj_xyzs = obj_xyzs.repeat(S, 1, 1, 1) # S, N, P, 3
161
+ current_pc_pose = torch.eye(4).repeat(S, N, 1, 1).to(self.device) # S, N, 4, 4
162
+ current_pc_pose[:, :, :3, 3] = torch.mean(new_obj_xyzs, dim=2) # S, N, 4, 4
163
+ current_pc_pose = current_pc_pose.reshape(S * N, 4, 4) # S x N, 4, 4
164
+
165
+ # optimize xyzrpy
166
+ obj_params = torch.zeros((S, N, 6)).to(self.device)
167
+ obj_params[:, :, :3] = pc_poses_in_struct[:, :, :3, 3]
168
+ obj_params[:, :, 3:] = tra3d.matrix_to_euler_angles(pc_poses_in_struct[:, :, :3, :3], "XYZ") # S, N, 6
169
+ #
170
+ # new_obj_xyzs_before_cem, goal_pc_pose_before_cem = move_pc(obj_xyzs, obj_params, struct_pose, current_pc_pose, device)
171
+ #
172
+ # if visualize:
173
+ # print("visualizing rearrangements predicted by the generator")
174
+ # visualize_batch_pcs(new_obj_xyzs_before_cem, S, N, P, limit_B=5)
175
+
176
+ ####################################################
177
+ # rank
178
+
179
+ # evaluate in batches
180
+ scores = torch.zeros(S).to(self.device)
181
+ no_intersection_scores = torch.zeros(S).to(self.device) # the higher the better
182
+ num_batches = int(S / B_discriminator)
183
+ if S % B_discriminator != 0:
184
+ num_batches += 1
185
+ for b in range(num_batches):
186
+ if b + 1 == num_batches:
187
+ cur_batch_idxs_start = b * B_discriminator
188
+ cur_batch_idxs_end = S
189
+ else:
190
+ cur_batch_idxs_start = b * B_discriminator
191
+ cur_batch_idxs_end = (b + 1) * B_discriminator
192
+ cur_batch_size = cur_batch_idxs_end - cur_batch_idxs_start
193
+
194
+ # print("current batch idxs start", cur_batch_idxs_start)
195
+ # print("current batch idxs end", cur_batch_idxs_end)
196
+ # print("size of the current batch", cur_batch_size)
197
+
198
+ batch_obj_params = obj_params[cur_batch_idxs_start: cur_batch_idxs_end]
199
+ batch_struct_pose = struct_pose[cur_batch_idxs_start * N: cur_batch_idxs_end * N]
200
+ batch_current_pc_pose = current_pc_pose[cur_batch_idxs_start * N:cur_batch_idxs_end * N]
201
+
202
+ new_obj_xyzs, _, subsampled_scene_xyz, _, obj_pair_xyzs = \
203
+ move_pc_and_create_scene_new(obj_xyzs, batch_obj_params, batch_struct_pose, batch_current_pc_pose,
204
+ target_object_inds, self.device,
205
+ return_scene_pts=False,
206
+ return_scene_pts_and_pc_idxs=False,
207
+ num_scene_pts=False,
208
+ normalize_pc=False,
209
+ return_pair_pc=True,
210
+ num_pair_pc_pts=self.collision_model.data_cfg.num_scene_pts,
211
+ normalize_pair_pc=self.collision_model.data_cfg.normalize_pc)
212
+
213
+ #######################################
214
+ # predict whether there are pairwise collisions
215
+ # if collision_score_weight > 0:
216
+ with torch.no_grad():
217
+ _, num_comb, num_pair_pc_pts, _ = obj_pair_xyzs.shape
218
+ # obj_pair_xyzs = obj_pair_xyzs.reshape(cur_batch_size * num_comb, num_pair_pc_pts, -1)
219
+ collision_logits = self.collision_backbone.forward(obj_pair_xyzs.reshape(cur_batch_size * num_comb, num_pair_pc_pts, -1))
220
+ collision_scores = self.collision_backbone.convert_logits(collision_logits).reshape(cur_batch_size, num_comb) # cur_batch_size, num_comb
221
+
222
+ # debug
223
+ # for bi, this_obj_pair_xyzs in enumerate(obj_pair_xyzs):
224
+ # print("batch id", bi)
225
+ # for pi, obj_pair_xyz in enumerate(this_obj_pair_xyzs):
226
+ # print("pair", pi)
227
+ # # obj_pair_xyzs: 2 * P, 5
228
+ # print("collision score", collision_scores[bi, pi])
229
+ # trimesh.PointCloud(obj_pair_xyz[:, :3].cpu()).show()
230
+
231
+ # 1 - mean() since the collision model predicts 1 if there is a collision
232
+ no_intersection_scores[cur_batch_idxs_start:cur_batch_idxs_end] = 1 - torch.mean(collision_scores, dim=1)
233
+ if visualize:
234
+ print("no intersection scores", no_intersection_scores)
235
+ # #######################################
236
+ # if discriminator_score_weight > 0:
237
+ # # # debug:
238
+ # # print(subsampled_scene_xyz.shape)
239
+ # # print(subsampled_scene_xyz[0])
240
+ # # trimesh.PointCloud(subsampled_scene_xyz[0, :, :3].cpu().numpy()).show()
241
+ # #
242
+ # with torch.no_grad():
243
+ #
244
+ # # Important: since this discriminator only uses local structure param, takes sentence from the first and last position
245
+ # # local_sentence = sentence[:, [0, 4]]
246
+ # # local_sentence_pad_mask = sentence_pad_mask[:, [0, 4]]
247
+ # # sentence_disc, sentence_pad_mask_disc, position_index_dic = discriminator_inference.dataset.tensorfy_sentence(raw_sentence_discriminator, raw_sentence_pad_mask_discriminator, raw_position_index_discriminator)
248
+ #
249
+ # sentence_disc = torch.LongTensor(
250
+ # [discriminator_tokenizer.tokenize(*i) for i in raw_sentence_discriminator])
251
+ # sentence_pad_mask_disc = torch.LongTensor(raw_sentence_pad_mask_discriminator)
252
+ # position_index_dic = torch.LongTensor(raw_position_index_discriminator)
253
+ #
254
+ # preds = discriminator_model.forward(subsampled_scene_xyz,
255
+ # sentence_disc.unsqueeze(0).repeat(cur_batch_size, 1).to(device),
256
+ # sentence_pad_mask_disc.unsqueeze(0).repeat(cur_batch_size,
257
+ # 1).to(device),
258
+ # position_index_dic.unsqueeze(0).repeat(cur_batch_size, 1).to(
259
+ # device))
260
+ # # preds = discriminator_model.forward(subsampled_scene_xyz)
261
+ # preds = discriminator_model.convert_logits(preds)
262
+ # preds = preds["is_circle"] # cur_batch_size,
263
+ # scores[cur_batch_idxs_start:cur_batch_idxs_end] = preds
264
+ # if visualize:
265
+ # print("discriminator scores", scores)
266
+
267
+ # scores = scores * discriminator_score_weight + no_intersection_scores * collision_score_weight
268
+ scores = no_intersection_scores
269
+ sort_idx = torch.argsort(scores).flip(dims=[0])[:num_elite]
270
+ elite_obj_params = obj_params[sort_idx] # num_elite, N, 6
271
+ elite_struct_poses = struct_pose.reshape(S, N, 4, 4)[sort_idx] # num_elite, N, 4, 4
272
+ elite_struct_poses = elite_struct_poses.reshape(num_elite * N, 4, 4) # num_elite x N, 4, 4
273
+ elite_scores = scores[sort_idx]
274
+ print("elite scores:", elite_scores)
275
+
276
+ ####################################################
277
+ # # visualize best samples
278
+ # num_scene_pts = 4096 # if discriminator_num_scene_pts is None else discriminator_num_scene_pts
279
+ # batch_current_pc_pose = current_pc_pose[0: num_elite * N]
280
+ # best_new_obj_xyzs, best_goal_pc_pose, best_subsampled_scene_xyz, _, _ = \
281
+ # move_pc_and_create_scene_new(obj_xyzs, elite_obj_params, elite_struct_poses, batch_current_pc_pose,
282
+ # target_object_inds, self.device,
283
+ # return_scene_pts=True, num_scene_pts=num_scene_pts, normalize_pc=True)
284
+ # if visualize:
285
+ # print("visualizing elite rearrangements ranked by collision model/discriminator")
286
+ # visualize_batch_pcs(best_new_obj_xyzs, num_elite, limit_B=num_elite)
287
+
288
+ # num_elite, N, 6
289
+ elite_obj_params = elite_obj_params.reshape(num_elite * N, -1)
290
+ pc_poses_in_struct = torch.eye(4).repeat(num_elite * N, 1, 1).to(self.device)
291
+ pc_poses_in_struct[:, :3, :3] = tra3d.euler_angles_to_matrix(elite_obj_params[:, 3:], "XYZ")
292
+ pc_poses_in_struct[:, :3, 3] = elite_obj_params[:, :3]
293
+ pc_poses_in_struct = pc_poses_in_struct.reshape(num_elite, N, 4, 4) # num_elite, N, 4, 4
294
+
295
+ struct_pose = elite_struct_poses.reshape(num_elite, N, 4, 4)[:, 0,].unsqueeze(1) # num_elite, 1, 4, 4
296
+
297
+ print(struct_pose.shape)
298
+ print(pc_poses_in_struct.shape)
299
+
300
+ return struct_pose, pc_poses_in_struct
src/StructDiffusion/language/__pycache__/sentence_encoder.cpython-38.pyc ADDED
Binary file (881 Bytes). View file
 
src/StructDiffusion/language/__pycache__/tokenizer.cpython-38.pyc CHANGED
Binary files a/src/StructDiffusion/language/__pycache__/tokenizer.cpython-38.pyc and b/src/StructDiffusion/language/__pycache__/tokenizer.cpython-38.pyc differ
 
src/StructDiffusion/language/convert_to_natural_language.ipynb ADDED
@@ -0,0 +1,773 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 51,
6
+ "outputs": [],
7
+ "source": [
8
+ "import os\n",
9
+ "import h5py\n",
10
+ "import json\n",
11
+ "import numpy as np\n",
12
+ "import tqdm\n",
13
+ "import itertools\n",
14
+ "import copy\n",
15
+ "from collections import defaultdict\n",
16
+ "\n",
17
+ "from StructDiffuser.tokenizer import Tokenizer"
18
+ ],
19
+ "metadata": {
20
+ "collapsed": false,
21
+ "pycharm": {
22
+ "name": "#%%\n"
23
+ }
24
+ }
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "execution_count": 13,
29
+ "metadata": {
30
+ "collapsed": true,
31
+ "pycharm": {
32
+ "name": "#%%\n"
33
+ }
34
+ },
35
+ "outputs": [],
36
+ "source": [
37
+ "class SemanticArrangementDataset:\n",
38
+ "\n",
39
+ " def __init__(self, data_roots, index_roots, splits, tokenizer):\n",
40
+ "\n",
41
+ " self.data_roots = data_roots\n",
42
+ " print(\"data dirs:\", self.data_roots)\n",
43
+ "\n",
44
+ " self.tokenizer = tokenizer\n",
45
+ "\n",
46
+ " self.arrangement_data = []\n",
47
+ " arrangement_steps = []\n",
48
+ " for split in splits:\n",
49
+ " for data_root, index_root in zip(data_roots, index_roots):\n",
50
+ " arrangement_indices_file = os.path.join(data_root, index_root, \"{}_arrangement_indices_file_all.txt\".format(split))\n",
51
+ " if os.path.exists(arrangement_indices_file):\n",
52
+ " with open(arrangement_indices_file, \"r\") as fh:\n",
53
+ " arrangement_steps.extend([(os.path.join(data_root, f[0]), f[1]) for f in eval(fh.readline().strip())])\n",
54
+ " else:\n",
55
+ " print(\"{} does not exist\".format(arrangement_indices_file))\n",
56
+ "\n",
57
+ " # only keep one dummy step for each rearrangement\n",
58
+ " for filename, step_t in arrangement_steps:\n",
59
+ " if step_t == 0:\n",
60
+ " self.arrangement_data.append(filename)\n",
61
+ " print(\"{} valid sequences\".format(len(self.arrangement_data)))\n",
62
+ "\n",
63
+ " def __len__(self):\n",
64
+ " return len(self.arrangement_data)\n",
65
+ "\n",
66
+ " def get_raw_data(self, idx):\n",
67
+ "\n",
68
+ " filename = self.arrangement_data[idx]\n",
69
+ " h5 = h5py.File(filename, 'r')\n",
70
+ " goal_specification = json.loads(str(np.array(h5[\"goal_specification\"])))\n",
71
+ "\n",
72
+ " ###################################\n",
73
+ " # preparing sentence\n",
74
+ " struct_spec = []\n",
75
+ "\n",
76
+ " # structure parameters\n",
77
+ " # 5 parameters\n",
78
+ " structure_parameters = goal_specification[\"shape\"]\n",
79
+ " if structure_parameters[\"type\"] == \"circle\" or structure_parameters[\"type\"] == \"line\":\n",
80
+ " struct_spec.append((structure_parameters[\"type\"], \"shape\"))\n",
81
+ " struct_spec.append((structure_parameters[\"rotation\"][2], \"rotation\"))\n",
82
+ " struct_spec.append((structure_parameters[\"position\"][0], \"position_x\"))\n",
83
+ " struct_spec.append((structure_parameters[\"position\"][1], \"position_y\"))\n",
84
+ " if structure_parameters[\"type\"] == \"circle\":\n",
85
+ " struct_spec.append((structure_parameters[\"radius\"], \"radius\"))\n",
86
+ " elif structure_parameters[\"type\"] == \"line\":\n",
87
+ " struct_spec.append((structure_parameters[\"length\"] / 2.0, \"radius\"))\n",
88
+ " else:\n",
89
+ " struct_spec.append((structure_parameters[\"type\"], \"shape\"))\n",
90
+ " struct_spec.append((structure_parameters[\"rotation\"][2], \"rotation\"))\n",
91
+ " struct_spec.append((structure_parameters[\"position\"][0], \"position_x\"))\n",
92
+ " struct_spec.append((structure_parameters[\"position\"][1], \"position_y\"))\n",
93
+ "\n",
94
+ " return struct_spec"
95
+ ]
96
+ },
97
+ {
98
+ "cell_type": "markdown",
99
+ "source": [],
100
+ "metadata": {
101
+ "collapsed": false,
102
+ "pycharm": {
103
+ "name": "#%% md\n"
104
+ }
105
+ }
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "execution_count": 14,
110
+ "outputs": [
111
+ {
112
+ "name": "stdout",
113
+ "output_type": "stream",
114
+ "text": [
115
+ "\n",
116
+ "Build one vacab for everything...\n",
117
+ "The vocab has 124 tokens: {'PAD': 0, 'CLS': 1, 'class:MASK': 2, 'class:Basket': 3, 'class:BeerBottle': 4, 'class:Book': 5, 'class:Bottle': 6, 'class:Bowl': 7, 'class:Calculator': 8, 'class:Candle': 9, 'class:CellPhone': 10, 'class:ComputerMouse': 11, 'class:Controller': 12, 'class:Cup': 13, 'class:Donut': 14, 'class:Fork': 15, 'class:Hammer': 16, 'class:Knife': 17, 'class:Marker': 18, 'class:MilkCarton': 19, 'class:Mug': 20, 'class:Pan': 21, 'class:Pen': 22, 'class:PillBottle': 23, 'class:Plate': 24, 'class:PowerStrip': 25, 'class:Scissors': 26, 'class:SoapBottle': 27, 'class:SodaCan': 28, 'class:Spoon': 29, 'class:Stapler': 30, 'class:Teapot': 31, 'class:VideoGameController': 32, 'class:WineBottle': 33, 'class:CanOpener': 34, 'class:Fruit': 35, 'scene:MASK': 36, 'scene:dinner': 37, 'size:MASK': 38, 'size:L': 39, 'size:M': 40, 'size:S': 41, 'color:MASK': 42, 'color:blue': 43, 'color:cyan': 44, 'color:green': 45, 'color:magenta': 46, 'color:red': 47, 'color:yellow': 48, 'material:MASK': 49, 'material:glass': 50, 'material:metal': 51, 'material:plastic': 52, 'radius:MASK': 53, 'radius:less': 54, 'radius:greater': 55, 'radius:equal': 56, 'radius:0': 57, 'radius:1': 58, 'radius:2': 59, 'position_x:MASK': 60, 'position_x:less': 61, 'position_x:greater': 62, 'position_x:equal': 63, 'position_x:0': 64, 'position_x:1': 65, 'position_x:2': 66, 'position_y:MASK': 67, 'position_y:less': 68, 'position_y:greater': 69, 'position_y:equal': 70, 'position_y:0': 71, 'position_y:1': 72, 'position_y:2': 73, 'rotation:MASK': 74, 'rotation:less': 75, 'rotation:greater': 76, 'rotation:equal': 77, 'rotation:0': 78, 'rotation:1': 79, 'rotation:2': 80, 'rotation:3': 81, 'height:MASK': 82, 'height:less': 83, 'height:greater': 84, 'height:equal': 85, 'height:0': 86, 'height:1': 87, 'height:2': 88, 'height:3': 89, 'height:4': 90, 'height:5': 91, 'height:6': 92, 'height:7': 93, 'height:8': 94, 'height:9': 95, 'volumn:MASK': 96, 'volumn:less': 97, 'volumn:greater': 98, 'volumn:equal': 99, 'volumn:0': 100, 'volumn:1': 101, 'volumn:2': 102, 'volumn:3': 103, 'volumn:4': 104, 'volumn:5': 105, 'volumn:6': 106, 'volumn:7': 107, 'volumn:8': 108, 'volumn:9': 109, 'uniform_angle:MASK': 110, 'uniform_angle:False': 111, 'uniform_angle:True': 112, 'face_center:MASK': 113, 'face_center:False': 114, 'face_center:True': 115, 'angle_ratio:MASK': 116, 'angle_ratio:0.5': 117, 'angle_ratio:1.0': 118, 'shape:MASK': 119, 'shape:circle': 120, 'shape:line': 121, 'shape:tower': 122, 'shape:dinner': 123}\n",
118
+ "\n",
119
+ "Build vocabs for object position\n",
120
+ "The obj_x vocab has 202 tokens: {'PAD': 0, 'MASK': 1, '0': 2, '1': 3, '2': 4, '3': 5, '4': 6, '5': 7, '6': 8, '7': 9, '8': 10, '9': 11, '10': 12, '11': 13, '12': 14, '13': 15, '14': 16, '15': 17, '16': 18, '17': 19, '18': 20, '19': 21, '20': 22, '21': 23, '22': 24, '23': 25, '24': 26, '25': 27, '26': 28, '27': 29, '28': 30, '29': 31, '30': 32, '31': 33, '32': 34, '33': 35, '34': 36, '35': 37, '36': 38, '37': 39, '38': 40, '39': 41, '40': 42, '41': 43, '42': 44, '43': 45, '44': 46, '45': 47, '46': 48, '47': 49, '48': 50, '49': 51, '50': 52, '51': 53, '52': 54, '53': 55, '54': 56, '55': 57, '56': 58, '57': 59, '58': 60, '59': 61, '60': 62, '61': 63, '62': 64, '63': 65, '64': 66, '65': 67, '66': 68, '67': 69, '68': 70, '69': 71, '70': 72, '71': 73, '72': 74, '73': 75, '74': 76, '75': 77, '76': 78, '77': 79, '78': 80, '79': 81, '80': 82, '81': 83, '82': 84, '83': 85, '84': 86, '85': 87, '86': 88, '87': 89, '88': 90, '89': 91, '90': 92, '91': 93, '92': 94, '93': 95, '94': 96, '95': 97, '96': 98, '97': 99, '98': 100, '99': 101, '100': 102, '101': 103, '102': 104, '103': 105, '104': 106, '105': 107, '106': 108, '107': 109, '108': 110, '109': 111, '110': 112, '111': 113, '112': 114, '113': 115, '114': 116, '115': 117, '116': 118, '117': 119, '118': 120, '119': 121, '120': 122, '121': 123, '122': 124, '123': 125, '124': 126, '125': 127, '126': 128, '127': 129, '128': 130, '129': 131, '130': 132, '131': 133, '132': 134, '133': 135, '134': 136, '135': 137, '136': 138, '137': 139, '138': 140, '139': 141, '140': 142, '141': 143, '142': 144, '143': 145, '144': 146, '145': 147, '146': 148, '147': 149, '148': 150, '149': 151, '150': 152, '151': 153, '152': 154, '153': 155, '154': 156, '155': 157, '156': 158, '157': 159, '158': 160, '159': 161, '160': 162, '161': 163, '162': 164, '163': 165, '164': 166, '165': 167, '166': 168, '167': 169, '168': 170, '169': 171, '170': 172, '171': 173, '172': 174, '173': 175, '174': 176, '175': 177, '176': 178, '177': 179, '178': 180, '179': 181, '180': 182, '181': 183, '182': 184, '183': 185, '184': 186, '185': 187, '186': 188, '187': 189, '188': 190, '189': 191, '190': 192, '191': 193, '192': 194, '193': 195, '194': 196, '195': 197, '196': 198, '197': 199, '198': 200, '199': 201}\n",
121
+ "The obj_y vocab has 202 tokens: {'PAD': 0, 'MASK': 1, '0': 2, '1': 3, '2': 4, '3': 5, '4': 6, '5': 7, '6': 8, '7': 9, '8': 10, '9': 11, '10': 12, '11': 13, '12': 14, '13': 15, '14': 16, '15': 17, '16': 18, '17': 19, '18': 20, '19': 21, '20': 22, '21': 23, '22': 24, '23': 25, '24': 26, '25': 27, '26': 28, '27': 29, '28': 30, '29': 31, '30': 32, '31': 33, '32': 34, '33': 35, '34': 36, '35': 37, '36': 38, '37': 39, '38': 40, '39': 41, '40': 42, '41': 43, '42': 44, '43': 45, '44': 46, '45': 47, '46': 48, '47': 49, '48': 50, '49': 51, '50': 52, '51': 53, '52': 54, '53': 55, '54': 56, '55': 57, '56': 58, '57': 59, '58': 60, '59': 61, '60': 62, '61': 63, '62': 64, '63': 65, '64': 66, '65': 67, '66': 68, '67': 69, '68': 70, '69': 71, '70': 72, '71': 73, '72': 74, '73': 75, '74': 76, '75': 77, '76': 78, '77': 79, '78': 80, '79': 81, '80': 82, '81': 83, '82': 84, '83': 85, '84': 86, '85': 87, '86': 88, '87': 89, '88': 90, '89': 91, '90': 92, '91': 93, '92': 94, '93': 95, '94': 96, '95': 97, '96': 98, '97': 99, '98': 100, '99': 101, '100': 102, '101': 103, '102': 104, '103': 105, '104': 106, '105': 107, '106': 108, '107': 109, '108': 110, '109': 111, '110': 112, '111': 113, '112': 114, '113': 115, '114': 116, '115': 117, '116': 118, '117': 119, '118': 120, '119': 121, '120': 122, '121': 123, '122': 124, '123': 125, '124': 126, '125': 127, '126': 128, '127': 129, '128': 130, '129': 131, '130': 132, '131': 133, '132': 134, '133': 135, '134': 136, '135': 137, '136': 138, '137': 139, '138': 140, '139': 141, '140': 142, '141': 143, '142': 144, '143': 145, '144': 146, '145': 147, '146': 148, '147': 149, '148': 150, '149': 151, '150': 152, '151': 153, '152': 154, '153': 155, '154': 156, '155': 157, '156': 158, '157': 159, '158': 160, '159': 161, '160': 162, '161': 163, '162': 164, '163': 165, '164': 166, '165': 167, '166': 168, '167': 169, '168': 170, '169': 171, '170': 172, '171': 173, '172': 174, '173': 175, '174': 176, '175': 177, '176': 178, '177': 179, '178': 180, '179': 181, '180': 182, '181': 183, '182': 184, '183': 185, '184': 186, '185': 187, '186': 188, '187': 189, '188': 190, '189': 191, '190': 192, '191': 193, '192': 194, '193': 195, '194': 196, '195': 197, '196': 198, '197': 199, '198': 200, '199': 201}\n",
122
+ "The obj_z vocab has 202 tokens: {'PAD': 0, 'MASK': 1, '0': 2, '1': 3, '2': 4, '3': 5, '4': 6, '5': 7, '6': 8, '7': 9, '8': 10, '9': 11, '10': 12, '11': 13, '12': 14, '13': 15, '14': 16, '15': 17, '16': 18, '17': 19, '18': 20, '19': 21, '20': 22, '21': 23, '22': 24, '23': 25, '24': 26, '25': 27, '26': 28, '27': 29, '28': 30, '29': 31, '30': 32, '31': 33, '32': 34, '33': 35, '34': 36, '35': 37, '36': 38, '37': 39, '38': 40, '39': 41, '40': 42, '41': 43, '42': 44, '43': 45, '44': 46, '45': 47, '46': 48, '47': 49, '48': 50, '49': 51, '50': 52, '51': 53, '52': 54, '53': 55, '54': 56, '55': 57, '56': 58, '57': 59, '58': 60, '59': 61, '60': 62, '61': 63, '62': 64, '63': 65, '64': 66, '65': 67, '66': 68, '67': 69, '68': 70, '69': 71, '70': 72, '71': 73, '72': 74, '73': 75, '74': 76, '75': 77, '76': 78, '77': 79, '78': 80, '79': 81, '80': 82, '81': 83, '82': 84, '83': 85, '84': 86, '85': 87, '86': 88, '87': 89, '88': 90, '89': 91, '90': 92, '91': 93, '92': 94, '93': 95, '94': 96, '95': 97, '96': 98, '97': 99, '98': 100, '99': 101, '100': 102, '101': 103, '102': 104, '103': 105, '104': 106, '105': 107, '106': 108, '107': 109, '108': 110, '109': 111, '110': 112, '111': 113, '112': 114, '113': 115, '114': 116, '115': 117, '116': 118, '117': 119, '118': 120, '119': 121, '120': 122, '121': 123, '122': 124, '123': 125, '124': 126, '125': 127, '126': 128, '127': 129, '128': 130, '129': 131, '130': 132, '131': 133, '132': 134, '133': 135, '134': 136, '135': 137, '136': 138, '137': 139, '138': 140, '139': 141, '140': 142, '141': 143, '142': 144, '143': 145, '144': 146, '145': 147, '146': 148, '147': 149, '148': 150, '149': 151, '150': 152, '151': 153, '152': 154, '153': 155, '154': 156, '155': 157, '156': 158, '157': 159, '158': 160, '159': 161, '160': 162, '161': 163, '162': 164, '163': 165, '164': 166, '165': 167, '166': 168, '167': 169, '168': 170, '169': 171, '170': 172, '171': 173, '172': 174, '173': 175, '174': 176, '175': 177, '176': 178, '177': 179, '178': 180, '179': 181, '180': 182, '181': 183, '182': 184, '183': 185, '184': 186, '185': 187, '186': 188, '187': 189, '188': 190, '189': 191, '190': 192, '191': 193, '192': 194, '193': 195, '194': 196, '195': 197, '196': 198, '197': 199, '198': 200, '199': 201}\n",
123
+ "The obj_rr vocab has 362 tokens: {'PAD': 0, 'MASK': 1, '0': 2, '1': 3, '2': 4, '3': 5, '4': 6, '5': 7, '6': 8, '7': 9, '8': 10, '9': 11, '10': 12, '11': 13, '12': 14, '13': 15, '14': 16, '15': 17, '16': 18, '17': 19, '18': 20, '19': 21, '20': 22, '21': 23, '22': 24, '23': 25, '24': 26, '25': 27, '26': 28, '27': 29, '28': 30, '29': 31, '30': 32, '31': 33, '32': 34, '33': 35, '34': 36, '35': 37, '36': 38, '37': 39, '38': 40, '39': 41, '40': 42, '41': 43, '42': 44, '43': 45, '44': 46, '45': 47, '46': 48, '47': 49, '48': 50, '49': 51, '50': 52, '51': 53, '52': 54, '53': 55, '54': 56, '55': 57, '56': 58, '57': 59, '58': 60, '59': 61, '60': 62, '61': 63, '62': 64, '63': 65, '64': 66, '65': 67, '66': 68, '67': 69, '68': 70, '69': 71, '70': 72, '71': 73, '72': 74, '73': 75, '74': 76, '75': 77, '76': 78, '77': 79, '78': 80, '79': 81, '80': 82, '81': 83, '82': 84, '83': 85, '84': 86, '85': 87, '86': 88, '87': 89, '88': 90, '89': 91, '90': 92, '91': 93, '92': 94, '93': 95, '94': 96, '95': 97, '96': 98, '97': 99, '98': 100, '99': 101, '100': 102, '101': 103, '102': 104, '103': 105, '104': 106, '105': 107, '106': 108, '107': 109, '108': 110, '109': 111, '110': 112, '111': 113, '112': 114, '113': 115, '114': 116, '115': 117, '116': 118, '117': 119, '118': 120, '119': 121, '120': 122, '121': 123, '122': 124, '123': 125, '124': 126, '125': 127, '126': 128, '127': 129, '128': 130, '129': 131, '130': 132, '131': 133, '132': 134, '133': 135, '134': 136, '135': 137, '136': 138, '137': 139, '138': 140, '139': 141, '140': 142, '141': 143, '142': 144, '143': 145, '144': 146, '145': 147, '146': 148, '147': 149, '148': 150, '149': 151, '150': 152, '151': 153, '152': 154, '153': 155, '154': 156, '155': 157, '156': 158, '157': 159, '158': 160, '159': 161, '160': 162, '161': 163, '162': 164, '163': 165, '164': 166, '165': 167, '166': 168, '167': 169, '168': 170, '169': 171, '170': 172, '171': 173, '172': 174, '173': 175, '174': 176, '175': 177, '176': 178, '177': 179, '178': 180, '179': 181, '180': 182, '181': 183, '182': 184, '183': 185, '184': 186, '185': 187, '186': 188, '187': 189, '188': 190, '189': 191, '190': 192, '191': 193, '192': 194, '193': 195, '194': 196, '195': 197, '196': 198, '197': 199, '198': 200, '199': 201, '200': 202, '201': 203, '202': 204, '203': 205, '204': 206, '205': 207, '206': 208, '207': 209, '208': 210, '209': 211, '210': 212, '211': 213, '212': 214, '213': 215, '214': 216, '215': 217, '216': 218, '217': 219, '218': 220, '219': 221, '220': 222, '221': 223, '222': 224, '223': 225, '224': 226, '225': 227, '226': 228, '227': 229, '228': 230, '229': 231, '230': 232, '231': 233, '232': 234, '233': 235, '234': 236, '235': 237, '236': 238, '237': 239, '238': 240, '239': 241, '240': 242, '241': 243, '242': 244, '243': 245, '244': 246, '245': 247, '246': 248, '247': 249, '248': 250, '249': 251, '250': 252, '251': 253, '252': 254, '253': 255, '254': 256, '255': 257, '256': 258, '257': 259, '258': 260, '259': 261, '260': 262, '261': 263, '262': 264, '263': 265, '264': 266, '265': 267, '266': 268, '267': 269, '268': 270, '269': 271, '270': 272, '271': 273, '272': 274, '273': 275, '274': 276, '275': 277, '276': 278, '277': 279, '278': 280, '279': 281, '280': 282, '281': 283, '282': 284, '283': 285, '284': 286, '285': 287, '286': 288, '287': 289, '288': 290, '289': 291, '290': 292, '291': 293, '292': 294, '293': 295, '294': 296, '295': 297, '296': 298, '297': 299, '298': 300, '299': 301, '300': 302, '301': 303, '302': 304, '303': 305, '304': 306, '305': 307, '306': 308, '307': 309, '308': 310, '309': 311, '310': 312, '311': 313, '312': 314, '313': 315, '314': 316, '315': 317, '316': 318, '317': 319, '318': 320, '319': 321, '320': 322, '321': 323, '322': 324, '323': 325, '324': 326, '325': 327, '326': 328, '327': 329, '328': 330, '329': 331, '330': 332, '331': 333, '332': 334, '333': 335, '334': 336, '335': 337, '336': 338, '337': 339, '338': 340, '339': 341, '340': 342, '341': 343, '342': 344, '343': 345, '344': 346, '345': 347, '346': 348, '347': 349, '348': 350, '349': 351, '350': 352, '351': 353, '352': 354, '353': 355, '354': 356, '355': 357, '356': 358, '357': 359, '358': 360, '359': 361}\n",
124
+ "The obj_rp vocab has 362 tokens: {'PAD': 0, 'MASK': 1, '0': 2, '1': 3, '2': 4, '3': 5, '4': 6, '5': 7, '6': 8, '7': 9, '8': 10, '9': 11, '10': 12, '11': 13, '12': 14, '13': 15, '14': 16, '15': 17, '16': 18, '17': 19, '18': 20, '19': 21, '20': 22, '21': 23, '22': 24, '23': 25, '24': 26, '25': 27, '26': 28, '27': 29, '28': 30, '29': 31, '30': 32, '31': 33, '32': 34, '33': 35, '34': 36, '35': 37, '36': 38, '37': 39, '38': 40, '39': 41, '40': 42, '41': 43, '42': 44, '43': 45, '44': 46, '45': 47, '46': 48, '47': 49, '48': 50, '49': 51, '50': 52, '51': 53, '52': 54, '53': 55, '54': 56, '55': 57, '56': 58, '57': 59, '58': 60, '59': 61, '60': 62, '61': 63, '62': 64, '63': 65, '64': 66, '65': 67, '66': 68, '67': 69, '68': 70, '69': 71, '70': 72, '71': 73, '72': 74, '73': 75, '74': 76, '75': 77, '76': 78, '77': 79, '78': 80, '79': 81, '80': 82, '81': 83, '82': 84, '83': 85, '84': 86, '85': 87, '86': 88, '87': 89, '88': 90, '89': 91, '90': 92, '91': 93, '92': 94, '93': 95, '94': 96, '95': 97, '96': 98, '97': 99, '98': 100, '99': 101, '100': 102, '101': 103, '102': 104, '103': 105, '104': 106, '105': 107, '106': 108, '107': 109, '108': 110, '109': 111, '110': 112, '111': 113, '112': 114, '113': 115, '114': 116, '115': 117, '116': 118, '117': 119, '118': 120, '119': 121, '120': 122, '121': 123, '122': 124, '123': 125, '124': 126, '125': 127, '126': 128, '127': 129, '128': 130, '129': 131, '130': 132, '131': 133, '132': 134, '133': 135, '134': 136, '135': 137, '136': 138, '137': 139, '138': 140, '139': 141, '140': 142, '141': 143, '142': 144, '143': 145, '144': 146, '145': 147, '146': 148, '147': 149, '148': 150, '149': 151, '150': 152, '151': 153, '152': 154, '153': 155, '154': 156, '155': 157, '156': 158, '157': 159, '158': 160, '159': 161, '160': 162, '161': 163, '162': 164, '163': 165, '164': 166, '165': 167, '166': 168, '167': 169, '168': 170, '169': 171, '170': 172, '171': 173, '172': 174, '173': 175, '174': 176, '175': 177, '176': 178, '177': 179, '178': 180, '179': 181, '180': 182, '181': 183, '182': 184, '183': 185, '184': 186, '185': 187, '186': 188, '187': 189, '188': 190, '189': 191, '190': 192, '191': 193, '192': 194, '193': 195, '194': 196, '195': 197, '196': 198, '197': 199, '198': 200, '199': 201, '200': 202, '201': 203, '202': 204, '203': 205, '204': 206, '205': 207, '206': 208, '207': 209, '208': 210, '209': 211, '210': 212, '211': 213, '212': 214, '213': 215, '214': 216, '215': 217, '216': 218, '217': 219, '218': 220, '219': 221, '220': 222, '221': 223, '222': 224, '223': 225, '224': 226, '225': 227, '226': 228, '227': 229, '228': 230, '229': 231, '230': 232, '231': 233, '232': 234, '233': 235, '234': 236, '235': 237, '236': 238, '237': 239, '238': 240, '239': 241, '240': 242, '241': 243, '242': 244, '243': 245, '244': 246, '245': 247, '246': 248, '247': 249, '248': 250, '249': 251, '250': 252, '251': 253, '252': 254, '253': 255, '254': 256, '255': 257, '256': 258, '257': 259, '258': 260, '259': 261, '260': 262, '261': 263, '262': 264, '263': 265, '264': 266, '265': 267, '266': 268, '267': 269, '268': 270, '269': 271, '270': 272, '271': 273, '272': 274, '273': 275, '274': 276, '275': 277, '276': 278, '277': 279, '278': 280, '279': 281, '280': 282, '281': 283, '282': 284, '283': 285, '284': 286, '285': 287, '286': 288, '287': 289, '288': 290, '289': 291, '290': 292, '291': 293, '292': 294, '293': 295, '294': 296, '295': 297, '296': 298, '297': 299, '298': 300, '299': 301, '300': 302, '301': 303, '302': 304, '303': 305, '304': 306, '305': 307, '306': 308, '307': 309, '308': 310, '309': 311, '310': 312, '311': 313, '312': 314, '313': 315, '314': 316, '315': 317, '316': 318, '317': 319, '318': 320, '319': 321, '320': 322, '321': 323, '322': 324, '323': 325, '324': 326, '325': 327, '326': 328, '327': 329, '328': 330, '329': 331, '330': 332, '331': 333, '332': 334, '333': 335, '334': 336, '335': 337, '336': 338, '337': 339, '338': 340, '339': 341, '340': 342, '341': 343, '342': 344, '343': 345, '344': 346, '345': 347, '346': 348, '347': 349, '348': 350, '349': 351, '350': 352, '351': 353, '352': 354, '353': 355, '354': 356, '355': 357, '356': 358, '357': 359, '358': 360, '359': 361}\n",
125
+ "The obj_ry vocab has 362 tokens: {'PAD': 0, 'MASK': 1, '0': 2, '1': 3, '2': 4, '3': 5, '4': 6, '5': 7, '6': 8, '7': 9, '8': 10, '9': 11, '10': 12, '11': 13, '12': 14, '13': 15, '14': 16, '15': 17, '16': 18, '17': 19, '18': 20, '19': 21, '20': 22, '21': 23, '22': 24, '23': 25, '24': 26, '25': 27, '26': 28, '27': 29, '28': 30, '29': 31, '30': 32, '31': 33, '32': 34, '33': 35, '34': 36, '35': 37, '36': 38, '37': 39, '38': 40, '39': 41, '40': 42, '41': 43, '42': 44, '43': 45, '44': 46, '45': 47, '46': 48, '47': 49, '48': 50, '49': 51, '50': 52, '51': 53, '52': 54, '53': 55, '54': 56, '55': 57, '56': 58, '57': 59, '58': 60, '59': 61, '60': 62, '61': 63, '62': 64, '63': 65, '64': 66, '65': 67, '66': 68, '67': 69, '68': 70, '69': 71, '70': 72, '71': 73, '72': 74, '73': 75, '74': 76, '75': 77, '76': 78, '77': 79, '78': 80, '79': 81, '80': 82, '81': 83, '82': 84, '83': 85, '84': 86, '85': 87, '86': 88, '87': 89, '88': 90, '89': 91, '90': 92, '91': 93, '92': 94, '93': 95, '94': 96, '95': 97, '96': 98, '97': 99, '98': 100, '99': 101, '100': 102, '101': 103, '102': 104, '103': 105, '104': 106, '105': 107, '106': 108, '107': 109, '108': 110, '109': 111, '110': 112, '111': 113, '112': 114, '113': 115, '114': 116, '115': 117, '116': 118, '117': 119, '118': 120, '119': 121, '120': 122, '121': 123, '122': 124, '123': 125, '124': 126, '125': 127, '126': 128, '127': 129, '128': 130, '129': 131, '130': 132, '131': 133, '132': 134, '133': 135, '134': 136, '135': 137, '136': 138, '137': 139, '138': 140, '139': 141, '140': 142, '141': 143, '142': 144, '143': 145, '144': 146, '145': 147, '146': 148, '147': 149, '148': 150, '149': 151, '150': 152, '151': 153, '152': 154, '153': 155, '154': 156, '155': 157, '156': 158, '157': 159, '158': 160, '159': 161, '160': 162, '161': 163, '162': 164, '163': 165, '164': 166, '165': 167, '166': 168, '167': 169, '168': 170, '169': 171, '170': 172, '171': 173, '172': 174, '173': 175, '174': 176, '175': 177, '176': 178, '177': 179, '178': 180, '179': 181, '180': 182, '181': 183, '182': 184, '183': 185, '184': 186, '185': 187, '186': 188, '187': 189, '188': 190, '189': 191, '190': 192, '191': 193, '192': 194, '193': 195, '194': 196, '195': 197, '196': 198, '197': 199, '198': 200, '199': 201, '200': 202, '201': 203, '202': 204, '203': 205, '204': 206, '205': 207, '206': 208, '207': 209, '208': 210, '209': 211, '210': 212, '211': 213, '212': 214, '213': 215, '214': 216, '215': 217, '216': 218, '217': 219, '218': 220, '219': 221, '220': 222, '221': 223, '222': 224, '223': 225, '224': 226, '225': 227, '226': 228, '227': 229, '228': 230, '229': 231, '230': 232, '231': 233, '232': 234, '233': 235, '234': 236, '235': 237, '236': 238, '237': 239, '238': 240, '239': 241, '240': 242, '241': 243, '242': 244, '243': 245, '244': 246, '245': 247, '246': 248, '247': 249, '248': 250, '249': 251, '250': 252, '251': 253, '252': 254, '253': 255, '254': 256, '255': 257, '256': 258, '257': 259, '258': 260, '259': 261, '260': 262, '261': 263, '262': 264, '263': 265, '264': 266, '265': 267, '266': 268, '267': 269, '268': 270, '269': 271, '270': 272, '271': 273, '272': 274, '273': 275, '274': 276, '275': 277, '276': 278, '277': 279, '278': 280, '279': 281, '280': 282, '281': 283, '282': 284, '283': 285, '284': 286, '285': 287, '286': 288, '287': 289, '288': 290, '289': 291, '290': 292, '291': 293, '292': 294, '293': 295, '294': 296, '295': 297, '296': 298, '297': 299, '298': 300, '299': 301, '300': 302, '301': 303, '302': 304, '303': 305, '304': 306, '305': 307, '306': 308, '307': 309, '308': 310, '309': 311, '310': 312, '311': 313, '312': 314, '313': 315, '314': 316, '315': 317, '316': 318, '317': 319, '318': 320, '319': 321, '320': 322, '321': 323, '322': 324, '323': 325, '324': 326, '325': 327, '326': 328, '327': 329, '328': 330, '329': 331, '330': 332, '331': 333, '332': 334, '333': 335, '334': 336, '335': 337, '336': 338, '337': 339, '338': 340, '339': 341, '340': 342, '341': 343, '342': 344, '343': 345, '344': 346, '345': 347, '346': 348, '347': 349, '348': 350, '349': 351, '350': 352, '351': 353, '352': 354, '353': 355, '354': 356, '355': 357, '356': 358, '357': 359, '358': 360, '359': 361}\n",
126
+ "The struct_x vocab has 202 tokens: {'PAD': 0, 'MASK': 1, '0': 2, '1': 3, '2': 4, '3': 5, '4': 6, '5': 7, '6': 8, '7': 9, '8': 10, '9': 11, '10': 12, '11': 13, '12': 14, '13': 15, '14': 16, '15': 17, '16': 18, '17': 19, '18': 20, '19': 21, '20': 22, '21': 23, '22': 24, '23': 25, '24': 26, '25': 27, '26': 28, '27': 29, '28': 30, '29': 31, '30': 32, '31': 33, '32': 34, '33': 35, '34': 36, '35': 37, '36': 38, '37': 39, '38': 40, '39': 41, '40': 42, '41': 43, '42': 44, '43': 45, '44': 46, '45': 47, '46': 48, '47': 49, '48': 50, '49': 51, '50': 52, '51': 53, '52': 54, '53': 55, '54': 56, '55': 57, '56': 58, '57': 59, '58': 60, '59': 61, '60': 62, '61': 63, '62': 64, '63': 65, '64': 66, '65': 67, '66': 68, '67': 69, '68': 70, '69': 71, '70': 72, '71': 73, '72': 74, '73': 75, '74': 76, '75': 77, '76': 78, '77': 79, '78': 80, '79': 81, '80': 82, '81': 83, '82': 84, '83': 85, '84': 86, '85': 87, '86': 88, '87': 89, '88': 90, '89': 91, '90': 92, '91': 93, '92': 94, '93': 95, '94': 96, '95': 97, '96': 98, '97': 99, '98': 100, '99': 101, '100': 102, '101': 103, '102': 104, '103': 105, '104': 106, '105': 107, '106': 108, '107': 109, '108': 110, '109': 111, '110': 112, '111': 113, '112': 114, '113': 115, '114': 116, '115': 117, '116': 118, '117': 119, '118': 120, '119': 121, '120': 122, '121': 123, '122': 124, '123': 125, '124': 126, '125': 127, '126': 128, '127': 129, '128': 130, '129': 131, '130': 132, '131': 133, '132': 134, '133': 135, '134': 136, '135': 137, '136': 138, '137': 139, '138': 140, '139': 141, '140': 142, '141': 143, '142': 144, '143': 145, '144': 146, '145': 147, '146': 148, '147': 149, '148': 150, '149': 151, '150': 152, '151': 153, '152': 154, '153': 155, '154': 156, '155': 157, '156': 158, '157': 159, '158': 160, '159': 161, '160': 162, '161': 163, '162': 164, '163': 165, '164': 166, '165': 167, '166': 168, '167': 169, '168': 170, '169': 171, '170': 172, '171': 173, '172': 174, '173': 175, '174': 176, '175': 177, '176': 178, '177': 179, '178': 180, '179': 181, '180': 182, '181': 183, '182': 184, '183': 185, '184': 186, '185': 187, '186': 188, '187': 189, '188': 190, '189': 191, '190': 192, '191': 193, '192': 194, '193': 195, '194': 196, '195': 197, '196': 198, '197': 199, '198': 200, '199': 201}\n",
127
+ "The struct_y vocab has 202 tokens: {'PAD': 0, 'MASK': 1, '0': 2, '1': 3, '2': 4, '3': 5, '4': 6, '5': 7, '6': 8, '7': 9, '8': 10, '9': 11, '10': 12, '11': 13, '12': 14, '13': 15, '14': 16, '15': 17, '16': 18, '17': 19, '18': 20, '19': 21, '20': 22, '21': 23, '22': 24, '23': 25, '24': 26, '25': 27, '26': 28, '27': 29, '28': 30, '29': 31, '30': 32, '31': 33, '32': 34, '33': 35, '34': 36, '35': 37, '36': 38, '37': 39, '38': 40, '39': 41, '40': 42, '41': 43, '42': 44, '43': 45, '44': 46, '45': 47, '46': 48, '47': 49, '48': 50, '49': 51, '50': 52, '51': 53, '52': 54, '53': 55, '54': 56, '55': 57, '56': 58, '57': 59, '58': 60, '59': 61, '60': 62, '61': 63, '62': 64, '63': 65, '64': 66, '65': 67, '66': 68, '67': 69, '68': 70, '69': 71, '70': 72, '71': 73, '72': 74, '73': 75, '74': 76, '75': 77, '76': 78, '77': 79, '78': 80, '79': 81, '80': 82, '81': 83, '82': 84, '83': 85, '84': 86, '85': 87, '86': 88, '87': 89, '88': 90, '89': 91, '90': 92, '91': 93, '92': 94, '93': 95, '94': 96, '95': 97, '96': 98, '97': 99, '98': 100, '99': 101, '100': 102, '101': 103, '102': 104, '103': 105, '104': 106, '105': 107, '106': 108, '107': 109, '108': 110, '109': 111, '110': 112, '111': 113, '112': 114, '113': 115, '114': 116, '115': 117, '116': 118, '117': 119, '118': 120, '119': 121, '120': 122, '121': 123, '122': 124, '123': 125, '124': 126, '125': 127, '126': 128, '127': 129, '128': 130, '129': 131, '130': 132, '131': 133, '132': 134, '133': 135, '134': 136, '135': 137, '136': 138, '137': 139, '138': 140, '139': 141, '140': 142, '141': 143, '142': 144, '143': 145, '144': 146, '145': 147, '146': 148, '147': 149, '148': 150, '149': 151, '150': 152, '151': 153, '152': 154, '153': 155, '154': 156, '155': 157, '156': 158, '157': 159, '158': 160, '159': 161, '160': 162, '161': 163, '162': 164, '163': 165, '164': 166, '165': 167, '166': 168, '167': 169, '168': 170, '169': 171, '170': 172, '171': 173, '172': 174, '173': 175, '174': 176, '175': 177, '176': 178, '177': 179, '178': 180, '179': 181, '180': 182, '181': 183, '182': 184, '183': 185, '184': 186, '185': 187, '186': 188, '187': 189, '188': 190, '189': 191, '190': 192, '191': 193, '192': 194, '193': 195, '194': 196, '195': 197, '196': 198, '197': 199, '198': 200, '199': 201}\n",
128
+ "The struct_z vocab has 202 tokens: {'PAD': 0, 'MASK': 1, '0': 2, '1': 3, '2': 4, '3': 5, '4': 6, '5': 7, '6': 8, '7': 9, '8': 10, '9': 11, '10': 12, '11': 13, '12': 14, '13': 15, '14': 16, '15': 17, '16': 18, '17': 19, '18': 20, '19': 21, '20': 22, '21': 23, '22': 24, '23': 25, '24': 26, '25': 27, '26': 28, '27': 29, '28': 30, '29': 31, '30': 32, '31': 33, '32': 34, '33': 35, '34': 36, '35': 37, '36': 38, '37': 39, '38': 40, '39': 41, '40': 42, '41': 43, '42': 44, '43': 45, '44': 46, '45': 47, '46': 48, '47': 49, '48': 50, '49': 51, '50': 52, '51': 53, '52': 54, '53': 55, '54': 56, '55': 57, '56': 58, '57': 59, '58': 60, '59': 61, '60': 62, '61': 63, '62': 64, '63': 65, '64': 66, '65': 67, '66': 68, '67': 69, '68': 70, '69': 71, '70': 72, '71': 73, '72': 74, '73': 75, '74': 76, '75': 77, '76': 78, '77': 79, '78': 80, '79': 81, '80': 82, '81': 83, '82': 84, '83': 85, '84': 86, '85': 87, '86': 88, '87': 89, '88': 90, '89': 91, '90': 92, '91': 93, '92': 94, '93': 95, '94': 96, '95': 97, '96': 98, '97': 99, '98': 100, '99': 101, '100': 102, '101': 103, '102': 104, '103': 105, '104': 106, '105': 107, '106': 108, '107': 109, '108': 110, '109': 111, '110': 112, '111': 113, '112': 114, '113': 115, '114': 116, '115': 117, '116': 118, '117': 119, '118': 120, '119': 121, '120': 122, '121': 123, '122': 124, '123': 125, '124': 126, '125': 127, '126': 128, '127': 129, '128': 130, '129': 131, '130': 132, '131': 133, '132': 134, '133': 135, '134': 136, '135': 137, '136': 138, '137': 139, '138': 140, '139': 141, '140': 142, '141': 143, '142': 144, '143': 145, '144': 146, '145': 147, '146': 148, '147': 149, '148': 150, '149': 151, '150': 152, '151': 153, '152': 154, '153': 155, '154': 156, '155': 157, '156': 158, '157': 159, '158': 160, '159': 161, '160': 162, '161': 163, '162': 164, '163': 165, '164': 166, '165': 167, '166': 168, '167': 169, '168': 170, '169': 171, '170': 172, '171': 173, '172': 174, '173': 175, '174': 176, '175': 177, '176': 178, '177': 179, '178': 180, '179': 181, '180': 182, '181': 183, '182': 184, '183': 185, '184': 186, '185': 187, '186': 188, '187': 189, '188': 190, '189': 191, '190': 192, '191': 193, '192': 194, '193': 195, '194': 196, '195': 197, '196': 198, '197': 199, '198': 200, '199': 201}\n",
129
+ "The struct_rr vocab has 362 tokens: {'PAD': 0, 'MASK': 1, '0': 2, '1': 3, '2': 4, '3': 5, '4': 6, '5': 7, '6': 8, '7': 9, '8': 10, '9': 11, '10': 12, '11': 13, '12': 14, '13': 15, '14': 16, '15': 17, '16': 18, '17': 19, '18': 20, '19': 21, '20': 22, '21': 23, '22': 24, '23': 25, '24': 26, '25': 27, '26': 28, '27': 29, '28': 30, '29': 31, '30': 32, '31': 33, '32': 34, '33': 35, '34': 36, '35': 37, '36': 38, '37': 39, '38': 40, '39': 41, '40': 42, '41': 43, '42': 44, '43': 45, '44': 46, '45': 47, '46': 48, '47': 49, '48': 50, '49': 51, '50': 52, '51': 53, '52': 54, '53': 55, '54': 56, '55': 57, '56': 58, '57': 59, '58': 60, '59': 61, '60': 62, '61': 63, '62': 64, '63': 65, '64': 66, '65': 67, '66': 68, '67': 69, '68': 70, '69': 71, '70': 72, '71': 73, '72': 74, '73': 75, '74': 76, '75': 77, '76': 78, '77': 79, '78': 80, '79': 81, '80': 82, '81': 83, '82': 84, '83': 85, '84': 86, '85': 87, '86': 88, '87': 89, '88': 90, '89': 91, '90': 92, '91': 93, '92': 94, '93': 95, '94': 96, '95': 97, '96': 98, '97': 99, '98': 100, '99': 101, '100': 102, '101': 103, '102': 104, '103': 105, '104': 106, '105': 107, '106': 108, '107': 109, '108': 110, '109': 111, '110': 112, '111': 113, '112': 114, '113': 115, '114': 116, '115': 117, '116': 118, '117': 119, '118': 120, '119': 121, '120': 122, '121': 123, '122': 124, '123': 125, '124': 126, '125': 127, '126': 128, '127': 129, '128': 130, '129': 131, '130': 132, '131': 133, '132': 134, '133': 135, '134': 136, '135': 137, '136': 138, '137': 139, '138': 140, '139': 141, '140': 142, '141': 143, '142': 144, '143': 145, '144': 146, '145': 147, '146': 148, '147': 149, '148': 150, '149': 151, '150': 152, '151': 153, '152': 154, '153': 155, '154': 156, '155': 157, '156': 158, '157': 159, '158': 160, '159': 161, '160': 162, '161': 163, '162': 164, '163': 165, '164': 166, '165': 167, '166': 168, '167': 169, '168': 170, '169': 171, '170': 172, '171': 173, '172': 174, '173': 175, '174': 176, '175': 177, '176': 178, '177': 179, '178': 180, '179': 181, '180': 182, '181': 183, '182': 184, '183': 185, '184': 186, '185': 187, '186': 188, '187': 189, '188': 190, '189': 191, '190': 192, '191': 193, '192': 194, '193': 195, '194': 196, '195': 197, '196': 198, '197': 199, '198': 200, '199': 201, '200': 202, '201': 203, '202': 204, '203': 205, '204': 206, '205': 207, '206': 208, '207': 209, '208': 210, '209': 211, '210': 212, '211': 213, '212': 214, '213': 215, '214': 216, '215': 217, '216': 218, '217': 219, '218': 220, '219': 221, '220': 222, '221': 223, '222': 224, '223': 225, '224': 226, '225': 227, '226': 228, '227': 229, '228': 230, '229': 231, '230': 232, '231': 233, '232': 234, '233': 235, '234': 236, '235': 237, '236': 238, '237': 239, '238': 240, '239': 241, '240': 242, '241': 243, '242': 244, '243': 245, '244': 246, '245': 247, '246': 248, '247': 249, '248': 250, '249': 251, '250': 252, '251': 253, '252': 254, '253': 255, '254': 256, '255': 257, '256': 258, '257': 259, '258': 260, '259': 261, '260': 262, '261': 263, '262': 264, '263': 265, '264': 266, '265': 267, '266': 268, '267': 269, '268': 270, '269': 271, '270': 272, '271': 273, '272': 274, '273': 275, '274': 276, '275': 277, '276': 278, '277': 279, '278': 280, '279': 281, '280': 282, '281': 283, '282': 284, '283': 285, '284': 286, '285': 287, '286': 288, '287': 289, '288': 290, '289': 291, '290': 292, '291': 293, '292': 294, '293': 295, '294': 296, '295': 297, '296': 298, '297': 299, '298': 300, '299': 301, '300': 302, '301': 303, '302': 304, '303': 305, '304': 306, '305': 307, '306': 308, '307': 309, '308': 310, '309': 311, '310': 312, '311': 313, '312': 314, '313': 315, '314': 316, '315': 317, '316': 318, '317': 319, '318': 320, '319': 321, '320': 322, '321': 323, '322': 324, '323': 325, '324': 326, '325': 327, '326': 328, '327': 329, '328': 330, '329': 331, '330': 332, '331': 333, '332': 334, '333': 335, '334': 336, '335': 337, '336': 338, '337': 339, '338': 340, '339': 341, '340': 342, '341': 343, '342': 344, '343': 345, '344': 346, '345': 347, '346': 348, '347': 349, '348': 350, '349': 351, '350': 352, '351': 353, '352': 354, '353': 355, '354': 356, '355': 357, '356': 358, '357': 359, '358': 360, '359': 361}\n",
130
+ "The struct_rp vocab has 362 tokens: {'PAD': 0, 'MASK': 1, '0': 2, '1': 3, '2': 4, '3': 5, '4': 6, '5': 7, '6': 8, '7': 9, '8': 10, '9': 11, '10': 12, '11': 13, '12': 14, '13': 15, '14': 16, '15': 17, '16': 18, '17': 19, '18': 20, '19': 21, '20': 22, '21': 23, '22': 24, '23': 25, '24': 26, '25': 27, '26': 28, '27': 29, '28': 30, '29': 31, '30': 32, '31': 33, '32': 34, '33': 35, '34': 36, '35': 37, '36': 38, '37': 39, '38': 40, '39': 41, '40': 42, '41': 43, '42': 44, '43': 45, '44': 46, '45': 47, '46': 48, '47': 49, '48': 50, '49': 51, '50': 52, '51': 53, '52': 54, '53': 55, '54': 56, '55': 57, '56': 58, '57': 59, '58': 60, '59': 61, '60': 62, '61': 63, '62': 64, '63': 65, '64': 66, '65': 67, '66': 68, '67': 69, '68': 70, '69': 71, '70': 72, '71': 73, '72': 74, '73': 75, '74': 76, '75': 77, '76': 78, '77': 79, '78': 80, '79': 81, '80': 82, '81': 83, '82': 84, '83': 85, '84': 86, '85': 87, '86': 88, '87': 89, '88': 90, '89': 91, '90': 92, '91': 93, '92': 94, '93': 95, '94': 96, '95': 97, '96': 98, '97': 99, '98': 100, '99': 101, '100': 102, '101': 103, '102': 104, '103': 105, '104': 106, '105': 107, '106': 108, '107': 109, '108': 110, '109': 111, '110': 112, '111': 113, '112': 114, '113': 115, '114': 116, '115': 117, '116': 118, '117': 119, '118': 120, '119': 121, '120': 122, '121': 123, '122': 124, '123': 125, '124': 126, '125': 127, '126': 128, '127': 129, '128': 130, '129': 131, '130': 132, '131': 133, '132': 134, '133': 135, '134': 136, '135': 137, '136': 138, '137': 139, '138': 140, '139': 141, '140': 142, '141': 143, '142': 144, '143': 145, '144': 146, '145': 147, '146': 148, '147': 149, '148': 150, '149': 151, '150': 152, '151': 153, '152': 154, '153': 155, '154': 156, '155': 157, '156': 158, '157': 159, '158': 160, '159': 161, '160': 162, '161': 163, '162': 164, '163': 165, '164': 166, '165': 167, '166': 168, '167': 169, '168': 170, '169': 171, '170': 172, '171': 173, '172': 174, '173': 175, '174': 176, '175': 177, '176': 178, '177': 179, '178': 180, '179': 181, '180': 182, '181': 183, '182': 184, '183': 185, '184': 186, '185': 187, '186': 188, '187': 189, '188': 190, '189': 191, '190': 192, '191': 193, '192': 194, '193': 195, '194': 196, '195': 197, '196': 198, '197': 199, '198': 200, '199': 201, '200': 202, '201': 203, '202': 204, '203': 205, '204': 206, '205': 207, '206': 208, '207': 209, '208': 210, '209': 211, '210': 212, '211': 213, '212': 214, '213': 215, '214': 216, '215': 217, '216': 218, '217': 219, '218': 220, '219': 221, '220': 222, '221': 223, '222': 224, '223': 225, '224': 226, '225': 227, '226': 228, '227': 229, '228': 230, '229': 231, '230': 232, '231': 233, '232': 234, '233': 235, '234': 236, '235': 237, '236': 238, '237': 239, '238': 240, '239': 241, '240': 242, '241': 243, '242': 244, '243': 245, '244': 246, '245': 247, '246': 248, '247': 249, '248': 250, '249': 251, '250': 252, '251': 253, '252': 254, '253': 255, '254': 256, '255': 257, '256': 258, '257': 259, '258': 260, '259': 261, '260': 262, '261': 263, '262': 264, '263': 265, '264': 266, '265': 267, '266': 268, '267': 269, '268': 270, '269': 271, '270': 272, '271': 273, '272': 274, '273': 275, '274': 276, '275': 277, '276': 278, '277': 279, '278': 280, '279': 281, '280': 282, '281': 283, '282': 284, '283': 285, '284': 286, '285': 287, '286': 288, '287': 289, '288': 290, '289': 291, '290': 292, '291': 293, '292': 294, '293': 295, '294': 296, '295': 297, '296': 298, '297': 299, '298': 300, '299': 301, '300': 302, '301': 303, '302': 304, '303': 305, '304': 306, '305': 307, '306': 308, '307': 309, '308': 310, '309': 311, '310': 312, '311': 313, '312': 314, '313': 315, '314': 316, '315': 317, '316': 318, '317': 319, '318': 320, '319': 321, '320': 322, '321': 323, '322': 324, '323': 325, '324': 326, '325': 327, '326': 328, '327': 329, '328': 330, '329': 331, '330': 332, '331': 333, '332': 334, '333': 335, '334': 336, '335': 337, '336': 338, '337': 339, '338': 340, '339': 341, '340': 342, '341': 343, '342': 344, '343': 345, '344': 346, '345': 347, '346': 348, '347': 349, '348': 350, '349': 351, '350': 352, '351': 353, '352': 354, '353': 355, '354': 356, '355': 357, '356': 358, '357': 359, '358': 360, '359': 361}\n",
131
+ "The struct_ry vocab has 362 tokens: {'PAD': 0, 'MASK': 1, '0': 2, '1': 3, '2': 4, '3': 5, '4': 6, '5': 7, '6': 8, '7': 9, '8': 10, '9': 11, '10': 12, '11': 13, '12': 14, '13': 15, '14': 16, '15': 17, '16': 18, '17': 19, '18': 20, '19': 21, '20': 22, '21': 23, '22': 24, '23': 25, '24': 26, '25': 27, '26': 28, '27': 29, '28': 30, '29': 31, '30': 32, '31': 33, '32': 34, '33': 35, '34': 36, '35': 37, '36': 38, '37': 39, '38': 40, '39': 41, '40': 42, '41': 43, '42': 44, '43': 45, '44': 46, '45': 47, '46': 48, '47': 49, '48': 50, '49': 51, '50': 52, '51': 53, '52': 54, '53': 55, '54': 56, '55': 57, '56': 58, '57': 59, '58': 60, '59': 61, '60': 62, '61': 63, '62': 64, '63': 65, '64': 66, '65': 67, '66': 68, '67': 69, '68': 70, '69': 71, '70': 72, '71': 73, '72': 74, '73': 75, '74': 76, '75': 77, '76': 78, '77': 79, '78': 80, '79': 81, '80': 82, '81': 83, '82': 84, '83': 85, '84': 86, '85': 87, '86': 88, '87': 89, '88': 90, '89': 91, '90': 92, '91': 93, '92': 94, '93': 95, '94': 96, '95': 97, '96': 98, '97': 99, '98': 100, '99': 101, '100': 102, '101': 103, '102': 104, '103': 105, '104': 106, '105': 107, '106': 108, '107': 109, '108': 110, '109': 111, '110': 112, '111': 113, '112': 114, '113': 115, '114': 116, '115': 117, '116': 118, '117': 119, '118': 120, '119': 121, '120': 122, '121': 123, '122': 124, '123': 125, '124': 126, '125': 127, '126': 128, '127': 129, '128': 130, '129': 131, '130': 132, '131': 133, '132': 134, '133': 135, '134': 136, '135': 137, '136': 138, '137': 139, '138': 140, '139': 141, '140': 142, '141': 143, '142': 144, '143': 145, '144': 146, '145': 147, '146': 148, '147': 149, '148': 150, '149': 151, '150': 152, '151': 153, '152': 154, '153': 155, '154': 156, '155': 157, '156': 158, '157': 159, '158': 160, '159': 161, '160': 162, '161': 163, '162': 164, '163': 165, '164': 166, '165': 167, '166': 168, '167': 169, '168': 170, '169': 171, '170': 172, '171': 173, '172': 174, '173': 175, '174': 176, '175': 177, '176': 178, '177': 179, '178': 180, '179': 181, '180': 182, '181': 183, '182': 184, '183': 185, '184': 186, '185': 187, '186': 188, '187': 189, '188': 190, '189': 191, '190': 192, '191': 193, '192': 194, '193': 195, '194': 196, '195': 197, '196': 198, '197': 199, '198': 200, '199': 201, '200': 202, '201': 203, '202': 204, '203': 205, '204': 206, '205': 207, '206': 208, '207': 209, '208': 210, '209': 211, '210': 212, '211': 213, '212': 214, '213': 215, '214': 216, '215': 217, '216': 218, '217': 219, '218': 220, '219': 221, '220': 222, '221': 223, '222': 224, '223': 225, '224': 226, '225': 227, '226': 228, '227': 229, '228': 230, '229': 231, '230': 232, '231': 233, '232': 234, '233': 235, '234': 236, '235': 237, '236': 238, '237': 239, '238': 240, '239': 241, '240': 242, '241': 243, '242': 244, '243': 245, '244': 246, '245': 247, '246': 248, '247': 249, '248': 250, '249': 251, '250': 252, '251': 253, '252': 254, '253': 255, '254': 256, '255': 257, '256': 258, '257': 259, '258': 260, '259': 261, '260': 262, '261': 263, '262': 264, '263': 265, '264': 266, '265': 267, '266': 268, '267': 269, '268': 270, '269': 271, '270': 272, '271': 273, '272': 274, '273': 275, '274': 276, '275': 277, '276': 278, '277': 279, '278': 280, '279': 281, '280': 282, '281': 283, '282': 284, '283': 285, '284': 286, '285': 287, '286': 288, '287': 289, '288': 290, '289': 291, '290': 292, '291': 293, '292': 294, '293': 295, '294': 296, '295': 297, '296': 298, '297': 299, '298': 300, '299': 301, '300': 302, '301': 303, '302': 304, '303': 305, '304': 306, '305': 307, '306': 308, '307': 309, '308': 310, '309': 311, '310': 312, '311': 313, '312': 314, '313': 315, '314': 316, '315': 317, '316': 318, '317': 319, '318': 320, '319': 321, '320': 322, '321': 323, '322': 324, '323': 325, '324': 326, '325': 327, '326': 328, '327': 329, '328': 330, '329': 331, '330': 332, '331': 333, '332': 334, '333': 335, '334': 336, '335': 337, '336': 338, '337': 339, '338': 340, '339': 341, '340': 342, '341': 343, '342': 344, '343': 345, '344': 346, '345': 347, '346': 348, '347': 349, '348': 350, '349': 351, '350': 352, '351': 353, '352': 354, '353': 355, '354': 356, '355': 357, '356': 358, '357': 359, '358': 360, '359': 361}\n",
132
+ "data dirs: ['/home/weiyu/data_drive/data_new_objects/examples_circle_new_objects/result', '/home/weiyu/data_drive/data_new_objects/examples_line_new_objects/result', '/home/weiyu/data_drive/data_new_objects/examples_tower_new_objects/result', '/home/weiyu/data_drive/data_new_objects/examples_dinner_new_objects/result']\n",
133
+ "40000 valid sequences\n"
134
+ ]
135
+ }
136
+ ],
137
+ "source": [
138
+ "tokenizer = Tokenizer(\"/home/weiyu/data_drive/data_new_objects/type_vocabs_coarse.json\")\n",
139
+ "\n",
140
+ "data_roots = []\n",
141
+ "index_roots = []\n",
142
+ "for shape, index in [(\"circle\", \"index_10k\"), (\"line\", \"index_10k\"), (\"tower\", \"index_10k\"), (\"dinner\", \"index_10k\")]:\n",
143
+ " data_roots.append(\"/home/weiyu/data_drive/data_new_objects/examples_{}_new_objects/result\".format(shape))\n",
144
+ " index_roots.append(index)\n",
145
+ "\n",
146
+ "dataset = SemanticArrangementDataset(data_roots=data_roots, index_roots=index_roots, splits=[\"train\", \"valid\", \"test\"], tokenizer=tokenizer)"
147
+ ],
148
+ "metadata": {
149
+ "collapsed": false,
150
+ "pycharm": {
151
+ "name": "#%%\n"
152
+ }
153
+ }
154
+ },
155
+ {
156
+ "cell_type": "code",
157
+ "execution_count": 4,
158
+ "outputs": [
159
+ {
160
+ "name": "stdout",
161
+ "output_type": "stream",
162
+ "text": [
163
+ "\n",
164
+ "\n",
165
+ "{'place_at_once': 'False', 'position': [0.4530459674902468, 0.2866384076623889, 0.011194709806729462], 'rotation': [5.101818936729106e-05, 1.362746309147995e-06, 2.145504341444197], 'type': 'tower'}\n",
166
+ "[('tower', 'shape'), (2.145504341444197, 'rotation'), (0.4530459674902468, 'position_x'), (0.2866384076623889, 'position_y')]\n",
167
+ "tower in the middle left of the table facing west\n",
168
+ "[('tower', 'shape'), (2.145504341444197, 'rotation'), (0.4530459674902468, 'position_x'), (0.2866384076623889, 'position_y')]\n",
169
+ "tower in the middle left of the table facing west\n",
170
+ "(('rotation', 'west'), ('shape', 'tower'), ('x', 'middle'), ('y', 'left'))\n",
171
+ "\n",
172
+ "\n",
173
+ "{'length': 0.15789473684210525, 'length_increment': 0.05, 'max_length': 1.0, 'min_length': 0.0, 'place_at_once': 'True', 'position': [0.5744088910421017, 0.0, 0.0], 'rotation': [0.0, -0.0, 0.0], 'type': 'dinner', 'uniform_space': 'False'}\n",
174
+ "[('dinner', 'shape'), (0.0, 'rotation'), (0.5744088910421017, 'position_x'), (0.0, 'position_y')]\n",
175
+ "dinner in the middle center of the table facing south\n",
176
+ "[('dinner', 'shape'), (0.0, 'rotation'), (0.5744088910421017, 'position_x'), (0.0, 'position_y')]\n",
177
+ "dinner in the middle center of the table facing south\n",
178
+ "(('rotation', 'south'), ('shape', 'dinner'), ('x', 'middle'), ('y', 'center'))\n",
179
+ "\n",
180
+ "\n",
181
+ "{'place_at_once': 'False', 'position': [0.5300184865230677, -0.11749143967722209, 0.043775766459831195], 'rotation': [8.311828443210225e-05, 2.8403995850279114e-05, -1.9831750137833084], 'type': 'tower'}\n",
182
+ "[('tower', 'shape'), (-1.9831750137833084, 'rotation'), (0.5300184865230677, 'position_x'), (-0.11749143967722209, 'position_y')]\n",
183
+ "tower in the middle center of the table facing north\n",
184
+ "[('tower', 'shape')]\n",
185
+ "tower\n",
186
+ "(('shape', 'tower'),)\n",
187
+ "\n",
188
+ "\n",
189
+ "{'length': 0.3157894736842105, 'length_increment': 0.05, 'max_length': 1.0, 'min_length': 0.0, 'place_at_once': 'True', 'position': [0.6482385523146229, 0.0, 0.0], 'rotation': [0.0, -0.0, 0.0], 'type': 'dinner', 'uniform_space': 'False'}\n",
190
+ "[('dinner', 'shape'), (0.0, 'rotation'), (0.6482385523146229, 'position_x'), (0.0, 'position_y')]\n",
191
+ "dinner in the top center of the table facing south\n",
192
+ "[('dinner', 'shape')]\n",
193
+ "dinner\n",
194
+ "(('shape', 'dinner'),)\n",
195
+ "\n",
196
+ "\n",
197
+ "{'angle_ratio': 1.0, 'face_center': 'True', 'max_radius': 0.5, 'min_radius': 0.050687861718942046, 'place_at_once': 'True', 'position': [0.2998438437491998, -0.03599718247376027, 0.0], 'radius': 0.0966402394976866, 'radius_increment': 0.005, 'rotation': [0.0, -0.0, 2.053106459668934], 'type': 'circle', 'uniform_angle': 'True'}\n",
198
+ "[('circle', 'shape'), (2.053106459668934, 'rotation'), (0.2998438437491998, 'position_x'), (-0.03599718247376027, 'position_y'), (0.0966402394976866, 'radius')]\n",
199
+ "small circle in the middle center of the table facing west\n",
200
+ "[('circle', 'shape'), (2.053106459668934, 'rotation'), (0.2998438437491998, 'position_x'), (-0.03599718247376027, 'position_y'), (0.0966402394976866, 'radius')]\n",
201
+ "small circle in the middle center of the table facing west\n",
202
+ "(('rotation', 'west'), ('shape', 'circle'), ('size', 'small'), ('x', 'middle'), ('y', 'center'))\n",
203
+ "\n",
204
+ "\n",
205
+ "{'length': 0.4245597103515504, 'length_increment': 0.005, 'max_length': 1.0, 'min_length': 0.21760311495166934, 'place_at_once': 'True', 'position': [0.6672547106460816, 0.0, 0.0], 'rotation': [0.0, -0.0, 0.0], 'type': 'line', 'uniform_space': 'True'}\n",
206
+ "[('line', 'shape'), (0.0, 'rotation'), (0.6672547106460816, 'position_x'), (0.0, 'position_y'), (0.2122798551757752, 'radius')]\n",
207
+ "medium line in the top center of the table facing south\n",
208
+ "[('line', 'shape'), (0.0, 'rotation'), (0.6672547106460816, 'position_x'), (0.2122798551757752, 'radius')]\n",
209
+ "medium line in the top facing south\n",
210
+ "(('rotation', 'south'), ('shape', 'line'), ('size', 'medium'), ('x', 'top'))\n",
211
+ "\n",
212
+ "\n",
213
+ "{'place_at_once': 'False', 'position': [0.6555576184899171, 0.22241488561049588, 0.006522659915853506], 'rotation': [-0.000139418832574769, -7.243860660016997e-05, 2.2437880740062814], 'type': 'tower'}\n",
214
+ "[('tower', 'shape'), (2.2437880740062814, 'rotation'), (0.6555576184899171, 'position_x'), (0.22241488561049588, 'position_y')]\n",
215
+ "tower in the top left of the table facing west\n",
216
+ "[(2.2437880740062814, 'rotation'), (0.6555576184899171, 'position_x')]\n",
217
+ "in the top facing west\n",
218
+ "(('rotation', 'west'), ('x', 'top'))\n",
219
+ "\n",
220
+ "\n",
221
+ "{'length': 0.4925060249864075, 'length_increment': 0.005, 'max_length': 1.0, 'min_length': 0.4925060249864075, 'place_at_once': 'True', 'position': [0.7754676784901477, 0.0, 0.0], 'rotation': [0.0, -0.0, 0.0], 'type': 'line', 'uniform_space': 'False'}\n",
222
+ "[('line', 'shape'), (0.0, 'rotation'), (0.7754676784901477, 'position_x'), (0.0, 'position_y'), (0.24625301249320375, 'radius')]\n",
223
+ "medium line in the top center of the table facing south\n",
224
+ "[(0.0, 'rotation'), (0.7754676784901477, 'position_x')]\n",
225
+ "in the top facing south\n",
226
+ "(('rotation', 'south'), ('x', 'top'))\n",
227
+ "\n",
228
+ "\n",
229
+ "{'angle_ratio': 1.0, 'face_center': 'True', 'max_radius': 0.5, 'min_radius': 0.2260219063147572, 'place_at_once': 'True', 'position': [0.6256453430245876, 0.1131426073908803, 0.0], 'radius': 0.2260219063147572, 'radius_increment': 0.005, 'rotation': [0.0, -0.0, 1.6063513593439724], 'type': 'circle', 'uniform_angle': 'True'}\n",
230
+ "[('circle', 'shape'), (1.6063513593439724, 'rotation'), (0.6256453430245876, 'position_x'), (0.1131426073908803, 'position_y'), (0.2260219063147572, 'radius')]\n",
231
+ "medium circle in the middle center of the table facing west\n",
232
+ "[(1.6063513593439724, 'rotation'), (0.6256453430245876, 'position_x')]\n",
233
+ "in the middle facing west\n",
234
+ "(('rotation', 'west'), ('x', 'middle'))\n",
235
+ "\n",
236
+ "\n",
237
+ "{'angle_ratio': 1.0, 'face_center': 'True', 'max_radius': 0.5, 'min_radius': 0.14976631196286583, 'place_at_once': 'True', 'position': [0.5157008668336853, 0.11005531020590054, 0.0], 'radius': 0.15991801306539147, 'radius_increment': 0.005, 'rotation': [0.0, -0.0, -2.2145659262893918], 'type': 'circle', 'uniform_angle': 'True'}\n",
238
+ "[('circle', 'shape'), (-2.2145659262893918, 'rotation'), (0.5157008668336853, 'position_x'), (0.11005531020590054, 'position_y'), (0.15991801306539147, 'radius')]\n",
239
+ "small circle in the middle center of the table facing north\n",
240
+ "[('circle', 'shape'), (0.5157008668336853, 'position_x'), (0.15991801306539147, 'radius')]\n",
241
+ "small circle in the middle\n",
242
+ "(('shape', 'circle'), ('size', 'small'), ('x', 'middle'))\n"
243
+ ]
244
+ }
245
+ ],
246
+ "source": [
247
+ "idxs = np.random.permutation(len(dataset))\n",
248
+ "for i in idxs[:10]:\n",
249
+ " print(\"\\n\")\n",
250
+ " struct_spec = dataset.get_raw_data(i)\n",
251
+ " print(struct_spec)\n",
252
+ " struct_word_spec = tokenizer.convert_structure_params_to_natural_language(struct_spec)\n",
253
+ " print(struct_word_spec)\n",
254
+ "\n",
255
+ " token_idxs = np.random.permutation(len(struct_spec))\n",
256
+ " token_idxs = token_idxs[:np.random.randint(1, len(struct_spec) + 1)]\n",
257
+ " token_idxs = sorted(token_idxs)\n",
258
+ " incomplete_struct_spec = [struct_spec[ti] for ti in token_idxs]\n",
259
+ "\n",
260
+ " print(incomplete_struct_spec)\n",
261
+ " print(tokenizer.convert_structure_params_to_natural_language(incomplete_struct_spec))\n",
262
+ "\n",
263
+ " type_value_tuple = tokenizer.convert_structure_params_to_type_value_tuple(incomplete_struct_spec)\n",
264
+ " print(type_value_tuple)"
265
+ ],
266
+ "metadata": {
267
+ "collapsed": false,
268
+ "pycharm": {
269
+ "name": "#%%\n"
270
+ }
271
+ }
272
+ },
273
+ {
274
+ "cell_type": "code",
275
+ "execution_count": 49,
276
+ "outputs": [
277
+ {
278
+ "name": "stderr",
279
+ "output_type": "stream",
280
+ "text": [
281
+ "100%|██████████| 40000/40000 [00:23<00:00, 1699.94it/s]"
282
+ ]
283
+ },
284
+ {
285
+ "name": "stdout",
286
+ "output_type": "stream",
287
+ "text": [
288
+ "669\n"
289
+ ]
290
+ },
291
+ {
292
+ "name": "stderr",
293
+ "output_type": "stream",
294
+ "text": [
295
+ "\n"
296
+ ]
297
+ }
298
+ ],
299
+ "source": [
300
+ "unique_type_value_tuples = set()\n",
301
+ "for i in tqdm.tqdm(idxs):\n",
302
+ " struct_spec = dataset.get_raw_data(i)\n",
303
+ "\n",
304
+ " incomplete_struct_specs = []\n",
305
+ " for L in range(1, len(struct_spec) + 1):\n",
306
+ " for subset in itertools.combinations(struct_spec, L):\n",
307
+ " incomplete_struct_specs.append(subset)\n",
308
+ "\n",
309
+ " # print(incomplete_struct_specs)\n",
310
+ "\n",
311
+ " type_value_tuples = []\n",
312
+ " for incomplete_struct_spec in incomplete_struct_specs:\n",
313
+ " type_value_tuples.append(tokenizer.convert_structure_params_to_type_value_tuple(incomplete_struct_spec))\n",
314
+ "\n",
315
+ " unique_type_value_tuples.update(type_value_tuples)\n",
316
+ "\n",
317
+ "print(len(unique_type_value_tuples))"
318
+ ],
319
+ "metadata": {
320
+ "collapsed": false,
321
+ "pycharm": {
322
+ "name": "#%%\n"
323
+ }
324
+ }
325
+ },
326
+ {
327
+ "cell_type": "code",
328
+ "execution_count": null,
329
+ "outputs": [],
330
+ "source": [
331
+ "sentence_template = [\n",
332
+ " \"Put the objects {in a [size][shape]} on the {[x][y] of} the table {facing [rotation]}.\",\n",
333
+ " \"Build a [size][shape] of the [objects] on the [x][y] of the table facing [rotation].\",\n",
334
+ " \"Put the [objects] on the [x][y] of the table and make a [shape] facing [rotation].\",\n",
335
+ " \"Rearrange the [objects] into a [shape], and put the structure on the [x][y] of the table facing [rotation].\",\n",
336
+ " \"Could you ...\",\n",
337
+ " \"Please ...\",\n",
338
+ " \"Pick up the objects, put them into a [size][shape], place the [shape] on the [x][y] of table, make sure the [shape] is facing [rotation].\"]\n",
339
+ "\n"
340
+ ],
341
+ "metadata": {
342
+ "collapsed": false,
343
+ "pycharm": {
344
+ "name": "#%%\n"
345
+ }
346
+ }
347
+ },
348
+ {
349
+ "cell_type": "markdown",
350
+ "source": [
351
+ "Enumerate all possible combinations of types"
352
+ ],
353
+ "metadata": {
354
+ "collapsed": false,
355
+ "pycharm": {
356
+ "name": "#%% md\n"
357
+ }
358
+ }
359
+ },
360
+ {
361
+ "cell_type": "code",
362
+ "execution_count": 31,
363
+ "outputs": [
364
+ {
365
+ "name": "stdout",
366
+ "output_type": "stream",
367
+ "text": [
368
+ "31\n",
369
+ "[('size',), ('shape',), ('x',), ('y',), ('rotation',), ('shape', 'size'), ('size', 'x'), ('size', 'y'), ('rotation', 'size'), ('shape', 'x'), ('shape', 'y'), ('rotation', 'shape'), ('x', 'y'), ('rotation', 'x'), ('rotation', 'y'), ('shape', 'size', 'x'), ('shape', 'size', 'y'), ('rotation', 'shape', 'size'), ('size', 'x', 'y'), ('rotation', 'size', 'x'), ('rotation', 'size', 'y'), ('shape', 'x', 'y'), ('rotation', 'shape', 'x'), ('rotation', 'shape', 'y'), ('rotation', 'x', 'y'), ('shape', 'size', 'x', 'y'), ('rotation', 'shape', 'size', 'x'), ('rotation', 'shape', 'size', 'y'), ('rotation', 'size', 'x', 'y'), ('rotation', 'shape', 'x', 'y'), ('rotation', 'shape', 'size', 'x', 'y')]\n"
370
+ ]
371
+ }
372
+ ],
373
+ "source": [
374
+ "import itertools\n",
375
+ "types = [\"size\", \"shape\", \"x\", \"y\", \"rotation\"]\n",
376
+ "\n",
377
+ "type_combs = []\n",
378
+ "for L in range(1, len(types) + 1):\n",
379
+ " for subset in itertools.combinations(types, L):\n",
380
+ " type_combs.append(tuple(sorted(subset)))\n",
381
+ "\n",
382
+ "print(len(type_combs))\n",
383
+ "print(type_combs)"
384
+ ],
385
+ "metadata": {
386
+ "collapsed": false,
387
+ "pycharm": {
388
+ "name": "#%%\n"
389
+ }
390
+ }
391
+ },
392
+ {
393
+ "cell_type": "code",
394
+ "execution_count": 46,
395
+ "outputs": [
396
+ {
397
+ "name": "stdout",
398
+ "output_type": "stream",
399
+ "text": [
400
+ "build a [size] shape from the objects ('size',)\n",
401
+ "put the objects in to a [size] shape ('size',)\n",
402
+ "place the objects as a [size] shape ('size',)\n",
403
+ "make a [size] shape from the objects ('size',)\n",
404
+ "rearrange the objects into a [size] structure ('size',)\n",
405
+ "build a [shape] ('shape',)\n",
406
+ "make a [shape] ('shape',)\n",
407
+ "put the objects into a [shape] ('shape',)\n",
408
+ "place the objects as a [shape] ('shape',)\n",
409
+ "pick up the objects, and place them as a [shape] ('shape',)\n",
410
+ "place the objects on the [x] of the table ('x',)\n",
411
+ "put the objects on [x] ('x',)\n",
412
+ "make a structure from the objects and place it on [x] ('x',)\n",
413
+ "on the [x] of the table, place the objects ('x',)\n",
414
+ "move the objects to the [x] ('x',)\n",
415
+ "place the objects on the [y] of the table ('y',)\n",
416
+ "put the objects on [y] ('y',)\n",
417
+ "make a structure from the objects and place it on [y] ('y',)\n",
418
+ "on the [y] of the table, place the objects ('y',)\n",
419
+ "move the objects to the [y] ('y',)\n",
420
+ "build a structure facing [rotation] ('rotation',)\n",
421
+ "make a structure from the objects and make sure it is pointing [rotation] ('rotation',)\n",
422
+ "put the objects in a structure that faces [rotation] ('rotation',)\n",
423
+ "rotate the object structure so that it points [rotation] ('rotation',)\n",
424
+ "[rotation] is the direction the structure built from the objects should be facing ('rotation',)\n",
425
+ "build a [size] [shape] ('shape', 'size')\n",
426
+ "make a [size] [shape] ('shape', 'size')\n",
427
+ "put the objects into a [size] [shape] ('shape', 'size')\n",
428
+ "place the objects as a [size] [shape] ('shape', 'size')\n",
429
+ "pick up the objects, and place them as a [size] [shape] ('shape', 'size')\n",
430
+ "build a [size] shape from the objects on the [x] of the table ('size', 'x')\n",
431
+ "put the objects in to a [size] shape and place it on [x] ('size', 'x')\n",
432
+ "on the [x] of the table, place the objects as a [size] shape ('size', 'x')\n",
433
+ "make a [size] shape from the objects and move it to [x] ('size', 'x')\n",
434
+ "rearrange the objects into a [size] structure on [x] ('size', 'x')\n",
435
+ "build a [size] shape from the objects on the [y] of the table ('size', 'y')\n",
436
+ "put the objects in to a [size] shape and place it on [y] ('size', 'y')\n",
437
+ "on the [y] of the table, place the objects as a [size] shape ('size', 'y')\n",
438
+ "make a [size] shape from the objects and move it to [y] ('size', 'y')\n",
439
+ "rearrange the objects into a [size] structure on [y] ('size', 'y')\n",
440
+ "build a [size] shape from the objects facing [rotation] ('rotation', 'size')\n",
441
+ "put the objects in to a [size] shape and place it so that it faces [rotation] ('rotation', 'size')\n",
442
+ "place the objects as a [size] shape and [rotation] is the direction the shape built from the objects should be facing ('rotation', 'size')\n",
443
+ "make a [size] structure from the objects and rotate the object structure so that it points [rotation] ('rotation', 'size')\n",
444
+ "rearrange the objects into a [size] structure that points to [rotation] ('rotation', 'size')\n",
445
+ "build a [shape] from the objects on the [x] of the table ('shape', 'x')\n",
446
+ "put the objects in to a [shape] and place it on [x] ('shape', 'x')\n",
447
+ "on the [x] of the table, place the objects as a [shape] ('shape', 'x')\n",
448
+ "make a [shape] from the objects and move it to [x] ('shape', 'x')\n",
449
+ "rearrange the objects into a [shape] on [x] ('shape', 'x')\n",
450
+ "build a [shape] from the objects on the [y] of the table ('shape', 'y')\n",
451
+ "put the objects in to a [shape] and place it on [y] ('shape', 'y')\n",
452
+ "on the [y] of the table, place the objects as a [shape] ('shape', 'y')\n",
453
+ "make a [shape] from the objects and move it to [y] ('shape', 'y')\n",
454
+ "rearrange the objects into a [shape] on [y] ('shape', 'y')\n",
455
+ "build a [shape] from the objects facing [rotation] ('rotation', 'shape')\n",
456
+ "put the objects in to a [shape] and place it so that it faces [rotation] ('rotation', 'shape')\n",
457
+ "place the objects as a [shape] and [rotation] is the direction the shape built from the objects should be facing ('rotation', 'shape')\n",
458
+ "make a [shape] from the objects and rotate the shape so that it points [rotation] ('rotation', 'shape')\n",
459
+ "rearrange the objects into a [shape] that points to [rotation] ('rotation', 'shape')\n",
460
+ "place the objects on the [x] and [y] of the table ('x', 'y')\n",
461
+ "put the objects on [x] [y] of the table ('x', 'y')\n",
462
+ "make a structure from the objects and place it on [x] [y] ('x', 'y')\n",
463
+ "on the [x] [y] of the table, place the objects ('x', 'y')\n",
464
+ "move the objects to the [x] [y] ('x', 'y')\n",
465
+ "build a structure on the [x] of the table facing [rotation] ('rotation', 'x')\n",
466
+ "make a structure from the objects and make sure it is pointing [rotation] and on [x] ('rotation', 'x')\n",
467
+ "rearrange the objects in a structure that faces [rotation] and place it on [x] ('rotation', 'x')\n",
468
+ "move and rotate the object structure so that it is on [x] and points [rotation] ('rotation', 'x')\n",
469
+ "[rotation] is the direction the structure built from the objects should be facing, [x] is the location ('rotation', 'x')\n",
470
+ "build a structure on the [y] of the table facing [rotation] ('rotation', 'y')\n",
471
+ "make a structure from the objects and make sure it is pointing [rotation] and on [y] ('rotation', 'y')\n",
472
+ "rearrange the objects in a structure that faces [rotation] and place it on [y] ('rotation', 'y')\n",
473
+ "move and rotate the object structure so that it is on [y] and points [rotation] ('rotation', 'y')\n",
474
+ "[rotation] is the direction the structure built from the objects should be facing, [y] is the location ('rotation', 'y')\n",
475
+ "build a [size] [shape] from the objects on the [x] of the table ('shape', 'size', 'x')\n",
476
+ "put the objects in to a [size] [shape] and place it on [x] ('shape', 'size', 'x')\n",
477
+ "on the [x] of the table, place the objects as a [shape], make the shape [size] ('shape', 'size', 'x')\n",
478
+ "make a [size] [shape] from the objects and move it to [x] ('shape', 'size', 'x')\n",
479
+ "rearrange the objects into a [size] [shape] on [x] ('shape', 'size', 'x')\n",
480
+ "build a [size] [shape] from the objects on the [y] of the table ('shape', 'size', 'y')\n",
481
+ "put the objects in to a [size] [shape] and place it on [y] ('shape', 'size', 'y')\n",
482
+ "on the [y] of the table, place the objects as a [shape], make the shape [size] ('shape', 'size', 'y')\n",
483
+ "make a [size] [shape] from the objects and move it to [y] ('shape', 'size', 'y')\n",
484
+ "rearrange the objects into a [size] [shape] on [y] ('shape', 'size', 'y')\n",
485
+ "build a [size] [shape] from the objects facing [rotation] ('rotation', 'shape', 'size')\n",
486
+ "put the objects in to a [size] [shape] and place it so that it faces [rotation] ('rotation', 'shape', 'size')\n",
487
+ "place the objects as a [size] [shape] and [rotation] is the direction the shape built from the objects should be facing ('rotation', 'shape', 'size')\n",
488
+ "make a [size] [shape] from the objects and rotate the shape so that it points [rotation] ('rotation', 'shape', 'size')\n",
489
+ "rearrange the objects into a [size] [shape] that points to [rotation] ('rotation', 'shape', 'size')\n",
490
+ "build a [size] shape from the objects on the [x] [y] of the table ('size', 'x', 'y')\n",
491
+ "put the objects in to a [size] shape and place it on [x] and [y] ('size', 'x', 'y')\n",
492
+ "on the [x] [y] of the table, place the objects as a [size] shape ('size', 'x', 'y')\n",
493
+ "make a [size] shape from the objects and move it to [x] [y] ('size', 'x', 'y')\n",
494
+ "rearrange the objects into a [size] structure on [x] and on [y] ('size', 'x', 'y')\n",
495
+ "build a [size] structure on the [x] of the table facing [rotation] ('rotation', 'size', 'x')\n",
496
+ "make a [size] structure from the objects and make sure it is pointing [rotation] and on [x] ('rotation', 'size', 'x')\n",
497
+ "rearrange the objects in a [size] structure that faces [rotation] and place it on [x] ('rotation', 'size', 'x')\n",
498
+ "move and rotate the [size] object structure so that it is on [x] and points [rotation] ('rotation', 'size', 'x')\n",
499
+ "[rotation] is the direction the [size] structure built from the objects should be facing, [x] is the location ('rotation', 'size', 'x')\n",
500
+ "build a [size] structure on the [y] of the table facing [rotation] ('rotation', 'size', 'y')\n",
501
+ "make a [size] structure from the objects and make sure it is pointing [rotation] and on [y] ('rotation', 'size', 'y')\n",
502
+ "rearrange the objects in a [size] structure that faces [rotation] and place it on [y] ('rotation', 'size', 'y')\n",
503
+ "move and rotate the [size] object structure so that it is on [y] and points [rotation] ('rotation', 'size', 'y')\n",
504
+ "[rotation] is the direction the [size] structure built from the objects should be facing, [y] is the location ('rotation', 'size', 'y')\n",
505
+ "build a [shape] from the objects on the [x] [y] of the table ('shape', 'x', 'y')\n",
506
+ "put the objects in to a [shape] and place it on [x] and [y] ('shape', 'x', 'y')\n",
507
+ "on the [x] [y] of the table, place the objects as a [shape] ('shape', 'x', 'y')\n",
508
+ "make a [shape] from the objects and move it to [x] [y] ('shape', 'x', 'y')\n",
509
+ "rearrange the objects into a [shape] on [x] and on [y] ('shape', 'x', 'y')\n",
510
+ "build a [shape] on the [x] of the table facing [rotation] ('rotation', 'shape', 'x')\n",
511
+ "make a [shape] from the objects and make sure it is pointing [rotation] and on [x] ('rotation', 'shape', 'x')\n",
512
+ "rearrange the objects in a [shape] that faces [rotation] and place it on [x] ('rotation', 'shape', 'x')\n",
513
+ "move and rotate the [shape] so that it is on [x] and points [rotation] ('rotation', 'shape', 'x')\n",
514
+ "[rotation] is the direction the [shape] built from the objects should be facing, [x] is the location ('rotation', 'shape', 'x')\n",
515
+ "build a [shape] on the [y] of the table facing [rotation] ('rotation', 'shape', 'y')\n",
516
+ "make a [shape] from the objects and make sure it is pointing [rotation] and on [y] ('rotation', 'shape', 'y')\n",
517
+ "rearrange the objects in a [shape] that faces [rotation] and place it on [y] ('rotation', 'shape', 'y')\n",
518
+ "move and rotate the [shape] so that it is on [y] and points [rotation] ('rotation', 'shape', 'y')\n",
519
+ "[rotation] is the direction the [shape] built from the objects should be facing, [y] is the location ('rotation', 'shape', 'y')\n",
520
+ "build a structure on the [x] [y] of the table facing [rotation] ('rotation', 'x', 'y')\n",
521
+ "make a structure from the objects and make sure it is pointing [rotation] and on [x] [y] ('rotation', 'x', 'y')\n",
522
+ "rearrange the objects in a structure that faces [rotation] and place it on [x] [y] ('rotation', 'x', 'y')\n",
523
+ "move and rotate the object structure so that it is on [x] [y] and points [rotation] ('rotation', 'x', 'y')\n",
524
+ "[rotation] is the direction the structure built from the objects should be facing, [x] [y] is the location ('rotation', 'x', 'y')\n",
525
+ "build a [shape] from the objects on the [x] [y] of the table, make the [shape] [size] ('shape', 'size', 'x', 'y')\n",
526
+ "put the objects in to a [size] [shape] and place it on [x] and [y] ('shape', 'size', 'x', 'y')\n",
527
+ "on the [x] [y] of the table, place the objects as a [size] [shape] ('shape', 'size', 'x', 'y')\n",
528
+ "make a [size] [shape] from the objects and move it to [x] [y] ('shape', 'size', 'x', 'y')\n",
529
+ "rearrange the objects into a [size] [shape] on [x] and on [y] ('shape', 'size', 'x', 'y')\n",
530
+ "build a [size] [shape] on the [x] of the table facing [rotation] ('rotation', 'shape', 'size', 'x')\n",
531
+ "make a [size] [shape] from the objects and make sure it is pointing [rotation] and on [x] ('rotation', 'shape', 'size', 'x')\n",
532
+ "rearrange the objects in a [size] [shape] that faces [rotation] and place it on [x] ('rotation', 'shape', 'size', 'x')\n",
533
+ "move and rotate the [size] [shape] so that it is on [x] and points [rotation] ('rotation', 'shape', 'size', 'x')\n",
534
+ "[rotation] is the direction the [size] [shape] built from the objects should be facing, [x] is the location ('rotation', 'shape', 'size', 'x')\n",
535
+ "build a [size] [shape] on the [y] of the table facing [rotation] ('rotation', 'shape', 'size', 'y')\n",
536
+ "make a [size] [shape] from the objects and make sure it is pointing [rotation] and on [y] ('rotation', 'shape', 'size', 'y')\n",
537
+ "rearrange the objects in a [size] [shape] that faces [rotation] and place it on [y] ('rotation', 'shape', 'size', 'y')\n",
538
+ "move and rotate the [size] [shape] so that it is on [y] and points [rotation] ('rotation', 'shape', 'size', 'y')\n",
539
+ "[rotation] is the direction the [size] [shape] built from the objects should be facing, [y] is the location ('rotation', 'shape', 'size', 'y')\n",
540
+ "build a [size] structure on the [x] [y] of the table facing [rotation] ('rotation', 'size', 'x', 'y')\n",
541
+ "make a [size] structure from the objects and make sure it is pointing [rotation] and on [x] [y] ('rotation', 'size', 'x', 'y')\n",
542
+ "rearrange the objects in a [size] structure that faces [rotation] and place it on [x] [y] ('rotation', 'size', 'x', 'y')\n",
543
+ "move and rotate the [size] object structure so that it is on [x] [y] and points [rotation] ('rotation', 'size', 'x', 'y')\n",
544
+ "[rotation] is the direction the [size] structure built from the objects should be facing, [x] [y] is the location ('rotation', 'size', 'x', 'y')\n",
545
+ "build a [shape] on the [x] [y] of the table facing [rotation] ('rotation', 'shape', 'x', 'y')\n",
546
+ "make a [shape] from the objects and make sure it is pointing [rotation] and on [x] [y] ('rotation', 'shape', 'x', 'y')\n",
547
+ "rearrange the objects as a [shape] that faces [rotation] and place it on [x] [y] ('rotation', 'shape', 'x', 'y')\n",
548
+ "move and rotate the [shape] so that it is on [x] [y] and points [rotation] ('rotation', 'shape', 'x', 'y')\n",
549
+ "[rotation] is the direction the [shape] built from the objects should be facing, [x] [y] is the location ('rotation', 'shape', 'x', 'y')\n",
550
+ "build a [size] [shape] on the [x] [y] of the table facing [rotation] ('rotation', 'shape', 'size', 'x', 'y')\n",
551
+ "make a [size] [shape] from the objects and make sure it is pointing [rotation] and on [x] [y] ('rotation', 'shape', 'size', 'x', 'y')\n",
552
+ "rearrange the objects as a [size] [shape] that faces [rotation] and place it on [x] [y] ('rotation', 'shape', 'size', 'x', 'y')\n",
553
+ "move and rotate the [size] [shape] so that it is on [x] [y] and points [rotation] ('rotation', 'shape', 'size', 'x', 'y')\n",
554
+ "[rotation] is the direction the [size] [shape] built from the objects should be facing, [x] [y] is the location ('rotation', 'shape', 'size', 'x', 'y')\n"
555
+ ]
556
+ }
557
+ ],
558
+ "source": [
559
+ "sentence_template_file = \"/home/weiyu/Research/intern/StructDiffuser/src/StructDiffuser/language/sentence_template.txt\"\n",
560
+ "\n",
561
+ "import re\n",
562
+ "\n",
563
+ "type_comb_to_templates = {}\n",
564
+ "for type_comb in type_combs:\n",
565
+ " type_comb_to_templates[type_comb] = []\n",
566
+ "\n",
567
+ "with open(sentence_template_file, \"r\") as fh:\n",
568
+ " for line in fh:\n",
569
+ " line = line.strip()\n",
570
+ " if line:\n",
571
+ " if line[0] == \"#\":\n",
572
+ " continue\n",
573
+ " type_list = re.findall('\\[[^\\]]*\\]', line)\n",
574
+ " type_comb = tuple(sorted(list(set([t[1:-1] for t in type_list]))))\n",
575
+ " print(line, type_comb)\n",
576
+ "\n",
577
+ " type_comb_to_templates[type_comb].append(line)"
578
+ ],
579
+ "metadata": {
580
+ "collapsed": false,
581
+ "pycharm": {
582
+ "name": "#%%\n"
583
+ }
584
+ }
585
+ },
586
+ {
587
+ "cell_type": "code",
588
+ "execution_count": 47,
589
+ "outputs": [],
590
+ "source": [
591
+ "for type_comb in type_comb_to_templates:\n",
592
+ " if len(type_comb_to_templates[type_comb]) != 5:\n",
593
+ " print(\"{} does not have 5 templates\".format(type_comb))"
594
+ ],
595
+ "metadata": {
596
+ "collapsed": false,
597
+ "pycharm": {
598
+ "name": "#%%\n"
599
+ }
600
+ }
601
+ },
602
+ {
603
+ "cell_type": "code",
604
+ "execution_count": 58,
605
+ "outputs": [
606
+ {
607
+ "name": "stderr",
608
+ "output_type": "stream",
609
+ "text": [
610
+ "100%|██████████| 669/669 [00:00<00:00, 60546.98it/s]\n"
611
+ ]
612
+ }
613
+ ],
614
+ "source": [
615
+ "template_sentences = []\n",
616
+ "type_value_tuple_to_template_sentences = defaultdict(set)\n",
617
+ "for type_value_tuple in tqdm.tqdm(list(unique_type_value_tuples)):\n",
618
+ " type_comb = tuple(sorted([tv[0] for tv in type_value_tuple]))\n",
619
+ " template_sentences = copy.deepcopy(type_comb_to_templates[type_comb])\n",
620
+ "\n",
621
+ " # print(type_value_tuple)\n",
622
+ " for template_sentence in template_sentences:\n",
623
+ " for t, v in type_value_tuple:\n",
624
+ " template_sentence = template_sentence.replace(\"[{}]\".format(t), v)\n",
625
+ " # print(template_sentence)\n",
626
+ "\n",
627
+ " type_value_tuple_to_template_sentences[type_value_tuple].add(template_sentence)\n",
628
+ "\n",
629
+ "# convert to list\n",
630
+ "for type_value_tuple in type_value_tuple_to_template_sentences:\n",
631
+ " type_value_tuple_to_template_sentences[type_value_tuple] = list(type_value_tuple_to_template_sentences[type_value_tuple])"
632
+ ],
633
+ "metadata": {
634
+ "collapsed": false,
635
+ "pycharm": {
636
+ "name": "#%%\n"
637
+ }
638
+ }
639
+ },
640
+ {
641
+ "cell_type": "code",
642
+ "execution_count": 73,
643
+ "outputs": [
644
+ {
645
+ "name": "stdout",
646
+ "output_type": "stream",
647
+ "text": [
648
+ "3345 unique template sentences\n"
649
+ ]
650
+ }
651
+ ],
652
+ "source": [
653
+ "unique_template_sentences = set()\n",
654
+ "\n",
655
+ "for type_value_tuple in type_value_tuple_to_template_sentences:\n",
656
+ " # print(\"\\n\")\n",
657
+ " # print(type_value_tuple)\n",
658
+ " for template_sentence in type_value_tuple_to_template_sentences[type_value_tuple]:\n",
659
+ " # print(template_sentence)\n",
660
+ " unique_template_sentences.add(template_sentence)\n",
661
+ "\n",
662
+ "unique_template_sentences = list(unique_template_sentences)\n",
663
+ "print(\"{} unique template sentences\".format(len(unique_template_sentences)))"
664
+ ],
665
+ "metadata": {
666
+ "collapsed": false,
667
+ "pycharm": {
668
+ "name": "#%%\n"
669
+ }
670
+ }
671
+ },
672
+ {
673
+ "cell_type": "code",
674
+ "execution_count": 72,
675
+ "outputs": [],
676
+ "source": [
677
+ "from sentence_transformers import SentenceTransformer\n",
678
+ "model = SentenceTransformer('all-MiniLM-L6-v2')"
679
+ ],
680
+ "metadata": {
681
+ "collapsed": false,
682
+ "pycharm": {
683
+ "name": "#%%\n"
684
+ }
685
+ }
686
+ },
687
+ {
688
+ "cell_type": "code",
689
+ "execution_count": 76,
690
+ "outputs": [
691
+ {
692
+ "name": "stdout",
693
+ "output_type": "stream",
694
+ "text": [
695
+ "(3345, 384)\n"
696
+ ]
697
+ }
698
+ ],
699
+ "source": [
700
+ "#Our sentences we like to encode\n",
701
+ "# sentences = ['This framework generates embeddings for each input sentence',\n",
702
+ "# 'Sentences are passed as a list of string.',\n",
703
+ "# 'The quick brown fox jumps over the lazy dog.']\n",
704
+ "#Sentences are encoded by calling model.encode()\n",
705
+ "\n",
706
+ "\n",
707
+ "embeddings = model.encode(unique_template_sentences)\n",
708
+ "print(embeddings.shape)"
709
+ ],
710
+ "metadata": {
711
+ "collapsed": false,
712
+ "pycharm": {
713
+ "name": "#%%\n"
714
+ }
715
+ }
716
+ },
717
+ {
718
+ "cell_type": "code",
719
+ "execution_count": 80,
720
+ "outputs": [],
721
+ "source": [
722
+ "template_sentence_to_embedding = {}\n",
723
+ "for embedding, template_sentence in zip(embeddings, unique_template_sentences):\n",
724
+ " template_sentence_to_embedding[template_sentence] = embedding"
725
+ ],
726
+ "metadata": {
727
+ "collapsed": false,
728
+ "pycharm": {
729
+ "name": "#%%\n"
730
+ }
731
+ }
732
+ },
733
+ {
734
+ "cell_type": "code",
735
+ "execution_count": 82,
736
+ "outputs": [],
737
+ "source": [
738
+ "import pickle\n",
739
+ "template_sentence_data = {\"template_sentence_to_embedding\": template_sentence_to_embedding,\n",
740
+ " \"type_value_tuple_to_template_sentences\": type_value_tuple_to_template_sentences}\n",
741
+ "with open(\"/home/weiyu/Research/intern/StructDiffuser/src/StructDiffuser/language/template_sentence_data.pkl\", \"wb\") as fh:\n",
742
+ " pickle.dump(template_sentence_data, fh)"
743
+ ],
744
+ "metadata": {
745
+ "collapsed": false,
746
+ "pycharm": {
747
+ "name": "#%%\n"
748
+ }
749
+ }
750
+ }
751
+ ],
752
+ "metadata": {
753
+ "kernelspec": {
754
+ "display_name": "Python 3",
755
+ "language": "python",
756
+ "name": "python3"
757
+ },
758
+ "language_info": {
759
+ "codemirror_mode": {
760
+ "name": "ipython",
761
+ "version": 2
762
+ },
763
+ "file_extension": ".py",
764
+ "mimetype": "text/x-python",
765
+ "name": "python",
766
+ "nbconvert_exporter": "python",
767
+ "pygments_lexer": "ipython2",
768
+ "version": "2.7.6"
769
+ }
770
+ },
771
+ "nbformat": 4,
772
+ "nbformat_minor": 0
773
+ }
src/StructDiffusion/language/sentence_encoder.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer
2
+
3
+ class SentenceBertEncoder:
4
+
5
+ def __init__(self):
6
+ self.model = SentenceTransformer('all-MiniLM-L6-v2')
7
+
8
+ def encode(self, sentences):
9
+ #Our sentences we like to encode
10
+ # sentences = ['This framework generates embeddings for each input sentence',
11
+ # 'Sentences are passed as a list of string.',
12
+ # 'The quick brown fox jumps over the lazy dog.']
13
+ #Sentences are encoded by calling model.encode()
14
+
15
+ embeddings = self.model.encode(sentences)
16
+ # print(embeddings.shape)
17
+ return embeddings
18
+
19
+
20
+ if __name__ == "__main__":
21
+ sentence_encoder = SentenceBertEncoder()
22
+ embedding = sentence_encoder.encode(["this is cool!"])
23
+ print(embedding.shape)
src/StructDiffusion/language/test_parrot_paraphrase.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from parrot import Parrot
2
+ import torch
3
+ import warnings
4
+ warnings.filterwarnings("ignore")
5
+
6
+ # [top]
7
+
8
+ # Put the [objects] in a [size][shape] on the [x][y] of the table facing [rotation].
9
+ # Build a [size][shape] of the [objects] on the [x][y] of the table facing [rotation].
10
+ # Put the [objects] on the [x][y] of the table and make a [shape] facing [rotation].
11
+ # Rearrange the [objects] into a [shape], and put the structure on the [x][y] of the table facing [rotation].
12
+ # Could you ...
13
+ # Please ...
14
+ # Pick up the objects, put them into a [size][shape], place the [shape] on the [x][y] of table, make sure the [shape] is facing [rotation].
15
+
16
+ if __name__ == "__main__":
17
+ '''
18
+ uncomment to get reproducable paraphrase generations
19
+ def random_state(seed):
20
+ torch.manual_seed(seed)
21
+ if torch.cuda.is_available():
22
+ torch.cuda.manual_seed_all(seed)
23
+
24
+ random_state(1234)
25
+ '''
26
+
27
+ #Init models (make sure you init ONLY once if you integrate this to your code)
28
+ parrot = Parrot(model_tag="prithivida/parrot_paraphraser_on_T5")
29
+
30
+ phrases = ["Rearrange the mugs in a circle on the top left of the table."]
31
+
32
+ for phrase in phrases:
33
+ print("-"*100)
34
+ print("Input_phrase: ", phrase)
35
+ print("-"*100)
36
+ para_phrases = parrot.augment(input_phrase=phrase, use_gpu=False, max_return_phrases=100, do_diverse=True)
37
+ for para_phrase in para_phrases:
38
+ print(para_phrase)
src/StructDiffusion/language/tokenizer.py CHANGED
@@ -517,25 +517,4 @@ class ContinuousTokenizer:
517
  idx = value
518
  else:
519
  raise KeyError("Do not recognize the type {} of the given token: {}".format(typ, value))
520
- return idx
521
-
522
-
523
- if __name__ == "__main__":
524
- tokenizer = Tokenizer("/home/weiyu/data_drive/data_new_objects/type_vocabs_coarse.json")
525
- # print(tokenizer.get_all_values_of_type("class"))
526
- # print(tokenizer.get_all_values_of_type("color"))
527
- # print(tokenizer.get_all_values_of_type("material"))
528
- #
529
- # for type in tokenizer.type_vocabs:
530
- # print(type, tokenizer.type_vocabs[type])
531
-
532
- tokenizer.prepare_grounding_reference()
533
-
534
- # for i in range(100):
535
- # types = list(tokenizer.continuous_types) + list(tokenizer.discrete_types)
536
- # for t in types:
537
- # v = tokenizer.get_valid_random_value(t)
538
- # print(v)
539
- # print(tokenizer.tokenize(v, t))
540
-
541
- # build_vocab("/home/weiyu/data_drive/examples_v4/leonardo/vocab.json", "/home/weiyu/data_drive/examples_v4/leonardo/type_vocabs.json")
 
517
  idx = value
518
  else:
519
  raise KeyError("Do not recognize the type {} of the given token: {}".format(typ, value))
520
+ return idx
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/StructDiffusion/models/__pycache__/models.cpython-38.pyc CHANGED
Binary files a/src/StructDiffusion/models/__pycache__/models.cpython-38.pyc and b/src/StructDiffusion/models/__pycache__/models.cpython-38.pyc differ
 
src/StructDiffusion/models/models.py CHANGED
@@ -26,6 +26,8 @@ class TransformerDiffusionModel(torch.nn.Module):
26
  word_emb_dim=160,
27
  time_emb_dim=80,
28
  use_virtual_structure_frame=True,
 
 
29
  ):
30
  super(TransformerDiffusionModel, self).__init__()
31
 
@@ -53,7 +55,12 @@ class TransformerDiffusionModel(torch.nn.Module):
53
  self.virtual_frame_embed = nn.Parameter(torch.randn(1, 1, posed_pc_emb_dim)) # B, 1, posed_pc_emb_dim
54
 
55
  # for language
56
- self.word_embeddings = torch.nn.Embedding(vocab_size, word_emb_dim, padding_idx=0)
 
 
 
 
 
57
 
58
  # for diffusion
59
  self.pose_encoder = nn.Sequential(nn.Linear(action_dim, pose_emb_dim))
@@ -88,7 +95,10 @@ class TransformerDiffusionModel(torch.nn.Module):
88
 
89
  batch_size, num_objects, num_pts, _ = pcs.shape
90
  _, num_poses, _ = poses.shape
91
- _, sentence_len = sentence.shape
 
 
 
92
  _, total_len = type_index.shape
93
 
94
  pcs = pcs.reshape(batch_size * num_objects, num_pts, -1)
@@ -102,7 +112,11 @@ class TransformerDiffusionModel(torch.nn.Module):
102
  tgt_obj_embed = torch.cat([pose_embed, posed_pc_embed], dim=-1)
103
 
104
  #########################
105
- sentence_embed = self.word_embeddings(sentence)
 
 
 
 
106
 
107
  #########################
108
 
 
26
  word_emb_dim=160,
27
  time_emb_dim=80,
28
  use_virtual_structure_frame=True,
29
+ use_sentence_embedding=False,
30
+ sentence_embedding_dim=None,
31
  ):
32
  super(TransformerDiffusionModel, self).__init__()
33
 
 
55
  self.virtual_frame_embed = nn.Parameter(torch.randn(1, 1, posed_pc_emb_dim)) # B, 1, posed_pc_emb_dim
56
 
57
  # for language
58
+ self.sentence_embedding_dim = sentence_embedding_dim
59
+ self.use_sentence_embedding = use_sentence_embedding
60
+ if use_sentence_embedding:
61
+ self.sentence_embedding_down_sample = torch.nn.Linear(sentence_embedding_dim, word_emb_dim)
62
+ else:
63
+ self.word_embeddings = torch.nn.Embedding(vocab_size, word_emb_dim, padding_idx=0)
64
 
65
  # for diffusion
66
  self.pose_encoder = nn.Sequential(nn.Linear(action_dim, pose_emb_dim))
 
95
 
96
  batch_size, num_objects, num_pts, _ = pcs.shape
97
  _, num_poses, _ = poses.shape
98
+ if self.use_sentence_embedding:
99
+ assert sentence.shape == (batch_size, self.sentence_embedding_dim), sentence.shape
100
+ else:
101
+ _, sentence_len = sentence.shape
102
  _, total_len = type_index.shape
103
 
104
  pcs = pcs.reshape(batch_size * num_objects, num_pts, -1)
 
112
  tgt_obj_embed = torch.cat([pose_embed, posed_pc_embed], dim=-1)
113
 
114
  #########################
115
+ if self.use_sentence_embedding:
116
+ # sentence: B, sentence_embedding_dim
117
+ sentence_embed = self.sentence_embedding_down_sample(sentence).unsqueeze(1) # B, 1, word_emb_dim
118
+ else:
119
+ sentence_embed = self.word_embeddings(sentence)
120
 
121
  #########################
122
 
src/StructDiffusion/utils/__pycache__/batch_inference.cpython-38.pyc CHANGED
Binary files a/src/StructDiffusion/utils/__pycache__/batch_inference.cpython-38.pyc and b/src/StructDiffusion/utils/__pycache__/batch_inference.cpython-38.pyc differ
 
src/StructDiffusion/utils/__pycache__/files.cpython-38.pyc CHANGED
Binary files a/src/StructDiffusion/utils/__pycache__/files.cpython-38.pyc and b/src/StructDiffusion/utils/__pycache__/files.cpython-38.pyc differ
 
src/StructDiffusion/utils/__pycache__/rearrangement.cpython-38.pyc CHANGED
Binary files a/src/StructDiffusion/utils/__pycache__/rearrangement.cpython-38.pyc and b/src/StructDiffusion/utils/__pycache__/rearrangement.cpython-38.pyc differ
 
src/StructDiffusion/utils/__pycache__/rotation_continuity.cpython-38.pyc CHANGED
Binary files a/src/StructDiffusion/utils/__pycache__/rotation_continuity.cpython-38.pyc and b/src/StructDiffusion/utils/__pycache__/rotation_continuity.cpython-38.pyc differ
 
src/StructDiffusion/utils/__pycache__/tra3d.cpython-38.pyc ADDED
Binary file (4.72 kB). View file
 
src/StructDiffusion/utils/batch_inference.py CHANGED
@@ -1,175 +1,10 @@
1
  import os
2
  import torch
3
  import numpy as np
4
- import pytorch3d.transforms as tra3d
5
 
6
  from StructDiffusion.utils.rearrangement import show_pcs_color_order, show_pcs_with_trimesh
7
  from StructDiffusion.utils.pointnet import random_point_sample, index_points
8
-
9
-
10
- def move_pc_and_create_scene(obj_xyzs, obj_params, struct_pose, current_pc_pose, target_object_inds,
11
- num_scene_pts, device, normalize_pc=False,
12
- return_pair_pc=False, normalize_pair_pc=False, num_pair_pc_pts=None,
13
- return_scene_pts=True, return_scene_pts_and_pc_idxs=False):
14
-
15
- # obj_xyzs: N, P, 3
16
- # obj_params: B, N, 6
17
- # struct_pose: B x N, 4, 4
18
- # current_pc_pose: B x N, 4, 4
19
- # target_object_inds: 1, N
20
-
21
- B, N, _ = obj_params.shape
22
- _, P, _ = obj_xyzs.shape
23
-
24
- # B, N, 6
25
- flat_obj_params = obj_params.reshape(B * N, -1)
26
- goal_pc_pose_in_struct = torch.eye(4).repeat(B * N, 1, 1).to(device)
27
- goal_pc_pose_in_struct[:, :3, :3] = tra3d.euler_angles_to_matrix(flat_obj_params[:, 3:], "XYZ")
28
- goal_pc_pose_in_struct[:, :3, 3] = flat_obj_params[:, :3] # B x N, 4, 4
29
-
30
- goal_pc_pose = struct_pose @ goal_pc_pose_in_struct
31
- goal_pc_transform = goal_pc_pose @ torch.inverse(current_pc_pose) # cur_batch_size x N, 4, 4
32
-
33
- # important: pytorch3d uses row-major ordering, need to transpose each transformation matrix
34
- transpose = tra3d.Transform3d(matrix=goal_pc_transform.transpose(1, 2))
35
-
36
- # obj_xyzs: N, P, 3
37
- new_obj_xyzs = obj_xyzs.repeat(B, 1, 1)
38
- new_obj_xyzs = transpose.transform_points(new_obj_xyzs)
39
-
40
- # put it back to B, N, P, 3
41
- new_obj_xyzs = new_obj_xyzs.reshape(B, N, P, -1)
42
- # visualize_batch_pcs(new_obj_xyzs, S, N, P)
43
-
44
- # ===================================
45
- # Pass to discriminator
46
- subsampled_scene_xyz = None
47
- if return_scene_pts:
48
-
49
- num_indicator = N
50
-
51
- # add one hot
52
- indicator_variables = torch.eye(num_indicator).repeat(B, 1, 1, P).reshape(B, num_indicator, P, num_indicator).to(device) # B, N, P, N
53
- # print(indicator_variables.shape)
54
- # print(new_obj_xyzs.shape)
55
- new_obj_xyzs = torch.cat([new_obj_xyzs, indicator_variables], dim=-1) # B, N, P, 3 + N
56
-
57
- # combine pcs in each scene
58
- scene_xyzs = new_obj_xyzs.reshape(B, N * P, 3 + N)
59
-
60
- # ToDo: maybe convert this to a batch operation
61
- subsampled_scene_xyz = torch.FloatTensor(B, num_scene_pts, 3 + N).to(device)
62
- for si, scene_xyz in enumerate(scene_xyzs):
63
- # scene_xyz: N*P, 3+N
64
- # target_object_inds: 1, N
65
- subsample_idx = torch.randint(0, torch.sum(target_object_inds[0]) * P, (num_scene_pts,)).to(device)
66
- subsampled_scene_xyz[si] = scene_xyz[subsample_idx]
67
-
68
- # # debug:
69
- # print("-"*50)
70
- # if si < 10:
71
- # trimesh.PointCloud(scene_xyz[:, :3].cpu().numpy(), colors=[255, 0, 0, 255]).show()
72
- # trimesh.PointCloud(subsampled_scene_xyz[si, :, :3].cpu().numpy(), colors=[0, 255, 0, 255]).show()
73
-
74
- # subsampled_scene_xyz: B, num_scene_pts, 3+N
75
- # new_obj_xyzs: B, N, P, 3
76
- # goal_pc_pose: B, N, 4, 4
77
-
78
- # important:
79
- if normalize_pc:
80
- subsampled_scene_xyz[:, :, 0:3] = pc_normalize_batch(subsampled_scene_xyz[:, :, 0:3])
81
-
82
- # # debug:
83
- # for si in range(10):
84
- # trimesh.PointCloud(subsampled_scene_xyz[si, :, :3].cpu().numpy(), colors=[0, 0, 255, 255]).show()
85
-
86
- if return_scene_pts_and_pc_idxs:
87
- num_indicator = N
88
- pc_idxs = torch.arange(0, num_indicator)[:, None].repeat(B, 1, P).reshape(B, num_indicator, P).to(device) # B, N, P
89
- # new_obj_xyzs: B, N, P, 3 + 1
90
-
91
- # combine pcs in each scene
92
- scene_xyzs = new_obj_xyzs.reshape(B, N * P, 3)
93
- pc_idxs = pc_idxs.reshape(B, N*P)
94
-
95
- subsampled_scene_xyz = torch.FloatTensor(B, num_scene_pts, 3).to(device)
96
- subsampled_pc_idxs = torch.LongTensor(B, num_scene_pts).to(device)
97
- for si, (scene_xyz, pc_idx) in enumerate(zip(scene_xyzs, pc_idxs)):
98
- # scene_xyz: N*P, 3+1
99
- # target_object_inds: 1, N
100
- subsample_idx = torch.randint(0, torch.sum(target_object_inds[0]) * P, (num_scene_pts,)).to(device)
101
- subsampled_scene_xyz[si] = scene_xyz[subsample_idx]
102
- subsampled_pc_idxs[si] = pc_idx[subsample_idx]
103
-
104
- # subsampled_scene_xyz: B, num_scene_pts, 3
105
- # subsampled_pc_idxs: B, num_scene_pts
106
- # new_obj_xyzs: B, N, P, 3
107
- # goal_pc_pose: B, N, 4, 4
108
-
109
- # important:
110
- if normalize_pc:
111
- subsampled_scene_xyz[:, :, 0:3] = pc_normalize_batch(subsampled_scene_xyz[:, :, 0:3])
112
-
113
- # TODO: visualize each individual object
114
- # debug
115
- # print(subsampled_scene_xyz.shape)
116
- # print(subsampled_pc_idxs.shape)
117
- # print("visualize subsampled scene")
118
- # for si in range(5):
119
- # trimesh.PointCloud(subsampled_scene_xyz[si, :, :3].cpu().numpy(), colors=[0, 0, 255, 255]).show()
120
-
121
- ###############################################
122
- # Create input for pairwise collision detector
123
- if return_pair_pc:
124
-
125
- assert num_pair_pc_pts is not None
126
-
127
- # new_obj_xyzs: B, N, P, 3 + N
128
- # target_object_inds: 1, N
129
- # ignore paddings
130
- num_objs = torch.sum(target_object_inds[0])
131
- obj_pair_idxs = torch.combinations(torch.arange(num_objs), r=2) # num_comb, 2
132
-
133
- # use [:, :, :, :3] to get obj_xyzs without object-wise indicator
134
- obj_pair_xyzs = new_obj_xyzs[:, :, :, :3][:, obj_pair_idxs] # B, num_comb, 2 (obj 1 and obj 2), P, 3
135
- num_comb = obj_pair_xyzs.shape[1]
136
- pair_indicator_variables = torch.eye(2).repeat(B, num_comb, 1, 1, P).reshape(B, num_comb, 2, P, 2).to(device) # B, num_comb, 2, P, 2
137
- obj_pair_xyzs = torch.cat([obj_pair_xyzs, pair_indicator_variables], dim=-1) # B, num_comb, 2, P, 3 (pc channels) + 2 (indicator for obj 1 and obj 2)
138
- obj_pair_xyzs = obj_pair_xyzs.reshape(B, num_comb, P * 2, 5)
139
-
140
- # random sample: idx = np.random.randint(0, scene_xyz.shape[0], self.num_scene_pts)
141
- obj_pair_xyzs = obj_pair_xyzs.reshape(B * num_comb, P * 2, 5)
142
- # random_point_sample() input dim: B, N, C
143
- rand_idxs = random_point_sample(obj_pair_xyzs, num_pair_pc_pts) # B * num_comb, num_pair_pc_pts
144
- obj_pair_xyzs = index_points(obj_pair_xyzs, rand_idxs) # B * num_comb, num_pair_pc_pts, 5
145
-
146
- if normalize_pair_pc:
147
- # pc_normalize_batch() input dim: pc: B, num_scene_pts, 3
148
- # obj_pair_xyzs = obj_pair_xyzs.reshape(B * num_comb, num_pair_pc_pts, 5)
149
- obj_pair_xyzs[:, :, 0:3] = pc_normalize_batch(obj_pair_xyzs[:, :, 0:3])
150
- obj_pair_xyzs = obj_pair_xyzs.reshape(B, num_comb, num_pair_pc_pts, 5)
151
-
152
- # # debug
153
- # for bi, this_obj_pair_xyzs in enumerate(obj_pair_xyzs):
154
- # print("batch id", bi)
155
- # for pi, obj_pair_xyz in enumerate(this_obj_pair_xyzs):
156
- # print("pair", pi)
157
- # # obj_pair_xyzs: 2 * P, 5
158
- # print(obj_pair_xyz[:, :3].shape)
159
- # trimesh.PointCloud(obj_pair_xyz[:, :3].cpu()).show()
160
-
161
- # obj_pair_xyzs: B, num_comb, num_pair_pc_pts, 3 + 2
162
- goal_pc_pose = goal_pc_pose.reshape(B, N, 4, 4)
163
-
164
- # TODO: update the return logic, a mess right now
165
- if return_scene_pts_and_pc_idxs:
166
- return subsampled_scene_xyz, subsampled_pc_idxs, new_obj_xyzs, goal_pc_pose
167
-
168
- if return_pair_pc:
169
- return subsampled_scene_xyz, new_obj_xyzs, goal_pc_pose, obj_pair_xyzs
170
- else:
171
- return subsampled_scene_xyz, new_obj_xyzs, goal_pc_pose
172
-
173
 
174
  def move_pc_and_create_scene_new(obj_xyzs, obj_params, struct_pose, current_pc_pose, target_object_inds, device,
175
  return_scene_pts=False, return_scene_pts_and_pc_idxs=False, num_scene_pts=None, normalize_pc=False,
@@ -193,12 +28,17 @@ def move_pc_and_create_scene_new(obj_xyzs, obj_params, struct_pose, current_pc_p
193
  goal_pc_pose = struct_pose @ goal_pc_pose_in_struct
194
  goal_pc_transform = goal_pc_pose @ torch.inverse(current_pc_pose) # cur_batch_size x N, 4, 4
195
 
196
- # important: pytorch3d uses row-major ordering, need to transpose each transformation matrix
197
- transpose = tra3d.Transform3d(matrix=goal_pc_transform.transpose(1, 2))
 
 
 
 
 
 
 
 
198
 
199
- # obj_xyzs: N, P, 3
200
- new_obj_xyzs = obj_xyzs.repeat(B, 1, 1)
201
- new_obj_xyzs = transpose.transform_points(new_obj_xyzs)
202
 
203
  # put it back to B, N, P, 3
204
  new_obj_xyzs = new_obj_xyzs.reshape(B, N, P, -1)
@@ -332,82 +172,6 @@ def move_pc_and_create_scene_new(obj_xyzs, obj_params, struct_pose, current_pc_p
332
  return new_obj_xyzs, goal_pc_pose, subsampled_scene_xyz, subsampled_pc_idxs, obj_pair_xyzs
333
 
334
 
335
- def move_pc(obj_xyzs, obj_params, struct_pose, current_pc_pose, device):
336
-
337
- # obj_xyzs: N, P, 3
338
- # obj_params: B, N, 6
339
- # struct_pose: B x N, 4, 4
340
- # current_pc_pose: B x N, 4, 4
341
- # target_object_inds: 1, N
342
-
343
- B, N, _ = obj_params.shape
344
- _, P, _ = obj_xyzs.shape
345
-
346
- # B, N, 6
347
- flat_obj_params = obj_params.reshape(B * N, -1)
348
- goal_pc_pose_in_struct = torch.eye(4).repeat(B * N, 1, 1).to(device)
349
- goal_pc_pose_in_struct[:, :3, :3] = tra3d.euler_angles_to_matrix(flat_obj_params[:, 3:], "XYZ")
350
- goal_pc_pose_in_struct[:, :3, 3] = flat_obj_params[:, :3] # B x N, 4, 4
351
-
352
- goal_pc_pose = struct_pose @ goal_pc_pose_in_struct
353
- goal_pc_transform = goal_pc_pose @ torch.inverse(current_pc_pose) # cur_batch_size x N, 4, 4
354
-
355
- # important: pytorch3d uses row-major ordering, need to transpose each transformation matrix
356
- transpose = tra3d.Transform3d(matrix=goal_pc_transform.transpose(1, 2))
357
-
358
- # obj_xyzs: N, P, 3
359
- new_obj_xyzs = obj_xyzs.repeat(B, 1, 1)
360
- new_obj_xyzs = transpose.transform_points(new_obj_xyzs)
361
-
362
- # put it back to B, N, P, 3
363
- new_obj_xyzs = new_obj_xyzs.reshape(B, N, P, -1)
364
- # visualize_batch_pcs(new_obj_xyzs, S, N, P)
365
-
366
- # subsampled_scene_xyz: B, num_scene_pts, 3+N
367
- # new_obj_xyzs: B, N, P, 3
368
- # goal_pc_pose: B, N, 4, 4
369
-
370
- goal_pc_pose = goal_pc_pose.reshape(B, N, 4, 4)
371
- return new_obj_xyzs, goal_pc_pose
372
-
373
-
374
- def move_pc_and_create_scene_simple(obj_xyzs, struct_pose, pc_poses_in_struct):
375
-
376
- device = obj_xyzs.device
377
-
378
- # obj_xyzs: B, N, P, 3 or 6
379
- # struct_pose: B, 1, 4, 4
380
- # pc_poses_in_struct: B, N, 4, 4
381
-
382
- B, N, _, _ = pc_poses_in_struct.shape
383
- _, _, P, _ = obj_xyzs.shape
384
-
385
- current_pc_poses = torch.eye(4).repeat(B, N, 1, 1).to(device) # B, N, 4, 4
386
- # print(torch.mean(obj_xyzs, dim=2).shape)
387
- current_pc_poses[:, :, :3, 3] = torch.mean(obj_xyzs[:, :, :, :3], dim=2) # B, N, 4, 4
388
- current_pc_poses = current_pc_poses.reshape(B * N, 4, 4) # B x N, 4, 4
389
-
390
- struct_pose = struct_pose.repeat(1, N, 1, 1) # B, N, 4, 4
391
- struct_pose = struct_pose.reshape(B * N, 4, 4) # B x 1, 4, 4
392
- pc_poses_in_struct = pc_poses_in_struct.reshape(B * N, 4, 4) # B x N, 4, 4
393
-
394
- goal_pc_pose = struct_pose @ pc_poses_in_struct # B x N, 4, 4
395
- # print("goal pc poses")
396
- # print(goal_pc_pose)
397
- goal_pc_transform = goal_pc_pose @ torch.inverse(current_pc_poses) # B x N, 4, 4
398
-
399
- # important: pytorch3d uses row-major ordering, need to transpose each transformation matrix
400
- transpose = tra3d.Transform3d(matrix=goal_pc_transform.transpose(1, 2))
401
-
402
- new_obj_xyzs = obj_xyzs.reshape(B * N, P, -1) # B x N, P, 3
403
- new_obj_xyzs[:, :, :3] = transpose.transform_points(new_obj_xyzs[:, :, :3])
404
-
405
- # put it back to B, N, P, 3
406
- new_obj_xyzs = new_obj_xyzs.reshape(B, N, P, -1)
407
-
408
- return new_obj_xyzs
409
-
410
-
411
  def compute_current_and_goal_pc_poses(obj_xyzs, struct_pose, pc_poses_in_struct):
412
 
413
  device = obj_xyzs.device
 
1
  import os
2
  import torch
3
  import numpy as np
 
4
 
5
  from StructDiffusion.utils.rearrangement import show_pcs_color_order, show_pcs_with_trimesh
6
  from StructDiffusion.utils.pointnet import random_point_sample, index_points
7
+ import StructDiffusion.utils.tra3d as tra3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  def move_pc_and_create_scene_new(obj_xyzs, obj_params, struct_pose, current_pc_pose, target_object_inds, device,
10
  return_scene_pts=False, return_scene_pts_and_pc_idxs=False, num_scene_pts=None, normalize_pc=False,
 
28
  goal_pc_pose = struct_pose @ goal_pc_pose_in_struct
29
  goal_pc_transform = goal_pc_pose @ torch.inverse(current_pc_pose) # cur_batch_size x N, 4, 4
30
 
31
+ # # important: pytorch3d uses row-major ordering, need to transpose each transformation matrix
32
+ # transpose = tra3d.Transform3d(matrix=goal_pc_transform.transpose(1, 2))
33
+ # # obj_xyzs: N, P, 3
34
+ # new_obj_xyzs = obj_xyzs.repeat(B, 1, 1)
35
+ # new_obj_xyzs = transpose.transform_points(new_obj_xyzs)
36
+
37
+ # a verision that does not rely on pytorch3d
38
+ new_obj_xyzs = obj_xyzs.repeat(B, 1, 1) # B x N, P, 3
39
+ new_obj_xyzs = torch.concat([new_obj_xyzs, torch.ones(B * N, P, 1).to(device)], dim=-1) # B x N, P, 4
40
+ new_obj_xyzs = torch.einsum('bij,bkj->bki', goal_pc_transform, new_obj_xyzs)[:, :, :3] # # B x N, P, 3
41
 
 
 
 
42
 
43
  # put it back to B, N, P, 3
44
  new_obj_xyzs = new_obj_xyzs.reshape(B, N, P, -1)
 
172
  return new_obj_xyzs, goal_pc_pose, subsampled_scene_xyz, subsampled_pc_idxs, obj_pair_xyzs
173
 
174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  def compute_current_and_goal_pc_poses(obj_xyzs, struct_pose, pc_poses_in_struct):
176
 
177
  device = obj_xyzs.device
src/StructDiffusion/utils/files.py CHANGED
@@ -1,9 +1,17 @@
1
  import os
2
 
 
3
  def get_checkpoint_path_from_dir(checkpoint_dir):
4
  checkpoint_path = None
5
  for file in os.listdir(checkpoint_dir):
6
  if "ckpt" in file:
7
  checkpoint_path = os.path.join(checkpoint_dir, file)
8
  assert checkpoint_path is not None
9
- return checkpoint_path
 
 
 
 
 
 
 
 
1
  import os
2
 
3
+
4
  def get_checkpoint_path_from_dir(checkpoint_dir):
5
  checkpoint_path = None
6
  for file in os.listdir(checkpoint_dir):
7
  if "ckpt" in file:
8
  checkpoint_path = os.path.join(checkpoint_dir, file)
9
  assert checkpoint_path is not None
10
+ return checkpoint_path
11
+
12
+
13
+ def replace_config_for_testing_data(cfg, testing_data_cfg):
14
+ cfg.DATASET.data_roots = testing_data_cfg.DATASET.data_roots
15
+ cfg.DATASET.index_roots = testing_data_cfg.DATASET.index_roots
16
+ cfg.DATASET.vocab_dir = testing_data_cfg.DATASET.vocab_dir
17
+
src/StructDiffusion/utils/np_speed_test.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import time
3
+
4
+ def numpy_test(w, h, r):
5
+ v, u = np.indices((h, w))
6
+ theta = (u - w / 2.) * 2 * np.pi / w
7
+ phi = (v - h / 2.) * np.pi / h
8
+ cos_phi = np.cos(phi)
9
+ x = cos_phi * np.sin(theta)
10
+ y = np.sin(phi)
11
+ z = cos_phi * np.cos(theta)
12
+
13
+ ray = np.dstack((x, y, z))
14
+ ray = ray.reshape(-1, 3).dot(r.T)
15
+ ray.shape = (h, w, 3)
16
+
17
+ x, y, z = np.dsplit(ray, 3)
18
+ theta = np.arctan2(x, z)
19
+ phi = np.arcsin(y)
20
+ u = theta * w / 2 / np.pi + w / 2.
21
+ v = phi * h / np.pi + h / 2.
22
+ xymap = np.dstack((u, v)).astype(np.float32)
23
+
24
+ return xymap
25
+
26
+ def matrix_multiplication():
27
+ for i in range(100):
28
+ np.random.random((1000, 1000)) @ np.random.random((1000, 1000))
29
+
30
+
31
+ if __name__ == "__main__":
32
+ w = 3584
33
+ h = int(w / 2)
34
+ r = np.array([
35
+ [0.61566148, -0.78369395, 0.08236955],
36
+ [0.78801075, 0.61228882, -0.06435415],
37
+ [0, 0.10452846, 0.9945219],])
38
+ begin_time = time.time()
39
+ # numpy_test(w, h, r)
40
+ matrix_multiplication()
41
+ print(time.time() - begin_time)
src/StructDiffusion/utils/rearrangement.py CHANGED
@@ -558,9 +558,12 @@ def fit_gaussians(samples, sigma_eps=0.01):
558
  return mus, sigmas
559
 
560
 
561
- def show_pcs_with_trimesh(obj_xyzs, obj_rgbs, return_scene=False):
562
- vis_pcs = [trimesh.PointCloud(obj_xyz, colors=np.concatenate([obj_rgb * 255, np.ones([obj_rgb.shape[0], 1]) * 255], axis=-1)) for
563
- obj_xyz, obj_rgb in zip(obj_xyzs, obj_rgbs)]
 
 
 
564
  scene = trimesh.Scene()
565
  # add the coordinate frame first
566
  geom = trimesh.creation.axis(0.01)
@@ -582,13 +585,35 @@ def show_pcs_with_trimesh(obj_xyzs, obj_rgbs, return_scene=False):
582
  RT_4x4 = np.linalg.inv(RT_4x4)
583
  RT_4x4 = RT_4x4 @ np.diag([1, -1, -1, 1])
584
  scene.camera_transform = RT_4x4
585
-
586
  if return_scene:
587
  return scene
588
  else:
589
  scene.show()
590
 
591
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
592
  def show_pcs_with_predictions(xyz, rgb, gts, predictions, add_coordinate_frame=False, return_buffer=False, add_table=True, side_view=True):
593
  """ Display point clouds """
594
 
 
558
  return mus, sigmas
559
 
560
 
561
+ def show_pcs_with_trimesh(obj_xyzs, obj_rgbs=None, return_scene=False):
562
+ if obj_rgbs is not None:
563
+ vis_pcs = [trimesh.PointCloud(obj_xyz, colors=np.concatenate([obj_rgb * 255, np.ones([obj_rgb.shape[0], 1]) * 255], axis=-1)) for
564
+ obj_xyz, obj_rgb in zip(obj_xyzs, obj_rgbs)]
565
+ else:
566
+ vis_pcs = [trimesh.PointCloud(obj_xyz) for obj_xyz in obj_xyzs]
567
  scene = trimesh.Scene()
568
  # add the coordinate frame first
569
  geom = trimesh.creation.axis(0.01)
 
585
  RT_4x4 = np.linalg.inv(RT_4x4)
586
  RT_4x4 = RT_4x4 @ np.diag([1, -1, -1, 1])
587
  scene.camera_transform = RT_4x4
 
588
  if return_scene:
589
  return scene
590
  else:
591
  scene.show()
592
 
593
 
594
+ def get_trimesh_scene_with_table():
595
+ scene = trimesh.Scene()
596
+ # add the coordinate frame first
597
+ geom = trimesh.creation.axis(0.01)
598
+ scene.add_geometry(geom)
599
+ table = trimesh.creation.box(extents=[1.0, 1.0, 0.02])
600
+ table.apply_translation([0.5, 0, -0.01])
601
+ table.visual.vertex_colors = [150, 111, 87, 125]
602
+ scene.add_geometry(table)
603
+ # bounds = trimesh.creation.box(extents=[4.0, 4.0, 4.0])
604
+ bounds = trimesh.creation.icosphere(subdivisions=3, radius=3.1)
605
+ bounds.apply_translation([0, 0, 0])
606
+ bounds.visual.vertex_colors = [30, 30, 30, 30]
607
+ # scene.add_geometry(bounds)
608
+ RT_4x4 = np.array([[-0.39560353822208355, -0.9183993826406329, 0.006357240869497738, 0.2651463080169481],
609
+ [-0.797630370081598, 0.3401340617616391, -0.4980909683511864, 0.2225696480721997],
610
+ [0.45528412367406523, -0.2021172778236285, -0.8671014777611122, 0.9449050652025951],
611
+ [0.0, 0.0, 0.0, 1.0]])
612
+ RT_4x4 = np.linalg.inv(RT_4x4)
613
+ RT_4x4 = RT_4x4 @ np.diag([1, -1, -1, 1])
614
+ scene.camera_transform = RT_4x4
615
+ return scene
616
+
617
  def show_pcs_with_predictions(xyz, rgb, gts, predictions, add_coordinate_frame=False, return_buffer=False, add_table=True, side_view=True):
618
  """ Display point clouds """
619
 
src/StructDiffusion/utils/tra3d.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ # source: https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html#matrix_to_euler_angles
5
+ # we don't want to build pytorch3d, so only pick functions we need to use
6
+
7
+ def _index_from_letter(letter: str) -> int:
8
+ if letter == "X":
9
+ return 0
10
+ if letter == "Y":
11
+ return 1
12
+ if letter == "Z":
13
+ return 2
14
+ raise ValueError("letter must be either X, Y or Z.")
15
+
16
+
17
+ def _angle_from_tan(
18
+ axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
19
+ ) -> torch.Tensor:
20
+ """
21
+ Extract the first or third Euler angle from the two members of
22
+ the matrix which are positive constant times its sine and cosine.
23
+
24
+ Args:
25
+ axis: Axis label "X" or "Y or "Z" for the angle we are finding.
26
+ other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
27
+ convention.
28
+ data: Rotation matrices as tensor of shape (..., 3, 3).
29
+ horizontal: Whether we are looking for the angle for the third axis,
30
+ which means the relevant entries are in the same row of the
31
+ rotation matrix. If not, they are in the same column.
32
+ tait_bryan: Whether the first and third axes in the convention differ.
33
+
34
+ Returns:
35
+ Euler Angles in radians for each matrix in data as a tensor
36
+ of shape (...).
37
+ """
38
+
39
+ i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
40
+ if horizontal:
41
+ i2, i1 = i1, i2
42
+ even = (axis + other_axis) in ["XY", "YZ", "ZX"]
43
+ if horizontal == even:
44
+ return torch.atan2(data[..., i1], data[..., i2])
45
+ if tait_bryan:
46
+ return torch.atan2(-data[..., i2], data[..., i1])
47
+ return torch.atan2(data[..., i2], -data[..., i1])
48
+
49
+
50
+ def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor:
51
+ """
52
+ Return the rotation matrices for one of the rotations about an axis
53
+ of which Euler angles describe, for each value of the angle given.
54
+
55
+ Args:
56
+ axis: Axis label "X" or "Y or "Z".
57
+ angle: any shape tensor of Euler angles in radians
58
+
59
+ Returns:
60
+ Rotation matrices as tensor of shape (..., 3, 3).
61
+ """
62
+
63
+ cos = torch.cos(angle)
64
+ sin = torch.sin(angle)
65
+ one = torch.ones_like(angle)
66
+ zero = torch.zeros_like(angle)
67
+
68
+ if axis == "X":
69
+ R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
70
+ elif axis == "Y":
71
+ R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
72
+ elif axis == "Z":
73
+ R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
74
+ else:
75
+ raise ValueError("letter must be either X, Y or Z.")
76
+
77
+ return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
78
+
79
+
80
+ def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tensor:
81
+ """
82
+ Convert rotations given as rotation matrices to Euler angles in radians.
83
+
84
+ Args:
85
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
86
+ convention: Convention string of three uppercase letters.
87
+
88
+ Returns:
89
+ Euler angles in radians as tensor of shape (..., 3).
90
+ """
91
+ if len(convention) != 3:
92
+ raise ValueError("Convention must have 3 letters.")
93
+ if convention[1] in (convention[0], convention[2]):
94
+ raise ValueError(f"Invalid convention {convention}.")
95
+ for letter in convention:
96
+ if letter not in ("X", "Y", "Z"):
97
+ raise ValueError(f"Invalid letter {letter} in convention string.")
98
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
99
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
100
+ i0 = _index_from_letter(convention[0])
101
+ i2 = _index_from_letter(convention[2])
102
+ tait_bryan = i0 != i2
103
+ if tait_bryan:
104
+ central_angle = torch.asin(
105
+ matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
106
+ )
107
+ else:
108
+ central_angle = torch.acos(matrix[..., i0, i0])
109
+
110
+ o = (
111
+ _angle_from_tan(
112
+ convention[0], convention[1], matrix[..., i2], False, tait_bryan
113
+ ),
114
+ central_angle,
115
+ _angle_from_tan(
116
+ convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
117
+ ),
118
+ )
119
+ return torch.stack(o, -1)
120
+
121
+
122
+ def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor:
123
+ """
124
+ Convert rotations given as Euler angles in radians to rotation matrices.
125
+
126
+ Args:
127
+ euler_angles: Euler angles in radians as tensor of shape (..., 3).
128
+ convention: Convention string of three uppercase letters from
129
+ {"X", "Y", and "Z"}.
130
+
131
+ Returns:
132
+ Rotation matrices as tensor of shape (..., 3, 3).
133
+ """
134
+ if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
135
+ raise ValueError("Invalid input euler angles.")
136
+ if len(convention) != 3:
137
+ raise ValueError("Convention must have 3 letters.")
138
+ if convention[1] in (convention[0], convention[2]):
139
+ raise ValueError(f"Invalid convention {convention}.")
140
+ for letter in convention:
141
+ if letter not in ("X", "Y", "Z"):
142
+ raise ValueError(f"Invalid letter {letter} in convention string.")
143
+ matrices = [
144
+ _axis_angle_rotation(c, e)
145
+ for c, e in zip(convention, torch.unbind(euler_angles, -1))
146
+ ]
147
+ # return functools.reduce(torch.matmul, matrices)
148
+ return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2])
tmp_data/input_scene.glb ADDED
Binary file (121 kB). View file
 
tmp_data/input_scene_102.glb ADDED
Binary file (80.8 kB). View file
 
tmp_data/input_scene_None.glb ADDED
Binary file (66.3 kB). View file
 
tmp_data/output_scene.glb ADDED
Binary file (121 kB). View file
 
tmp_data/output_scene_102.glb ADDED
Binary file (91.3 kB). View file
 
wandb_logs/StructDiffusion/CollisionDiscriminator/checkpoints/epoch=199-step=653400.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ed90c4976e69f96324d0245ecc635fa987d66d84c512d1e2105ce9f7f3df39ea
3
+ size 34533413
wandb_logs/StructDiffusion/ConditionalPoseDiffusionLanguage/checkpoints/epoch=199-step=100000.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:edd6c08c5fafd129f365fd70fc1b6ad68e643a9111e522432279afb7ba387a89
3
+ size 59947673