Spaces:
Paused
Paused
Weiyu Liu
commited on
Commit
•
8c02843
1
Parent(s):
a77a4ae
add demo
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +131 -0
- configs/base.yaml +3 -0
- configs/conditional_pose_diffusion.yaml +81 -0
- configs/pairwise_collision.yaml +42 -0
- data/data00000000.h5 +3 -0
- data/data00000002.h5 +3 -0
- data/data00000003.h5 +3 -0
- data/data00000004.h5 +3 -0
- data/data00000006.h5 +3 -0
- data/data00000008.h5 +3 -0
- data/data00000009.h5 +3 -0
- data/data00000012.h5 +3 -0
- data/data00000013.h5 +3 -0
- data/data00000015.h5 +3 -0
- data/type_vocabs_coarse.json +1 -0
- packages.txt +1 -0
- requirements.txt +13 -0
- scripts/infer.py +78 -0
- scripts/infer_with_discriminator.py +81 -0
- scripts/train_discriminator.py +46 -0
- scripts/train_generator.py +49 -0
- src/StructDiffusion/__init__.py +0 -0
- src/StructDiffusion/__pycache__/__init__.cpython-37.pyc +0 -0
- src/StructDiffusion/__pycache__/__init__.cpython-38.pyc +0 -0
- src/StructDiffusion/data/__init__.py +0 -0
- src/StructDiffusion/data/__pycache__/__init__.cpython-37.pyc +0 -0
- src/StructDiffusion/data/__pycache__/__init__.cpython-38.pyc +0 -0
- src/StructDiffusion/data/__pycache__/pairwise_collision.cpython-37.pyc +0 -0
- src/StructDiffusion/data/__pycache__/semantic_arrangement.cpython-37.pyc +0 -0
- src/StructDiffusion/data/__pycache__/semantic_arrangement.cpython-38.pyc +0 -0
- src/StructDiffusion/data/__pycache__/semantic_arrangement_demo.cpython-38.pyc +0 -0
- src/StructDiffusion/data/pairwise_collision.py +361 -0
- src/StructDiffusion/data/semantic_arrangement.py +579 -0
- src/StructDiffusion/data/semantic_arrangement_demo.py +563 -0
- src/StructDiffusion/diffusion/__init__.py +0 -0
- src/StructDiffusion/diffusion/__pycache__/__init__.cpython-37.pyc +0 -0
- src/StructDiffusion/diffusion/__pycache__/__init__.cpython-38.pyc +0 -0
- src/StructDiffusion/diffusion/__pycache__/noise_schedule.cpython-37.pyc +0 -0
- src/StructDiffusion/diffusion/__pycache__/noise_schedule.cpython-38.pyc +0 -0
- src/StructDiffusion/diffusion/__pycache__/pose_conversion.cpython-37.pyc +0 -0
- src/StructDiffusion/diffusion/__pycache__/pose_conversion.cpython-38.pyc +0 -0
- src/StructDiffusion/diffusion/__pycache__/sampler.cpython-37.pyc +0 -0
- src/StructDiffusion/diffusion/__pycache__/sampler.cpython-38.pyc +0 -0
- src/StructDiffusion/diffusion/noise_schedule.py +81 -0
- src/StructDiffusion/diffusion/pose_conversion.py +103 -0
- src/StructDiffusion/diffusion/sampler.py +296 -0
- src/StructDiffusion/language/__init__.py +0 -0
- src/StructDiffusion/language/__pycache__/__init__.cpython-37.pyc +0 -0
- src/StructDiffusion/language/__pycache__/__init__.cpython-38.pyc +0 -0
- src/StructDiffusion/language/__pycache__/tokenizer.cpython-37.pyc +0 -0
app.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import torch
|
4 |
+
import trimesh
|
5 |
+
import numpy as np
|
6 |
+
import pytorch_lightning as pl
|
7 |
+
import gradio as gr
|
8 |
+
from omegaconf import OmegaConf
|
9 |
+
|
10 |
+
import sys
|
11 |
+
sys.path.append('./src')
|
12 |
+
|
13 |
+
from StructDiffusion.data.semantic_arrangement_demo import SemanticArrangementDataset
|
14 |
+
from StructDiffusion.language.tokenizer import Tokenizer
|
15 |
+
from StructDiffusion.models.pl_models import ConditionalPoseDiffusionModel
|
16 |
+
from StructDiffusion.diffusion.sampler import Sampler
|
17 |
+
from StructDiffusion.diffusion.pose_conversion import get_struct_objs_poses
|
18 |
+
from StructDiffusion.utils.files import get_checkpoint_path_from_dir
|
19 |
+
from StructDiffusion.utils.batch_inference import move_pc_and_create_scene_simple, visualize_batch_pcs
|
20 |
+
from StructDiffusion.utils.rearrangement import show_pcs_with_trimesh
|
21 |
+
|
22 |
+
|
23 |
+
class Infer_Wrapper:
|
24 |
+
|
25 |
+
def __init__(self, args, cfg):
|
26 |
+
|
27 |
+
# load
|
28 |
+
pl.seed_everything(args.eval_random_seed)
|
29 |
+
self.device = (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
|
30 |
+
|
31 |
+
checkpoint_dir = os.path.join(cfg.WANDB.save_dir, cfg.WANDB.project, args.checkpoint_id, "checkpoints")
|
32 |
+
checkpoint_path = get_checkpoint_path_from_dir(checkpoint_dir)
|
33 |
+
|
34 |
+
self.tokenizer = Tokenizer(cfg.DATASET.vocab_dir)
|
35 |
+
# override ignore_rgb for visualization
|
36 |
+
cfg.DATASET.ignore_rgb = False
|
37 |
+
self.dataset = SemanticArrangementDataset(tokenizer=self.tokenizer, **cfg.DATASET)
|
38 |
+
|
39 |
+
self.sampler = Sampler(ConditionalPoseDiffusionModel, checkpoint_path, self.device)
|
40 |
+
|
41 |
+
def run(self, di):
|
42 |
+
|
43 |
+
# di = np.random.choice(len(self.dataset))
|
44 |
+
|
45 |
+
raw_datum = self.dataset.get_raw_data(di)
|
46 |
+
print(self.tokenizer.convert_structure_params_to_natural_language(raw_datum["sentence"]))
|
47 |
+
datum = self.dataset.convert_to_tensors(raw_datum, self.tokenizer)
|
48 |
+
batch = self.dataset.single_datum_to_batch(datum, args.num_samples, self.device, inference_mode=True)
|
49 |
+
|
50 |
+
num_poses = datum["goal_poses"].shape[0]
|
51 |
+
xs = self.sampler.sample(batch, num_poses)
|
52 |
+
|
53 |
+
struct_pose, pc_poses_in_struct = get_struct_objs_poses(xs[0])
|
54 |
+
new_obj_xyzs = move_pc_and_create_scene_simple(batch["pcs"], struct_pose, pc_poses_in_struct)
|
55 |
+
|
56 |
+
# vis
|
57 |
+
vis_obj_xyzs = new_obj_xyzs[:3]
|
58 |
+
if torch.is_tensor(vis_obj_xyzs):
|
59 |
+
if vis_obj_xyzs.is_cuda:
|
60 |
+
vis_obj_xyzs = vis_obj_xyzs.detach().cpu()
|
61 |
+
vis_obj_xyzs = vis_obj_xyzs.numpy()
|
62 |
+
|
63 |
+
# for bi, vis_obj_xyz in enumerate(vis_obj_xyzs):
|
64 |
+
# if verbose:
|
65 |
+
# print("example {}".format(bi))
|
66 |
+
# print(vis_obj_xyz.shape)
|
67 |
+
#
|
68 |
+
# if trimesh:
|
69 |
+
# show_pcs_with_trimesh([xyz[:, :3] for xyz in vis_obj_xyz], [xyz[:, 3:] for xyz in vis_obj_xyz])
|
70 |
+
vis_obj_xyz = vis_obj_xyzs[0]
|
71 |
+
scene = show_pcs_with_trimesh([xyz[:, :3] for xyz in vis_obj_xyz], [xyz[:, 3:] for xyz in vis_obj_xyz], return_scene=True)
|
72 |
+
|
73 |
+
scene_filename = "./tmp_data/scene.glb"
|
74 |
+
scene.export(scene_filename)
|
75 |
+
|
76 |
+
# pc_filename = "/home/weiyu/Research/StructDiffusion/StructDiffusion/interactive_demo/tmp_data/pc.glb"
|
77 |
+
# scene_filename = "/home/weiyu/Research/StructDiffusion/StructDiffusion/interactive_demo/tmp_data/scene.glb"
|
78 |
+
#
|
79 |
+
# vis_obj_xyz = vis_obj_xyz.reshape(-1, 6)
|
80 |
+
# vis_pc = trimesh.PointCloud(vis_obj_xyz[:, :3], colors=np.concatenate([vis_obj_xyz[:, 3:] * 255, np.ones([vis_obj_xyz.shape[0], 1]) * 255], axis=-1))
|
81 |
+
# vis_pc.export(pc_filename)
|
82 |
+
#
|
83 |
+
# scene = trimesh.Scene()
|
84 |
+
# # add the coordinate frame first
|
85 |
+
# # geom = trimesh.creation.axis(0.01)
|
86 |
+
# # scene.add_geometry(geom)
|
87 |
+
# table = trimesh.creation.box(extents=[1.0, 1.0, 0.02])
|
88 |
+
# table.apply_translation([0.5, 0, -0.01])
|
89 |
+
# table.visual.vertex_colors = [150, 111, 87, 125]
|
90 |
+
# scene.add_geometry(table)
|
91 |
+
# # bounds = trimesh.creation.box(extents=[4.0, 4.0, 4.0])
|
92 |
+
# # bounds = trimesh.creation.icosphere(subdivisions=3, radius=3.1)
|
93 |
+
# # bounds.apply_translation([0, 0, 0])
|
94 |
+
# # bounds.visual.vertex_colors = [30, 30, 30, 30]
|
95 |
+
# # scene.add_geometry(bounds)
|
96 |
+
# # RT_4x4 = np.array([[-0.39560353822208355, -0.9183993826406329, 0.006357240869497738, 0.2651463080169481],
|
97 |
+
# # [-0.797630370081598, 0.3401340617616391, -0.4980909683511864, 0.2225696480721997],
|
98 |
+
# # [0.45528412367406523, -0.2021172778236285, -0.8671014777611122, 0.9449050652025951],
|
99 |
+
# # [0.0, 0.0, 0.0, 1.0]])
|
100 |
+
# # RT_4x4 = np.linalg.inv(RT_4x4)
|
101 |
+
# # RT_4x4 = RT_4x4 @ np.diag([1, -1, -1, 1])
|
102 |
+
# # scene.camera_transform = RT_4x4
|
103 |
+
#
|
104 |
+
# mesh_list = trimesh.util.concatenate(scene.dump())
|
105 |
+
# print(mesh_list)
|
106 |
+
# trimesh.io.export.export_mesh(mesh_list, scene_filename, file_type='obj')
|
107 |
+
|
108 |
+
return scene_filename
|
109 |
+
|
110 |
+
|
111 |
+
args = OmegaConf.create()
|
112 |
+
args.base_config_file = "./configs/base.yaml"
|
113 |
+
args.config_file = "./configs/conditional_pose_diffusion.yaml"
|
114 |
+
args.checkpoint_id = "ConditionalPoseDiffusion"
|
115 |
+
args.eval_random_seed = 42
|
116 |
+
args.num_samples = 1
|
117 |
+
|
118 |
+
base_cfg = OmegaConf.load(args.base_config_file)
|
119 |
+
cfg = OmegaConf.load(args.config_file)
|
120 |
+
cfg = OmegaConf.merge(base_cfg, cfg)
|
121 |
+
|
122 |
+
infer_wrapper = Infer_Wrapper(args, cfg)
|
123 |
+
|
124 |
+
demo = gr.Interface(
|
125 |
+
fn=infer_wrapper.run,
|
126 |
+
inputs=gr.Slider(0, len(infer_wrapper.dataset)),
|
127 |
+
# clear color range [0-1.0]
|
128 |
+
outputs=gr.Model3D(clear_color=[0, 0, 0, 0], label="3D Model")
|
129 |
+
)
|
130 |
+
|
131 |
+
demo.launch()
|
configs/base.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
base_dirs:
|
2 |
+
data: data
|
3 |
+
wandb_dir: wandb_logs
|
configs/conditional_pose_diffusion.yaml
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
random_seed: 1
|
2 |
+
|
3 |
+
WANDB:
|
4 |
+
project: StructDiffusion
|
5 |
+
save_dir: ${base_dirs.wandb_dir}
|
6 |
+
name: conditional_pose_diffusion
|
7 |
+
|
8 |
+
DATASET:
|
9 |
+
data_root: ${base_dirs.data}
|
10 |
+
vocab_dir: ${base_dirs.data}/type_vocabs_coarse.json
|
11 |
+
|
12 |
+
# important
|
13 |
+
use_virtual_structure_frame: True
|
14 |
+
ignore_distractor_objects: True
|
15 |
+
ignore_rgb: True
|
16 |
+
|
17 |
+
# the following are determined by the dataset
|
18 |
+
max_num_target_objects: 7
|
19 |
+
max_num_distractor_objects: 5
|
20 |
+
max_num_shape_parameters: 5
|
21 |
+
# set to zeros because they are not used for now
|
22 |
+
max_num_rearrange_features: 0
|
23 |
+
max_num_anchor_features: 0
|
24 |
+
|
25 |
+
num_pts: 1024
|
26 |
+
filter_num_moved_objects_range:
|
27 |
+
data_augmentation: False
|
28 |
+
|
29 |
+
DATALOADER:
|
30 |
+
batch_size: 64
|
31 |
+
num_workers: 8
|
32 |
+
pin_memory: True
|
33 |
+
|
34 |
+
MODEL:
|
35 |
+
# transformer encoder
|
36 |
+
encoder_input_dim: 256
|
37 |
+
num_attention_heads: 8
|
38 |
+
encoder_hidden_dim: 512
|
39 |
+
encoder_dropout: 0.0
|
40 |
+
encoder_activation: relu
|
41 |
+
encoder_num_layers: 8
|
42 |
+
# output head
|
43 |
+
structure_dropout: 0
|
44 |
+
object_dropout: 0
|
45 |
+
# pc encoder
|
46 |
+
ignore_rgb: ${DATASET.ignore_rgb}
|
47 |
+
pc_emb_dim: 256
|
48 |
+
posed_pc_emb_dim: 80
|
49 |
+
# pose encoder
|
50 |
+
pose_emb_dim: 80
|
51 |
+
# language
|
52 |
+
word_emb_dim: 160
|
53 |
+
# diffusion step
|
54 |
+
time_emb_dim: 80
|
55 |
+
# sequence embeddings
|
56 |
+
# max_num_target_objects (+ max_num_distractor_objects if not ignore_distractor_objects)
|
57 |
+
max_seq_size: 7
|
58 |
+
max_token_type_size: 4
|
59 |
+
seq_pos_emb_dim: 8
|
60 |
+
seq_type_emb_dim: 8
|
61 |
+
# virtual frame
|
62 |
+
use_virtual_structure_frame: ${DATASET.use_virtual_structure_frame}
|
63 |
+
|
64 |
+
NOISE_SCHEDULE:
|
65 |
+
timesteps: 200
|
66 |
+
|
67 |
+
LOSS:
|
68 |
+
type: huber
|
69 |
+
|
70 |
+
OPTIMIZER:
|
71 |
+
lr: 0.0001
|
72 |
+
weight_decay: 0 #0.0001
|
73 |
+
# lr_restart: 3000
|
74 |
+
# warmup: 10
|
75 |
+
|
76 |
+
TRAINER:
|
77 |
+
max_epochs: 200
|
78 |
+
gradient_clip_val: 1.0
|
79 |
+
gpus: 1
|
80 |
+
deterministic: False
|
81 |
+
# enable_progress_bar: False
|
configs/pairwise_collision.yaml
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
random_seed: 1
|
2 |
+
|
3 |
+
WANDB:
|
4 |
+
project: StructDiffusion
|
5 |
+
save_dir: ${base_dirs.wandb_dir}
|
6 |
+
name: pairwise_collision
|
7 |
+
|
8 |
+
DATASET:
|
9 |
+
urdf_pc_idx_file: ${base_dirs.pairwise_collision_data}/urdf_pc_idx.pkl
|
10 |
+
collision_data_dir: ${base_dirs.pairwise_collision_data}
|
11 |
+
|
12 |
+
# important
|
13 |
+
num_pts: 1024
|
14 |
+
num_scene_pts: 2048
|
15 |
+
normalize_pc: True
|
16 |
+
random_rotation: True
|
17 |
+
data_augmentation: False
|
18 |
+
|
19 |
+
DATALOADER:
|
20 |
+
batch_size: 32
|
21 |
+
num_workers: 8
|
22 |
+
pin_memory: True
|
23 |
+
|
24 |
+
MODEL:
|
25 |
+
max_num_objects: 2
|
26 |
+
include_env_pc: False
|
27 |
+
pct_random_sampling: True
|
28 |
+
|
29 |
+
LOSS:
|
30 |
+
type: Focal
|
31 |
+
focal_gamma: 2
|
32 |
+
|
33 |
+
OPTIMIZER:
|
34 |
+
lr: 0.0001
|
35 |
+
weight_decay: 0
|
36 |
+
|
37 |
+
TRAINER:
|
38 |
+
max_epochs: 200
|
39 |
+
gradient_clip_val: 1.0
|
40 |
+
gpus: 1
|
41 |
+
deterministic: False
|
42 |
+
# enable_progress_bar: False
|
data/data00000000.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:947574252625d338b9f37217eacf61f520136e27b458b6d3e65330339e8b299c
|
3 |
+
size 1271489
|
data/data00000002.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3302432de555fed767c5b0d99c35ca01d5e4ac38cf4a0760b8ccb456b432e0e0
|
3 |
+
size 3235242
|
data/data00000003.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0b907ba7c3a17f98a438617b462b2a4d3d3f8593c2dc47feb5a6cc3da8c034fc
|
3 |
+
size 2059708
|
data/data00000004.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d8ec0136dd4055d304e9b7f5697b79613099b8f8f1e5eec94281f22d8d47cca1
|
3 |
+
size 2591656
|
data/data00000006.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0e74ebf185b0af58df0fa2483d5fd58a12b3b62ccac27ff665f35c5c7a13b8d8
|
3 |
+
size 1572332
|
data/data00000008.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:db015354a9d53e6fbaf0b040ce226484150b0af226a5c13a0b9f5cb9961db73c
|
3 |
+
size 2167265
|
data/data00000009.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:990ad13f423d9089b30de81d002d23d9d00cf3e007fd7073793cbec03c456ebb
|
3 |
+
size 3607752
|
data/data00000012.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:93161f5666c54dbc259c9efa516b67340613b592a1ed42e6c63d4cc8a495002a
|
3 |
+
size 2525622
|
data/data00000013.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:94c9cfe6d9f0df176eb0a3baccdf53c7e6e5fc807e5e7ea9e138ad7159f500d9
|
3 |
+
size 1715352
|
data/data00000015.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9aab522f9ace1a03b1705fe3bd693b589971d133d45c232ef9ec53842a540bfa
|
3 |
+
size 2647026
|
data/type_vocabs_coarse.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"class": {"Basket": 0, "BeerBottle": 1, "Book": 2, "Bottle": 3, "Bowl": 4, "Calculator": 5, "Candle": 6, "CellPhone": 7, "ComputerMouse": 8, "Controller": 9, "Cup": 10, "Donut": 11, "Fork": 12, "Hammer": 13, "Knife": 14, "Marker": 15, "MilkCarton": 16, "Mug": 17, "Pan": 18, "Pen": 19, "PillBottle": 20, "Plate": 21, "PowerStrip": 22, "Scissors": 23, "SoapBottle": 24, "SodaCan": 25, "Spoon": 26, "Stapler": 27, "Teapot": 28, "VideoGameController": 29, "WineBottle": 30, "CanOpener":31, "Fruit": 32}, "scene": {"dinner": 0}, "size": {"L": 0, "M": 1, "S": 2}, "color": {"blue": 0, "cyan": 1, "green": 2, "magenta": 3, "red": 4, "yellow": 5}, "material": {"glass": 0, "metal": 1, "plastic": 2}, "comparator": {"less": 1, "greater": 2, "equal": 3}, "radius": [0.0, 0.5, 3], "position_x": [-0.1, 1.0, 3], "position_y": [-0.5, 0.5, 3], "rotation": [-3.15, 3.15, 4], "height": [0.0, 0.5, 10], "volumn": [0.0, 0.015, 10], "uniform_angle": {"False": 0, "True": 1}, "face_center": {"False": 0, "True": 1}, "angle_ratio": {"0.5": 0, "1.0": 1}, "shape": {"circle": 0, "line": 1, "tower": 2, "dinner": 3}, "obj_x": [-1.0, 1.0, 200], "obj_y": [-1.0, 1.0, 200], "obj_z": [-1.0, 1.0, 200], "obj_rr": [-3.15, 3.15, 360], "obj_rp": [-3.15, 3.15, 360], "obj_ry": [-3.15, 3.15, 360],"struct_x": [-1.0, 1.0, 200], "struct_y": [-1.0, 1.0, 200], "struct_z": [-1.0, 1.0, 200], "struct_rr": [-3.15, 3.15, 360], "struct_rp": [-3.15, 3.15, 360], "struct_ry": [-3.15, 3.15, 360]}
|
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
python3-opencv
|
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy==1.21
|
2 |
+
h5py==2.10.0
|
3 |
+
opencv-python
|
4 |
+
open3d
|
5 |
+
trimesh==3.10.2
|
6 |
+
pyglet==1.5.0
|
7 |
+
pybullet==3.1.7
|
8 |
+
nvisii==1.1.70
|
9 |
+
openpyxl
|
10 |
+
pytorch_lightning==1.6.1
|
11 |
+
wandb===0.13.10
|
12 |
+
pytorch3d==0.3.0
|
13 |
+
omegaconf==2.2.2
|
scripts/infer.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import pytorch_lightning as pl
|
6 |
+
from omegaconf import OmegaConf
|
7 |
+
|
8 |
+
from StructDiffusion.data.semantic_arrangement import SemanticArrangementDataset
|
9 |
+
from StructDiffusion.language.tokenizer import Tokenizer
|
10 |
+
from StructDiffusion.models.pl_models import ConditionalPoseDiffusionModel
|
11 |
+
from StructDiffusion.diffusion.sampler import Sampler
|
12 |
+
from StructDiffusion.diffusion.pose_conversion import get_struct_objs_poses
|
13 |
+
from StructDiffusion.utils.files import get_checkpoint_path_from_dir
|
14 |
+
from StructDiffusion.utils.batch_inference import move_pc_and_create_scene_simple, visualize_batch_pcs
|
15 |
+
|
16 |
+
|
17 |
+
def main(args, cfg):
|
18 |
+
|
19 |
+
pl.seed_everything(args.eval_random_seed)
|
20 |
+
|
21 |
+
device = (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
|
22 |
+
|
23 |
+
checkpoint_dir = os.path.join(cfg.WANDB.save_dir, cfg.WANDB.project, args.checkpoint_id, "checkpoints")
|
24 |
+
checkpoint_path = get_checkpoint_path_from_dir(checkpoint_dir)
|
25 |
+
|
26 |
+
if args.eval_mode == "infer":
|
27 |
+
|
28 |
+
tokenizer = Tokenizer(cfg.DATASET.vocab_dir)
|
29 |
+
# override ignore_rgb for visualization
|
30 |
+
cfg.DATASET.ignore_rgb = False
|
31 |
+
dataset = SemanticArrangementDataset(split="test", tokenizer=tokenizer, **cfg.DATASET)
|
32 |
+
|
33 |
+
sampler = Sampler(ConditionalPoseDiffusionModel, checkpoint_path, device)
|
34 |
+
|
35 |
+
data_idxs = np.random.permutation(len(dataset))
|
36 |
+
for di in data_idxs:
|
37 |
+
raw_datum = dataset.get_raw_data(di)
|
38 |
+
print(tokenizer.convert_structure_params_to_natural_language(raw_datum["sentence"]))
|
39 |
+
datum = dataset.convert_to_tensors(raw_datum, tokenizer)
|
40 |
+
batch = dataset.single_datum_to_batch(datum, args.num_samples, device, inference_mode=True)
|
41 |
+
|
42 |
+
num_poses = datum["goal_poses"].shape[0]
|
43 |
+
xs = sampler.sample(batch, num_poses)
|
44 |
+
|
45 |
+
struct_pose, pc_poses_in_struct = get_struct_objs_poses(xs[0])
|
46 |
+
new_obj_xyzs = move_pc_and_create_scene_simple(batch["pcs"], struct_pose, pc_poses_in_struct)
|
47 |
+
visualize_batch_pcs(new_obj_xyzs, args.num_samples, limit_B=10, trimesh=True)
|
48 |
+
|
49 |
+
|
50 |
+
if __name__ == "__main__":
|
51 |
+
parser = argparse.ArgumentParser(description="infer")
|
52 |
+
parser.add_argument("--base_config_file", help='base config yaml file',
|
53 |
+
default='../configs/base.yaml',
|
54 |
+
type=str)
|
55 |
+
parser.add_argument("--config_file", help='config yaml file',
|
56 |
+
default='../configs/conditional_pose_diffusion.yaml',
|
57 |
+
type=str)
|
58 |
+
parser.add_argument("--checkpoint_id",
|
59 |
+
default="ConditionalPoseDiffusion",
|
60 |
+
type=str)
|
61 |
+
parser.add_argument("--eval_mode",
|
62 |
+
default="infer",
|
63 |
+
type=str)
|
64 |
+
parser.add_argument("--eval_random_seed",
|
65 |
+
default=42,
|
66 |
+
type=int)
|
67 |
+
parser.add_argument("--num_samples",
|
68 |
+
default=10,
|
69 |
+
type=int)
|
70 |
+
args = parser.parse_args()
|
71 |
+
|
72 |
+
base_cfg = OmegaConf.load(args.base_config_file)
|
73 |
+
cfg = OmegaConf.load(args.config_file)
|
74 |
+
cfg = OmegaConf.merge(base_cfg, cfg)
|
75 |
+
|
76 |
+
main(args, cfg)
|
77 |
+
|
78 |
+
|
scripts/infer_with_discriminator.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import pytorch_lightning as pl
|
6 |
+
from omegaconf import OmegaConf
|
7 |
+
|
8 |
+
from StructDiffusion.data.semantic_arrangement import SemanticArrangementDataset
|
9 |
+
from StructDiffusion.language.tokenizer import Tokenizer
|
10 |
+
from StructDiffusion.models.pl_models import ConditionalPoseDiffusionModel, PairwiseCollisionModel
|
11 |
+
from StructDiffusion.diffusion.sampler import SamplerV2
|
12 |
+
from StructDiffusion.diffusion.pose_conversion import get_struct_objs_poses
|
13 |
+
from StructDiffusion.utils.files import get_checkpoint_path_from_dir
|
14 |
+
from StructDiffusion.utils.batch_inference import move_pc_and_create_scene_simple, visualize_batch_pcs
|
15 |
+
|
16 |
+
|
17 |
+
def main(args, cfg):
|
18 |
+
|
19 |
+
pl.seed_everything(args.eval_random_seed)
|
20 |
+
|
21 |
+
device = (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
|
22 |
+
|
23 |
+
diffusion_checkpoint_path = get_checkpoint_path_from_dir(os.path.join(cfg.WANDB.save_dir, cfg.WANDB.project, args.diffusion_checkpoint_id, "checkpoints"))
|
24 |
+
collision_checkpoint_path = get_checkpoint_path_from_dir(os.path.join(cfg.WANDB.save_dir, cfg.WANDB.project, args.collision_checkpoint_id, "checkpoints"))
|
25 |
+
|
26 |
+
if args.eval_mode == "infer":
|
27 |
+
|
28 |
+
tokenizer = Tokenizer(cfg.DATASET.vocab_dir)
|
29 |
+
# override ignore_rgb for visualization
|
30 |
+
cfg.DATASET.ignore_rgb = False
|
31 |
+
dataset = SemanticArrangementDataset(split="test", tokenizer=tokenizer, **cfg.DATASET)
|
32 |
+
|
33 |
+
sampler = SamplerV2(ConditionalPoseDiffusionModel, diffusion_checkpoint_path,
|
34 |
+
PairwiseCollisionModel, collision_checkpoint_path, device)
|
35 |
+
|
36 |
+
data_idxs = np.random.permutation(len(dataset))
|
37 |
+
for di in data_idxs:
|
38 |
+
raw_datum = dataset.get_raw_data(di)
|
39 |
+
print(tokenizer.convert_structure_params_to_natural_language(raw_datum["sentence"]))
|
40 |
+
datum = dataset.convert_to_tensors(raw_datum, tokenizer)
|
41 |
+
batch = dataset.single_datum_to_batch(datum, args.num_samples, device, inference_mode=True)
|
42 |
+
|
43 |
+
num_poses = datum["goal_poses"].shape[0]
|
44 |
+
struct_pose, pc_poses_in_struct = sampler.sample(batch, num_poses)
|
45 |
+
|
46 |
+
new_obj_xyzs = move_pc_and_create_scene_simple(batch["pcs"], struct_pose, pc_poses_in_struct)
|
47 |
+
visualize_batch_pcs(new_obj_xyzs, args.num_samples, limit_B=10, trimesh=True)
|
48 |
+
|
49 |
+
|
50 |
+
if __name__ == "__main__":
|
51 |
+
parser = argparse.ArgumentParser(description="infer")
|
52 |
+
parser.add_argument("--base_config_file", help='base config yaml file',
|
53 |
+
default='../configs/base.yaml',
|
54 |
+
type=str)
|
55 |
+
parser.add_argument("--config_file", help='config yaml file',
|
56 |
+
default='../configs/conditional_pose_diffusion.yaml',
|
57 |
+
type=str)
|
58 |
+
parser.add_argument("--diffusion_checkpoint_id",
|
59 |
+
default="ConditionalPoseDiffusion",
|
60 |
+
type=str)
|
61 |
+
parser.add_argument("--collision_checkpoint_id",
|
62 |
+
default="curhl56k",
|
63 |
+
type=str)
|
64 |
+
parser.add_argument("--eval_mode",
|
65 |
+
default="infer",
|
66 |
+
type=str)
|
67 |
+
parser.add_argument("--eval_random_seed",
|
68 |
+
default=42,
|
69 |
+
type=int)
|
70 |
+
parser.add_argument("--num_samples",
|
71 |
+
default=10,
|
72 |
+
type=int)
|
73 |
+
args = parser.parse_args()
|
74 |
+
|
75 |
+
base_cfg = OmegaConf.load(args.base_config_file)
|
76 |
+
cfg = OmegaConf.load(args.config_file)
|
77 |
+
cfg = OmegaConf.merge(base_cfg, cfg)
|
78 |
+
|
79 |
+
main(args, cfg)
|
80 |
+
|
81 |
+
|
scripts/train_discriminator.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
from torch.utils.data import DataLoader
|
4 |
+
from omegaconf import OmegaConf
|
5 |
+
import pytorch_lightning as pl
|
6 |
+
from pytorch_lightning.loggers import WandbLogger
|
7 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
8 |
+
|
9 |
+
from StructDiffusion.data.pairwise_collision import PairwiseCollisionDataset
|
10 |
+
from StructDiffusion.models.pl_models import PairwiseCollisionModel
|
11 |
+
|
12 |
+
|
13 |
+
def main(cfg):
|
14 |
+
|
15 |
+
pl.seed_everything(cfg.random_seed)
|
16 |
+
|
17 |
+
wandb_logger = WandbLogger(**cfg.WANDB)
|
18 |
+
wandb_logger.experiment.config.update(cfg)
|
19 |
+
checkpoint_callback = ModelCheckpoint()
|
20 |
+
|
21 |
+
full_dataset = PairwiseCollisionDataset(**cfg.DATASET)
|
22 |
+
train_dataset, valid_dataset = torch.utils.data.random_split(full_dataset, [int(len(full_dataset) * 0.7), len(full_dataset) - int(len(full_dataset) * 0.7)])
|
23 |
+
train_dataloader = DataLoader(train_dataset, shuffle=True, **cfg.DATALOADER)
|
24 |
+
valid_dataloader = DataLoader(valid_dataset, shuffle=False, **cfg.DATALOADER)
|
25 |
+
|
26 |
+
model = PairwiseCollisionModel(cfg.MODEL, cfg.LOSS, cfg.OPTIMIZER, cfg.DATASET)
|
27 |
+
|
28 |
+
trainer = pl.Trainer(logger=wandb_logger, callbacks=[checkpoint_callback], **cfg.TRAINER)
|
29 |
+
|
30 |
+
trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=valid_dataloader)
|
31 |
+
|
32 |
+
|
33 |
+
if __name__ == "__main__":
|
34 |
+
parser = argparse.ArgumentParser(description="train")
|
35 |
+
parser.add_argument("--base_config_file", help='base config yaml file',
|
36 |
+
default='../configs/base.yaml',
|
37 |
+
type=str)
|
38 |
+
parser.add_argument("--config_file", help='config yaml file',
|
39 |
+
default='../configs/pairwise_collision.yaml',
|
40 |
+
type=str)
|
41 |
+
args = parser.parse_args()
|
42 |
+
base_cfg = OmegaConf.load(args.base_config_file)
|
43 |
+
cfg = OmegaConf.load(args.config_file)
|
44 |
+
cfg = OmegaConf.merge(base_cfg, cfg)
|
45 |
+
|
46 |
+
main(cfg)
|
scripts/train_generator.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import DataLoader
|
2 |
+
import argparse
|
3 |
+
from omegaconf import OmegaConf
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
from pytorch_lightning.loggers import WandbLogger
|
6 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
7 |
+
|
8 |
+
from StructDiffusion.data.semantic_arrangement import SemanticArrangementDataset
|
9 |
+
from StructDiffusion.language.tokenizer import Tokenizer
|
10 |
+
from StructDiffusion.models.pl_models import ConditionalPoseDiffusionModel
|
11 |
+
|
12 |
+
|
13 |
+
def main(cfg):
|
14 |
+
|
15 |
+
pl.seed_everything(cfg.random_seed)
|
16 |
+
|
17 |
+
wandb_logger = WandbLogger(**cfg.WANDB)
|
18 |
+
wandb_logger.experiment.config.update(cfg)
|
19 |
+
checkpoint_callback = ModelCheckpoint()
|
20 |
+
|
21 |
+
tokenizer = Tokenizer(cfg.DATASET.vocab_dir)
|
22 |
+
vocab_size = tokenizer.get_vocab_size()
|
23 |
+
|
24 |
+
train_dataset = SemanticArrangementDataset(split="train", tokenizer=tokenizer, **cfg.DATASET)
|
25 |
+
valid_dataset = SemanticArrangementDataset(split="valid", tokenizer=tokenizer, **cfg.DATASET)
|
26 |
+
train_dataloader = DataLoader(train_dataset, shuffle=True, **cfg.DATALOADER)
|
27 |
+
valid_dataloader = DataLoader(valid_dataset, shuffle=False, **cfg.DATALOADER)
|
28 |
+
|
29 |
+
model = ConditionalPoseDiffusionModel(vocab_size, cfg.MODEL, cfg.LOSS, cfg.NOISE_SCHEDULE, cfg.OPTIMIZER)
|
30 |
+
|
31 |
+
trainer = pl.Trainer(logger=wandb_logger, callbacks=[checkpoint_callback], **cfg.TRAINER)
|
32 |
+
|
33 |
+
trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=valid_dataloader)
|
34 |
+
|
35 |
+
|
36 |
+
if __name__ == "__main__":
|
37 |
+
parser = argparse.ArgumentParser(description="train")
|
38 |
+
parser.add_argument("--base_config_file", help='base config yaml file',
|
39 |
+
default='../configs/base.yaml',
|
40 |
+
type=str)
|
41 |
+
parser.add_argument("--config_file", help='config yaml file',
|
42 |
+
default='../configs/conditional_pose_diffusion.yaml',
|
43 |
+
type=str)
|
44 |
+
args = parser.parse_args()
|
45 |
+
base_cfg = OmegaConf.load(args.base_config_file)
|
46 |
+
cfg = OmegaConf.load(args.config_file)
|
47 |
+
cfg = OmegaConf.merge(base_cfg, cfg)
|
48 |
+
|
49 |
+
main(cfg)
|
src/StructDiffusion/__init__.py
ADDED
File without changes
|
src/StructDiffusion/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (171 Bytes). View file
|
|
src/StructDiffusion/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (175 Bytes). View file
|
|
src/StructDiffusion/data/__init__.py
ADDED
File without changes
|
src/StructDiffusion/data/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (176 Bytes). View file
|
|
src/StructDiffusion/data/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (180 Bytes). View file
|
|
src/StructDiffusion/data/__pycache__/pairwise_collision.cpython-37.pyc
ADDED
Binary file (9.72 kB). View file
|
|
src/StructDiffusion/data/__pycache__/semantic_arrangement.cpython-37.pyc
ADDED
Binary file (17.1 kB). View file
|
|
src/StructDiffusion/data/__pycache__/semantic_arrangement.cpython-38.pyc
ADDED
Binary file (17.1 kB). View file
|
|
src/StructDiffusion/data/__pycache__/semantic_arrangement_demo.cpython-38.pyc
ADDED
Binary file (16.4 kB). View file
|
|
src/StructDiffusion/data/pairwise_collision.py
ADDED
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import h5py
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import trimesh
|
6 |
+
import torch
|
7 |
+
import json
|
8 |
+
from collections import defaultdict
|
9 |
+
import tqdm
|
10 |
+
import pickle
|
11 |
+
from random import shuffle
|
12 |
+
|
13 |
+
# Local imports
|
14 |
+
from StructDiffusion.utils.rearrangement import show_pcs, get_pts, array_to_tensor
|
15 |
+
from StructDiffusion.utils.pointnet import pc_normalize
|
16 |
+
|
17 |
+
import StructDiffusion.utils.brain2.camera as cam
|
18 |
+
import StructDiffusion.utils.brain2.image as img
|
19 |
+
import StructDiffusion.utils.transformations as tra
|
20 |
+
|
21 |
+
|
22 |
+
def load_pairwise_collision_data(h5_filename):
|
23 |
+
|
24 |
+
fh = h5py.File(h5_filename, 'r')
|
25 |
+
data_dict = {}
|
26 |
+
data_dict["obj1_info"] = eval(fh["obj1_info"][()])
|
27 |
+
data_dict["obj2_info"] = eval(fh["obj2_info"][()])
|
28 |
+
data_dict["obj1_poses"] = fh["obj1_poses"][:]
|
29 |
+
data_dict["obj2_poses"] = fh["obj2_poses"][:]
|
30 |
+
data_dict["intersection_labels"] = fh["intersection_labels"][:]
|
31 |
+
|
32 |
+
return data_dict
|
33 |
+
|
34 |
+
|
35 |
+
class PairwiseCollisionDataset(torch.utils.data.Dataset):
|
36 |
+
|
37 |
+
def __init__(self, urdf_pc_idx_file, collision_data_dir, random_rotation=True,
|
38 |
+
num_pts=1024, normalize_pc=True, num_scene_pts=2048, data_augmentation=False,
|
39 |
+
debug=False):
|
40 |
+
|
41 |
+
# load dictionary mapping from urdf to list of pc data, each sample is
|
42 |
+
# {"step_t": step_t, "obj": obj, "filename": filename}
|
43 |
+
with open(urdf_pc_idx_file, "rb") as fh:
|
44 |
+
self.urdf_to_pc_data = pickle.load(fh)
|
45 |
+
# filter out broken files
|
46 |
+
for urdf in self.urdf_to_pc_data:
|
47 |
+
valid_pc_data = []
|
48 |
+
for pd in self.urdf_to_pc_data[urdf]:
|
49 |
+
filename = pd["filename"]
|
50 |
+
if "data00026058" in filename or "data00011415" in filename or "data00026061" in filename or "data00700565" in filename or "data00505290" in filename:
|
51 |
+
continue
|
52 |
+
valid_pc_data.append(pd)
|
53 |
+
if valid_pc_data:
|
54 |
+
self.urdf_to_pc_data[urdf] = valid_pc_data
|
55 |
+
|
56 |
+
# build data index
|
57 |
+
# each sample is a tuple of (collision filename, idx for the labels and poses)
|
58 |
+
if collision_data_dir is not None:
|
59 |
+
self.data_idxs = self.build_data_idxs(collision_data_dir)
|
60 |
+
else:
|
61 |
+
print("WARNING: collision_data_dir is None")
|
62 |
+
|
63 |
+
self.num_pts = num_pts
|
64 |
+
self.debug = debug
|
65 |
+
self.normalize_pc = normalize_pc
|
66 |
+
self.num_scene_pts = num_scene_pts
|
67 |
+
self.random_rotation = random_rotation
|
68 |
+
|
69 |
+
# Noise
|
70 |
+
self.data_augmentation = data_augmentation
|
71 |
+
# additive noise
|
72 |
+
self.gp_rescale_factor_range = [12, 20]
|
73 |
+
self.gaussian_scale_range = [0., 0.003]
|
74 |
+
# multiplicative noise
|
75 |
+
self.gamma_shape = 1000.
|
76 |
+
self.gamma_scale = 0.001
|
77 |
+
|
78 |
+
def build_data_idxs(self, collision_data_dir):
|
79 |
+
print("Load collision data...")
|
80 |
+
positive_data = []
|
81 |
+
negative_data = []
|
82 |
+
for filename in tqdm.tqdm(os.listdir(collision_data_dir)):
|
83 |
+
if "h5" not in filename:
|
84 |
+
continue
|
85 |
+
h5_filename = os.path.join(collision_data_dir, filename)
|
86 |
+
data_dict = load_pairwise_collision_data(h5_filename)
|
87 |
+
obj1_urdf = data_dict["obj1_info"]["urdf"]
|
88 |
+
obj2_urdf = data_dict["obj2_info"]["urdf"]
|
89 |
+
if obj1_urdf not in self.urdf_to_pc_data:
|
90 |
+
print("no pc data for urdf:", obj1_urdf)
|
91 |
+
continue
|
92 |
+
if obj2_urdf not in self.urdf_to_pc_data:
|
93 |
+
print("no pc data for urdf:", obj2_urdf)
|
94 |
+
continue
|
95 |
+
for idx, l in enumerate(data_dict["intersection_labels"]):
|
96 |
+
if l:
|
97 |
+
# intersection
|
98 |
+
positive_data.append((h5_filename, idx))
|
99 |
+
else:
|
100 |
+
negative_data.append((h5_filename, idx))
|
101 |
+
print("Num pairwise intersections:", len(positive_data))
|
102 |
+
print("Num pairwise no intersections:", len(negative_data))
|
103 |
+
|
104 |
+
if len(negative_data) != len(positive_data):
|
105 |
+
min_len = min(len(negative_data), len(positive_data))
|
106 |
+
positive_data = [positive_data[i] for i in np.random.permutation(len(positive_data))[:min_len]]
|
107 |
+
negative_data = [negative_data[i] for i in np.random.permutation(len(negative_data))[:min_len]]
|
108 |
+
print("after balancing")
|
109 |
+
print("Num pairwise intersections:", len(positive_data))
|
110 |
+
print("Num pairwise no intersections:", len(negative_data))
|
111 |
+
|
112 |
+
return positive_data + negative_data
|
113 |
+
|
114 |
+
def create_urdf_pc_idxs(self, urdf_pc_idx_file, data_roots, index_roots):
|
115 |
+
print("Load pc data")
|
116 |
+
arrangement_steps = []
|
117 |
+
for split in ["train"]:
|
118 |
+
for data_root, index_root in zip(data_roots, index_roots):
|
119 |
+
arrangement_indices_file = os.path.join(data_root, index_root,"{}_arrangement_indices_file_all.txt".format(split))
|
120 |
+
if os.path.exists(arrangement_indices_file):
|
121 |
+
with open(arrangement_indices_file, "r") as fh:
|
122 |
+
arrangement_steps.extend([(os.path.join(data_root, f[0]), f[1]) for f in eval(fh.readline().strip())])
|
123 |
+
else:
|
124 |
+
print("{} does not exist".format(arrangement_indices_file))
|
125 |
+
|
126 |
+
urdf_to_pc_data = defaultdict(list)
|
127 |
+
for filename, step_t in tqdm.tqdm(arrangement_steps):
|
128 |
+
h5 = h5py.File(filename, 'r')
|
129 |
+
ids = self._get_ids(h5)
|
130 |
+
# moved_objs = h5['moved_objs'][()].split(',')
|
131 |
+
all_objs = sorted([o for o in ids.keys() if "object_" in o])
|
132 |
+
goal_specification = json.loads(str(np.array(h5["goal_specification"])))
|
133 |
+
obj_infos = goal_specification["rearrange"]["objects"] + goal_specification["anchor"]["objects"] + goal_specification["distract"]["objects"]
|
134 |
+
for obj, obj_info in zip(all_objs, obj_infos):
|
135 |
+
urdf_to_pc_data[obj_info["urdf"]].append({"step_t": step_t, "obj": obj, "filename": filename})
|
136 |
+
|
137 |
+
with open(urdf_pc_idx_file, "wb") as fh:
|
138 |
+
pickle.dump(urdf_to_pc_data, fh)
|
139 |
+
|
140 |
+
return urdf_to_pc_data
|
141 |
+
|
142 |
+
def add_noise_to_depth(self, depth_img):
|
143 |
+
""" add depth noise """
|
144 |
+
multiplicative_noise = np.random.gamma(self.gamma_shape, self.gamma_scale)
|
145 |
+
depth_img = multiplicative_noise * depth_img
|
146 |
+
return depth_img
|
147 |
+
|
148 |
+
def add_noise_to_xyz(self, xyz_img, depth_img):
|
149 |
+
""" TODO: remove this code or at least celean it up"""
|
150 |
+
xyz_img = xyz_img.copy()
|
151 |
+
H, W, C = xyz_img.shape
|
152 |
+
gp_rescale_factor = np.random.randint(self.gp_rescale_factor_range[0],
|
153 |
+
self.gp_rescale_factor_range[1])
|
154 |
+
gp_scale = np.random.uniform(self.gaussian_scale_range[0],
|
155 |
+
self.gaussian_scale_range[1])
|
156 |
+
small_H, small_W = (np.array([H, W]) / gp_rescale_factor).astype(int)
|
157 |
+
additive_noise = np.random.normal(loc=0.0, scale=gp_scale, size=(small_H, small_W, C))
|
158 |
+
additive_noise = cv2.resize(additive_noise, (W, H), interpolation=cv2.INTER_CUBIC)
|
159 |
+
xyz_img[depth_img > 0, :] += additive_noise[depth_img > 0, :]
|
160 |
+
return xyz_img
|
161 |
+
|
162 |
+
def _get_images(self, h5, idx, ee=True):
|
163 |
+
if ee:
|
164 |
+
RGB, DEPTH, SEG = "ee_rgb", "ee_depth", "ee_seg"
|
165 |
+
DMIN, DMAX = "ee_depth_min", "ee_depth_max"
|
166 |
+
else:
|
167 |
+
RGB, DEPTH, SEG = "rgb", "depth", "seg"
|
168 |
+
DMIN, DMAX = "depth_min", "depth_max"
|
169 |
+
dmin = h5[DMIN][idx]
|
170 |
+
dmax = h5[DMAX][idx]
|
171 |
+
rgb1 = img.PNGToNumpy(h5[RGB][idx])[:, :, :3] / 255. # remove alpha
|
172 |
+
depth1 = h5[DEPTH][idx] / 20000. * (dmax - dmin) + dmin
|
173 |
+
seg1 = img.PNGToNumpy(h5[SEG][idx])
|
174 |
+
|
175 |
+
valid1 = np.logical_and(depth1 > 0.1, depth1 < 2.)
|
176 |
+
|
177 |
+
# proj_matrix = h5['proj_matrix'][()]
|
178 |
+
camera = cam.get_camera_from_h5(h5)
|
179 |
+
if self.data_augmentation:
|
180 |
+
depth1 = self.add_noise_to_depth(depth1)
|
181 |
+
|
182 |
+
xyz1 = cam.compute_xyz(depth1, camera)
|
183 |
+
if self.data_augmentation:
|
184 |
+
xyz1 = self.add_noise_to_xyz(xyz1, depth1)
|
185 |
+
|
186 |
+
# Transform the point cloud
|
187 |
+
# Here it is...
|
188 |
+
# CAM_POSE = "ee_cam_pose" if ee else "cam_pose"
|
189 |
+
CAM_POSE = "ee_camera_view" if ee else "camera_view"
|
190 |
+
cam_pose = h5[CAM_POSE][idx]
|
191 |
+
if ee:
|
192 |
+
# ee_camera_view has 0s for x, y, z
|
193 |
+
cam_pos = h5["ee_cam_pose"][:][:3, 3]
|
194 |
+
cam_pose[:3, 3] = cam_pos
|
195 |
+
|
196 |
+
# Get transformed point cloud
|
197 |
+
h, w, d = xyz1.shape
|
198 |
+
xyz1 = xyz1.reshape(h * w, -1)
|
199 |
+
xyz1 = trimesh.transform_points(xyz1, cam_pose)
|
200 |
+
xyz1 = xyz1.reshape(h, w, -1)
|
201 |
+
|
202 |
+
scene1 = rgb1, depth1, seg1, valid1, xyz1
|
203 |
+
|
204 |
+
return scene1
|
205 |
+
|
206 |
+
def _get_ids(self, h5):
|
207 |
+
"""
|
208 |
+
get object ids
|
209 |
+
|
210 |
+
@param h5:
|
211 |
+
@return:
|
212 |
+
"""
|
213 |
+
ids = {}
|
214 |
+
for k in h5.keys():
|
215 |
+
if k.startswith("id_"):
|
216 |
+
ids[k[3:]] = h5[k][()]
|
217 |
+
return ids
|
218 |
+
|
219 |
+
def get_obj_pc(self, h5, step_t, obj):
|
220 |
+
scene = self._get_images(h5, step_t, ee=True)
|
221 |
+
rgb, depth, seg, valid, xyz = scene
|
222 |
+
|
223 |
+
# getting object point clouds
|
224 |
+
ids = self._get_ids(h5)
|
225 |
+
obj_mask = np.logical_and(seg == ids[obj], valid)
|
226 |
+
if np.sum(obj_mask) <= 0:
|
227 |
+
raise Exception
|
228 |
+
ok, obj_xyz, obj_rgb, _ = get_pts(xyz, rgb, obj_mask, num_pts=self.num_pts, to_tensor=False)
|
229 |
+
obj_pc_center = np.mean(obj_xyz, axis=0)
|
230 |
+
obj_pose = h5[obj][step_t]
|
231 |
+
|
232 |
+
obj_pc_pose = np.eye(4)
|
233 |
+
obj_pc_pose[:3, 3] = obj_pc_center[:3]
|
234 |
+
|
235 |
+
return obj_xyz, obj_rgb, obj_pc_pose, obj_pose
|
236 |
+
|
237 |
+
def __len__(self):
|
238 |
+
return len(self.data_idxs)
|
239 |
+
|
240 |
+
def __getitem__(self, idx):
|
241 |
+
collision_filename, collision_idx = self.data_idxs[idx]
|
242 |
+
collision_data_dict = load_pairwise_collision_data(collision_filename)
|
243 |
+
|
244 |
+
obj1_urdf = collision_data_dict["obj1_info"]["urdf"]
|
245 |
+
obj2_urdf = collision_data_dict["obj2_info"]["urdf"]
|
246 |
+
|
247 |
+
# TODO: find a better way to sample pc data?
|
248 |
+
obj1_pc_data = np.random.choice(self.urdf_to_pc_data[obj1_urdf])
|
249 |
+
obj2_pc_data = np.random.choice(self.urdf_to_pc_data[obj2_urdf])
|
250 |
+
|
251 |
+
obj1_xyz, obj1_rgb, obj1_pc_pose, obj1_pose = self.get_obj_pc(h5py.File(obj1_pc_data["filename"], "r"), obj1_pc_data["step_t"], obj1_pc_data["obj"])
|
252 |
+
obj2_xyz, obj2_rgb, obj2_pc_pose, obj2_pose = self.get_obj_pc(h5py.File(obj2_pc_data["filename"], "r"), obj2_pc_data["step_t"], obj2_pc_data["obj"])
|
253 |
+
|
254 |
+
obj1_c_pose = collision_data_dict["obj1_poses"][collision_idx]
|
255 |
+
obj2_c_pose = collision_data_dict["obj2_poses"][collision_idx]
|
256 |
+
label = collision_data_dict["intersection_labels"][collision_idx]
|
257 |
+
|
258 |
+
obj1_transform = obj1_c_pose @ np.linalg.inv(obj1_pose)
|
259 |
+
obj2_transform = obj2_c_pose @ np.linalg.inv(obj2_pose)
|
260 |
+
obj1_c_xyz = trimesh.transform_points(obj1_xyz, obj1_transform)
|
261 |
+
obj2_c_xyz = trimesh.transform_points(obj2_xyz, obj2_transform)
|
262 |
+
|
263 |
+
# if self.debug:
|
264 |
+
# show_pcs([obj1_c_xyz, obj2_c_xyz], [obj1_rgb, obj2_rgb], add_coordinate_frame=True)
|
265 |
+
|
266 |
+
###################################
|
267 |
+
obj_xyzs = [obj1_c_xyz, obj2_c_xyz]
|
268 |
+
shuffle(obj_xyzs)
|
269 |
+
|
270 |
+
num_indicator = 2
|
271 |
+
new_obj_xyzs = []
|
272 |
+
for oi, obj_xyz in enumerate(obj_xyzs):
|
273 |
+
obj_xyz = np.concatenate([obj_xyz, np.tile(np.eye(num_indicator)[oi], (obj_xyz.shape[0], 1))], axis=1)
|
274 |
+
new_obj_xyzs.append(obj_xyz)
|
275 |
+
scene_xyz = np.concatenate(new_obj_xyzs, axis=0)
|
276 |
+
|
277 |
+
# subsampling and normalizing pc
|
278 |
+
idx = np.random.randint(0, scene_xyz.shape[0], self.num_scene_pts)
|
279 |
+
scene_xyz = scene_xyz[idx]
|
280 |
+
if self.normalize_pc:
|
281 |
+
scene_xyz[:, 0:3] = pc_normalize(scene_xyz[:, 0:3])
|
282 |
+
|
283 |
+
if self.random_rotation:
|
284 |
+
scene_xyz[:, 0:3] = trimesh.transform_points(scene_xyz[:, 0:3], tra.euler_matrix(0, 0, np.random.uniform(low=0, high=2 * np.pi)))
|
285 |
+
|
286 |
+
###################################
|
287 |
+
scene_xyz = array_to_tensor(scene_xyz)
|
288 |
+
# convert to torch data
|
289 |
+
label = int(label)
|
290 |
+
|
291 |
+
if self.debug:
|
292 |
+
print("intersection:", label)
|
293 |
+
show_pcs([scene_xyz[:, 0:3]], [np.tile(np.array([0, 1, 0], dtype=np.float), (scene_xyz.shape[0], 1))], add_coordinate_frame=True)
|
294 |
+
|
295 |
+
datum = {
|
296 |
+
"scene_xyz": scene_xyz,
|
297 |
+
"label": torch.FloatTensor([label]),
|
298 |
+
}
|
299 |
+
return datum
|
300 |
+
|
301 |
+
# @staticmethod
|
302 |
+
# def collate_fn(data):
|
303 |
+
# """
|
304 |
+
# :param data:
|
305 |
+
# :return:
|
306 |
+
# """
|
307 |
+
#
|
308 |
+
# batched_data_dict = {}
|
309 |
+
# for key in ["is_circle"]:
|
310 |
+
# batched_data_dict[key] = torch.cat([dict[key] for dict in data], dim=0)
|
311 |
+
# for key in ["scene_xyz"]:
|
312 |
+
# batched_data_dict[key] = torch.stack([dict[key] for dict in data], dim=0)
|
313 |
+
#
|
314 |
+
# return batched_data_dict
|
315 |
+
#
|
316 |
+
# # def create_pair_xyzs_from_obj_xyzs(self, new_obj_xyzs, debug=False):
|
317 |
+
# #
|
318 |
+
# # new_obj_xyzs = [xyz.cpu().numpy() for xyz in new_obj_xyzs]
|
319 |
+
# #
|
320 |
+
# # # compute pairwise collision
|
321 |
+
# # scene_xyzs = []
|
322 |
+
# # obj_xyz_pair_idxs = list(itertools.combinations(range(len(new_obj_xyzs)), 2))
|
323 |
+
# #
|
324 |
+
# # for obj_xyz_pair_idx in obj_xyz_pair_idxs:
|
325 |
+
# # obj_xyz_pair = [new_obj_xyzs[obj_xyz_pair_idx[0]], new_obj_xyzs[obj_xyz_pair_idx[1]]]
|
326 |
+
# # num_indicator = 2
|
327 |
+
# # obj_xyz_pair_ind = []
|
328 |
+
# # for oi, obj_xyz in enumerate(obj_xyz_pair):
|
329 |
+
# # obj_xyz = np.concatenate([obj_xyz, np.tile(np.eye(num_indicator)[oi], (obj_xyz.shape[0], 1))], axis=1)
|
330 |
+
# # obj_xyz_pair_ind.append(obj_xyz)
|
331 |
+
# # pair_scene_xyz = np.concatenate(obj_xyz_pair_ind, axis=0)
|
332 |
+
# #
|
333 |
+
# # # subsampling and normalizing pc
|
334 |
+
# # rand_idx = np.random.randint(0, pair_scene_xyz.shape[0], self.num_scene_pts)
|
335 |
+
# # pair_scene_xyz = pair_scene_xyz[rand_idx]
|
336 |
+
# # if self.normalize_pc:
|
337 |
+
# # pair_scene_xyz[:, 0:3] = pc_normalize(pair_scene_xyz[:, 0:3])
|
338 |
+
# #
|
339 |
+
# # scene_xyzs.append(array_to_tensor(pair_scene_xyz))
|
340 |
+
# #
|
341 |
+
# # if debug:
|
342 |
+
# # for scene_xyz in scene_xyzs:
|
343 |
+
# # show_pcs([scene_xyz[:, 0:3]], [np.tile(np.array([0, 1, 0], dtype=np.float), (scene_xyz.shape[0], 1))],
|
344 |
+
# # add_coordinate_frame=True)
|
345 |
+
# #
|
346 |
+
# # return scene_xyzs
|
347 |
+
|
348 |
+
|
349 |
+
if __name__ == "__main__":
|
350 |
+
dataset = PairwiseCollisionDataset(urdf_pc_idx_file="/home/weiyu/data_drive/StructDiffusion/pairwise_collision_data/urdf_pc_idx.pkl",
|
351 |
+
collision_data_dir="/home/weiyu/data_drive/StructDiffusion/pairwise_collision_data",
|
352 |
+
debug=False)
|
353 |
+
|
354 |
+
for i in tqdm.tqdm(np.random.permutation(len(dataset))):
|
355 |
+
# print(i)
|
356 |
+
d = dataset[i]
|
357 |
+
# print(d["label"])
|
358 |
+
|
359 |
+
# dl = torch.utils.data.DataLoader(dataset, batch_size=32, num_workers=8)
|
360 |
+
# for b in tqdm.tqdm(dl):
|
361 |
+
# pass
|
src/StructDiffusion/data/semantic_arrangement.py
ADDED
@@ -0,0 +1,579 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import cv2
|
3 |
+
import h5py
|
4 |
+
import numpy as np
|
5 |
+
import os
|
6 |
+
import trimesh
|
7 |
+
import torch
|
8 |
+
from tqdm import tqdm
|
9 |
+
import json
|
10 |
+
import random
|
11 |
+
|
12 |
+
from torch.utils.data import DataLoader
|
13 |
+
|
14 |
+
# Local imports
|
15 |
+
from StructDiffusion.utils.rearrangement import show_pcs, get_pts, combine_and_sample_xyzs
|
16 |
+
from StructDiffusion.language.tokenizer import Tokenizer
|
17 |
+
|
18 |
+
import StructDiffusion.utils.brain2.camera as cam
|
19 |
+
import StructDiffusion.utils.brain2.image as img
|
20 |
+
import StructDiffusion.utils.transformations as tra
|
21 |
+
|
22 |
+
|
23 |
+
class SemanticArrangementDataset(torch.utils.data.Dataset):
|
24 |
+
|
25 |
+
def __init__(self, data_roots, index_roots, split, tokenizer,
|
26 |
+
max_num_target_objects=11, max_num_distractor_objects=5,
|
27 |
+
max_num_shape_parameters=7, max_num_rearrange_features=1, max_num_anchor_features=3,
|
28 |
+
num_pts=1024,
|
29 |
+
use_virtual_structure_frame=True, ignore_distractor_objects=True, ignore_rgb=True,
|
30 |
+
filter_num_moved_objects_range=None, shuffle_object_index=False,
|
31 |
+
data_augmentation=True, debug=False, **kwargs):
|
32 |
+
"""
|
33 |
+
|
34 |
+
Note: setting filter_num_moved_objects_range=[k, k] and max_num_objects=k will create no padding for target objs
|
35 |
+
|
36 |
+
:param data_root:
|
37 |
+
:param split: train, valid, or test
|
38 |
+
:param shuffle_object_index: whether to shuffle the positions of target objects and other objects in the sequence
|
39 |
+
:param debug:
|
40 |
+
:param max_num_shape_parameters:
|
41 |
+
:param max_num_objects:
|
42 |
+
:param max_num_rearrange_features:
|
43 |
+
:param max_num_anchor_features:
|
44 |
+
:param num_pts:
|
45 |
+
:param use_stored_arrangement_indices:
|
46 |
+
:param kwargs:
|
47 |
+
"""
|
48 |
+
|
49 |
+
self.use_virtual_structure_frame = use_virtual_structure_frame
|
50 |
+
self.ignore_distractor_objects = ignore_distractor_objects
|
51 |
+
self.ignore_rgb = ignore_rgb and not debug
|
52 |
+
|
53 |
+
self.num_pts = num_pts
|
54 |
+
self.debug = debug
|
55 |
+
|
56 |
+
self.max_num_objects = max_num_target_objects
|
57 |
+
self.max_num_other_objects = max_num_distractor_objects
|
58 |
+
self.max_num_shape_parameters = max_num_shape_parameters
|
59 |
+
self.max_num_rearrange_features = max_num_rearrange_features
|
60 |
+
self.max_num_anchor_features = max_num_anchor_features
|
61 |
+
self.shuffle_object_index = shuffle_object_index
|
62 |
+
|
63 |
+
# used to tokenize the language part
|
64 |
+
self.tokenizer = tokenizer
|
65 |
+
|
66 |
+
# retrieve data
|
67 |
+
self.data_roots = data_roots
|
68 |
+
self.arrangement_data = []
|
69 |
+
arrangement_steps = []
|
70 |
+
for ddx in range(len(data_roots)):
|
71 |
+
data_root = data_roots[ddx]
|
72 |
+
index_root = index_roots[ddx]
|
73 |
+
arrangement_indices_file = os.path.join(data_root, index_root, "{}_arrangement_indices_file_all.txt".format(split))
|
74 |
+
if os.path.exists(arrangement_indices_file):
|
75 |
+
with open(arrangement_indices_file, "r") as fh:
|
76 |
+
arrangement_steps.extend([(os.path.join(data_root, f[0]), f[1]) for f in eval(fh.readline().strip())])
|
77 |
+
else:
|
78 |
+
print("{} does not exist".format(arrangement_indices_file))
|
79 |
+
# only keep the goal, ignore the intermediate steps
|
80 |
+
for filename, step_t in arrangement_steps:
|
81 |
+
if step_t == 0:
|
82 |
+
if "data00026058" in filename or "data00011415" in filename or "data00026061" in filename or "data00700565" in filename:
|
83 |
+
continue
|
84 |
+
self.arrangement_data.append((filename, step_t))
|
85 |
+
# if specified, filter data
|
86 |
+
if filter_num_moved_objects_range is not None:
|
87 |
+
self.arrangement_data = self.filter_based_on_number_of_moved_objects(filter_num_moved_objects_range)
|
88 |
+
print("{} valid sequences".format(len(self.arrangement_data)))
|
89 |
+
|
90 |
+
# Data Aug
|
91 |
+
self.data_augmentation = data_augmentation
|
92 |
+
# additive noise
|
93 |
+
self.gp_rescale_factor_range = [12, 20]
|
94 |
+
self.gaussian_scale_range = [0., 0.003]
|
95 |
+
# multiplicative noise
|
96 |
+
self.gamma_shape = 1000.
|
97 |
+
self.gamma_scale = 0.001
|
98 |
+
|
99 |
+
def filter_based_on_number_of_moved_objects(self, filter_num_moved_objects_range):
|
100 |
+
assert len(list(filter_num_moved_objects_range)) == 2
|
101 |
+
min_num, max_num = filter_num_moved_objects_range
|
102 |
+
print("Remove scenes that have less than {} or more than {} objects being moved".format(min_num, max_num))
|
103 |
+
ok_data = []
|
104 |
+
for filename, step_t in self.arrangement_data:
|
105 |
+
h5 = h5py.File(filename, 'r')
|
106 |
+
moved_objs = h5['moved_objs'][()].split(',')
|
107 |
+
if min_num <= len(moved_objs) <= max_num:
|
108 |
+
ok_data.append((filename, step_t))
|
109 |
+
print("{} valid sequences left".format(len(ok_data)))
|
110 |
+
return ok_data
|
111 |
+
|
112 |
+
def get_data_idx(self, idx):
|
113 |
+
# Create the datum to return
|
114 |
+
file_idx = np.argmax(idx < self.file_to_count)
|
115 |
+
data = h5py.File(self.data_files[file_idx], 'r')
|
116 |
+
if file_idx > 0:
|
117 |
+
# for lang2sym, idx is always 0
|
118 |
+
idx = idx - self.file_to_count[file_idx - 1]
|
119 |
+
return data, idx, file_idx
|
120 |
+
|
121 |
+
def add_noise_to_depth(self, depth_img):
|
122 |
+
""" add depth noise """
|
123 |
+
multiplicative_noise = np.random.gamma(self.gamma_shape, self.gamma_scale)
|
124 |
+
depth_img = multiplicative_noise * depth_img
|
125 |
+
return depth_img
|
126 |
+
|
127 |
+
def add_noise_to_xyz(self, xyz_img, depth_img):
|
128 |
+
""" TODO: remove this code or at least celean it up"""
|
129 |
+
xyz_img = xyz_img.copy()
|
130 |
+
H, W, C = xyz_img.shape
|
131 |
+
gp_rescale_factor = np.random.randint(self.gp_rescale_factor_range[0],
|
132 |
+
self.gp_rescale_factor_range[1])
|
133 |
+
gp_scale = np.random.uniform(self.gaussian_scale_range[0],
|
134 |
+
self.gaussian_scale_range[1])
|
135 |
+
small_H, small_W = (np.array([H, W]) / gp_rescale_factor).astype(int)
|
136 |
+
additive_noise = np.random.normal(loc=0.0, scale=gp_scale, size=(small_H, small_W, C))
|
137 |
+
additive_noise = cv2.resize(additive_noise, (W, H), interpolation=cv2.INTER_CUBIC)
|
138 |
+
xyz_img[depth_img > 0, :] += additive_noise[depth_img > 0, :]
|
139 |
+
return xyz_img
|
140 |
+
|
141 |
+
def random_index(self):
|
142 |
+
return self[np.random.randint(len(self))]
|
143 |
+
|
144 |
+
def _get_rgb(self, h5, idx, ee=True):
|
145 |
+
RGB = "ee_rgb" if ee else "rgb"
|
146 |
+
rgb1 = img.PNGToNumpy(h5[RGB][idx])[:, :, :3] / 255. # remove alpha
|
147 |
+
return rgb1
|
148 |
+
|
149 |
+
def _get_depth(self, h5, idx, ee=True):
|
150 |
+
DEPTH = "ee_depth" if ee else "depth"
|
151 |
+
|
152 |
+
def _get_images(self, h5, idx, ee=True):
|
153 |
+
if ee:
|
154 |
+
RGB, DEPTH, SEG = "ee_rgb", "ee_depth", "ee_seg"
|
155 |
+
DMIN, DMAX = "ee_depth_min", "ee_depth_max"
|
156 |
+
else:
|
157 |
+
RGB, DEPTH, SEG = "rgb", "depth", "seg"
|
158 |
+
DMIN, DMAX = "depth_min", "depth_max"
|
159 |
+
dmin = h5[DMIN][idx]
|
160 |
+
dmax = h5[DMAX][idx]
|
161 |
+
rgb1 = img.PNGToNumpy(h5[RGB][idx])[:, :, :3] / 255. # remove alpha
|
162 |
+
depth1 = h5[DEPTH][idx] / 20000. * (dmax - dmin) + dmin
|
163 |
+
seg1 = img.PNGToNumpy(h5[SEG][idx])
|
164 |
+
|
165 |
+
valid1 = np.logical_and(depth1 > 0.1, depth1 < 2.)
|
166 |
+
|
167 |
+
# proj_matrix = h5['proj_matrix'][()]
|
168 |
+
camera = cam.get_camera_from_h5(h5)
|
169 |
+
if self.data_augmentation:
|
170 |
+
depth1 = self.add_noise_to_depth(depth1)
|
171 |
+
|
172 |
+
xyz1 = cam.compute_xyz(depth1, camera)
|
173 |
+
if self.data_augmentation:
|
174 |
+
xyz1 = self.add_noise_to_xyz(xyz1, depth1)
|
175 |
+
|
176 |
+
# Transform the point cloud
|
177 |
+
# Here it is...
|
178 |
+
# CAM_POSE = "ee_cam_pose" if ee else "cam_pose"
|
179 |
+
CAM_POSE = "ee_camera_view" if ee else "camera_view"
|
180 |
+
cam_pose = h5[CAM_POSE][idx]
|
181 |
+
if ee:
|
182 |
+
# ee_camera_view has 0s for x, y, z
|
183 |
+
cam_pos = h5["ee_cam_pose"][:][:3, 3]
|
184 |
+
cam_pose[:3, 3] = cam_pos
|
185 |
+
|
186 |
+
# Get transformed point cloud
|
187 |
+
h, w, d = xyz1.shape
|
188 |
+
xyz1 = xyz1.reshape(h * w, -1)
|
189 |
+
xyz1 = trimesh.transform_points(xyz1, cam_pose)
|
190 |
+
xyz1 = xyz1.reshape(h, w, -1)
|
191 |
+
|
192 |
+
scene1 = rgb1, depth1, seg1, valid1, xyz1
|
193 |
+
|
194 |
+
return scene1
|
195 |
+
|
196 |
+
def __len__(self):
|
197 |
+
return len(self.arrangement_data)
|
198 |
+
|
199 |
+
def _get_ids(self, h5):
|
200 |
+
"""
|
201 |
+
get object ids
|
202 |
+
|
203 |
+
@param h5:
|
204 |
+
@return:
|
205 |
+
"""
|
206 |
+
ids = {}
|
207 |
+
for k in h5.keys():
|
208 |
+
if k.startswith("id_"):
|
209 |
+
ids[k[3:]] = h5[k][()]
|
210 |
+
return ids
|
211 |
+
|
212 |
+
def get_positive_ratio(self):
|
213 |
+
num_pos = 0
|
214 |
+
for d in self.arrangement_data:
|
215 |
+
filename, step_t = d
|
216 |
+
if step_t == 0:
|
217 |
+
num_pos += 1
|
218 |
+
return (len(self.arrangement_data) - num_pos) * 1.0 / num_pos
|
219 |
+
|
220 |
+
def get_object_position_vocab_sizes(self):
|
221 |
+
return self.tokenizer.get_object_position_vocab_sizes()
|
222 |
+
|
223 |
+
def get_vocab_size(self):
|
224 |
+
return self.tokenizer.get_vocab_size()
|
225 |
+
|
226 |
+
def get_data_index(self, idx):
|
227 |
+
filename = self.arrangement_data[idx]
|
228 |
+
return filename
|
229 |
+
|
230 |
+
def get_raw_data(self, idx, inference_mode=False, shuffle_object_index=False):
|
231 |
+
"""
|
232 |
+
|
233 |
+
:param idx:
|
234 |
+
:param inference_mode:
|
235 |
+
:param shuffle_object_index: used to test different orders of objects
|
236 |
+
:return:
|
237 |
+
"""
|
238 |
+
|
239 |
+
filename, _ = self.arrangement_data[idx]
|
240 |
+
|
241 |
+
h5 = h5py.File(filename, 'r')
|
242 |
+
ids = self._get_ids(h5)
|
243 |
+
all_objs = sorted([o for o in ids.keys() if "object_" in o])
|
244 |
+
goal_specification = json.loads(str(np.array(h5["goal_specification"])))
|
245 |
+
num_rearrange_objs = len(goal_specification["rearrange"]["objects"])
|
246 |
+
num_other_objs = len(goal_specification["anchor"]["objects"] + goal_specification["distract"]["objects"])
|
247 |
+
assert len(all_objs) == num_rearrange_objs + num_other_objs, "{}, {}".format(len(all_objs), num_rearrange_objs + num_other_objs)
|
248 |
+
assert num_rearrange_objs <= self.max_num_objects
|
249 |
+
assert num_other_objs <= self.max_num_other_objects
|
250 |
+
|
251 |
+
# important: only using the last step
|
252 |
+
step_t = num_rearrange_objs
|
253 |
+
|
254 |
+
target_objs = all_objs[:num_rearrange_objs]
|
255 |
+
other_objs = all_objs[num_rearrange_objs:]
|
256 |
+
|
257 |
+
structure_parameters = goal_specification["shape"]
|
258 |
+
|
259 |
+
# Important: ensure the order is correct
|
260 |
+
if structure_parameters["type"] == "circle" or structure_parameters["type"] == "line":
|
261 |
+
target_objs = target_objs[::-1]
|
262 |
+
elif structure_parameters["type"] == "tower" or structure_parameters["type"] == "dinner":
|
263 |
+
target_objs = target_objs
|
264 |
+
else:
|
265 |
+
raise KeyError("{} structure is not recognized".format(structure_parameters["type"]))
|
266 |
+
all_objs = target_objs + other_objs
|
267 |
+
|
268 |
+
###################################
|
269 |
+
# getting scene images and point clouds
|
270 |
+
scene = self._get_images(h5, step_t, ee=True)
|
271 |
+
rgb, depth, seg, valid, xyz = scene
|
272 |
+
if inference_mode:
|
273 |
+
initial_scene = scene
|
274 |
+
|
275 |
+
# getting object point clouds
|
276 |
+
obj_pcs = []
|
277 |
+
obj_pad_mask = []
|
278 |
+
current_pc_poses = []
|
279 |
+
other_obj_pcs = []
|
280 |
+
other_obj_pad_mask = []
|
281 |
+
for obj in all_objs:
|
282 |
+
obj_mask = np.logical_and(seg == ids[obj], valid)
|
283 |
+
if np.sum(obj_mask) <= 0:
|
284 |
+
raise Exception
|
285 |
+
ok, obj_xyz, obj_rgb, _ = get_pts(xyz, rgb, obj_mask, num_pts=self.num_pts)
|
286 |
+
if not ok:
|
287 |
+
raise Exception
|
288 |
+
|
289 |
+
if obj in target_objs:
|
290 |
+
if self.ignore_rgb:
|
291 |
+
obj_pcs.append(obj_xyz)
|
292 |
+
else:
|
293 |
+
obj_pcs.append(torch.concat([obj_xyz, obj_rgb], dim=-1))
|
294 |
+
obj_pad_mask.append(0)
|
295 |
+
pc_pose = np.eye(4)
|
296 |
+
pc_pose[:3, 3] = torch.mean(obj_xyz, dim=0).numpy()
|
297 |
+
current_pc_poses.append(pc_pose)
|
298 |
+
elif obj in other_objs:
|
299 |
+
if self.ignore_rgb:
|
300 |
+
other_obj_pcs.append(obj_xyz)
|
301 |
+
else:
|
302 |
+
other_obj_pcs.append(torch.concat([obj_xyz, obj_rgb], dim=-1))
|
303 |
+
other_obj_pad_mask.append(0)
|
304 |
+
else:
|
305 |
+
raise Exception
|
306 |
+
|
307 |
+
###################################
|
308 |
+
# computes goal positions for objects
|
309 |
+
# Important: because of the noises we added to point clouds, the rearranged point clouds will not be perfect
|
310 |
+
if self.use_virtual_structure_frame:
|
311 |
+
goal_structure_pose = tra.euler_matrix(structure_parameters["rotation"][0], structure_parameters["rotation"][1],
|
312 |
+
structure_parameters["rotation"][2])
|
313 |
+
goal_structure_pose[:3, 3] = [structure_parameters["position"][0], structure_parameters["position"][1],
|
314 |
+
structure_parameters["position"][2]]
|
315 |
+
goal_structure_pose_inv = np.linalg.inv(goal_structure_pose)
|
316 |
+
|
317 |
+
goal_obj_poses = []
|
318 |
+
current_obj_poses = []
|
319 |
+
goal_pc_poses = []
|
320 |
+
for obj, current_pc_pose in zip(target_objs, current_pc_poses):
|
321 |
+
goal_pose = h5[obj][0]
|
322 |
+
current_pose = h5[obj][step_t]
|
323 |
+
if inference_mode:
|
324 |
+
goal_obj_poses.append(goal_pose)
|
325 |
+
current_obj_poses.append(current_pose)
|
326 |
+
|
327 |
+
goal_pc_pose = goal_pose @ np.linalg.inv(current_pose) @ current_pc_pose
|
328 |
+
if self.use_virtual_structure_frame:
|
329 |
+
goal_pc_pose = goal_structure_pose_inv @ goal_pc_pose
|
330 |
+
goal_pc_poses.append(goal_pc_pose)
|
331 |
+
|
332 |
+
# transform current object point cloud to the goal point cloud in the world frame
|
333 |
+
if self.debug:
|
334 |
+
new_obj_pcs = [copy.deepcopy(pc.numpy()) for pc in obj_pcs]
|
335 |
+
for i, obj_pc in enumerate(new_obj_pcs):
|
336 |
+
|
337 |
+
current_pc_pose = current_pc_poses[i]
|
338 |
+
goal_pc_pose = goal_pc_poses[i]
|
339 |
+
if self.use_virtual_structure_frame:
|
340 |
+
goal_pc_pose = goal_structure_pose @ goal_pc_pose
|
341 |
+
print("current pc pose", current_pc_pose)
|
342 |
+
print("goal pc pose", goal_pc_pose)
|
343 |
+
|
344 |
+
goal_pc_transform = goal_pc_pose @ np.linalg.inv(current_pc_pose)
|
345 |
+
print("transform", goal_pc_transform)
|
346 |
+
new_obj_pc = copy.deepcopy(obj_pc)
|
347 |
+
new_obj_pc[:, :3] = trimesh.transform_points(obj_pc[:, :3], goal_pc_transform)
|
348 |
+
print(new_obj_pc.shape)
|
349 |
+
|
350 |
+
# visualize rearrangement sequence (new_obj_xyzs), the current object before moving (obj_xyz), and other objects
|
351 |
+
new_obj_pcs[i] = new_obj_pc
|
352 |
+
new_obj_pcs[i][:, 3:] = np.tile(np.array([1, 0, 0], dtype=np.float), (new_obj_pc.shape[0], 1))
|
353 |
+
new_obj_rgb_current = np.tile(np.array([0, 1, 0], dtype=np.float), (new_obj_pc.shape[0], 1))
|
354 |
+
show_pcs([pc[:, :3] for pc in new_obj_pcs] + [pc[:, :3] for pc in other_obj_pcs] + [obj_pc[:, :3]],
|
355 |
+
[pc[:, 3:] for pc in new_obj_pcs] + [pc[:, 3:] for pc in other_obj_pcs] + [new_obj_rgb_current],
|
356 |
+
add_coordinate_frame=True)
|
357 |
+
show_pcs([pc[:, :3] for pc in new_obj_pcs], [pc[:, 3:] for pc in new_obj_pcs], add_coordinate_frame=True)
|
358 |
+
|
359 |
+
# pad data
|
360 |
+
for i in range(self.max_num_objects - len(target_objs)):
|
361 |
+
obj_pcs.append(torch.zeros_like(obj_pcs[0], dtype=torch.float32))
|
362 |
+
obj_pad_mask.append(1)
|
363 |
+
for i in range(self.max_num_other_objects - len(other_objs)):
|
364 |
+
other_obj_pcs.append(torch.zeros_like(obj_pcs[0], dtype=torch.float32))
|
365 |
+
other_obj_pad_mask.append(1)
|
366 |
+
|
367 |
+
###################################
|
368 |
+
# preparing sentence
|
369 |
+
sentence = []
|
370 |
+
sentence_pad_mask = []
|
371 |
+
|
372 |
+
# structure parameters
|
373 |
+
# 5 parameters
|
374 |
+
structure_parameters = goal_specification["shape"]
|
375 |
+
if structure_parameters["type"] == "circle" or structure_parameters["type"] == "line":
|
376 |
+
sentence.append((structure_parameters["type"], "shape"))
|
377 |
+
sentence.append((structure_parameters["rotation"][2], "rotation"))
|
378 |
+
sentence.append((structure_parameters["position"][0], "position_x"))
|
379 |
+
sentence.append((structure_parameters["position"][1], "position_y"))
|
380 |
+
if structure_parameters["type"] == "circle":
|
381 |
+
sentence.append((structure_parameters["radius"], "radius"))
|
382 |
+
elif structure_parameters["type"] == "line":
|
383 |
+
sentence.append((structure_parameters["length"] / 2.0, "radius"))
|
384 |
+
for _ in range(5):
|
385 |
+
sentence_pad_mask.append(0)
|
386 |
+
else:
|
387 |
+
sentence.append((structure_parameters["type"], "shape"))
|
388 |
+
sentence.append((structure_parameters["rotation"][2], "rotation"))
|
389 |
+
sentence.append((structure_parameters["position"][0], "position_x"))
|
390 |
+
sentence.append((structure_parameters["position"][1], "position_y"))
|
391 |
+
for _ in range(4):
|
392 |
+
sentence_pad_mask.append(0)
|
393 |
+
sentence.append(("PAD", None))
|
394 |
+
sentence_pad_mask.append(1)
|
395 |
+
|
396 |
+
###################################
|
397 |
+
# paddings
|
398 |
+
for i in range(self.max_num_objects - len(target_objs)):
|
399 |
+
goal_pc_poses.append(np.eye(4))
|
400 |
+
|
401 |
+
###################################
|
402 |
+
if self.debug:
|
403 |
+
print("---")
|
404 |
+
print("all objects:", all_objs)
|
405 |
+
print("target objects:", target_objs)
|
406 |
+
print("other objects:", other_objs)
|
407 |
+
print("goal specification:", goal_specification)
|
408 |
+
print("sentence:", sentence)
|
409 |
+
show_pcs([pc[:, :3] for pc in obj_pcs + other_obj_pcs], [pc[:, 3:] for pc in obj_pcs + other_obj_pcs], add_coordinate_frame=True)
|
410 |
+
|
411 |
+
assert len(obj_pcs) == len(goal_pc_poses)
|
412 |
+
###################################
|
413 |
+
|
414 |
+
# shuffle the position of objects
|
415 |
+
if shuffle_object_index:
|
416 |
+
shuffle_target_object_indices = list(range(len(target_objs)))
|
417 |
+
random.shuffle(shuffle_target_object_indices)
|
418 |
+
shuffle_object_indices = shuffle_target_object_indices + list(range(len(target_objs), self.max_num_objects))
|
419 |
+
obj_pcs = [obj_pcs[i] for i in shuffle_object_indices]
|
420 |
+
goal_pc_poses = [goal_pc_poses[i] for i in shuffle_object_indices]
|
421 |
+
if inference_mode:
|
422 |
+
goal_obj_poses = [goal_obj_poses[i] for i in shuffle_object_indices]
|
423 |
+
current_obj_poses = [current_obj_poses[i] for i in shuffle_object_indices]
|
424 |
+
target_objs = [target_objs[i] for i in shuffle_target_object_indices]
|
425 |
+
current_pc_poses = [current_pc_poses[i] for i in shuffle_object_indices]
|
426 |
+
|
427 |
+
###################################
|
428 |
+
if self.use_virtual_structure_frame:
|
429 |
+
if self.ignore_distractor_objects:
|
430 |
+
# language, structure virtual frame, target objects
|
431 |
+
pcs = obj_pcs
|
432 |
+
type_index = [0] * self.max_num_shape_parameters + [2] + [3] * self.max_num_objects
|
433 |
+
position_index = list(range(self.max_num_shape_parameters)) + [0] + list(range(self.max_num_objects))
|
434 |
+
pad_mask = sentence_pad_mask + [0] + obj_pad_mask
|
435 |
+
else:
|
436 |
+
# language, distractor objects, structure virtual frame, target objects
|
437 |
+
pcs = other_obj_pcs + obj_pcs
|
438 |
+
type_index = [0] * self.max_num_shape_parameters + [1] * self.max_num_other_objects + [2] + [3] * self.max_num_objects
|
439 |
+
position_index = list(range(self.max_num_shape_parameters)) + list(range(self.max_num_other_objects)) + [0] + list(range(self.max_num_objects))
|
440 |
+
pad_mask = sentence_pad_mask + other_obj_pad_mask + [0] + obj_pad_mask
|
441 |
+
goal_poses = [goal_structure_pose] + goal_pc_poses
|
442 |
+
else:
|
443 |
+
if self.ignore_distractor_objects:
|
444 |
+
# language, target objects
|
445 |
+
pcs = obj_pcs
|
446 |
+
type_index = [0] * self.max_num_shape_parameters + [3] * self.max_num_objects
|
447 |
+
position_index = list(range(self.max_num_shape_parameters)) + list(range(self.max_num_objects))
|
448 |
+
pad_mask = sentence_pad_mask + obj_pad_mask
|
449 |
+
else:
|
450 |
+
# language, distractor objects, target objects
|
451 |
+
pcs = other_obj_pcs + obj_pcs
|
452 |
+
type_index = [0] * self.max_num_shape_parameters + [1] * self.max_num_other_objects + [3] * self.max_num_objects
|
453 |
+
position_index = list(range(self.max_num_shape_parameters)) + list(range(self.max_num_other_objects)) + list(range(self.max_num_objects))
|
454 |
+
pad_mask = sentence_pad_mask + other_obj_pad_mask + obj_pad_mask
|
455 |
+
goal_poses = goal_pc_poses
|
456 |
+
|
457 |
+
datum = {
|
458 |
+
"pcs": pcs,
|
459 |
+
"sentence": sentence,
|
460 |
+
"goal_poses": goal_poses,
|
461 |
+
"type_index": type_index,
|
462 |
+
"position_index": position_index,
|
463 |
+
"pad_mask": pad_mask,
|
464 |
+
"t": step_t,
|
465 |
+
"filename": filename
|
466 |
+
}
|
467 |
+
|
468 |
+
if inference_mode:
|
469 |
+
datum["rgb"] = rgb
|
470 |
+
datum["goal_obj_poses"] = goal_obj_poses
|
471 |
+
datum["current_obj_poses"] = current_obj_poses
|
472 |
+
datum["target_objs"] = target_objs
|
473 |
+
datum["initial_scene"] = initial_scene
|
474 |
+
datum["ids"] = ids
|
475 |
+
datum["goal_specification"] = goal_specification
|
476 |
+
datum["current_pc_poses"] = current_pc_poses
|
477 |
+
|
478 |
+
return datum
|
479 |
+
|
480 |
+
@staticmethod
|
481 |
+
def convert_to_tensors(datum, tokenizer):
|
482 |
+
tensors = {
|
483 |
+
"pcs": torch.stack(datum["pcs"], dim=0),
|
484 |
+
"sentence": torch.LongTensor(np.array([tokenizer.tokenize(*i) for i in datum["sentence"]])),
|
485 |
+
"goal_poses": torch.FloatTensor(np.array(datum["goal_poses"])),
|
486 |
+
"type_index": torch.LongTensor(np.array(datum["type_index"])),
|
487 |
+
"position_index": torch.LongTensor(np.array(datum["position_index"])),
|
488 |
+
"pad_mask": torch.LongTensor(np.array(datum["pad_mask"])),
|
489 |
+
"t": datum["t"],
|
490 |
+
"filename": datum["filename"]
|
491 |
+
}
|
492 |
+
return tensors
|
493 |
+
|
494 |
+
def __getitem__(self, idx):
|
495 |
+
|
496 |
+
datum = self.convert_to_tensors(self.get_raw_data(idx, shuffle_object_index=self.shuffle_object_index),
|
497 |
+
self.tokenizer)
|
498 |
+
|
499 |
+
return datum
|
500 |
+
|
501 |
+
def single_datum_to_batch(self, x, num_samples, device, inference_mode=True):
|
502 |
+
tensor_x = {}
|
503 |
+
|
504 |
+
tensor_x["pcs"] = x["pcs"].to(device)[None, :, :, :].repeat(num_samples, 1, 1, 1)
|
505 |
+
tensor_x["sentence"] = x["sentence"].to(device)[None, :].repeat(num_samples, 1)
|
506 |
+
if not inference_mode:
|
507 |
+
tensor_x["goal_poses"] = x["goal_poses"].to(device)[None, :, :, :].repeat(num_samples, 1, 1, 1)
|
508 |
+
|
509 |
+
tensor_x["type_index"] = x["type_index"].to(device)[None, :].repeat(num_samples, 1)
|
510 |
+
tensor_x["position_index"] = x["position_index"].to(device)[None, :].repeat(num_samples, 1)
|
511 |
+
tensor_x["pad_mask"] = x["pad_mask"].to(device)[None, :].repeat(num_samples, 1)
|
512 |
+
|
513 |
+
return tensor_x
|
514 |
+
|
515 |
+
|
516 |
+
def compute_min_max(dataloader):
|
517 |
+
|
518 |
+
# tensor([-0.3557, -0.3847, 0.0000, -1.0000, -1.0000, -0.4759, -1.0000, -1.0000,
|
519 |
+
# -0.9079, -0.8668, -0.9105, -0.4186])
|
520 |
+
# tensor([0.3915, 0.3494, 0.3267, 1.0000, 1.0000, 0.8961, 1.0000, 1.0000, 0.8194,
|
521 |
+
# 0.4787, 0.6421, 1.0000])
|
522 |
+
# tensor([0.0918, -0.3758, 0.0000, -1.0000, -1.0000, 0.0000, -1.0000, -1.0000,
|
523 |
+
# -0.0000, 0.0000, 0.0000, 1.0000])
|
524 |
+
# tensor([0.9199, 0.3710, 0.0000, 1.0000, 1.0000, 0.0000, 1.0000, 1.0000, -0.0000,
|
525 |
+
# 0.0000, 0.0000, 1.0000])
|
526 |
+
|
527 |
+
min_value = torch.ones(16) * 10000
|
528 |
+
max_value = torch.ones(16) * -10000
|
529 |
+
for d in tqdm(dataloader):
|
530 |
+
goal_poses = d["goal_poses"]
|
531 |
+
goal_poses = goal_poses.reshape(-1, 16)
|
532 |
+
current_max, _ = torch.max(goal_poses, dim=0)
|
533 |
+
current_min, _ = torch.min(goal_poses, dim=0)
|
534 |
+
max_value[max_value < current_max] = current_max[max_value < current_max]
|
535 |
+
max_value[max_value > current_min] = current_min[max_value > current_min]
|
536 |
+
print(f"{min_value} - {max_value}")
|
537 |
+
|
538 |
+
|
539 |
+
if __name__ == "__main__":
|
540 |
+
|
541 |
+
tokenizer = Tokenizer("/home/weiyu/data_drive/data_new_objects/type_vocabs_coarse.json")
|
542 |
+
|
543 |
+
data_roots = []
|
544 |
+
index_roots = []
|
545 |
+
for shape, index in [("circle", "index_10k"), ("line", "index_10k"), ("stacking", "index_10k"), ("dinner", "index_10k")]:
|
546 |
+
data_roots.append("/home/weiyu/data_drive/data_new_objects/examples_{}_new_objects/result".format(shape))
|
547 |
+
index_roots.append(index)
|
548 |
+
|
549 |
+
dataset = SemanticArrangementDataset(data_roots=data_roots,
|
550 |
+
index_roots=index_roots,
|
551 |
+
split="valid", tokenizer=tokenizer,
|
552 |
+
max_num_target_objects=7,
|
553 |
+
max_num_distractor_objects=5,
|
554 |
+
max_num_shape_parameters=5,
|
555 |
+
max_num_rearrange_features=0,
|
556 |
+
max_num_anchor_features=0,
|
557 |
+
num_pts=1024,
|
558 |
+
use_virtual_structure_frame=True,
|
559 |
+
ignore_distractor_objects=True,
|
560 |
+
ignore_rgb=True,
|
561 |
+
filter_num_moved_objects_range=None, # [5, 5]
|
562 |
+
data_augmentation=False,
|
563 |
+
shuffle_object_index=False,
|
564 |
+
debug=False)
|
565 |
+
|
566 |
+
# print(len(dataset))
|
567 |
+
# for d in dataset:
|
568 |
+
# print("\n\n" + "="*100)
|
569 |
+
|
570 |
+
dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=8)
|
571 |
+
for i, d in enumerate(tqdm(dataloader)):
|
572 |
+
pass
|
573 |
+
# for k in d:
|
574 |
+
# if isinstance(d[k], torch.Tensor):
|
575 |
+
# print("--size", k, d[k].shape)
|
576 |
+
# for k in d:
|
577 |
+
# print(k, d[k])
|
578 |
+
#
|
579 |
+
# input("next?")
|
src/StructDiffusion/data/semantic_arrangement_demo.py
ADDED
@@ -0,0 +1,563 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import cv2
|
3 |
+
import h5py
|
4 |
+
import numpy as np
|
5 |
+
import os
|
6 |
+
import trimesh
|
7 |
+
import torch
|
8 |
+
from tqdm import tqdm
|
9 |
+
import json
|
10 |
+
import random
|
11 |
+
|
12 |
+
from torch.utils.data import DataLoader
|
13 |
+
|
14 |
+
# Local imports
|
15 |
+
from StructDiffusion.utils.rearrangement import show_pcs, get_pts, combine_and_sample_xyzs
|
16 |
+
from StructDiffusion.language.tokenizer import Tokenizer
|
17 |
+
|
18 |
+
import StructDiffusion.utils.brain2.camera as cam
|
19 |
+
import StructDiffusion.utils.brain2.image as img
|
20 |
+
import StructDiffusion.utils.transformations as tra
|
21 |
+
|
22 |
+
|
23 |
+
class SemanticArrangementDataset(torch.utils.data.Dataset):
|
24 |
+
|
25 |
+
def __init__(self, data_root, tokenizer,
|
26 |
+
max_num_target_objects=11, max_num_distractor_objects=5,
|
27 |
+
max_num_shape_parameters=7, max_num_rearrange_features=1, max_num_anchor_features=3,
|
28 |
+
num_pts=1024,
|
29 |
+
use_virtual_structure_frame=True, ignore_distractor_objects=True, ignore_rgb=True,
|
30 |
+
filter_num_moved_objects_range=None, shuffle_object_index=False,
|
31 |
+
data_augmentation=True, debug=False, **kwargs):
|
32 |
+
"""
|
33 |
+
|
34 |
+
Note: setting filter_num_moved_objects_range=[k, k] and max_num_objects=k will create no padding for target objs
|
35 |
+
|
36 |
+
:param data_root:
|
37 |
+
:param split: train, valid, or test
|
38 |
+
:param shuffle_object_index: whether to shuffle the positions of target objects and other objects in the sequence
|
39 |
+
:param debug:
|
40 |
+
:param max_num_shape_parameters:
|
41 |
+
:param max_num_objects:
|
42 |
+
:param max_num_rearrange_features:
|
43 |
+
:param max_num_anchor_features:
|
44 |
+
:param num_pts:
|
45 |
+
:param use_stored_arrangement_indices:
|
46 |
+
:param kwargs:
|
47 |
+
"""
|
48 |
+
|
49 |
+
self.use_virtual_structure_frame = use_virtual_structure_frame
|
50 |
+
self.ignore_distractor_objects = ignore_distractor_objects
|
51 |
+
self.ignore_rgb = ignore_rgb and not debug
|
52 |
+
|
53 |
+
self.num_pts = num_pts
|
54 |
+
self.debug = debug
|
55 |
+
|
56 |
+
self.max_num_objects = max_num_target_objects
|
57 |
+
self.max_num_other_objects = max_num_distractor_objects
|
58 |
+
self.max_num_shape_parameters = max_num_shape_parameters
|
59 |
+
self.max_num_rearrange_features = max_num_rearrange_features
|
60 |
+
self.max_num_anchor_features = max_num_anchor_features
|
61 |
+
self.shuffle_object_index = shuffle_object_index
|
62 |
+
|
63 |
+
# used to tokenize the language part
|
64 |
+
self.tokenizer = tokenizer
|
65 |
+
|
66 |
+
# retrieve data
|
67 |
+
self.data_root = data_root
|
68 |
+
self.arrangement_data = []
|
69 |
+
for filename in os.listdir(data_root):
|
70 |
+
if ".h5" in filename:
|
71 |
+
self.arrangement_data.append((os.path.join(data_root, filename), 0))
|
72 |
+
print("{} valid sequences".format(len(self.arrangement_data)))
|
73 |
+
|
74 |
+
# Data Aug
|
75 |
+
self.data_augmentation = data_augmentation
|
76 |
+
# additive noise
|
77 |
+
self.gp_rescale_factor_range = [12, 20]
|
78 |
+
self.gaussian_scale_range = [0., 0.003]
|
79 |
+
# multiplicative noise
|
80 |
+
self.gamma_shape = 1000.
|
81 |
+
self.gamma_scale = 0.001
|
82 |
+
|
83 |
+
def filter_based_on_number_of_moved_objects(self, filter_num_moved_objects_range):
|
84 |
+
assert len(list(filter_num_moved_objects_range)) == 2
|
85 |
+
min_num, max_num = filter_num_moved_objects_range
|
86 |
+
print("Remove scenes that have less than {} or more than {} objects being moved".format(min_num, max_num))
|
87 |
+
ok_data = []
|
88 |
+
for filename, step_t in self.arrangement_data:
|
89 |
+
h5 = h5py.File(filename, 'r')
|
90 |
+
moved_objs = h5['moved_objs'][()].split(',')
|
91 |
+
if min_num <= len(moved_objs) <= max_num:
|
92 |
+
ok_data.append((filename, step_t))
|
93 |
+
print("{} valid sequences left".format(len(ok_data)))
|
94 |
+
return ok_data
|
95 |
+
|
96 |
+
def get_data_idx(self, idx):
|
97 |
+
# Create the datum to return
|
98 |
+
file_idx = np.argmax(idx < self.file_to_count)
|
99 |
+
data = h5py.File(self.data_files[file_idx], 'r')
|
100 |
+
if file_idx > 0:
|
101 |
+
# for lang2sym, idx is always 0
|
102 |
+
idx = idx - self.file_to_count[file_idx - 1]
|
103 |
+
return data, idx, file_idx
|
104 |
+
|
105 |
+
def add_noise_to_depth(self, depth_img):
|
106 |
+
""" add depth noise """
|
107 |
+
multiplicative_noise = np.random.gamma(self.gamma_shape, self.gamma_scale)
|
108 |
+
depth_img = multiplicative_noise * depth_img
|
109 |
+
return depth_img
|
110 |
+
|
111 |
+
def add_noise_to_xyz(self, xyz_img, depth_img):
|
112 |
+
""" TODO: remove this code or at least celean it up"""
|
113 |
+
xyz_img = xyz_img.copy()
|
114 |
+
H, W, C = xyz_img.shape
|
115 |
+
gp_rescale_factor = np.random.randint(self.gp_rescale_factor_range[0],
|
116 |
+
self.gp_rescale_factor_range[1])
|
117 |
+
gp_scale = np.random.uniform(self.gaussian_scale_range[0],
|
118 |
+
self.gaussian_scale_range[1])
|
119 |
+
small_H, small_W = (np.array([H, W]) / gp_rescale_factor).astype(int)
|
120 |
+
additive_noise = np.random.normal(loc=0.0, scale=gp_scale, size=(small_H, small_W, C))
|
121 |
+
additive_noise = cv2.resize(additive_noise, (W, H), interpolation=cv2.INTER_CUBIC)
|
122 |
+
xyz_img[depth_img > 0, :] += additive_noise[depth_img > 0, :]
|
123 |
+
return xyz_img
|
124 |
+
|
125 |
+
def random_index(self):
|
126 |
+
return self[np.random.randint(len(self))]
|
127 |
+
|
128 |
+
def _get_rgb(self, h5, idx, ee=True):
|
129 |
+
RGB = "ee_rgb" if ee else "rgb"
|
130 |
+
rgb1 = img.PNGToNumpy(h5[RGB][idx])[:, :, :3] / 255. # remove alpha
|
131 |
+
return rgb1
|
132 |
+
|
133 |
+
def _get_depth(self, h5, idx, ee=True):
|
134 |
+
DEPTH = "ee_depth" if ee else "depth"
|
135 |
+
|
136 |
+
def _get_images(self, h5, idx, ee=True):
|
137 |
+
if ee:
|
138 |
+
RGB, DEPTH, SEG = "ee_rgb", "ee_depth", "ee_seg"
|
139 |
+
DMIN, DMAX = "ee_depth_min", "ee_depth_max"
|
140 |
+
else:
|
141 |
+
RGB, DEPTH, SEG = "rgb", "depth", "seg"
|
142 |
+
DMIN, DMAX = "depth_min", "depth_max"
|
143 |
+
dmin = h5[DMIN][idx]
|
144 |
+
dmax = h5[DMAX][idx]
|
145 |
+
rgb1 = img.PNGToNumpy(h5[RGB][idx])[:, :, :3] / 255. # remove alpha
|
146 |
+
depth1 = h5[DEPTH][idx] / 20000. * (dmax - dmin) + dmin
|
147 |
+
seg1 = img.PNGToNumpy(h5[SEG][idx])
|
148 |
+
|
149 |
+
valid1 = np.logical_and(depth1 > 0.1, depth1 < 2.)
|
150 |
+
|
151 |
+
# proj_matrix = h5['proj_matrix'][()]
|
152 |
+
camera = cam.get_camera_from_h5(h5)
|
153 |
+
if self.data_augmentation:
|
154 |
+
depth1 = self.add_noise_to_depth(depth1)
|
155 |
+
|
156 |
+
xyz1 = cam.compute_xyz(depth1, camera)
|
157 |
+
if self.data_augmentation:
|
158 |
+
xyz1 = self.add_noise_to_xyz(xyz1, depth1)
|
159 |
+
|
160 |
+
# Transform the point cloud
|
161 |
+
# Here it is...
|
162 |
+
# CAM_POSE = "ee_cam_pose" if ee else "cam_pose"
|
163 |
+
CAM_POSE = "ee_camera_view" if ee else "camera_view"
|
164 |
+
cam_pose = h5[CAM_POSE][idx]
|
165 |
+
if ee:
|
166 |
+
# ee_camera_view has 0s for x, y, z
|
167 |
+
cam_pos = h5["ee_cam_pose"][:][:3, 3]
|
168 |
+
cam_pose[:3, 3] = cam_pos
|
169 |
+
|
170 |
+
# Get transformed point cloud
|
171 |
+
h, w, d = xyz1.shape
|
172 |
+
xyz1 = xyz1.reshape(h * w, -1)
|
173 |
+
xyz1 = trimesh.transform_points(xyz1, cam_pose)
|
174 |
+
xyz1 = xyz1.reshape(h, w, -1)
|
175 |
+
|
176 |
+
scene1 = rgb1, depth1, seg1, valid1, xyz1
|
177 |
+
|
178 |
+
return scene1
|
179 |
+
|
180 |
+
def __len__(self):
|
181 |
+
return len(self.arrangement_data)
|
182 |
+
|
183 |
+
def _get_ids(self, h5):
|
184 |
+
"""
|
185 |
+
get object ids
|
186 |
+
|
187 |
+
@param h5:
|
188 |
+
@return:
|
189 |
+
"""
|
190 |
+
ids = {}
|
191 |
+
for k in h5.keys():
|
192 |
+
if k.startswith("id_"):
|
193 |
+
ids[k[3:]] = h5[k][()]
|
194 |
+
return ids
|
195 |
+
|
196 |
+
def get_positive_ratio(self):
|
197 |
+
num_pos = 0
|
198 |
+
for d in self.arrangement_data:
|
199 |
+
filename, step_t = d
|
200 |
+
if step_t == 0:
|
201 |
+
num_pos += 1
|
202 |
+
return (len(self.arrangement_data) - num_pos) * 1.0 / num_pos
|
203 |
+
|
204 |
+
def get_object_position_vocab_sizes(self):
|
205 |
+
return self.tokenizer.get_object_position_vocab_sizes()
|
206 |
+
|
207 |
+
def get_vocab_size(self):
|
208 |
+
return self.tokenizer.get_vocab_size()
|
209 |
+
|
210 |
+
def get_data_index(self, idx):
|
211 |
+
filename = self.arrangement_data[idx]
|
212 |
+
return filename
|
213 |
+
|
214 |
+
def get_raw_data(self, idx, inference_mode=False, shuffle_object_index=False):
|
215 |
+
"""
|
216 |
+
|
217 |
+
:param idx:
|
218 |
+
:param inference_mode:
|
219 |
+
:param shuffle_object_index: used to test different orders of objects
|
220 |
+
:return:
|
221 |
+
"""
|
222 |
+
|
223 |
+
filename, _ = self.arrangement_data[idx]
|
224 |
+
|
225 |
+
h5 = h5py.File(filename, 'r')
|
226 |
+
ids = self._get_ids(h5)
|
227 |
+
all_objs = sorted([o for o in ids.keys() if "object_" in o])
|
228 |
+
goal_specification = json.loads(str(np.array(h5["goal_specification"])))
|
229 |
+
num_rearrange_objs = len(goal_specification["rearrange"]["objects"])
|
230 |
+
num_other_objs = len(goal_specification["anchor"]["objects"] + goal_specification["distract"]["objects"])
|
231 |
+
assert len(all_objs) == num_rearrange_objs + num_other_objs, "{}, {}".format(len(all_objs), num_rearrange_objs + num_other_objs)
|
232 |
+
assert num_rearrange_objs <= self.max_num_objects
|
233 |
+
assert num_other_objs <= self.max_num_other_objects
|
234 |
+
|
235 |
+
# important: only using the last step
|
236 |
+
step_t = num_rearrange_objs
|
237 |
+
|
238 |
+
target_objs = all_objs[:num_rearrange_objs]
|
239 |
+
other_objs = all_objs[num_rearrange_objs:]
|
240 |
+
|
241 |
+
structure_parameters = goal_specification["shape"]
|
242 |
+
|
243 |
+
# Important: ensure the order is correct
|
244 |
+
if structure_parameters["type"] == "circle" or structure_parameters["type"] == "line":
|
245 |
+
target_objs = target_objs[::-1]
|
246 |
+
elif structure_parameters["type"] == "tower" or structure_parameters["type"] == "dinner":
|
247 |
+
target_objs = target_objs
|
248 |
+
else:
|
249 |
+
raise KeyError("{} structure is not recognized".format(structure_parameters["type"]))
|
250 |
+
all_objs = target_objs + other_objs
|
251 |
+
|
252 |
+
###################################
|
253 |
+
# getting scene images and point clouds
|
254 |
+
scene = self._get_images(h5, step_t, ee=True)
|
255 |
+
rgb, depth, seg, valid, xyz = scene
|
256 |
+
if inference_mode:
|
257 |
+
initial_scene = scene
|
258 |
+
|
259 |
+
# getting object point clouds
|
260 |
+
obj_pcs = []
|
261 |
+
obj_pad_mask = []
|
262 |
+
current_pc_poses = []
|
263 |
+
other_obj_pcs = []
|
264 |
+
other_obj_pad_mask = []
|
265 |
+
for obj in all_objs:
|
266 |
+
obj_mask = np.logical_and(seg == ids[obj], valid)
|
267 |
+
if np.sum(obj_mask) <= 0:
|
268 |
+
raise Exception
|
269 |
+
ok, obj_xyz, obj_rgb, _ = get_pts(xyz, rgb, obj_mask, num_pts=self.num_pts)
|
270 |
+
if not ok:
|
271 |
+
raise Exception
|
272 |
+
|
273 |
+
if obj in target_objs:
|
274 |
+
if self.ignore_rgb:
|
275 |
+
obj_pcs.append(obj_xyz)
|
276 |
+
else:
|
277 |
+
obj_pcs.append(torch.concat([obj_xyz, obj_rgb], dim=-1))
|
278 |
+
obj_pad_mask.append(0)
|
279 |
+
pc_pose = np.eye(4)
|
280 |
+
pc_pose[:3, 3] = torch.mean(obj_xyz, dim=0).numpy()
|
281 |
+
current_pc_poses.append(pc_pose)
|
282 |
+
elif obj in other_objs:
|
283 |
+
if self.ignore_rgb:
|
284 |
+
other_obj_pcs.append(obj_xyz)
|
285 |
+
else:
|
286 |
+
other_obj_pcs.append(torch.concat([obj_xyz, obj_rgb], dim=-1))
|
287 |
+
other_obj_pad_mask.append(0)
|
288 |
+
else:
|
289 |
+
raise Exception
|
290 |
+
|
291 |
+
###################################
|
292 |
+
# computes goal positions for objects
|
293 |
+
# Important: because of the noises we added to point clouds, the rearranged point clouds will not be perfect
|
294 |
+
if self.use_virtual_structure_frame:
|
295 |
+
goal_structure_pose = tra.euler_matrix(structure_parameters["rotation"][0], structure_parameters["rotation"][1],
|
296 |
+
structure_parameters["rotation"][2])
|
297 |
+
goal_structure_pose[:3, 3] = [structure_parameters["position"][0], structure_parameters["position"][1],
|
298 |
+
structure_parameters["position"][2]]
|
299 |
+
goal_structure_pose_inv = np.linalg.inv(goal_structure_pose)
|
300 |
+
|
301 |
+
goal_obj_poses = []
|
302 |
+
current_obj_poses = []
|
303 |
+
goal_pc_poses = []
|
304 |
+
for obj, current_pc_pose in zip(target_objs, current_pc_poses):
|
305 |
+
goal_pose = h5[obj][0]
|
306 |
+
current_pose = h5[obj][step_t]
|
307 |
+
if inference_mode:
|
308 |
+
goal_obj_poses.append(goal_pose)
|
309 |
+
current_obj_poses.append(current_pose)
|
310 |
+
|
311 |
+
goal_pc_pose = goal_pose @ np.linalg.inv(current_pose) @ current_pc_pose
|
312 |
+
if self.use_virtual_structure_frame:
|
313 |
+
goal_pc_pose = goal_structure_pose_inv @ goal_pc_pose
|
314 |
+
goal_pc_poses.append(goal_pc_pose)
|
315 |
+
|
316 |
+
# transform current object point cloud to the goal point cloud in the world frame
|
317 |
+
if self.debug:
|
318 |
+
new_obj_pcs = [copy.deepcopy(pc.numpy()) for pc in obj_pcs]
|
319 |
+
for i, obj_pc in enumerate(new_obj_pcs):
|
320 |
+
|
321 |
+
current_pc_pose = current_pc_poses[i]
|
322 |
+
goal_pc_pose = goal_pc_poses[i]
|
323 |
+
if self.use_virtual_structure_frame:
|
324 |
+
goal_pc_pose = goal_structure_pose @ goal_pc_pose
|
325 |
+
print("current pc pose", current_pc_pose)
|
326 |
+
print("goal pc pose", goal_pc_pose)
|
327 |
+
|
328 |
+
goal_pc_transform = goal_pc_pose @ np.linalg.inv(current_pc_pose)
|
329 |
+
print("transform", goal_pc_transform)
|
330 |
+
new_obj_pc = copy.deepcopy(obj_pc)
|
331 |
+
new_obj_pc[:, :3] = trimesh.transform_points(obj_pc[:, :3], goal_pc_transform)
|
332 |
+
print(new_obj_pc.shape)
|
333 |
+
|
334 |
+
# visualize rearrangement sequence (new_obj_xyzs), the current object before moving (obj_xyz), and other objects
|
335 |
+
new_obj_pcs[i] = new_obj_pc
|
336 |
+
new_obj_pcs[i][:, 3:] = np.tile(np.array([1, 0, 0], dtype=np.float), (new_obj_pc.shape[0], 1))
|
337 |
+
new_obj_rgb_current = np.tile(np.array([0, 1, 0], dtype=np.float), (new_obj_pc.shape[0], 1))
|
338 |
+
show_pcs([pc[:, :3] for pc in new_obj_pcs] + [pc[:, :3] for pc in other_obj_pcs] + [obj_pc[:, :3]],
|
339 |
+
[pc[:, 3:] for pc in new_obj_pcs] + [pc[:, 3:] for pc in other_obj_pcs] + [new_obj_rgb_current],
|
340 |
+
add_coordinate_frame=True)
|
341 |
+
show_pcs([pc[:, :3] for pc in new_obj_pcs], [pc[:, 3:] for pc in new_obj_pcs], add_coordinate_frame=True)
|
342 |
+
|
343 |
+
# pad data
|
344 |
+
for i in range(self.max_num_objects - len(target_objs)):
|
345 |
+
obj_pcs.append(torch.zeros_like(obj_pcs[0], dtype=torch.float32))
|
346 |
+
obj_pad_mask.append(1)
|
347 |
+
for i in range(self.max_num_other_objects - len(other_objs)):
|
348 |
+
other_obj_pcs.append(torch.zeros_like(obj_pcs[0], dtype=torch.float32))
|
349 |
+
other_obj_pad_mask.append(1)
|
350 |
+
|
351 |
+
###################################
|
352 |
+
# preparing sentence
|
353 |
+
sentence = []
|
354 |
+
sentence_pad_mask = []
|
355 |
+
|
356 |
+
# structure parameters
|
357 |
+
# 5 parameters
|
358 |
+
structure_parameters = goal_specification["shape"]
|
359 |
+
if structure_parameters["type"] == "circle" or structure_parameters["type"] == "line":
|
360 |
+
sentence.append((structure_parameters["type"], "shape"))
|
361 |
+
sentence.append((structure_parameters["rotation"][2], "rotation"))
|
362 |
+
sentence.append((structure_parameters["position"][0], "position_x"))
|
363 |
+
sentence.append((structure_parameters["position"][1], "position_y"))
|
364 |
+
if structure_parameters["type"] == "circle":
|
365 |
+
sentence.append((structure_parameters["radius"], "radius"))
|
366 |
+
elif structure_parameters["type"] == "line":
|
367 |
+
sentence.append((structure_parameters["length"] / 2.0, "radius"))
|
368 |
+
for _ in range(5):
|
369 |
+
sentence_pad_mask.append(0)
|
370 |
+
else:
|
371 |
+
sentence.append((structure_parameters["type"], "shape"))
|
372 |
+
sentence.append((structure_parameters["rotation"][2], "rotation"))
|
373 |
+
sentence.append((structure_parameters["position"][0], "position_x"))
|
374 |
+
sentence.append((structure_parameters["position"][1], "position_y"))
|
375 |
+
for _ in range(4):
|
376 |
+
sentence_pad_mask.append(0)
|
377 |
+
sentence.append(("PAD", None))
|
378 |
+
sentence_pad_mask.append(1)
|
379 |
+
|
380 |
+
###################################
|
381 |
+
# paddings
|
382 |
+
for i in range(self.max_num_objects - len(target_objs)):
|
383 |
+
goal_pc_poses.append(np.eye(4))
|
384 |
+
|
385 |
+
###################################
|
386 |
+
if self.debug:
|
387 |
+
print("---")
|
388 |
+
print("all objects:", all_objs)
|
389 |
+
print("target objects:", target_objs)
|
390 |
+
print("other objects:", other_objs)
|
391 |
+
print("goal specification:", goal_specification)
|
392 |
+
print("sentence:", sentence)
|
393 |
+
show_pcs([pc[:, :3] for pc in obj_pcs + other_obj_pcs], [pc[:, 3:] for pc in obj_pcs + other_obj_pcs], add_coordinate_frame=True)
|
394 |
+
|
395 |
+
assert len(obj_pcs) == len(goal_pc_poses)
|
396 |
+
###################################
|
397 |
+
|
398 |
+
# shuffle the position of objects
|
399 |
+
if shuffle_object_index:
|
400 |
+
shuffle_target_object_indices = list(range(len(target_objs)))
|
401 |
+
random.shuffle(shuffle_target_object_indices)
|
402 |
+
shuffle_object_indices = shuffle_target_object_indices + list(range(len(target_objs), self.max_num_objects))
|
403 |
+
obj_pcs = [obj_pcs[i] for i in shuffle_object_indices]
|
404 |
+
goal_pc_poses = [goal_pc_poses[i] for i in shuffle_object_indices]
|
405 |
+
if inference_mode:
|
406 |
+
goal_obj_poses = [goal_obj_poses[i] for i in shuffle_object_indices]
|
407 |
+
current_obj_poses = [current_obj_poses[i] for i in shuffle_object_indices]
|
408 |
+
target_objs = [target_objs[i] for i in shuffle_target_object_indices]
|
409 |
+
current_pc_poses = [current_pc_poses[i] for i in shuffle_object_indices]
|
410 |
+
|
411 |
+
###################################
|
412 |
+
if self.use_virtual_structure_frame:
|
413 |
+
if self.ignore_distractor_objects:
|
414 |
+
# language, structure virtual frame, target objects
|
415 |
+
pcs = obj_pcs
|
416 |
+
type_index = [0] * self.max_num_shape_parameters + [2] + [3] * self.max_num_objects
|
417 |
+
position_index = list(range(self.max_num_shape_parameters)) + [0] + list(range(self.max_num_objects))
|
418 |
+
pad_mask = sentence_pad_mask + [0] + obj_pad_mask
|
419 |
+
else:
|
420 |
+
# language, distractor objects, structure virtual frame, target objects
|
421 |
+
pcs = other_obj_pcs + obj_pcs
|
422 |
+
type_index = [0] * self.max_num_shape_parameters + [1] * self.max_num_other_objects + [2] + [3] * self.max_num_objects
|
423 |
+
position_index = list(range(self.max_num_shape_parameters)) + list(range(self.max_num_other_objects)) + [0] + list(range(self.max_num_objects))
|
424 |
+
pad_mask = sentence_pad_mask + other_obj_pad_mask + [0] + obj_pad_mask
|
425 |
+
goal_poses = [goal_structure_pose] + goal_pc_poses
|
426 |
+
else:
|
427 |
+
if self.ignore_distractor_objects:
|
428 |
+
# language, target objects
|
429 |
+
pcs = obj_pcs
|
430 |
+
type_index = [0] * self.max_num_shape_parameters + [3] * self.max_num_objects
|
431 |
+
position_index = list(range(self.max_num_shape_parameters)) + list(range(self.max_num_objects))
|
432 |
+
pad_mask = sentence_pad_mask + obj_pad_mask
|
433 |
+
else:
|
434 |
+
# language, distractor objects, target objects
|
435 |
+
pcs = other_obj_pcs + obj_pcs
|
436 |
+
type_index = [0] * self.max_num_shape_parameters + [1] * self.max_num_other_objects + [3] * self.max_num_objects
|
437 |
+
position_index = list(range(self.max_num_shape_parameters)) + list(range(self.max_num_other_objects)) + list(range(self.max_num_objects))
|
438 |
+
pad_mask = sentence_pad_mask + other_obj_pad_mask + obj_pad_mask
|
439 |
+
goal_poses = goal_pc_poses
|
440 |
+
|
441 |
+
datum = {
|
442 |
+
"pcs": pcs,
|
443 |
+
"sentence": sentence,
|
444 |
+
"goal_poses": goal_poses,
|
445 |
+
"type_index": type_index,
|
446 |
+
"position_index": position_index,
|
447 |
+
"pad_mask": pad_mask,
|
448 |
+
"t": step_t,
|
449 |
+
"filename": filename
|
450 |
+
}
|
451 |
+
|
452 |
+
if inference_mode:
|
453 |
+
datum["rgb"] = rgb
|
454 |
+
datum["goal_obj_poses"] = goal_obj_poses
|
455 |
+
datum["current_obj_poses"] = current_obj_poses
|
456 |
+
datum["target_objs"] = target_objs
|
457 |
+
datum["initial_scene"] = initial_scene
|
458 |
+
datum["ids"] = ids
|
459 |
+
datum["goal_specification"] = goal_specification
|
460 |
+
datum["current_pc_poses"] = current_pc_poses
|
461 |
+
|
462 |
+
return datum
|
463 |
+
|
464 |
+
@staticmethod
|
465 |
+
def convert_to_tensors(datum, tokenizer):
|
466 |
+
tensors = {
|
467 |
+
"pcs": torch.stack(datum["pcs"], dim=0),
|
468 |
+
"sentence": torch.LongTensor(np.array([tokenizer.tokenize(*i) for i in datum["sentence"]])),
|
469 |
+
"goal_poses": torch.FloatTensor(np.array(datum["goal_poses"])),
|
470 |
+
"type_index": torch.LongTensor(np.array(datum["type_index"])),
|
471 |
+
"position_index": torch.LongTensor(np.array(datum["position_index"])),
|
472 |
+
"pad_mask": torch.LongTensor(np.array(datum["pad_mask"])),
|
473 |
+
"t": datum["t"],
|
474 |
+
"filename": datum["filename"]
|
475 |
+
}
|
476 |
+
return tensors
|
477 |
+
|
478 |
+
def __getitem__(self, idx):
|
479 |
+
|
480 |
+
datum = self.convert_to_tensors(self.get_raw_data(idx, shuffle_object_index=self.shuffle_object_index),
|
481 |
+
self.tokenizer)
|
482 |
+
|
483 |
+
return datum
|
484 |
+
|
485 |
+
def single_datum_to_batch(self, x, num_samples, device, inference_mode=True):
|
486 |
+
tensor_x = {}
|
487 |
+
|
488 |
+
tensor_x["pcs"] = x["pcs"].to(device)[None, :, :, :].repeat(num_samples, 1, 1, 1)
|
489 |
+
tensor_x["sentence"] = x["sentence"].to(device)[None, :].repeat(num_samples, 1)
|
490 |
+
if not inference_mode:
|
491 |
+
tensor_x["goal_poses"] = x["goal_poses"].to(device)[None, :, :, :].repeat(num_samples, 1, 1, 1)
|
492 |
+
|
493 |
+
tensor_x["type_index"] = x["type_index"].to(device)[None, :].repeat(num_samples, 1)
|
494 |
+
tensor_x["position_index"] = x["position_index"].to(device)[None, :].repeat(num_samples, 1)
|
495 |
+
tensor_x["pad_mask"] = x["pad_mask"].to(device)[None, :].repeat(num_samples, 1)
|
496 |
+
|
497 |
+
return tensor_x
|
498 |
+
|
499 |
+
|
500 |
+
def compute_min_max(dataloader):
|
501 |
+
|
502 |
+
# tensor([-0.3557, -0.3847, 0.0000, -1.0000, -1.0000, -0.4759, -1.0000, -1.0000,
|
503 |
+
# -0.9079, -0.8668, -0.9105, -0.4186])
|
504 |
+
# tensor([0.3915, 0.3494, 0.3267, 1.0000, 1.0000, 0.8961, 1.0000, 1.0000, 0.8194,
|
505 |
+
# 0.4787, 0.6421, 1.0000])
|
506 |
+
# tensor([0.0918, -0.3758, 0.0000, -1.0000, -1.0000, 0.0000, -1.0000, -1.0000,
|
507 |
+
# -0.0000, 0.0000, 0.0000, 1.0000])
|
508 |
+
# tensor([0.9199, 0.3710, 0.0000, 1.0000, 1.0000, 0.0000, 1.0000, 1.0000, -0.0000,
|
509 |
+
# 0.0000, 0.0000, 1.0000])
|
510 |
+
|
511 |
+
min_value = torch.ones(16) * 10000
|
512 |
+
max_value = torch.ones(16) * -10000
|
513 |
+
for d in tqdm(dataloader):
|
514 |
+
goal_poses = d["goal_poses"]
|
515 |
+
goal_poses = goal_poses.reshape(-1, 16)
|
516 |
+
current_max, _ = torch.max(goal_poses, dim=0)
|
517 |
+
current_min, _ = torch.min(goal_poses, dim=0)
|
518 |
+
max_value[max_value < current_max] = current_max[max_value < current_max]
|
519 |
+
max_value[max_value > current_min] = current_min[max_value > current_min]
|
520 |
+
print(f"{min_value} - {max_value}")
|
521 |
+
|
522 |
+
|
523 |
+
if __name__ == "__main__":
|
524 |
+
|
525 |
+
tokenizer = Tokenizer("/home/weiyu/data_drive/data_new_objects/type_vocabs_coarse.json")
|
526 |
+
|
527 |
+
data_roots = []
|
528 |
+
index_roots = []
|
529 |
+
for shape, index in [("circle", "index_10k"), ("line", "index_10k"), ("stacking", "index_10k"), ("dinner", "index_10k")]:
|
530 |
+
data_roots.append("/home/weiyu/data_drive/data_new_objects/examples_{}_new_objects/result".format(shape))
|
531 |
+
index_roots.append(index)
|
532 |
+
|
533 |
+
dataset = SemanticArrangementDataset(data_roots=data_roots,
|
534 |
+
index_roots=index_roots,
|
535 |
+
split="valid", tokenizer=tokenizer,
|
536 |
+
max_num_target_objects=7,
|
537 |
+
max_num_distractor_objects=5,
|
538 |
+
max_num_shape_parameters=5,
|
539 |
+
max_num_rearrange_features=0,
|
540 |
+
max_num_anchor_features=0,
|
541 |
+
num_pts=1024,
|
542 |
+
use_virtual_structure_frame=True,
|
543 |
+
ignore_distractor_objects=True,
|
544 |
+
ignore_rgb=True,
|
545 |
+
filter_num_moved_objects_range=None, # [5, 5]
|
546 |
+
data_augmentation=False,
|
547 |
+
shuffle_object_index=False,
|
548 |
+
debug=False)
|
549 |
+
|
550 |
+
# print(len(dataset))
|
551 |
+
# for d in dataset:
|
552 |
+
# print("\n\n" + "="*100)
|
553 |
+
|
554 |
+
dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=8)
|
555 |
+
for i, d in enumerate(tqdm(dataloader)):
|
556 |
+
pass
|
557 |
+
# for k in d:
|
558 |
+
# if isinstance(d[k], torch.Tensor):
|
559 |
+
# print("--size", k, d[k].shape)
|
560 |
+
# for k in d:
|
561 |
+
# print(k, d[k])
|
562 |
+
#
|
563 |
+
# input("next?")
|
src/StructDiffusion/diffusion/__init__.py
ADDED
File without changes
|
src/StructDiffusion/diffusion/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (181 Bytes). View file
|
|
src/StructDiffusion/diffusion/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (185 Bytes). View file
|
|
src/StructDiffusion/diffusion/__pycache__/noise_schedule.cpython-37.pyc
ADDED
Binary file (2.57 kB). View file
|
|
src/StructDiffusion/diffusion/__pycache__/noise_schedule.cpython-38.pyc
ADDED
Binary file (2.57 kB). View file
|
|
src/StructDiffusion/diffusion/__pycache__/pose_conversion.cpython-37.pyc
ADDED
Binary file (2.25 kB). View file
|
|
src/StructDiffusion/diffusion/__pycache__/pose_conversion.cpython-38.pyc
ADDED
Binary file (2.27 kB). View file
|
|
src/StructDiffusion/diffusion/__pycache__/sampler.cpython-37.pyc
ADDED
Binary file (5.74 kB). View file
|
|
src/StructDiffusion/diffusion/__pycache__/sampler.cpython-38.pyc
ADDED
Binary file (5.71 kB). View file
|
|
src/StructDiffusion/diffusion/noise_schedule.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
def cosine_beta_schedule(timesteps, s=0.008):
|
7 |
+
"""
|
8 |
+
cosine schedule as proposed in https://arxiv.org/abs/2102.09672
|
9 |
+
"""
|
10 |
+
steps = timesteps + 1
|
11 |
+
x = torch.linspace(0, timesteps, steps)
|
12 |
+
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
|
13 |
+
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
14 |
+
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
15 |
+
return torch.clip(betas, 0.0001, 0.9999)
|
16 |
+
|
17 |
+
|
18 |
+
def linear_beta_schedule(timesteps):
|
19 |
+
beta_start = 0.0001
|
20 |
+
beta_end = 0.02
|
21 |
+
return torch.linspace(beta_start, beta_end, timesteps)
|
22 |
+
|
23 |
+
|
24 |
+
def quadratic_beta_schedule(timesteps):
|
25 |
+
beta_start = 0.0001
|
26 |
+
beta_end = 0.02
|
27 |
+
return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2
|
28 |
+
|
29 |
+
|
30 |
+
def sigmoid_beta_schedule(timesteps):
|
31 |
+
beta_start = 0.0001
|
32 |
+
beta_end = 0.02
|
33 |
+
betas = torch.linspace(-6, 6, timesteps)
|
34 |
+
return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
|
35 |
+
|
36 |
+
|
37 |
+
class NoiseSchedule:
|
38 |
+
|
39 |
+
def __init__(self, timesteps=200):
|
40 |
+
|
41 |
+
self.timesteps = timesteps
|
42 |
+
|
43 |
+
# define beta schedule
|
44 |
+
self.betas = linear_beta_schedule(timesteps=timesteps)
|
45 |
+
# self.betas = cosine_beta_schedule(timesteps=timesteps)
|
46 |
+
|
47 |
+
# define alphas
|
48 |
+
self.alphas = 1. - self.betas
|
49 |
+
# alphas_cumprod: alpha bar
|
50 |
+
self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
|
51 |
+
self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
|
52 |
+
self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)
|
53 |
+
|
54 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
55 |
+
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
|
56 |
+
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)
|
57 |
+
|
58 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
59 |
+
self.posterior_variance = self.betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
|
60 |
+
|
61 |
+
|
62 |
+
def extract(a, t, x_shape):
|
63 |
+
batch_size = t.shape[0]
|
64 |
+
out = a.gather(-1, t.cpu())
|
65 |
+
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
|
66 |
+
|
67 |
+
|
68 |
+
# forward diffusion (using the nice property)
|
69 |
+
def q_sample(x_start, t, noise_schedule, noise=None):
|
70 |
+
if noise is None:
|
71 |
+
noise = torch.randn_like(x_start)
|
72 |
+
|
73 |
+
sqrt_alphas_cumprod_t = extract(noise_schedule.sqrt_alphas_cumprod, t, x_start.shape)
|
74 |
+
# print("sqrt_alphas_cumprod_t", sqrt_alphas_cumprod_t)
|
75 |
+
sqrt_one_minus_alphas_cumprod_t = extract(
|
76 |
+
noise_schedule.sqrt_one_minus_alphas_cumprod, t, x_start.shape
|
77 |
+
)
|
78 |
+
# print("sqrt_one_minus_alphas_cumprod_t", sqrt_one_minus_alphas_cumprod_t)
|
79 |
+
# print("noise", noise)
|
80 |
+
|
81 |
+
return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
|
src/StructDiffusion/diffusion/pose_conversion.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import pytorch3d.transforms as tra3d
|
4 |
+
|
5 |
+
from StructDiffusion.utils.rotation_continuity import compute_rotation_matrix_from_ortho6d
|
6 |
+
|
7 |
+
|
8 |
+
def get_diffusion_variables_from_9D_actions(struct_xyztheta_inputs, obj_xyztheta_inputs):
|
9 |
+
|
10 |
+
# important: we need to get the first two columns, not first two rows
|
11 |
+
# array([[ 3, 4, 5],
|
12 |
+
# [ 6, 7, 8],
|
13 |
+
# [ 9, 10, 11]])
|
14 |
+
xyz_6d_idxs = [0, 1, 2, 3, 6, 9, 4, 7, 10]
|
15 |
+
|
16 |
+
# print(batch_data["obj_xyztheta_inputs"].shape)
|
17 |
+
# print(batch_data["struct_xyztheta_inputs"].shape)
|
18 |
+
|
19 |
+
# only get the first and second columns of rotation
|
20 |
+
obj_xyztheta_inputs = obj_xyztheta_inputs[:, :, xyz_6d_idxs] # B, N, 9
|
21 |
+
struct_xyztheta_inputs = struct_xyztheta_inputs[:, :, xyz_6d_idxs] # B, 1, 9
|
22 |
+
|
23 |
+
x = torch.cat([struct_xyztheta_inputs, obj_xyztheta_inputs], dim=1) # B, 1 + N, 9
|
24 |
+
|
25 |
+
# print(x.shape)
|
26 |
+
|
27 |
+
return x
|
28 |
+
|
29 |
+
|
30 |
+
def get_diffusion_variables_from_H(poses):
|
31 |
+
"""
|
32 |
+
[[0,1,2,3],
|
33 |
+
[4,5,6,7],
|
34 |
+
[8,9,10,11],
|
35 |
+
[12,13,14,15]
|
36 |
+
:param obj_xyztheta_inputs: B, N, 4, 4
|
37 |
+
:return:
|
38 |
+
"""
|
39 |
+
|
40 |
+
xyz_6d_idxs = [3, 7, 11, 0, 4, 8, 1, 5, 9]
|
41 |
+
|
42 |
+
B, N, _, _ = poses.shape
|
43 |
+
x = poses.reshape(B, N, 16)[:, :, xyz_6d_idxs] # B, N, 9
|
44 |
+
return x
|
45 |
+
|
46 |
+
|
47 |
+
def get_struct_objs_poses(x):
|
48 |
+
|
49 |
+
on_gpu = x.is_cuda
|
50 |
+
if not on_gpu:
|
51 |
+
x = x.cuda()
|
52 |
+
|
53 |
+
# assert x.is_cuda, "compute_rotation_matrix_from_ortho6d requires input to be on gpu"
|
54 |
+
device = x.device
|
55 |
+
|
56 |
+
# important: the noisy x can go out of bounds
|
57 |
+
x = torch.clamp(x, min=-1, max=1)
|
58 |
+
|
59 |
+
# x: B, 1 + N, 9
|
60 |
+
B = x.shape[0]
|
61 |
+
N = x.shape[1] - 1
|
62 |
+
|
63 |
+
# compute_rotation_matrix_from_ortho6d takes in [B, 6], outputs [B, 3, 3]
|
64 |
+
x_6d = x[:, :, 3:].reshape(-1, 6)
|
65 |
+
x_rot = compute_rotation_matrix_from_ortho6d(x_6d).reshape(B, N+1, 3, 3) # B, 1 + N, 3, 3
|
66 |
+
|
67 |
+
x_trans = x[:, :, :3] # B, 1 + N, 3
|
68 |
+
|
69 |
+
x_full = torch.eye(4).repeat(B, 1 + N, 1, 1).to(device)
|
70 |
+
x_full[:, :, :3, :3] = x_rot
|
71 |
+
x_full[:, :, :3, 3] = x_trans
|
72 |
+
|
73 |
+
struct_pose = x_full[:, 0].unsqueeze(1) # B, 1, 4, 4
|
74 |
+
pc_poses_in_struct = x_full[:, 1:] # B, N, 4, 4
|
75 |
+
|
76 |
+
if not on_gpu:
|
77 |
+
struct_pose = struct_pose.cpu()
|
78 |
+
pc_poses_in_struct = pc_poses_in_struct.cpu()
|
79 |
+
|
80 |
+
return struct_pose, pc_poses_in_struct
|
81 |
+
|
82 |
+
|
83 |
+
def compute_current_and_goal_pc_poses(obj_xyzs, struct_pose, pc_poses_in_struct):
|
84 |
+
|
85 |
+
device = obj_xyzs.device
|
86 |
+
|
87 |
+
# obj_xyzs: B, N, P, 3
|
88 |
+
# struct_pose: B, 1, 4, 4
|
89 |
+
# pc_poses_in_struct: B, N, 4, 4
|
90 |
+
B, N, _, _ = pc_poses_in_struct.shape
|
91 |
+
_, _, P, _ = obj_xyzs.shape
|
92 |
+
|
93 |
+
current_pc_poses = torch.eye(4).repeat(B, N, 1, 1).to(device) # B, N, 4, 4
|
94 |
+
# print(torch.mean(obj_xyzs, dim=2).shape)
|
95 |
+
current_pc_poses[:, :, :3, 3] = torch.mean(obj_xyzs, dim=2) # B, N, 4, 4
|
96 |
+
|
97 |
+
struct_pose = struct_pose.repeat(1, N, 1, 1) # B, N, 4, 4
|
98 |
+
struct_pose = struct_pose.reshape(B * N, 4, 4) # B x 1, 4, 4
|
99 |
+
pc_poses_in_struct = pc_poses_in_struct.reshape(B * N, 4, 4) # B x N, 4, 4
|
100 |
+
|
101 |
+
goal_pc_poses = struct_pose @ pc_poses_in_struct # B x N, 4, 4
|
102 |
+
goal_pc_poses = goal_pc_poses.reshape(B, N, 4, 4) # B, N, 4, 4
|
103 |
+
return current_pc_poses, goal_pc_poses
|
src/StructDiffusion/diffusion/sampler.py
ADDED
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from tqdm import tqdm
|
3 |
+
import pytorch3d.transforms as tra3d
|
4 |
+
|
5 |
+
from StructDiffusion.diffusion.noise_schedule import extract
|
6 |
+
from StructDiffusion.diffusion.pose_conversion import get_struct_objs_poses
|
7 |
+
from StructDiffusion.utils.batch_inference import move_pc_and_create_scene_simple, visualize_batch_pcs, move_pc_and_create_scene_new
|
8 |
+
|
9 |
+
class Sampler:
|
10 |
+
|
11 |
+
def __init__(self, model_class, checkpoint_path, device, debug=False):
|
12 |
+
|
13 |
+
self.debug = debug
|
14 |
+
self.device = device
|
15 |
+
|
16 |
+
self.model = model_class.load_from_checkpoint(checkpoint_path)
|
17 |
+
self.backbone = self.model.model
|
18 |
+
self.backbone.to(device)
|
19 |
+
self.backbone.eval()
|
20 |
+
|
21 |
+
def sample(self, batch, num_poses):
|
22 |
+
|
23 |
+
noise_schedule = self.model.noise_schedule
|
24 |
+
|
25 |
+
B = batch["pcs"].shape[0]
|
26 |
+
|
27 |
+
x_noisy = torch.randn((B, num_poses, 9), device=self.device)
|
28 |
+
|
29 |
+
xs = []
|
30 |
+
for t_index in tqdm(reversed(range(0, noise_schedule.timesteps)),
|
31 |
+
desc='sampling loop time step', total=noise_schedule.timesteps):
|
32 |
+
|
33 |
+
t = torch.full((B,), t_index, device=self.device, dtype=torch.long)
|
34 |
+
|
35 |
+
# noise schedule
|
36 |
+
betas_t = extract(noise_schedule.betas, t, x_noisy.shape)
|
37 |
+
sqrt_one_minus_alphas_cumprod_t = extract(noise_schedule.sqrt_one_minus_alphas_cumprod, t, x_noisy.shape)
|
38 |
+
sqrt_recip_alphas_t = extract(noise_schedule.sqrt_recip_alphas, t, x_noisy.shape)
|
39 |
+
|
40 |
+
# predict noise
|
41 |
+
pcs = batch["pcs"]
|
42 |
+
sentence = batch["sentence"]
|
43 |
+
type_index = batch["type_index"]
|
44 |
+
position_index = batch["position_index"]
|
45 |
+
pad_mask = batch["pad_mask"]
|
46 |
+
# calling the backbone instead of the pytorch-lightning model
|
47 |
+
with torch.no_grad():
|
48 |
+
predicted_noise = self.backbone.forward(t, pcs, sentence, x_noisy, type_index, position_index, pad_mask)
|
49 |
+
|
50 |
+
# compute noisy x at t
|
51 |
+
model_mean = sqrt_recip_alphas_t * (x_noisy - betas_t * predicted_noise / sqrt_one_minus_alphas_cumprod_t)
|
52 |
+
if t_index == 0:
|
53 |
+
x_noisy = model_mean
|
54 |
+
else:
|
55 |
+
posterior_variance_t = extract(noise_schedule.posterior_variance, t, x_noisy.shape)
|
56 |
+
noise = torch.randn_like(x_noisy)
|
57 |
+
x_noisy = model_mean + torch.sqrt(posterior_variance_t) * noise
|
58 |
+
|
59 |
+
xs.append(x_noisy)
|
60 |
+
|
61 |
+
xs = list(reversed(xs))
|
62 |
+
return xs
|
63 |
+
|
64 |
+
class SamplerV2:
|
65 |
+
|
66 |
+
def __init__(self, diffusion_model_class, diffusion_checkpoint_path,
|
67 |
+
collision_model_class, collision_checkpoint_path,
|
68 |
+
device, debug=False):
|
69 |
+
|
70 |
+
self.debug = debug
|
71 |
+
self.device = device
|
72 |
+
|
73 |
+
self.diffusion_model = diffusion_model_class.load_from_checkpoint(diffusion_checkpoint_path)
|
74 |
+
self.diffusion_backbone = self.diffusion_model.model
|
75 |
+
self.diffusion_backbone.to(device)
|
76 |
+
self.diffusion_backbone.eval()
|
77 |
+
|
78 |
+
self.collision_model = collision_model_class.load_from_checkpoint(collision_checkpoint_path)
|
79 |
+
self.collision_backbone = self.collision_model.model
|
80 |
+
self.collision_backbone.to(device)
|
81 |
+
self.collision_backbone.eval()
|
82 |
+
|
83 |
+
def sample(self, batch, num_poses):
|
84 |
+
|
85 |
+
noise_schedule = self.diffusion_model.noise_schedule
|
86 |
+
|
87 |
+
B = batch["pcs"].shape[0]
|
88 |
+
|
89 |
+
x_noisy = torch.randn((B, num_poses, 9), device=self.device)
|
90 |
+
|
91 |
+
xs = []
|
92 |
+
for t_index in tqdm(reversed(range(0, noise_schedule.timesteps)),
|
93 |
+
desc='sampling loop time step', total=noise_schedule.timesteps):
|
94 |
+
|
95 |
+
t = torch.full((B,), t_index, device=self.device, dtype=torch.long)
|
96 |
+
|
97 |
+
# noise schedule
|
98 |
+
betas_t = extract(noise_schedule.betas, t, x_noisy.shape)
|
99 |
+
sqrt_one_minus_alphas_cumprod_t = extract(noise_schedule.sqrt_one_minus_alphas_cumprod, t, x_noisy.shape)
|
100 |
+
sqrt_recip_alphas_t = extract(noise_schedule.sqrt_recip_alphas, t, x_noisy.shape)
|
101 |
+
|
102 |
+
# predict noise
|
103 |
+
pcs = batch["pcs"]
|
104 |
+
sentence = batch["sentence"]
|
105 |
+
type_index = batch["type_index"]
|
106 |
+
position_index = batch["position_index"]
|
107 |
+
pad_mask = batch["pad_mask"]
|
108 |
+
# calling the backbone instead of the pytorch-lightning model
|
109 |
+
with torch.no_grad():
|
110 |
+
predicted_noise = self.diffusion_backbone.forward(t, pcs, sentence, x_noisy, type_index, position_index, pad_mask)
|
111 |
+
|
112 |
+
# compute noisy x at t
|
113 |
+
model_mean = sqrt_recip_alphas_t * (x_noisy - betas_t * predicted_noise / sqrt_one_minus_alphas_cumprod_t)
|
114 |
+
if t_index == 0:
|
115 |
+
x_noisy = model_mean
|
116 |
+
else:
|
117 |
+
posterior_variance_t = extract(noise_schedule.posterior_variance, t, x_noisy.shape)
|
118 |
+
noise = torch.randn_like(x_noisy)
|
119 |
+
x_noisy = model_mean + torch.sqrt(posterior_variance_t) * noise
|
120 |
+
|
121 |
+
xs.append(x_noisy)
|
122 |
+
|
123 |
+
xs = list(reversed(xs))
|
124 |
+
|
125 |
+
visualize = True
|
126 |
+
|
127 |
+
struct_pose, pc_poses_in_struct = get_struct_objs_poses(xs[0])
|
128 |
+
# struct_pose: B, 1, 4, 4
|
129 |
+
# pc_poses_in_struct: B, N, 4, 4
|
130 |
+
|
131 |
+
S = B
|
132 |
+
num_elite = 10
|
133 |
+
####################################################
|
134 |
+
# only keep one copy
|
135 |
+
|
136 |
+
# N, P, 3
|
137 |
+
obj_xyzs = batch["pcs"][0][:, :, :3]
|
138 |
+
print("obj_xyzs shape", obj_xyzs.shape)
|
139 |
+
|
140 |
+
# 1, N
|
141 |
+
# object_pad_mask: padding location has 1
|
142 |
+
num_target_objs = num_poses
|
143 |
+
if self.diffusion_backbone.use_virtual_structure_frame:
|
144 |
+
num_target_objs -= 1
|
145 |
+
object_pad_mask = batch["pad_mask"][0][-num_target_objs:].unsqueeze(0)
|
146 |
+
target_object_inds = 1 - object_pad_mask
|
147 |
+
print("target_object_inds shape", target_object_inds.shape)
|
148 |
+
print("target_object_inds", target_object_inds)
|
149 |
+
|
150 |
+
N, P, _ = obj_xyzs.shape
|
151 |
+
print("S, N, P: {}, {}, {}".format(S, N, P))
|
152 |
+
|
153 |
+
####################################################
|
154 |
+
# S, N, ...
|
155 |
+
|
156 |
+
struct_pose = struct_pose.repeat(1, N, 1, 1) # S, N, 4, 4
|
157 |
+
struct_pose = struct_pose.reshape(S * N, 4, 4) # S x N, 4, 4
|
158 |
+
|
159 |
+
new_obj_xyzs = obj_xyzs.repeat(S, 1, 1, 1) # S, N, P, 3
|
160 |
+
current_pc_pose = torch.eye(4).repeat(S, N, 1, 1).to(self.device) # S, N, 4, 4
|
161 |
+
current_pc_pose[:, :, :3, 3] = torch.mean(new_obj_xyzs, dim=2) # S, N, 4, 4
|
162 |
+
current_pc_pose = current_pc_pose.reshape(S * N, 4, 4) # S x N, 4, 4
|
163 |
+
|
164 |
+
# optimize xyzrpy
|
165 |
+
obj_params = torch.zeros((S, N, 6)).to(self.device)
|
166 |
+
obj_params[:, :, :3] = pc_poses_in_struct[:, :, :3, 3]
|
167 |
+
obj_params[:, :, 3:] = tra3d.matrix_to_euler_angles(pc_poses_in_struct[:, :, :3, :3], "XYZ") # S, N, 6
|
168 |
+
#
|
169 |
+
# new_obj_xyzs_before_cem, goal_pc_pose_before_cem = move_pc(obj_xyzs, obj_params, struct_pose, current_pc_pose, device)
|
170 |
+
#
|
171 |
+
# if visualize:
|
172 |
+
# print("visualizing rearrangements predicted by the generator")
|
173 |
+
# visualize_batch_pcs(new_obj_xyzs_before_cem, S, N, P, limit_B=5)
|
174 |
+
|
175 |
+
####################################################
|
176 |
+
# rank
|
177 |
+
|
178 |
+
# evaluate in batches
|
179 |
+
scores = torch.zeros(S).to(self.device)
|
180 |
+
no_intersection_scores = torch.zeros(S).to(self.device) # the higher the better
|
181 |
+
num_batches = int(S / B)
|
182 |
+
if S % B != 0:
|
183 |
+
num_batches += 1
|
184 |
+
for b in range(num_batches):
|
185 |
+
if b + 1 == num_batches:
|
186 |
+
cur_batch_idxs_start = b * B
|
187 |
+
cur_batch_idxs_end = S
|
188 |
+
else:
|
189 |
+
cur_batch_idxs_start = b * B
|
190 |
+
cur_batch_idxs_end = (b + 1) * B
|
191 |
+
cur_batch_size = cur_batch_idxs_end - cur_batch_idxs_start
|
192 |
+
|
193 |
+
# print("current batch idxs start", cur_batch_idxs_start)
|
194 |
+
# print("current batch idxs end", cur_batch_idxs_end)
|
195 |
+
# print("size of the current batch", cur_batch_size)
|
196 |
+
|
197 |
+
batch_obj_params = obj_params[cur_batch_idxs_start: cur_batch_idxs_end]
|
198 |
+
batch_struct_pose = struct_pose[cur_batch_idxs_start * N: cur_batch_idxs_end * N]
|
199 |
+
batch_current_pc_pose = current_pc_pose[cur_batch_idxs_start * N:cur_batch_idxs_end * N]
|
200 |
+
|
201 |
+
new_obj_xyzs, _, subsampled_scene_xyz, _, obj_pair_xyzs = \
|
202 |
+
move_pc_and_create_scene_new(obj_xyzs, batch_obj_params, batch_struct_pose, batch_current_pc_pose,
|
203 |
+
target_object_inds, self.device,
|
204 |
+
return_scene_pts=False,
|
205 |
+
return_scene_pts_and_pc_idxs=False,
|
206 |
+
num_scene_pts=False,
|
207 |
+
normalize_pc=False,
|
208 |
+
return_pair_pc=True,
|
209 |
+
num_pair_pc_pts=self.collision_model.data_cfg.num_scene_pts,
|
210 |
+
normalize_pair_pc=self.collision_model.data_cfg.normalize_pc)
|
211 |
+
|
212 |
+
#######################################
|
213 |
+
# predict whether there are pairwise collisions
|
214 |
+
# if collision_score_weight > 0:
|
215 |
+
with torch.no_grad():
|
216 |
+
_, num_comb, num_pair_pc_pts, _ = obj_pair_xyzs.shape
|
217 |
+
# obj_pair_xyzs = obj_pair_xyzs.reshape(cur_batch_size * num_comb, num_pair_pc_pts, -1)
|
218 |
+
collision_logits = self.collision_backbone.forward(obj_pair_xyzs.reshape(cur_batch_size * num_comb, num_pair_pc_pts, -1))
|
219 |
+
collision_scores = self.collision_backbone.convert_logits(collision_logits).reshape(cur_batch_size, num_comb) # cur_batch_size, num_comb
|
220 |
+
|
221 |
+
# debug
|
222 |
+
# for bi, this_obj_pair_xyzs in enumerate(obj_pair_xyzs):
|
223 |
+
# print("batch id", bi)
|
224 |
+
# for pi, obj_pair_xyz in enumerate(this_obj_pair_xyzs):
|
225 |
+
# print("pair", pi)
|
226 |
+
# # obj_pair_xyzs: 2 * P, 5
|
227 |
+
# print("collision score", collision_scores[bi, pi])
|
228 |
+
# trimesh.PointCloud(obj_pair_xyz[:, :3].cpu()).show()
|
229 |
+
|
230 |
+
# 1 - mean() since the collision model predicts 1 if there is a collision
|
231 |
+
no_intersection_scores[cur_batch_idxs_start:cur_batch_idxs_end] = 1 - torch.mean(collision_scores, dim=1)
|
232 |
+
if visualize:
|
233 |
+
print("no intersection scores", no_intersection_scores)
|
234 |
+
# #######################################
|
235 |
+
# if discriminator_score_weight > 0:
|
236 |
+
# # # debug:
|
237 |
+
# # print(subsampled_scene_xyz.shape)
|
238 |
+
# # print(subsampled_scene_xyz[0])
|
239 |
+
# # trimesh.PointCloud(subsampled_scene_xyz[0, :, :3].cpu().numpy()).show()
|
240 |
+
# #
|
241 |
+
# with torch.no_grad():
|
242 |
+
#
|
243 |
+
# # Important: since this discriminator only uses local structure param, takes sentence from the first and last position
|
244 |
+
# # local_sentence = sentence[:, [0, 4]]
|
245 |
+
# # local_sentence_pad_mask = sentence_pad_mask[:, [0, 4]]
|
246 |
+
# # sentence_disc, sentence_pad_mask_disc, position_index_dic = discriminator_inference.dataset.tensorfy_sentence(raw_sentence_discriminator, raw_sentence_pad_mask_discriminator, raw_position_index_discriminator)
|
247 |
+
#
|
248 |
+
# sentence_disc = torch.LongTensor(
|
249 |
+
# [discriminator_tokenizer.tokenize(*i) for i in raw_sentence_discriminator])
|
250 |
+
# sentence_pad_mask_disc = torch.LongTensor(raw_sentence_pad_mask_discriminator)
|
251 |
+
# position_index_dic = torch.LongTensor(raw_position_index_discriminator)
|
252 |
+
#
|
253 |
+
# preds = discriminator_model.forward(subsampled_scene_xyz,
|
254 |
+
# sentence_disc.unsqueeze(0).repeat(cur_batch_size, 1).to(device),
|
255 |
+
# sentence_pad_mask_disc.unsqueeze(0).repeat(cur_batch_size,
|
256 |
+
# 1).to(device),
|
257 |
+
# position_index_dic.unsqueeze(0).repeat(cur_batch_size, 1).to(
|
258 |
+
# device))
|
259 |
+
# # preds = discriminator_model.forward(subsampled_scene_xyz)
|
260 |
+
# preds = discriminator_model.convert_logits(preds)
|
261 |
+
# preds = preds["is_circle"] # cur_batch_size,
|
262 |
+
# scores[cur_batch_idxs_start:cur_batch_idxs_end] = preds
|
263 |
+
# if visualize:
|
264 |
+
# print("discriminator scores", scores)
|
265 |
+
|
266 |
+
# scores = scores * discriminator_score_weight + no_intersection_scores * collision_score_weight
|
267 |
+
scores = no_intersection_scores
|
268 |
+
sort_idx = torch.argsort(scores).flip(dims=[0])[:num_elite]
|
269 |
+
elite_obj_params = obj_params[sort_idx] # num_elite, N, 6
|
270 |
+
elite_struct_poses = struct_pose.reshape(S, N, 4, 4)[sort_idx] # num_elite, N, 4, 4
|
271 |
+
elite_struct_poses = elite_struct_poses.reshape(num_elite * N, 4, 4) # num_elite x N, 4, 4
|
272 |
+
elite_scores = scores[sort_idx]
|
273 |
+
print("elite scores:", elite_scores)
|
274 |
+
|
275 |
+
####################################################
|
276 |
+
# # visualize best samples
|
277 |
+
# num_scene_pts = 4096 # if discriminator_num_scene_pts is None else discriminator_num_scene_pts
|
278 |
+
# batch_current_pc_pose = current_pc_pose[0: num_elite * N]
|
279 |
+
# best_new_obj_xyzs, best_goal_pc_pose, best_subsampled_scene_xyz, _, _ = \
|
280 |
+
# move_pc_and_create_scene_new(obj_xyzs, elite_obj_params, elite_struct_poses, batch_current_pc_pose,
|
281 |
+
# target_object_inds, self.device,
|
282 |
+
# return_scene_pts=True, num_scene_pts=num_scene_pts, normalize_pc=True)
|
283 |
+
# if visualize:
|
284 |
+
# print("visualizing elite rearrangements ranked by collision model/discriminator")
|
285 |
+
# visualize_batch_pcs(best_new_obj_xyzs, num_elite, limit_B=num_elite)
|
286 |
+
|
287 |
+
# num_elite, N, 6
|
288 |
+
elite_obj_params = elite_obj_params.reshape(num_elite * N, -1)
|
289 |
+
pc_poses_in_struct = torch.eye(4).repeat(num_elite * N, 1, 1).to(self.device)
|
290 |
+
pc_poses_in_struct[:, :3, :3] = tra3d.euler_angles_to_matrix(elite_obj_params[:, 3:], "XYZ")
|
291 |
+
pc_poses_in_struct[:, :3, 3] = elite_obj_params[:, :3]
|
292 |
+
pc_poses_in_struct = pc_poses_in_struct.reshape(num_elite, N, 4, 4) # num_elite, N, 4, 4
|
293 |
+
|
294 |
+
struct_pose = elite_struct_poses.reshape(num_elite, N, 4, 4)[:, 0,].unsqueeze(1) # num_elite, 1, 4, 4
|
295 |
+
|
296 |
+
return struct_pose, pc_poses_in_struct
|
src/StructDiffusion/language/__init__.py
ADDED
File without changes
|
src/StructDiffusion/language/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (180 Bytes). View file
|
|
src/StructDiffusion/language/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (184 Bytes). View file
|
|
src/StructDiffusion/language/__pycache__/tokenizer.cpython-37.pyc
ADDED
Binary file (11.4 kB). View file
|
|