diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..25625008b6a8963639be071f8cbc067f692d51ce
--- /dev/null
+++ b/app.py
@@ -0,0 +1,131 @@
+import os
+import argparse
+import torch
+import trimesh
+import numpy as np
+import pytorch_lightning as pl
+import gradio as gr
+from omegaconf import OmegaConf
+
+import sys
+sys.path.append('./src')
+
+from StructDiffusion.data.semantic_arrangement_demo import SemanticArrangementDataset
+from StructDiffusion.language.tokenizer import Tokenizer
+from StructDiffusion.models.pl_models import ConditionalPoseDiffusionModel
+from StructDiffusion.diffusion.sampler import Sampler
+from StructDiffusion.diffusion.pose_conversion import get_struct_objs_poses
+from StructDiffusion.utils.files import get_checkpoint_path_from_dir
+from StructDiffusion.utils.batch_inference import move_pc_and_create_scene_simple, visualize_batch_pcs
+from StructDiffusion.utils.rearrangement import show_pcs_with_trimesh
+
+
+class Infer_Wrapper:
+
+ def __init__(self, args, cfg):
+
+ # load
+ pl.seed_everything(args.eval_random_seed)
+ self.device = (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
+
+ checkpoint_dir = os.path.join(cfg.WANDB.save_dir, cfg.WANDB.project, args.checkpoint_id, "checkpoints")
+ checkpoint_path = get_checkpoint_path_from_dir(checkpoint_dir)
+
+ self.tokenizer = Tokenizer(cfg.DATASET.vocab_dir)
+ # override ignore_rgb for visualization
+ cfg.DATASET.ignore_rgb = False
+ self.dataset = SemanticArrangementDataset(tokenizer=self.tokenizer, **cfg.DATASET)
+
+ self.sampler = Sampler(ConditionalPoseDiffusionModel, checkpoint_path, self.device)
+
+ def run(self, di):
+
+ # di = np.random.choice(len(self.dataset))
+
+ raw_datum = self.dataset.get_raw_data(di)
+ print(self.tokenizer.convert_structure_params_to_natural_language(raw_datum["sentence"]))
+ datum = self.dataset.convert_to_tensors(raw_datum, self.tokenizer)
+ batch = self.dataset.single_datum_to_batch(datum, args.num_samples, self.device, inference_mode=True)
+
+ num_poses = datum["goal_poses"].shape[0]
+ xs = self.sampler.sample(batch, num_poses)
+
+ struct_pose, pc_poses_in_struct = get_struct_objs_poses(xs[0])
+ new_obj_xyzs = move_pc_and_create_scene_simple(batch["pcs"], struct_pose, pc_poses_in_struct)
+
+ # vis
+ vis_obj_xyzs = new_obj_xyzs[:3]
+ if torch.is_tensor(vis_obj_xyzs):
+ if vis_obj_xyzs.is_cuda:
+ vis_obj_xyzs = vis_obj_xyzs.detach().cpu()
+ vis_obj_xyzs = vis_obj_xyzs.numpy()
+
+ # for bi, vis_obj_xyz in enumerate(vis_obj_xyzs):
+ # if verbose:
+ # print("example {}".format(bi))
+ # print(vis_obj_xyz.shape)
+ #
+ # if trimesh:
+ # show_pcs_with_trimesh([xyz[:, :3] for xyz in vis_obj_xyz], [xyz[:, 3:] for xyz in vis_obj_xyz])
+ vis_obj_xyz = vis_obj_xyzs[0]
+ scene = show_pcs_with_trimesh([xyz[:, :3] for xyz in vis_obj_xyz], [xyz[:, 3:] for xyz in vis_obj_xyz], return_scene=True)
+
+ scene_filename = "./tmp_data/scene.glb"
+ scene.export(scene_filename)
+
+ # pc_filename = "/home/weiyu/Research/StructDiffusion/StructDiffusion/interactive_demo/tmp_data/pc.glb"
+ # scene_filename = "/home/weiyu/Research/StructDiffusion/StructDiffusion/interactive_demo/tmp_data/scene.glb"
+ #
+ # vis_obj_xyz = vis_obj_xyz.reshape(-1, 6)
+ # vis_pc = trimesh.PointCloud(vis_obj_xyz[:, :3], colors=np.concatenate([vis_obj_xyz[:, 3:] * 255, np.ones([vis_obj_xyz.shape[0], 1]) * 255], axis=-1))
+ # vis_pc.export(pc_filename)
+ #
+ # scene = trimesh.Scene()
+ # # add the coordinate frame first
+ # # geom = trimesh.creation.axis(0.01)
+ # # scene.add_geometry(geom)
+ # table = trimesh.creation.box(extents=[1.0, 1.0, 0.02])
+ # table.apply_translation([0.5, 0, -0.01])
+ # table.visual.vertex_colors = [150, 111, 87, 125]
+ # scene.add_geometry(table)
+ # # bounds = trimesh.creation.box(extents=[4.0, 4.0, 4.0])
+ # # bounds = trimesh.creation.icosphere(subdivisions=3, radius=3.1)
+ # # bounds.apply_translation([0, 0, 0])
+ # # bounds.visual.vertex_colors = [30, 30, 30, 30]
+ # # scene.add_geometry(bounds)
+ # # RT_4x4 = np.array([[-0.39560353822208355, -0.9183993826406329, 0.006357240869497738, 0.2651463080169481],
+ # # [-0.797630370081598, 0.3401340617616391, -0.4980909683511864, 0.2225696480721997],
+ # # [0.45528412367406523, -0.2021172778236285, -0.8671014777611122, 0.9449050652025951],
+ # # [0.0, 0.0, 0.0, 1.0]])
+ # # RT_4x4 = np.linalg.inv(RT_4x4)
+ # # RT_4x4 = RT_4x4 @ np.diag([1, -1, -1, 1])
+ # # scene.camera_transform = RT_4x4
+ #
+ # mesh_list = trimesh.util.concatenate(scene.dump())
+ # print(mesh_list)
+ # trimesh.io.export.export_mesh(mesh_list, scene_filename, file_type='obj')
+
+ return scene_filename
+
+
+args = OmegaConf.create()
+args.base_config_file = "./configs/base.yaml"
+args.config_file = "./configs/conditional_pose_diffusion.yaml"
+args.checkpoint_id = "ConditionalPoseDiffusion"
+args.eval_random_seed = 42
+args.num_samples = 1
+
+base_cfg = OmegaConf.load(args.base_config_file)
+cfg = OmegaConf.load(args.config_file)
+cfg = OmegaConf.merge(base_cfg, cfg)
+
+infer_wrapper = Infer_Wrapper(args, cfg)
+
+demo = gr.Interface(
+ fn=infer_wrapper.run,
+ inputs=gr.Slider(0, len(infer_wrapper.dataset)),
+ # clear color range [0-1.0]
+ outputs=gr.Model3D(clear_color=[0, 0, 0, 0], label="3D Model")
+)
+
+demo.launch()
\ No newline at end of file
diff --git a/configs/base.yaml b/configs/base.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0f57f0e51988fbb67ca874b8520ee214e12f19bb
--- /dev/null
+++ b/configs/base.yaml
@@ -0,0 +1,3 @@
+base_dirs:
+ data: data
+ wandb_dir: wandb_logs
\ No newline at end of file
diff --git a/configs/conditional_pose_diffusion.yaml b/configs/conditional_pose_diffusion.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..71086185256cb26421c452e015047aabcaa960ce
--- /dev/null
+++ b/configs/conditional_pose_diffusion.yaml
@@ -0,0 +1,81 @@
+random_seed: 1
+
+WANDB:
+ project: StructDiffusion
+ save_dir: ${base_dirs.wandb_dir}
+ name: conditional_pose_diffusion
+
+DATASET:
+ data_root: ${base_dirs.data}
+ vocab_dir: ${base_dirs.data}/type_vocabs_coarse.json
+
+ # important
+ use_virtual_structure_frame: True
+ ignore_distractor_objects: True
+ ignore_rgb: True
+
+ # the following are determined by the dataset
+ max_num_target_objects: 7
+ max_num_distractor_objects: 5
+ max_num_shape_parameters: 5
+ # set to zeros because they are not used for now
+ max_num_rearrange_features: 0
+ max_num_anchor_features: 0
+
+ num_pts: 1024
+ filter_num_moved_objects_range:
+ data_augmentation: False
+
+DATALOADER:
+ batch_size: 64
+ num_workers: 8
+ pin_memory: True
+
+MODEL:
+ # transformer encoder
+ encoder_input_dim: 256
+ num_attention_heads: 8
+ encoder_hidden_dim: 512
+ encoder_dropout: 0.0
+ encoder_activation: relu
+ encoder_num_layers: 8
+ # output head
+ structure_dropout: 0
+ object_dropout: 0
+ # pc encoder
+ ignore_rgb: ${DATASET.ignore_rgb}
+ pc_emb_dim: 256
+ posed_pc_emb_dim: 80
+ # pose encoder
+ pose_emb_dim: 80
+ # language
+ word_emb_dim: 160
+ # diffusion step
+ time_emb_dim: 80
+ # sequence embeddings
+ # max_num_target_objects (+ max_num_distractor_objects if not ignore_distractor_objects)
+ max_seq_size: 7
+ max_token_type_size: 4
+ seq_pos_emb_dim: 8
+ seq_type_emb_dim: 8
+ # virtual frame
+ use_virtual_structure_frame: ${DATASET.use_virtual_structure_frame}
+
+NOISE_SCHEDULE:
+ timesteps: 200
+
+LOSS:
+ type: huber
+
+OPTIMIZER:
+ lr: 0.0001
+ weight_decay: 0 #0.0001
+ # lr_restart: 3000
+ # warmup: 10
+
+TRAINER:
+ max_epochs: 200
+ gradient_clip_val: 1.0
+ gpus: 1
+ deterministic: False
+ # enable_progress_bar: False
\ No newline at end of file
diff --git a/configs/pairwise_collision.yaml b/configs/pairwise_collision.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d53be3bca4468d0e054ebf2f6edad233a4fb19dd
--- /dev/null
+++ b/configs/pairwise_collision.yaml
@@ -0,0 +1,42 @@
+random_seed: 1
+
+WANDB:
+ project: StructDiffusion
+ save_dir: ${base_dirs.wandb_dir}
+ name: pairwise_collision
+
+DATASET:
+ urdf_pc_idx_file: ${base_dirs.pairwise_collision_data}/urdf_pc_idx.pkl
+ collision_data_dir: ${base_dirs.pairwise_collision_data}
+
+ # important
+ num_pts: 1024
+ num_scene_pts: 2048
+ normalize_pc: True
+ random_rotation: True
+ data_augmentation: False
+
+DATALOADER:
+ batch_size: 32
+ num_workers: 8
+ pin_memory: True
+
+MODEL:
+ max_num_objects: 2
+ include_env_pc: False
+ pct_random_sampling: True
+
+LOSS:
+ type: Focal
+ focal_gamma: 2
+
+OPTIMIZER:
+ lr: 0.0001
+ weight_decay: 0
+
+TRAINER:
+ max_epochs: 200
+ gradient_clip_val: 1.0
+ gpus: 1
+ deterministic: False
+ # enable_progress_bar: False
\ No newline at end of file
diff --git a/data/data00000000.h5 b/data/data00000000.h5
new file mode 100755
index 0000000000000000000000000000000000000000..ef5df9df8fc9935abde7e9888abad26f7f72e10c
--- /dev/null
+++ b/data/data00000000.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:947574252625d338b9f37217eacf61f520136e27b458b6d3e65330339e8b299c
+size 1271489
diff --git a/data/data00000002.h5 b/data/data00000002.h5
new file mode 100755
index 0000000000000000000000000000000000000000..17e1ce30b8ae08f51059f46d2c1350e750713627
--- /dev/null
+++ b/data/data00000002.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3302432de555fed767c5b0d99c35ca01d5e4ac38cf4a0760b8ccb456b432e0e0
+size 3235242
diff --git a/data/data00000003.h5 b/data/data00000003.h5
new file mode 100755
index 0000000000000000000000000000000000000000..b0875002cae581f22752bb4da8981eb2455bf407
--- /dev/null
+++ b/data/data00000003.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0b907ba7c3a17f98a438617b462b2a4d3d3f8593c2dc47feb5a6cc3da8c034fc
+size 2059708
diff --git a/data/data00000004.h5 b/data/data00000004.h5
new file mode 100755
index 0000000000000000000000000000000000000000..61ba41edc4070d40a6628aa154154ac513dd376b
--- /dev/null
+++ b/data/data00000004.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d8ec0136dd4055d304e9b7f5697b79613099b8f8f1e5eec94281f22d8d47cca1
+size 2591656
diff --git a/data/data00000006.h5 b/data/data00000006.h5
new file mode 100755
index 0000000000000000000000000000000000000000..78f26a2a27967b099024348c8955bb3b92c64997
--- /dev/null
+++ b/data/data00000006.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0e74ebf185b0af58df0fa2483d5fd58a12b3b62ccac27ff665f35c5c7a13b8d8
+size 1572332
diff --git a/data/data00000008.h5 b/data/data00000008.h5
new file mode 100755
index 0000000000000000000000000000000000000000..0cefd1068e399ee5423bfe4e5821d5d903fbc24f
--- /dev/null
+++ b/data/data00000008.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:db015354a9d53e6fbaf0b040ce226484150b0af226a5c13a0b9f5cb9961db73c
+size 2167265
diff --git a/data/data00000009.h5 b/data/data00000009.h5
new file mode 100755
index 0000000000000000000000000000000000000000..944f7f52fe5a8fbe75f05b178bd1639501d8a757
--- /dev/null
+++ b/data/data00000009.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:990ad13f423d9089b30de81d002d23d9d00cf3e007fd7073793cbec03c456ebb
+size 3607752
diff --git a/data/data00000012.h5 b/data/data00000012.h5
new file mode 100755
index 0000000000000000000000000000000000000000..4d03a37a210f47f01463e9567c603fb9f9451d16
--- /dev/null
+++ b/data/data00000012.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:93161f5666c54dbc259c9efa516b67340613b592a1ed42e6c63d4cc8a495002a
+size 2525622
diff --git a/data/data00000013.h5 b/data/data00000013.h5
new file mode 100755
index 0000000000000000000000000000000000000000..51389ce34ced6cab838f8898f6c9344f138eec82
--- /dev/null
+++ b/data/data00000013.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:94c9cfe6d9f0df176eb0a3baccdf53c7e6e5fc807e5e7ea9e138ad7159f500d9
+size 1715352
diff --git a/data/data00000015.h5 b/data/data00000015.h5
new file mode 100755
index 0000000000000000000000000000000000000000..2112d04990e9d6facec890933d59b47e2f2c8053
--- /dev/null
+++ b/data/data00000015.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9aab522f9ace1a03b1705fe3bd693b589971d133d45c232ef9ec53842a540bfa
+size 2647026
diff --git a/data/type_vocabs_coarse.json b/data/type_vocabs_coarse.json
new file mode 100644
index 0000000000000000000000000000000000000000..91544d8abba1f3aae15fb12f5a64559bd83cab29
--- /dev/null
+++ b/data/type_vocabs_coarse.json
@@ -0,0 +1 @@
+{"class": {"Basket": 0, "BeerBottle": 1, "Book": 2, "Bottle": 3, "Bowl": 4, "Calculator": 5, "Candle": 6, "CellPhone": 7, "ComputerMouse": 8, "Controller": 9, "Cup": 10, "Donut": 11, "Fork": 12, "Hammer": 13, "Knife": 14, "Marker": 15, "MilkCarton": 16, "Mug": 17, "Pan": 18, "Pen": 19, "PillBottle": 20, "Plate": 21, "PowerStrip": 22, "Scissors": 23, "SoapBottle": 24, "SodaCan": 25, "Spoon": 26, "Stapler": 27, "Teapot": 28, "VideoGameController": 29, "WineBottle": 30, "CanOpener":31, "Fruit": 32}, "scene": {"dinner": 0}, "size": {"L": 0, "M": 1, "S": 2}, "color": {"blue": 0, "cyan": 1, "green": 2, "magenta": 3, "red": 4, "yellow": 5}, "material": {"glass": 0, "metal": 1, "plastic": 2}, "comparator": {"less": 1, "greater": 2, "equal": 3}, "radius": [0.0, 0.5, 3], "position_x": [-0.1, 1.0, 3], "position_y": [-0.5, 0.5, 3], "rotation": [-3.15, 3.15, 4], "height": [0.0, 0.5, 10], "volumn": [0.0, 0.015, 10], "uniform_angle": {"False": 0, "True": 1}, "face_center": {"False": 0, "True": 1}, "angle_ratio": {"0.5": 0, "1.0": 1}, "shape": {"circle": 0, "line": 1, "tower": 2, "dinner": 3}, "obj_x": [-1.0, 1.0, 200], "obj_y": [-1.0, 1.0, 200], "obj_z": [-1.0, 1.0, 200], "obj_rr": [-3.15, 3.15, 360], "obj_rp": [-3.15, 3.15, 360], "obj_ry": [-3.15, 3.15, 360],"struct_x": [-1.0, 1.0, 200], "struct_y": [-1.0, 1.0, 200], "struct_z": [-1.0, 1.0, 200], "struct_rr": [-3.15, 3.15, 360], "struct_rp": [-3.15, 3.15, 360], "struct_ry": [-3.15, 3.15, 360]}
\ No newline at end of file
diff --git a/packages.txt b/packages.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d4458ef18f1d7c83facf3d8f9653b23a58947437
--- /dev/null
+++ b/packages.txt
@@ -0,0 +1 @@
+python3-opencv
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..3d14e06166201ca5f642f6b1b7c626aef6ea86aa
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,13 @@
+numpy==1.21
+h5py==2.10.0
+opencv-python
+open3d
+trimesh==3.10.2
+pyglet==1.5.0
+pybullet==3.1.7
+nvisii==1.1.70
+openpyxl
+pytorch_lightning==1.6.1
+wandb===0.13.10
+pytorch3d==0.3.0
+omegaconf==2.2.2
\ No newline at end of file
diff --git a/scripts/infer.py b/scripts/infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a332032715d1fe655876d9a36d588a1380f75d9
--- /dev/null
+++ b/scripts/infer.py
@@ -0,0 +1,78 @@
+import os
+import argparse
+import torch
+import numpy as np
+import pytorch_lightning as pl
+from omegaconf import OmegaConf
+
+from StructDiffusion.data.semantic_arrangement import SemanticArrangementDataset
+from StructDiffusion.language.tokenizer import Tokenizer
+from StructDiffusion.models.pl_models import ConditionalPoseDiffusionModel
+from StructDiffusion.diffusion.sampler import Sampler
+from StructDiffusion.diffusion.pose_conversion import get_struct_objs_poses
+from StructDiffusion.utils.files import get_checkpoint_path_from_dir
+from StructDiffusion.utils.batch_inference import move_pc_and_create_scene_simple, visualize_batch_pcs
+
+
+def main(args, cfg):
+
+ pl.seed_everything(args.eval_random_seed)
+
+ device = (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
+
+ checkpoint_dir = os.path.join(cfg.WANDB.save_dir, cfg.WANDB.project, args.checkpoint_id, "checkpoints")
+ checkpoint_path = get_checkpoint_path_from_dir(checkpoint_dir)
+
+ if args.eval_mode == "infer":
+
+ tokenizer = Tokenizer(cfg.DATASET.vocab_dir)
+ # override ignore_rgb for visualization
+ cfg.DATASET.ignore_rgb = False
+ dataset = SemanticArrangementDataset(split="test", tokenizer=tokenizer, **cfg.DATASET)
+
+ sampler = Sampler(ConditionalPoseDiffusionModel, checkpoint_path, device)
+
+ data_idxs = np.random.permutation(len(dataset))
+ for di in data_idxs:
+ raw_datum = dataset.get_raw_data(di)
+ print(tokenizer.convert_structure_params_to_natural_language(raw_datum["sentence"]))
+ datum = dataset.convert_to_tensors(raw_datum, tokenizer)
+ batch = dataset.single_datum_to_batch(datum, args.num_samples, device, inference_mode=True)
+
+ num_poses = datum["goal_poses"].shape[0]
+ xs = sampler.sample(batch, num_poses)
+
+ struct_pose, pc_poses_in_struct = get_struct_objs_poses(xs[0])
+ new_obj_xyzs = move_pc_and_create_scene_simple(batch["pcs"], struct_pose, pc_poses_in_struct)
+ visualize_batch_pcs(new_obj_xyzs, args.num_samples, limit_B=10, trimesh=True)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="infer")
+ parser.add_argument("--base_config_file", help='base config yaml file',
+ default='../configs/base.yaml',
+ type=str)
+ parser.add_argument("--config_file", help='config yaml file',
+ default='../configs/conditional_pose_diffusion.yaml',
+ type=str)
+ parser.add_argument("--checkpoint_id",
+ default="ConditionalPoseDiffusion",
+ type=str)
+ parser.add_argument("--eval_mode",
+ default="infer",
+ type=str)
+ parser.add_argument("--eval_random_seed",
+ default=42,
+ type=int)
+ parser.add_argument("--num_samples",
+ default=10,
+ type=int)
+ args = parser.parse_args()
+
+ base_cfg = OmegaConf.load(args.base_config_file)
+ cfg = OmegaConf.load(args.config_file)
+ cfg = OmegaConf.merge(base_cfg, cfg)
+
+ main(args, cfg)
+
+
diff --git a/scripts/infer_with_discriminator.py b/scripts/infer_with_discriminator.py
new file mode 100644
index 0000000000000000000000000000000000000000..3452eb7fd2acd90d4482dd388f926bbd318b3515
--- /dev/null
+++ b/scripts/infer_with_discriminator.py
@@ -0,0 +1,81 @@
+import os
+import argparse
+import torch
+import numpy as np
+import pytorch_lightning as pl
+from omegaconf import OmegaConf
+
+from StructDiffusion.data.semantic_arrangement import SemanticArrangementDataset
+from StructDiffusion.language.tokenizer import Tokenizer
+from StructDiffusion.models.pl_models import ConditionalPoseDiffusionModel, PairwiseCollisionModel
+from StructDiffusion.diffusion.sampler import SamplerV2
+from StructDiffusion.diffusion.pose_conversion import get_struct_objs_poses
+from StructDiffusion.utils.files import get_checkpoint_path_from_dir
+from StructDiffusion.utils.batch_inference import move_pc_and_create_scene_simple, visualize_batch_pcs
+
+
+def main(args, cfg):
+
+ pl.seed_everything(args.eval_random_seed)
+
+ device = (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
+
+ diffusion_checkpoint_path = get_checkpoint_path_from_dir(os.path.join(cfg.WANDB.save_dir, cfg.WANDB.project, args.diffusion_checkpoint_id, "checkpoints"))
+ collision_checkpoint_path = get_checkpoint_path_from_dir(os.path.join(cfg.WANDB.save_dir, cfg.WANDB.project, args.collision_checkpoint_id, "checkpoints"))
+
+ if args.eval_mode == "infer":
+
+ tokenizer = Tokenizer(cfg.DATASET.vocab_dir)
+ # override ignore_rgb for visualization
+ cfg.DATASET.ignore_rgb = False
+ dataset = SemanticArrangementDataset(split="test", tokenizer=tokenizer, **cfg.DATASET)
+
+ sampler = SamplerV2(ConditionalPoseDiffusionModel, diffusion_checkpoint_path,
+ PairwiseCollisionModel, collision_checkpoint_path, device)
+
+ data_idxs = np.random.permutation(len(dataset))
+ for di in data_idxs:
+ raw_datum = dataset.get_raw_data(di)
+ print(tokenizer.convert_structure_params_to_natural_language(raw_datum["sentence"]))
+ datum = dataset.convert_to_tensors(raw_datum, tokenizer)
+ batch = dataset.single_datum_to_batch(datum, args.num_samples, device, inference_mode=True)
+
+ num_poses = datum["goal_poses"].shape[0]
+ struct_pose, pc_poses_in_struct = sampler.sample(batch, num_poses)
+
+ new_obj_xyzs = move_pc_and_create_scene_simple(batch["pcs"], struct_pose, pc_poses_in_struct)
+ visualize_batch_pcs(new_obj_xyzs, args.num_samples, limit_B=10, trimesh=True)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="infer")
+ parser.add_argument("--base_config_file", help='base config yaml file',
+ default='../configs/base.yaml',
+ type=str)
+ parser.add_argument("--config_file", help='config yaml file',
+ default='../configs/conditional_pose_diffusion.yaml',
+ type=str)
+ parser.add_argument("--diffusion_checkpoint_id",
+ default="ConditionalPoseDiffusion",
+ type=str)
+ parser.add_argument("--collision_checkpoint_id",
+ default="curhl56k",
+ type=str)
+ parser.add_argument("--eval_mode",
+ default="infer",
+ type=str)
+ parser.add_argument("--eval_random_seed",
+ default=42,
+ type=int)
+ parser.add_argument("--num_samples",
+ default=10,
+ type=int)
+ args = parser.parse_args()
+
+ base_cfg = OmegaConf.load(args.base_config_file)
+ cfg = OmegaConf.load(args.config_file)
+ cfg = OmegaConf.merge(base_cfg, cfg)
+
+ main(args, cfg)
+
+
diff --git a/scripts/train_discriminator.py b/scripts/train_discriminator.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa5a476c37b2c8c2c0e4192107800cd5986f6124
--- /dev/null
+++ b/scripts/train_discriminator.py
@@ -0,0 +1,46 @@
+import argparse
+import torch
+from torch.utils.data import DataLoader
+from omegaconf import OmegaConf
+import pytorch_lightning as pl
+from pytorch_lightning.loggers import WandbLogger
+from pytorch_lightning.callbacks import ModelCheckpoint
+
+from StructDiffusion.data.pairwise_collision import PairwiseCollisionDataset
+from StructDiffusion.models.pl_models import PairwiseCollisionModel
+
+
+def main(cfg):
+
+ pl.seed_everything(cfg.random_seed)
+
+ wandb_logger = WandbLogger(**cfg.WANDB)
+ wandb_logger.experiment.config.update(cfg)
+ checkpoint_callback = ModelCheckpoint()
+
+ full_dataset = PairwiseCollisionDataset(**cfg.DATASET)
+ train_dataset, valid_dataset = torch.utils.data.random_split(full_dataset, [int(len(full_dataset) * 0.7), len(full_dataset) - int(len(full_dataset) * 0.7)])
+ train_dataloader = DataLoader(train_dataset, shuffle=True, **cfg.DATALOADER)
+ valid_dataloader = DataLoader(valid_dataset, shuffle=False, **cfg.DATALOADER)
+
+ model = PairwiseCollisionModel(cfg.MODEL, cfg.LOSS, cfg.OPTIMIZER, cfg.DATASET)
+
+ trainer = pl.Trainer(logger=wandb_logger, callbacks=[checkpoint_callback], **cfg.TRAINER)
+
+ trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=valid_dataloader)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="train")
+ parser.add_argument("--base_config_file", help='base config yaml file',
+ default='../configs/base.yaml',
+ type=str)
+ parser.add_argument("--config_file", help='config yaml file',
+ default='../configs/pairwise_collision.yaml',
+ type=str)
+ args = parser.parse_args()
+ base_cfg = OmegaConf.load(args.base_config_file)
+ cfg = OmegaConf.load(args.config_file)
+ cfg = OmegaConf.merge(base_cfg, cfg)
+
+ main(cfg)
\ No newline at end of file
diff --git a/scripts/train_generator.py b/scripts/train_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..8131a9f2626b0f84bc91e1d76250a4e5819f8785
--- /dev/null
+++ b/scripts/train_generator.py
@@ -0,0 +1,49 @@
+from torch.utils.data import DataLoader
+import argparse
+from omegaconf import OmegaConf
+import pytorch_lightning as pl
+from pytorch_lightning.loggers import WandbLogger
+from pytorch_lightning.callbacks import ModelCheckpoint
+
+from StructDiffusion.data.semantic_arrangement import SemanticArrangementDataset
+from StructDiffusion.language.tokenizer import Tokenizer
+from StructDiffusion.models.pl_models import ConditionalPoseDiffusionModel
+
+
+def main(cfg):
+
+ pl.seed_everything(cfg.random_seed)
+
+ wandb_logger = WandbLogger(**cfg.WANDB)
+ wandb_logger.experiment.config.update(cfg)
+ checkpoint_callback = ModelCheckpoint()
+
+ tokenizer = Tokenizer(cfg.DATASET.vocab_dir)
+ vocab_size = tokenizer.get_vocab_size()
+
+ train_dataset = SemanticArrangementDataset(split="train", tokenizer=tokenizer, **cfg.DATASET)
+ valid_dataset = SemanticArrangementDataset(split="valid", tokenizer=tokenizer, **cfg.DATASET)
+ train_dataloader = DataLoader(train_dataset, shuffle=True, **cfg.DATALOADER)
+ valid_dataloader = DataLoader(valid_dataset, shuffle=False, **cfg.DATALOADER)
+
+ model = ConditionalPoseDiffusionModel(vocab_size, cfg.MODEL, cfg.LOSS, cfg.NOISE_SCHEDULE, cfg.OPTIMIZER)
+
+ trainer = pl.Trainer(logger=wandb_logger, callbacks=[checkpoint_callback], **cfg.TRAINER)
+
+ trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=valid_dataloader)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="train")
+ parser.add_argument("--base_config_file", help='base config yaml file',
+ default='../configs/base.yaml',
+ type=str)
+ parser.add_argument("--config_file", help='config yaml file',
+ default='../configs/conditional_pose_diffusion.yaml',
+ type=str)
+ args = parser.parse_args()
+ base_cfg = OmegaConf.load(args.base_config_file)
+ cfg = OmegaConf.load(args.config_file)
+ cfg = OmegaConf.merge(base_cfg, cfg)
+
+ main(cfg)
\ No newline at end of file
diff --git a/src/StructDiffusion/__init__.py b/src/StructDiffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/StructDiffusion/__pycache__/__init__.cpython-37.pyc b/src/StructDiffusion/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..73b43dd4131a1a5fe63caa5e0cc1e02ee7a25d93
Binary files /dev/null and b/src/StructDiffusion/__pycache__/__init__.cpython-37.pyc differ
diff --git a/src/StructDiffusion/__pycache__/__init__.cpython-38.pyc b/src/StructDiffusion/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..962a4af7693c5c3f4471bf199241b94961d2a085
Binary files /dev/null and b/src/StructDiffusion/__pycache__/__init__.cpython-38.pyc differ
diff --git a/src/StructDiffusion/data/__init__.py b/src/StructDiffusion/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/StructDiffusion/data/__pycache__/__init__.cpython-37.pyc b/src/StructDiffusion/data/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..49140fb3e29de7aa109cdcf66678a0a1e2f63717
Binary files /dev/null and b/src/StructDiffusion/data/__pycache__/__init__.cpython-37.pyc differ
diff --git a/src/StructDiffusion/data/__pycache__/__init__.cpython-38.pyc b/src/StructDiffusion/data/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..196f5127bc03b7b7fdba37a1dc8f835363f6266f
Binary files /dev/null and b/src/StructDiffusion/data/__pycache__/__init__.cpython-38.pyc differ
diff --git a/src/StructDiffusion/data/__pycache__/pairwise_collision.cpython-37.pyc b/src/StructDiffusion/data/__pycache__/pairwise_collision.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1a858b5344ce7e72d040d4a4ec23d3640534dedf
Binary files /dev/null and b/src/StructDiffusion/data/__pycache__/pairwise_collision.cpython-37.pyc differ
diff --git a/src/StructDiffusion/data/__pycache__/semantic_arrangement.cpython-37.pyc b/src/StructDiffusion/data/__pycache__/semantic_arrangement.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4da8facf199b17a7cf2d255dd2bacf22774df8a1
Binary files /dev/null and b/src/StructDiffusion/data/__pycache__/semantic_arrangement.cpython-37.pyc differ
diff --git a/src/StructDiffusion/data/__pycache__/semantic_arrangement.cpython-38.pyc b/src/StructDiffusion/data/__pycache__/semantic_arrangement.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..62d26963f2419f17e4dbb98fa821f224df59a9d6
Binary files /dev/null and b/src/StructDiffusion/data/__pycache__/semantic_arrangement.cpython-38.pyc differ
diff --git a/src/StructDiffusion/data/__pycache__/semantic_arrangement_demo.cpython-38.pyc b/src/StructDiffusion/data/__pycache__/semantic_arrangement_demo.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0ae89fd0a7d26e8c837fd3d41b8e8f118d2c12ec
Binary files /dev/null and b/src/StructDiffusion/data/__pycache__/semantic_arrangement_demo.cpython-38.pyc differ
diff --git a/src/StructDiffusion/data/pairwise_collision.py b/src/StructDiffusion/data/pairwise_collision.py
new file mode 100644
index 0000000000000000000000000000000000000000..2519d16fce3fe0df4a38b02e24ce174cfef09c25
--- /dev/null
+++ b/src/StructDiffusion/data/pairwise_collision.py
@@ -0,0 +1,361 @@
+import cv2
+import h5py
+import numpy as np
+import os
+import trimesh
+import torch
+import json
+from collections import defaultdict
+import tqdm
+import pickle
+from random import shuffle
+
+# Local imports
+from StructDiffusion.utils.rearrangement import show_pcs, get_pts, array_to_tensor
+from StructDiffusion.utils.pointnet import pc_normalize
+
+import StructDiffusion.utils.brain2.camera as cam
+import StructDiffusion.utils.brain2.image as img
+import StructDiffusion.utils.transformations as tra
+
+
+def load_pairwise_collision_data(h5_filename):
+
+ fh = h5py.File(h5_filename, 'r')
+ data_dict = {}
+ data_dict["obj1_info"] = eval(fh["obj1_info"][()])
+ data_dict["obj2_info"] = eval(fh["obj2_info"][()])
+ data_dict["obj1_poses"] = fh["obj1_poses"][:]
+ data_dict["obj2_poses"] = fh["obj2_poses"][:]
+ data_dict["intersection_labels"] = fh["intersection_labels"][:]
+
+ return data_dict
+
+
+class PairwiseCollisionDataset(torch.utils.data.Dataset):
+
+ def __init__(self, urdf_pc_idx_file, collision_data_dir, random_rotation=True,
+ num_pts=1024, normalize_pc=True, num_scene_pts=2048, data_augmentation=False,
+ debug=False):
+
+ # load dictionary mapping from urdf to list of pc data, each sample is
+ # {"step_t": step_t, "obj": obj, "filename": filename}
+ with open(urdf_pc_idx_file, "rb") as fh:
+ self.urdf_to_pc_data = pickle.load(fh)
+ # filter out broken files
+ for urdf in self.urdf_to_pc_data:
+ valid_pc_data = []
+ for pd in self.urdf_to_pc_data[urdf]:
+ filename = pd["filename"]
+ if "data00026058" in filename or "data00011415" in filename or "data00026061" in filename or "data00700565" in filename or "data00505290" in filename:
+ continue
+ valid_pc_data.append(pd)
+ if valid_pc_data:
+ self.urdf_to_pc_data[urdf] = valid_pc_data
+
+ # build data index
+ # each sample is a tuple of (collision filename, idx for the labels and poses)
+ if collision_data_dir is not None:
+ self.data_idxs = self.build_data_idxs(collision_data_dir)
+ else:
+ print("WARNING: collision_data_dir is None")
+
+ self.num_pts = num_pts
+ self.debug = debug
+ self.normalize_pc = normalize_pc
+ self.num_scene_pts = num_scene_pts
+ self.random_rotation = random_rotation
+
+ # Noise
+ self.data_augmentation = data_augmentation
+ # additive noise
+ self.gp_rescale_factor_range = [12, 20]
+ self.gaussian_scale_range = [0., 0.003]
+ # multiplicative noise
+ self.gamma_shape = 1000.
+ self.gamma_scale = 0.001
+
+ def build_data_idxs(self, collision_data_dir):
+ print("Load collision data...")
+ positive_data = []
+ negative_data = []
+ for filename in tqdm.tqdm(os.listdir(collision_data_dir)):
+ if "h5" not in filename:
+ continue
+ h5_filename = os.path.join(collision_data_dir, filename)
+ data_dict = load_pairwise_collision_data(h5_filename)
+ obj1_urdf = data_dict["obj1_info"]["urdf"]
+ obj2_urdf = data_dict["obj2_info"]["urdf"]
+ if obj1_urdf not in self.urdf_to_pc_data:
+ print("no pc data for urdf:", obj1_urdf)
+ continue
+ if obj2_urdf not in self.urdf_to_pc_data:
+ print("no pc data for urdf:", obj2_urdf)
+ continue
+ for idx, l in enumerate(data_dict["intersection_labels"]):
+ if l:
+ # intersection
+ positive_data.append((h5_filename, idx))
+ else:
+ negative_data.append((h5_filename, idx))
+ print("Num pairwise intersections:", len(positive_data))
+ print("Num pairwise no intersections:", len(negative_data))
+
+ if len(negative_data) != len(positive_data):
+ min_len = min(len(negative_data), len(positive_data))
+ positive_data = [positive_data[i] for i in np.random.permutation(len(positive_data))[:min_len]]
+ negative_data = [negative_data[i] for i in np.random.permutation(len(negative_data))[:min_len]]
+ print("after balancing")
+ print("Num pairwise intersections:", len(positive_data))
+ print("Num pairwise no intersections:", len(negative_data))
+
+ return positive_data + negative_data
+
+ def create_urdf_pc_idxs(self, urdf_pc_idx_file, data_roots, index_roots):
+ print("Load pc data")
+ arrangement_steps = []
+ for split in ["train"]:
+ for data_root, index_root in zip(data_roots, index_roots):
+ arrangement_indices_file = os.path.join(data_root, index_root,"{}_arrangement_indices_file_all.txt".format(split))
+ if os.path.exists(arrangement_indices_file):
+ with open(arrangement_indices_file, "r") as fh:
+ arrangement_steps.extend([(os.path.join(data_root, f[0]), f[1]) for f in eval(fh.readline().strip())])
+ else:
+ print("{} does not exist".format(arrangement_indices_file))
+
+ urdf_to_pc_data = defaultdict(list)
+ for filename, step_t in tqdm.tqdm(arrangement_steps):
+ h5 = h5py.File(filename, 'r')
+ ids = self._get_ids(h5)
+ # moved_objs = h5['moved_objs'][()].split(',')
+ all_objs = sorted([o for o in ids.keys() if "object_" in o])
+ goal_specification = json.loads(str(np.array(h5["goal_specification"])))
+ obj_infos = goal_specification["rearrange"]["objects"] + goal_specification["anchor"]["objects"] + goal_specification["distract"]["objects"]
+ for obj, obj_info in zip(all_objs, obj_infos):
+ urdf_to_pc_data[obj_info["urdf"]].append({"step_t": step_t, "obj": obj, "filename": filename})
+
+ with open(urdf_pc_idx_file, "wb") as fh:
+ pickle.dump(urdf_to_pc_data, fh)
+
+ return urdf_to_pc_data
+
+ def add_noise_to_depth(self, depth_img):
+ """ add depth noise """
+ multiplicative_noise = np.random.gamma(self.gamma_shape, self.gamma_scale)
+ depth_img = multiplicative_noise * depth_img
+ return depth_img
+
+ def add_noise_to_xyz(self, xyz_img, depth_img):
+ """ TODO: remove this code or at least celean it up"""
+ xyz_img = xyz_img.copy()
+ H, W, C = xyz_img.shape
+ gp_rescale_factor = np.random.randint(self.gp_rescale_factor_range[0],
+ self.gp_rescale_factor_range[1])
+ gp_scale = np.random.uniform(self.gaussian_scale_range[0],
+ self.gaussian_scale_range[1])
+ small_H, small_W = (np.array([H, W]) / gp_rescale_factor).astype(int)
+ additive_noise = np.random.normal(loc=0.0, scale=gp_scale, size=(small_H, small_W, C))
+ additive_noise = cv2.resize(additive_noise, (W, H), interpolation=cv2.INTER_CUBIC)
+ xyz_img[depth_img > 0, :] += additive_noise[depth_img > 0, :]
+ return xyz_img
+
+ def _get_images(self, h5, idx, ee=True):
+ if ee:
+ RGB, DEPTH, SEG = "ee_rgb", "ee_depth", "ee_seg"
+ DMIN, DMAX = "ee_depth_min", "ee_depth_max"
+ else:
+ RGB, DEPTH, SEG = "rgb", "depth", "seg"
+ DMIN, DMAX = "depth_min", "depth_max"
+ dmin = h5[DMIN][idx]
+ dmax = h5[DMAX][idx]
+ rgb1 = img.PNGToNumpy(h5[RGB][idx])[:, :, :3] / 255. # remove alpha
+ depth1 = h5[DEPTH][idx] / 20000. * (dmax - dmin) + dmin
+ seg1 = img.PNGToNumpy(h5[SEG][idx])
+
+ valid1 = np.logical_and(depth1 > 0.1, depth1 < 2.)
+
+ # proj_matrix = h5['proj_matrix'][()]
+ camera = cam.get_camera_from_h5(h5)
+ if self.data_augmentation:
+ depth1 = self.add_noise_to_depth(depth1)
+
+ xyz1 = cam.compute_xyz(depth1, camera)
+ if self.data_augmentation:
+ xyz1 = self.add_noise_to_xyz(xyz1, depth1)
+
+ # Transform the point cloud
+ # Here it is...
+ # CAM_POSE = "ee_cam_pose" if ee else "cam_pose"
+ CAM_POSE = "ee_camera_view" if ee else "camera_view"
+ cam_pose = h5[CAM_POSE][idx]
+ if ee:
+ # ee_camera_view has 0s for x, y, z
+ cam_pos = h5["ee_cam_pose"][:][:3, 3]
+ cam_pose[:3, 3] = cam_pos
+
+ # Get transformed point cloud
+ h, w, d = xyz1.shape
+ xyz1 = xyz1.reshape(h * w, -1)
+ xyz1 = trimesh.transform_points(xyz1, cam_pose)
+ xyz1 = xyz1.reshape(h, w, -1)
+
+ scene1 = rgb1, depth1, seg1, valid1, xyz1
+
+ return scene1
+
+ def _get_ids(self, h5):
+ """
+ get object ids
+
+ @param h5:
+ @return:
+ """
+ ids = {}
+ for k in h5.keys():
+ if k.startswith("id_"):
+ ids[k[3:]] = h5[k][()]
+ return ids
+
+ def get_obj_pc(self, h5, step_t, obj):
+ scene = self._get_images(h5, step_t, ee=True)
+ rgb, depth, seg, valid, xyz = scene
+
+ # getting object point clouds
+ ids = self._get_ids(h5)
+ obj_mask = np.logical_and(seg == ids[obj], valid)
+ if np.sum(obj_mask) <= 0:
+ raise Exception
+ ok, obj_xyz, obj_rgb, _ = get_pts(xyz, rgb, obj_mask, num_pts=self.num_pts, to_tensor=False)
+ obj_pc_center = np.mean(obj_xyz, axis=0)
+ obj_pose = h5[obj][step_t]
+
+ obj_pc_pose = np.eye(4)
+ obj_pc_pose[:3, 3] = obj_pc_center[:3]
+
+ return obj_xyz, obj_rgb, obj_pc_pose, obj_pose
+
+ def __len__(self):
+ return len(self.data_idxs)
+
+ def __getitem__(self, idx):
+ collision_filename, collision_idx = self.data_idxs[idx]
+ collision_data_dict = load_pairwise_collision_data(collision_filename)
+
+ obj1_urdf = collision_data_dict["obj1_info"]["urdf"]
+ obj2_urdf = collision_data_dict["obj2_info"]["urdf"]
+
+ # TODO: find a better way to sample pc data?
+ obj1_pc_data = np.random.choice(self.urdf_to_pc_data[obj1_urdf])
+ obj2_pc_data = np.random.choice(self.urdf_to_pc_data[obj2_urdf])
+
+ obj1_xyz, obj1_rgb, obj1_pc_pose, obj1_pose = self.get_obj_pc(h5py.File(obj1_pc_data["filename"], "r"), obj1_pc_data["step_t"], obj1_pc_data["obj"])
+ obj2_xyz, obj2_rgb, obj2_pc_pose, obj2_pose = self.get_obj_pc(h5py.File(obj2_pc_data["filename"], "r"), obj2_pc_data["step_t"], obj2_pc_data["obj"])
+
+ obj1_c_pose = collision_data_dict["obj1_poses"][collision_idx]
+ obj2_c_pose = collision_data_dict["obj2_poses"][collision_idx]
+ label = collision_data_dict["intersection_labels"][collision_idx]
+
+ obj1_transform = obj1_c_pose @ np.linalg.inv(obj1_pose)
+ obj2_transform = obj2_c_pose @ np.linalg.inv(obj2_pose)
+ obj1_c_xyz = trimesh.transform_points(obj1_xyz, obj1_transform)
+ obj2_c_xyz = trimesh.transform_points(obj2_xyz, obj2_transform)
+
+ # if self.debug:
+ # show_pcs([obj1_c_xyz, obj2_c_xyz], [obj1_rgb, obj2_rgb], add_coordinate_frame=True)
+
+ ###################################
+ obj_xyzs = [obj1_c_xyz, obj2_c_xyz]
+ shuffle(obj_xyzs)
+
+ num_indicator = 2
+ new_obj_xyzs = []
+ for oi, obj_xyz in enumerate(obj_xyzs):
+ obj_xyz = np.concatenate([obj_xyz, np.tile(np.eye(num_indicator)[oi], (obj_xyz.shape[0], 1))], axis=1)
+ new_obj_xyzs.append(obj_xyz)
+ scene_xyz = np.concatenate(new_obj_xyzs, axis=0)
+
+ # subsampling and normalizing pc
+ idx = np.random.randint(0, scene_xyz.shape[0], self.num_scene_pts)
+ scene_xyz = scene_xyz[idx]
+ if self.normalize_pc:
+ scene_xyz[:, 0:3] = pc_normalize(scene_xyz[:, 0:3])
+
+ if self.random_rotation:
+ scene_xyz[:, 0:3] = trimesh.transform_points(scene_xyz[:, 0:3], tra.euler_matrix(0, 0, np.random.uniform(low=0, high=2 * np.pi)))
+
+ ###################################
+ scene_xyz = array_to_tensor(scene_xyz)
+ # convert to torch data
+ label = int(label)
+
+ if self.debug:
+ print("intersection:", label)
+ show_pcs([scene_xyz[:, 0:3]], [np.tile(np.array([0, 1, 0], dtype=np.float), (scene_xyz.shape[0], 1))], add_coordinate_frame=True)
+
+ datum = {
+ "scene_xyz": scene_xyz,
+ "label": torch.FloatTensor([label]),
+ }
+ return datum
+
+ # @staticmethod
+ # def collate_fn(data):
+ # """
+ # :param data:
+ # :return:
+ # """
+ #
+ # batched_data_dict = {}
+ # for key in ["is_circle"]:
+ # batched_data_dict[key] = torch.cat([dict[key] for dict in data], dim=0)
+ # for key in ["scene_xyz"]:
+ # batched_data_dict[key] = torch.stack([dict[key] for dict in data], dim=0)
+ #
+ # return batched_data_dict
+ #
+ # # def create_pair_xyzs_from_obj_xyzs(self, new_obj_xyzs, debug=False):
+ # #
+ # # new_obj_xyzs = [xyz.cpu().numpy() for xyz in new_obj_xyzs]
+ # #
+ # # # compute pairwise collision
+ # # scene_xyzs = []
+ # # obj_xyz_pair_idxs = list(itertools.combinations(range(len(new_obj_xyzs)), 2))
+ # #
+ # # for obj_xyz_pair_idx in obj_xyz_pair_idxs:
+ # # obj_xyz_pair = [new_obj_xyzs[obj_xyz_pair_idx[0]], new_obj_xyzs[obj_xyz_pair_idx[1]]]
+ # # num_indicator = 2
+ # # obj_xyz_pair_ind = []
+ # # for oi, obj_xyz in enumerate(obj_xyz_pair):
+ # # obj_xyz = np.concatenate([obj_xyz, np.tile(np.eye(num_indicator)[oi], (obj_xyz.shape[0], 1))], axis=1)
+ # # obj_xyz_pair_ind.append(obj_xyz)
+ # # pair_scene_xyz = np.concatenate(obj_xyz_pair_ind, axis=0)
+ # #
+ # # # subsampling and normalizing pc
+ # # rand_idx = np.random.randint(0, pair_scene_xyz.shape[0], self.num_scene_pts)
+ # # pair_scene_xyz = pair_scene_xyz[rand_idx]
+ # # if self.normalize_pc:
+ # # pair_scene_xyz[:, 0:3] = pc_normalize(pair_scene_xyz[:, 0:3])
+ # #
+ # # scene_xyzs.append(array_to_tensor(pair_scene_xyz))
+ # #
+ # # if debug:
+ # # for scene_xyz in scene_xyzs:
+ # # show_pcs([scene_xyz[:, 0:3]], [np.tile(np.array([0, 1, 0], dtype=np.float), (scene_xyz.shape[0], 1))],
+ # # add_coordinate_frame=True)
+ # #
+ # # return scene_xyzs
+
+
+if __name__ == "__main__":
+ dataset = PairwiseCollisionDataset(urdf_pc_idx_file="/home/weiyu/data_drive/StructDiffusion/pairwise_collision_data/urdf_pc_idx.pkl",
+ collision_data_dir="/home/weiyu/data_drive/StructDiffusion/pairwise_collision_data",
+ debug=False)
+
+ for i in tqdm.tqdm(np.random.permutation(len(dataset))):
+ # print(i)
+ d = dataset[i]
+ # print(d["label"])
+
+ # dl = torch.utils.data.DataLoader(dataset, batch_size=32, num_workers=8)
+ # for b in tqdm.tqdm(dl):
+ # pass
diff --git a/src/StructDiffusion/data/semantic_arrangement.py b/src/StructDiffusion/data/semantic_arrangement.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ac4ef12d7ebdb768b7af5b58d8263d171bb91bb
--- /dev/null
+++ b/src/StructDiffusion/data/semantic_arrangement.py
@@ -0,0 +1,579 @@
+import copy
+import cv2
+import h5py
+import numpy as np
+import os
+import trimesh
+import torch
+from tqdm import tqdm
+import json
+import random
+
+from torch.utils.data import DataLoader
+
+# Local imports
+from StructDiffusion.utils.rearrangement import show_pcs, get_pts, combine_and_sample_xyzs
+from StructDiffusion.language.tokenizer import Tokenizer
+
+import StructDiffusion.utils.brain2.camera as cam
+import StructDiffusion.utils.brain2.image as img
+import StructDiffusion.utils.transformations as tra
+
+
+class SemanticArrangementDataset(torch.utils.data.Dataset):
+
+ def __init__(self, data_roots, index_roots, split, tokenizer,
+ max_num_target_objects=11, max_num_distractor_objects=5,
+ max_num_shape_parameters=7, max_num_rearrange_features=1, max_num_anchor_features=3,
+ num_pts=1024,
+ use_virtual_structure_frame=True, ignore_distractor_objects=True, ignore_rgb=True,
+ filter_num_moved_objects_range=None, shuffle_object_index=False,
+ data_augmentation=True, debug=False, **kwargs):
+ """
+
+ Note: setting filter_num_moved_objects_range=[k, k] and max_num_objects=k will create no padding for target objs
+
+ :param data_root:
+ :param split: train, valid, or test
+ :param shuffle_object_index: whether to shuffle the positions of target objects and other objects in the sequence
+ :param debug:
+ :param max_num_shape_parameters:
+ :param max_num_objects:
+ :param max_num_rearrange_features:
+ :param max_num_anchor_features:
+ :param num_pts:
+ :param use_stored_arrangement_indices:
+ :param kwargs:
+ """
+
+ self.use_virtual_structure_frame = use_virtual_structure_frame
+ self.ignore_distractor_objects = ignore_distractor_objects
+ self.ignore_rgb = ignore_rgb and not debug
+
+ self.num_pts = num_pts
+ self.debug = debug
+
+ self.max_num_objects = max_num_target_objects
+ self.max_num_other_objects = max_num_distractor_objects
+ self.max_num_shape_parameters = max_num_shape_parameters
+ self.max_num_rearrange_features = max_num_rearrange_features
+ self.max_num_anchor_features = max_num_anchor_features
+ self.shuffle_object_index = shuffle_object_index
+
+ # used to tokenize the language part
+ self.tokenizer = tokenizer
+
+ # retrieve data
+ self.data_roots = data_roots
+ self.arrangement_data = []
+ arrangement_steps = []
+ for ddx in range(len(data_roots)):
+ data_root = data_roots[ddx]
+ index_root = index_roots[ddx]
+ arrangement_indices_file = os.path.join(data_root, index_root, "{}_arrangement_indices_file_all.txt".format(split))
+ if os.path.exists(arrangement_indices_file):
+ with open(arrangement_indices_file, "r") as fh:
+ arrangement_steps.extend([(os.path.join(data_root, f[0]), f[1]) for f in eval(fh.readline().strip())])
+ else:
+ print("{} does not exist".format(arrangement_indices_file))
+ # only keep the goal, ignore the intermediate steps
+ for filename, step_t in arrangement_steps:
+ if step_t == 0:
+ if "data00026058" in filename or "data00011415" in filename or "data00026061" in filename or "data00700565" in filename:
+ continue
+ self.arrangement_data.append((filename, step_t))
+ # if specified, filter data
+ if filter_num_moved_objects_range is not None:
+ self.arrangement_data = self.filter_based_on_number_of_moved_objects(filter_num_moved_objects_range)
+ print("{} valid sequences".format(len(self.arrangement_data)))
+
+ # Data Aug
+ self.data_augmentation = data_augmentation
+ # additive noise
+ self.gp_rescale_factor_range = [12, 20]
+ self.gaussian_scale_range = [0., 0.003]
+ # multiplicative noise
+ self.gamma_shape = 1000.
+ self.gamma_scale = 0.001
+
+ def filter_based_on_number_of_moved_objects(self, filter_num_moved_objects_range):
+ assert len(list(filter_num_moved_objects_range)) == 2
+ min_num, max_num = filter_num_moved_objects_range
+ print("Remove scenes that have less than {} or more than {} objects being moved".format(min_num, max_num))
+ ok_data = []
+ for filename, step_t in self.arrangement_data:
+ h5 = h5py.File(filename, 'r')
+ moved_objs = h5['moved_objs'][()].split(',')
+ if min_num <= len(moved_objs) <= max_num:
+ ok_data.append((filename, step_t))
+ print("{} valid sequences left".format(len(ok_data)))
+ return ok_data
+
+ def get_data_idx(self, idx):
+ # Create the datum to return
+ file_idx = np.argmax(idx < self.file_to_count)
+ data = h5py.File(self.data_files[file_idx], 'r')
+ if file_idx > 0:
+ # for lang2sym, idx is always 0
+ idx = idx - self.file_to_count[file_idx - 1]
+ return data, idx, file_idx
+
+ def add_noise_to_depth(self, depth_img):
+ """ add depth noise """
+ multiplicative_noise = np.random.gamma(self.gamma_shape, self.gamma_scale)
+ depth_img = multiplicative_noise * depth_img
+ return depth_img
+
+ def add_noise_to_xyz(self, xyz_img, depth_img):
+ """ TODO: remove this code or at least celean it up"""
+ xyz_img = xyz_img.copy()
+ H, W, C = xyz_img.shape
+ gp_rescale_factor = np.random.randint(self.gp_rescale_factor_range[0],
+ self.gp_rescale_factor_range[1])
+ gp_scale = np.random.uniform(self.gaussian_scale_range[0],
+ self.gaussian_scale_range[1])
+ small_H, small_W = (np.array([H, W]) / gp_rescale_factor).astype(int)
+ additive_noise = np.random.normal(loc=0.0, scale=gp_scale, size=(small_H, small_W, C))
+ additive_noise = cv2.resize(additive_noise, (W, H), interpolation=cv2.INTER_CUBIC)
+ xyz_img[depth_img > 0, :] += additive_noise[depth_img > 0, :]
+ return xyz_img
+
+ def random_index(self):
+ return self[np.random.randint(len(self))]
+
+ def _get_rgb(self, h5, idx, ee=True):
+ RGB = "ee_rgb" if ee else "rgb"
+ rgb1 = img.PNGToNumpy(h5[RGB][idx])[:, :, :3] / 255. # remove alpha
+ return rgb1
+
+ def _get_depth(self, h5, idx, ee=True):
+ DEPTH = "ee_depth" if ee else "depth"
+
+ def _get_images(self, h5, idx, ee=True):
+ if ee:
+ RGB, DEPTH, SEG = "ee_rgb", "ee_depth", "ee_seg"
+ DMIN, DMAX = "ee_depth_min", "ee_depth_max"
+ else:
+ RGB, DEPTH, SEG = "rgb", "depth", "seg"
+ DMIN, DMAX = "depth_min", "depth_max"
+ dmin = h5[DMIN][idx]
+ dmax = h5[DMAX][idx]
+ rgb1 = img.PNGToNumpy(h5[RGB][idx])[:, :, :3] / 255. # remove alpha
+ depth1 = h5[DEPTH][idx] / 20000. * (dmax - dmin) + dmin
+ seg1 = img.PNGToNumpy(h5[SEG][idx])
+
+ valid1 = np.logical_and(depth1 > 0.1, depth1 < 2.)
+
+ # proj_matrix = h5['proj_matrix'][()]
+ camera = cam.get_camera_from_h5(h5)
+ if self.data_augmentation:
+ depth1 = self.add_noise_to_depth(depth1)
+
+ xyz1 = cam.compute_xyz(depth1, camera)
+ if self.data_augmentation:
+ xyz1 = self.add_noise_to_xyz(xyz1, depth1)
+
+ # Transform the point cloud
+ # Here it is...
+ # CAM_POSE = "ee_cam_pose" if ee else "cam_pose"
+ CAM_POSE = "ee_camera_view" if ee else "camera_view"
+ cam_pose = h5[CAM_POSE][idx]
+ if ee:
+ # ee_camera_view has 0s for x, y, z
+ cam_pos = h5["ee_cam_pose"][:][:3, 3]
+ cam_pose[:3, 3] = cam_pos
+
+ # Get transformed point cloud
+ h, w, d = xyz1.shape
+ xyz1 = xyz1.reshape(h * w, -1)
+ xyz1 = trimesh.transform_points(xyz1, cam_pose)
+ xyz1 = xyz1.reshape(h, w, -1)
+
+ scene1 = rgb1, depth1, seg1, valid1, xyz1
+
+ return scene1
+
+ def __len__(self):
+ return len(self.arrangement_data)
+
+ def _get_ids(self, h5):
+ """
+ get object ids
+
+ @param h5:
+ @return:
+ """
+ ids = {}
+ for k in h5.keys():
+ if k.startswith("id_"):
+ ids[k[3:]] = h5[k][()]
+ return ids
+
+ def get_positive_ratio(self):
+ num_pos = 0
+ for d in self.arrangement_data:
+ filename, step_t = d
+ if step_t == 0:
+ num_pos += 1
+ return (len(self.arrangement_data) - num_pos) * 1.0 / num_pos
+
+ def get_object_position_vocab_sizes(self):
+ return self.tokenizer.get_object_position_vocab_sizes()
+
+ def get_vocab_size(self):
+ return self.tokenizer.get_vocab_size()
+
+ def get_data_index(self, idx):
+ filename = self.arrangement_data[idx]
+ return filename
+
+ def get_raw_data(self, idx, inference_mode=False, shuffle_object_index=False):
+ """
+
+ :param idx:
+ :param inference_mode:
+ :param shuffle_object_index: used to test different orders of objects
+ :return:
+ """
+
+ filename, _ = self.arrangement_data[idx]
+
+ h5 = h5py.File(filename, 'r')
+ ids = self._get_ids(h5)
+ all_objs = sorted([o for o in ids.keys() if "object_" in o])
+ goal_specification = json.loads(str(np.array(h5["goal_specification"])))
+ num_rearrange_objs = len(goal_specification["rearrange"]["objects"])
+ num_other_objs = len(goal_specification["anchor"]["objects"] + goal_specification["distract"]["objects"])
+ assert len(all_objs) == num_rearrange_objs + num_other_objs, "{}, {}".format(len(all_objs), num_rearrange_objs + num_other_objs)
+ assert num_rearrange_objs <= self.max_num_objects
+ assert num_other_objs <= self.max_num_other_objects
+
+ # important: only using the last step
+ step_t = num_rearrange_objs
+
+ target_objs = all_objs[:num_rearrange_objs]
+ other_objs = all_objs[num_rearrange_objs:]
+
+ structure_parameters = goal_specification["shape"]
+
+ # Important: ensure the order is correct
+ if structure_parameters["type"] == "circle" or structure_parameters["type"] == "line":
+ target_objs = target_objs[::-1]
+ elif structure_parameters["type"] == "tower" or structure_parameters["type"] == "dinner":
+ target_objs = target_objs
+ else:
+ raise KeyError("{} structure is not recognized".format(structure_parameters["type"]))
+ all_objs = target_objs + other_objs
+
+ ###################################
+ # getting scene images and point clouds
+ scene = self._get_images(h5, step_t, ee=True)
+ rgb, depth, seg, valid, xyz = scene
+ if inference_mode:
+ initial_scene = scene
+
+ # getting object point clouds
+ obj_pcs = []
+ obj_pad_mask = []
+ current_pc_poses = []
+ other_obj_pcs = []
+ other_obj_pad_mask = []
+ for obj in all_objs:
+ obj_mask = np.logical_and(seg == ids[obj], valid)
+ if np.sum(obj_mask) <= 0:
+ raise Exception
+ ok, obj_xyz, obj_rgb, _ = get_pts(xyz, rgb, obj_mask, num_pts=self.num_pts)
+ if not ok:
+ raise Exception
+
+ if obj in target_objs:
+ if self.ignore_rgb:
+ obj_pcs.append(obj_xyz)
+ else:
+ obj_pcs.append(torch.concat([obj_xyz, obj_rgb], dim=-1))
+ obj_pad_mask.append(0)
+ pc_pose = np.eye(4)
+ pc_pose[:3, 3] = torch.mean(obj_xyz, dim=0).numpy()
+ current_pc_poses.append(pc_pose)
+ elif obj in other_objs:
+ if self.ignore_rgb:
+ other_obj_pcs.append(obj_xyz)
+ else:
+ other_obj_pcs.append(torch.concat([obj_xyz, obj_rgb], dim=-1))
+ other_obj_pad_mask.append(0)
+ else:
+ raise Exception
+
+ ###################################
+ # computes goal positions for objects
+ # Important: because of the noises we added to point clouds, the rearranged point clouds will not be perfect
+ if self.use_virtual_structure_frame:
+ goal_structure_pose = tra.euler_matrix(structure_parameters["rotation"][0], structure_parameters["rotation"][1],
+ structure_parameters["rotation"][2])
+ goal_structure_pose[:3, 3] = [structure_parameters["position"][0], structure_parameters["position"][1],
+ structure_parameters["position"][2]]
+ goal_structure_pose_inv = np.linalg.inv(goal_structure_pose)
+
+ goal_obj_poses = []
+ current_obj_poses = []
+ goal_pc_poses = []
+ for obj, current_pc_pose in zip(target_objs, current_pc_poses):
+ goal_pose = h5[obj][0]
+ current_pose = h5[obj][step_t]
+ if inference_mode:
+ goal_obj_poses.append(goal_pose)
+ current_obj_poses.append(current_pose)
+
+ goal_pc_pose = goal_pose @ np.linalg.inv(current_pose) @ current_pc_pose
+ if self.use_virtual_structure_frame:
+ goal_pc_pose = goal_structure_pose_inv @ goal_pc_pose
+ goal_pc_poses.append(goal_pc_pose)
+
+ # transform current object point cloud to the goal point cloud in the world frame
+ if self.debug:
+ new_obj_pcs = [copy.deepcopy(pc.numpy()) for pc in obj_pcs]
+ for i, obj_pc in enumerate(new_obj_pcs):
+
+ current_pc_pose = current_pc_poses[i]
+ goal_pc_pose = goal_pc_poses[i]
+ if self.use_virtual_structure_frame:
+ goal_pc_pose = goal_structure_pose @ goal_pc_pose
+ print("current pc pose", current_pc_pose)
+ print("goal pc pose", goal_pc_pose)
+
+ goal_pc_transform = goal_pc_pose @ np.linalg.inv(current_pc_pose)
+ print("transform", goal_pc_transform)
+ new_obj_pc = copy.deepcopy(obj_pc)
+ new_obj_pc[:, :3] = trimesh.transform_points(obj_pc[:, :3], goal_pc_transform)
+ print(new_obj_pc.shape)
+
+ # visualize rearrangement sequence (new_obj_xyzs), the current object before moving (obj_xyz), and other objects
+ new_obj_pcs[i] = new_obj_pc
+ new_obj_pcs[i][:, 3:] = np.tile(np.array([1, 0, 0], dtype=np.float), (new_obj_pc.shape[0], 1))
+ new_obj_rgb_current = np.tile(np.array([0, 1, 0], dtype=np.float), (new_obj_pc.shape[0], 1))
+ show_pcs([pc[:, :3] for pc in new_obj_pcs] + [pc[:, :3] for pc in other_obj_pcs] + [obj_pc[:, :3]],
+ [pc[:, 3:] for pc in new_obj_pcs] + [pc[:, 3:] for pc in other_obj_pcs] + [new_obj_rgb_current],
+ add_coordinate_frame=True)
+ show_pcs([pc[:, :3] for pc in new_obj_pcs], [pc[:, 3:] for pc in new_obj_pcs], add_coordinate_frame=True)
+
+ # pad data
+ for i in range(self.max_num_objects - len(target_objs)):
+ obj_pcs.append(torch.zeros_like(obj_pcs[0], dtype=torch.float32))
+ obj_pad_mask.append(1)
+ for i in range(self.max_num_other_objects - len(other_objs)):
+ other_obj_pcs.append(torch.zeros_like(obj_pcs[0], dtype=torch.float32))
+ other_obj_pad_mask.append(1)
+
+ ###################################
+ # preparing sentence
+ sentence = []
+ sentence_pad_mask = []
+
+ # structure parameters
+ # 5 parameters
+ structure_parameters = goal_specification["shape"]
+ if structure_parameters["type"] == "circle" or structure_parameters["type"] == "line":
+ sentence.append((structure_parameters["type"], "shape"))
+ sentence.append((structure_parameters["rotation"][2], "rotation"))
+ sentence.append((structure_parameters["position"][0], "position_x"))
+ sentence.append((structure_parameters["position"][1], "position_y"))
+ if structure_parameters["type"] == "circle":
+ sentence.append((structure_parameters["radius"], "radius"))
+ elif structure_parameters["type"] == "line":
+ sentence.append((structure_parameters["length"] / 2.0, "radius"))
+ for _ in range(5):
+ sentence_pad_mask.append(0)
+ else:
+ sentence.append((structure_parameters["type"], "shape"))
+ sentence.append((structure_parameters["rotation"][2], "rotation"))
+ sentence.append((structure_parameters["position"][0], "position_x"))
+ sentence.append((structure_parameters["position"][1], "position_y"))
+ for _ in range(4):
+ sentence_pad_mask.append(0)
+ sentence.append(("PAD", None))
+ sentence_pad_mask.append(1)
+
+ ###################################
+ # paddings
+ for i in range(self.max_num_objects - len(target_objs)):
+ goal_pc_poses.append(np.eye(4))
+
+ ###################################
+ if self.debug:
+ print("---")
+ print("all objects:", all_objs)
+ print("target objects:", target_objs)
+ print("other objects:", other_objs)
+ print("goal specification:", goal_specification)
+ print("sentence:", sentence)
+ show_pcs([pc[:, :3] for pc in obj_pcs + other_obj_pcs], [pc[:, 3:] for pc in obj_pcs + other_obj_pcs], add_coordinate_frame=True)
+
+ assert len(obj_pcs) == len(goal_pc_poses)
+ ###################################
+
+ # shuffle the position of objects
+ if shuffle_object_index:
+ shuffle_target_object_indices = list(range(len(target_objs)))
+ random.shuffle(shuffle_target_object_indices)
+ shuffle_object_indices = shuffle_target_object_indices + list(range(len(target_objs), self.max_num_objects))
+ obj_pcs = [obj_pcs[i] for i in shuffle_object_indices]
+ goal_pc_poses = [goal_pc_poses[i] for i in shuffle_object_indices]
+ if inference_mode:
+ goal_obj_poses = [goal_obj_poses[i] for i in shuffle_object_indices]
+ current_obj_poses = [current_obj_poses[i] for i in shuffle_object_indices]
+ target_objs = [target_objs[i] for i in shuffle_target_object_indices]
+ current_pc_poses = [current_pc_poses[i] for i in shuffle_object_indices]
+
+ ###################################
+ if self.use_virtual_structure_frame:
+ if self.ignore_distractor_objects:
+ # language, structure virtual frame, target objects
+ pcs = obj_pcs
+ type_index = [0] * self.max_num_shape_parameters + [2] + [3] * self.max_num_objects
+ position_index = list(range(self.max_num_shape_parameters)) + [0] + list(range(self.max_num_objects))
+ pad_mask = sentence_pad_mask + [0] + obj_pad_mask
+ else:
+ # language, distractor objects, structure virtual frame, target objects
+ pcs = other_obj_pcs + obj_pcs
+ type_index = [0] * self.max_num_shape_parameters + [1] * self.max_num_other_objects + [2] + [3] * self.max_num_objects
+ position_index = list(range(self.max_num_shape_parameters)) + list(range(self.max_num_other_objects)) + [0] + list(range(self.max_num_objects))
+ pad_mask = sentence_pad_mask + other_obj_pad_mask + [0] + obj_pad_mask
+ goal_poses = [goal_structure_pose] + goal_pc_poses
+ else:
+ if self.ignore_distractor_objects:
+ # language, target objects
+ pcs = obj_pcs
+ type_index = [0] * self.max_num_shape_parameters + [3] * self.max_num_objects
+ position_index = list(range(self.max_num_shape_parameters)) + list(range(self.max_num_objects))
+ pad_mask = sentence_pad_mask + obj_pad_mask
+ else:
+ # language, distractor objects, target objects
+ pcs = other_obj_pcs + obj_pcs
+ type_index = [0] * self.max_num_shape_parameters + [1] * self.max_num_other_objects + [3] * self.max_num_objects
+ position_index = list(range(self.max_num_shape_parameters)) + list(range(self.max_num_other_objects)) + list(range(self.max_num_objects))
+ pad_mask = sentence_pad_mask + other_obj_pad_mask + obj_pad_mask
+ goal_poses = goal_pc_poses
+
+ datum = {
+ "pcs": pcs,
+ "sentence": sentence,
+ "goal_poses": goal_poses,
+ "type_index": type_index,
+ "position_index": position_index,
+ "pad_mask": pad_mask,
+ "t": step_t,
+ "filename": filename
+ }
+
+ if inference_mode:
+ datum["rgb"] = rgb
+ datum["goal_obj_poses"] = goal_obj_poses
+ datum["current_obj_poses"] = current_obj_poses
+ datum["target_objs"] = target_objs
+ datum["initial_scene"] = initial_scene
+ datum["ids"] = ids
+ datum["goal_specification"] = goal_specification
+ datum["current_pc_poses"] = current_pc_poses
+
+ return datum
+
+ @staticmethod
+ def convert_to_tensors(datum, tokenizer):
+ tensors = {
+ "pcs": torch.stack(datum["pcs"], dim=0),
+ "sentence": torch.LongTensor(np.array([tokenizer.tokenize(*i) for i in datum["sentence"]])),
+ "goal_poses": torch.FloatTensor(np.array(datum["goal_poses"])),
+ "type_index": torch.LongTensor(np.array(datum["type_index"])),
+ "position_index": torch.LongTensor(np.array(datum["position_index"])),
+ "pad_mask": torch.LongTensor(np.array(datum["pad_mask"])),
+ "t": datum["t"],
+ "filename": datum["filename"]
+ }
+ return tensors
+
+ def __getitem__(self, idx):
+
+ datum = self.convert_to_tensors(self.get_raw_data(idx, shuffle_object_index=self.shuffle_object_index),
+ self.tokenizer)
+
+ return datum
+
+ def single_datum_to_batch(self, x, num_samples, device, inference_mode=True):
+ tensor_x = {}
+
+ tensor_x["pcs"] = x["pcs"].to(device)[None, :, :, :].repeat(num_samples, 1, 1, 1)
+ tensor_x["sentence"] = x["sentence"].to(device)[None, :].repeat(num_samples, 1)
+ if not inference_mode:
+ tensor_x["goal_poses"] = x["goal_poses"].to(device)[None, :, :, :].repeat(num_samples, 1, 1, 1)
+
+ tensor_x["type_index"] = x["type_index"].to(device)[None, :].repeat(num_samples, 1)
+ tensor_x["position_index"] = x["position_index"].to(device)[None, :].repeat(num_samples, 1)
+ tensor_x["pad_mask"] = x["pad_mask"].to(device)[None, :].repeat(num_samples, 1)
+
+ return tensor_x
+
+
+def compute_min_max(dataloader):
+
+ # tensor([-0.3557, -0.3847, 0.0000, -1.0000, -1.0000, -0.4759, -1.0000, -1.0000,
+ # -0.9079, -0.8668, -0.9105, -0.4186])
+ # tensor([0.3915, 0.3494, 0.3267, 1.0000, 1.0000, 0.8961, 1.0000, 1.0000, 0.8194,
+ # 0.4787, 0.6421, 1.0000])
+ # tensor([0.0918, -0.3758, 0.0000, -1.0000, -1.0000, 0.0000, -1.0000, -1.0000,
+ # -0.0000, 0.0000, 0.0000, 1.0000])
+ # tensor([0.9199, 0.3710, 0.0000, 1.0000, 1.0000, 0.0000, 1.0000, 1.0000, -0.0000,
+ # 0.0000, 0.0000, 1.0000])
+
+ min_value = torch.ones(16) * 10000
+ max_value = torch.ones(16) * -10000
+ for d in tqdm(dataloader):
+ goal_poses = d["goal_poses"]
+ goal_poses = goal_poses.reshape(-1, 16)
+ current_max, _ = torch.max(goal_poses, dim=0)
+ current_min, _ = torch.min(goal_poses, dim=0)
+ max_value[max_value < current_max] = current_max[max_value < current_max]
+ max_value[max_value > current_min] = current_min[max_value > current_min]
+ print(f"{min_value} - {max_value}")
+
+
+if __name__ == "__main__":
+
+ tokenizer = Tokenizer("/home/weiyu/data_drive/data_new_objects/type_vocabs_coarse.json")
+
+ data_roots = []
+ index_roots = []
+ for shape, index in [("circle", "index_10k"), ("line", "index_10k"), ("stacking", "index_10k"), ("dinner", "index_10k")]:
+ data_roots.append("/home/weiyu/data_drive/data_new_objects/examples_{}_new_objects/result".format(shape))
+ index_roots.append(index)
+
+ dataset = SemanticArrangementDataset(data_roots=data_roots,
+ index_roots=index_roots,
+ split="valid", tokenizer=tokenizer,
+ max_num_target_objects=7,
+ max_num_distractor_objects=5,
+ max_num_shape_parameters=5,
+ max_num_rearrange_features=0,
+ max_num_anchor_features=0,
+ num_pts=1024,
+ use_virtual_structure_frame=True,
+ ignore_distractor_objects=True,
+ ignore_rgb=True,
+ filter_num_moved_objects_range=None, # [5, 5]
+ data_augmentation=False,
+ shuffle_object_index=False,
+ debug=False)
+
+ # print(len(dataset))
+ # for d in dataset:
+ # print("\n\n" + "="*100)
+
+ dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=8)
+ for i, d in enumerate(tqdm(dataloader)):
+ pass
+ # for k in d:
+ # if isinstance(d[k], torch.Tensor):
+ # print("--size", k, d[k].shape)
+ # for k in d:
+ # print(k, d[k])
+ #
+ # input("next?")
\ No newline at end of file
diff --git a/src/StructDiffusion/data/semantic_arrangement_demo.py b/src/StructDiffusion/data/semantic_arrangement_demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..653ccfd758aefc73071dfb36ba78ea46774ac7b5
--- /dev/null
+++ b/src/StructDiffusion/data/semantic_arrangement_demo.py
@@ -0,0 +1,563 @@
+import copy
+import cv2
+import h5py
+import numpy as np
+import os
+import trimesh
+import torch
+from tqdm import tqdm
+import json
+import random
+
+from torch.utils.data import DataLoader
+
+# Local imports
+from StructDiffusion.utils.rearrangement import show_pcs, get_pts, combine_and_sample_xyzs
+from StructDiffusion.language.tokenizer import Tokenizer
+
+import StructDiffusion.utils.brain2.camera as cam
+import StructDiffusion.utils.brain2.image as img
+import StructDiffusion.utils.transformations as tra
+
+
+class SemanticArrangementDataset(torch.utils.data.Dataset):
+
+ def __init__(self, data_root, tokenizer,
+ max_num_target_objects=11, max_num_distractor_objects=5,
+ max_num_shape_parameters=7, max_num_rearrange_features=1, max_num_anchor_features=3,
+ num_pts=1024,
+ use_virtual_structure_frame=True, ignore_distractor_objects=True, ignore_rgb=True,
+ filter_num_moved_objects_range=None, shuffle_object_index=False,
+ data_augmentation=True, debug=False, **kwargs):
+ """
+
+ Note: setting filter_num_moved_objects_range=[k, k] and max_num_objects=k will create no padding for target objs
+
+ :param data_root:
+ :param split: train, valid, or test
+ :param shuffle_object_index: whether to shuffle the positions of target objects and other objects in the sequence
+ :param debug:
+ :param max_num_shape_parameters:
+ :param max_num_objects:
+ :param max_num_rearrange_features:
+ :param max_num_anchor_features:
+ :param num_pts:
+ :param use_stored_arrangement_indices:
+ :param kwargs:
+ """
+
+ self.use_virtual_structure_frame = use_virtual_structure_frame
+ self.ignore_distractor_objects = ignore_distractor_objects
+ self.ignore_rgb = ignore_rgb and not debug
+
+ self.num_pts = num_pts
+ self.debug = debug
+
+ self.max_num_objects = max_num_target_objects
+ self.max_num_other_objects = max_num_distractor_objects
+ self.max_num_shape_parameters = max_num_shape_parameters
+ self.max_num_rearrange_features = max_num_rearrange_features
+ self.max_num_anchor_features = max_num_anchor_features
+ self.shuffle_object_index = shuffle_object_index
+
+ # used to tokenize the language part
+ self.tokenizer = tokenizer
+
+ # retrieve data
+ self.data_root = data_root
+ self.arrangement_data = []
+ for filename in os.listdir(data_root):
+ if ".h5" in filename:
+ self.arrangement_data.append((os.path.join(data_root, filename), 0))
+ print("{} valid sequences".format(len(self.arrangement_data)))
+
+ # Data Aug
+ self.data_augmentation = data_augmentation
+ # additive noise
+ self.gp_rescale_factor_range = [12, 20]
+ self.gaussian_scale_range = [0., 0.003]
+ # multiplicative noise
+ self.gamma_shape = 1000.
+ self.gamma_scale = 0.001
+
+ def filter_based_on_number_of_moved_objects(self, filter_num_moved_objects_range):
+ assert len(list(filter_num_moved_objects_range)) == 2
+ min_num, max_num = filter_num_moved_objects_range
+ print("Remove scenes that have less than {} or more than {} objects being moved".format(min_num, max_num))
+ ok_data = []
+ for filename, step_t in self.arrangement_data:
+ h5 = h5py.File(filename, 'r')
+ moved_objs = h5['moved_objs'][()].split(',')
+ if min_num <= len(moved_objs) <= max_num:
+ ok_data.append((filename, step_t))
+ print("{} valid sequences left".format(len(ok_data)))
+ return ok_data
+
+ def get_data_idx(self, idx):
+ # Create the datum to return
+ file_idx = np.argmax(idx < self.file_to_count)
+ data = h5py.File(self.data_files[file_idx], 'r')
+ if file_idx > 0:
+ # for lang2sym, idx is always 0
+ idx = idx - self.file_to_count[file_idx - 1]
+ return data, idx, file_idx
+
+ def add_noise_to_depth(self, depth_img):
+ """ add depth noise """
+ multiplicative_noise = np.random.gamma(self.gamma_shape, self.gamma_scale)
+ depth_img = multiplicative_noise * depth_img
+ return depth_img
+
+ def add_noise_to_xyz(self, xyz_img, depth_img):
+ """ TODO: remove this code or at least celean it up"""
+ xyz_img = xyz_img.copy()
+ H, W, C = xyz_img.shape
+ gp_rescale_factor = np.random.randint(self.gp_rescale_factor_range[0],
+ self.gp_rescale_factor_range[1])
+ gp_scale = np.random.uniform(self.gaussian_scale_range[0],
+ self.gaussian_scale_range[1])
+ small_H, small_W = (np.array([H, W]) / gp_rescale_factor).astype(int)
+ additive_noise = np.random.normal(loc=0.0, scale=gp_scale, size=(small_H, small_W, C))
+ additive_noise = cv2.resize(additive_noise, (W, H), interpolation=cv2.INTER_CUBIC)
+ xyz_img[depth_img > 0, :] += additive_noise[depth_img > 0, :]
+ return xyz_img
+
+ def random_index(self):
+ return self[np.random.randint(len(self))]
+
+ def _get_rgb(self, h5, idx, ee=True):
+ RGB = "ee_rgb" if ee else "rgb"
+ rgb1 = img.PNGToNumpy(h5[RGB][idx])[:, :, :3] / 255. # remove alpha
+ return rgb1
+
+ def _get_depth(self, h5, idx, ee=True):
+ DEPTH = "ee_depth" if ee else "depth"
+
+ def _get_images(self, h5, idx, ee=True):
+ if ee:
+ RGB, DEPTH, SEG = "ee_rgb", "ee_depth", "ee_seg"
+ DMIN, DMAX = "ee_depth_min", "ee_depth_max"
+ else:
+ RGB, DEPTH, SEG = "rgb", "depth", "seg"
+ DMIN, DMAX = "depth_min", "depth_max"
+ dmin = h5[DMIN][idx]
+ dmax = h5[DMAX][idx]
+ rgb1 = img.PNGToNumpy(h5[RGB][idx])[:, :, :3] / 255. # remove alpha
+ depth1 = h5[DEPTH][idx] / 20000. * (dmax - dmin) + dmin
+ seg1 = img.PNGToNumpy(h5[SEG][idx])
+
+ valid1 = np.logical_and(depth1 > 0.1, depth1 < 2.)
+
+ # proj_matrix = h5['proj_matrix'][()]
+ camera = cam.get_camera_from_h5(h5)
+ if self.data_augmentation:
+ depth1 = self.add_noise_to_depth(depth1)
+
+ xyz1 = cam.compute_xyz(depth1, camera)
+ if self.data_augmentation:
+ xyz1 = self.add_noise_to_xyz(xyz1, depth1)
+
+ # Transform the point cloud
+ # Here it is...
+ # CAM_POSE = "ee_cam_pose" if ee else "cam_pose"
+ CAM_POSE = "ee_camera_view" if ee else "camera_view"
+ cam_pose = h5[CAM_POSE][idx]
+ if ee:
+ # ee_camera_view has 0s for x, y, z
+ cam_pos = h5["ee_cam_pose"][:][:3, 3]
+ cam_pose[:3, 3] = cam_pos
+
+ # Get transformed point cloud
+ h, w, d = xyz1.shape
+ xyz1 = xyz1.reshape(h * w, -1)
+ xyz1 = trimesh.transform_points(xyz1, cam_pose)
+ xyz1 = xyz1.reshape(h, w, -1)
+
+ scene1 = rgb1, depth1, seg1, valid1, xyz1
+
+ return scene1
+
+ def __len__(self):
+ return len(self.arrangement_data)
+
+ def _get_ids(self, h5):
+ """
+ get object ids
+
+ @param h5:
+ @return:
+ """
+ ids = {}
+ for k in h5.keys():
+ if k.startswith("id_"):
+ ids[k[3:]] = h5[k][()]
+ return ids
+
+ def get_positive_ratio(self):
+ num_pos = 0
+ for d in self.arrangement_data:
+ filename, step_t = d
+ if step_t == 0:
+ num_pos += 1
+ return (len(self.arrangement_data) - num_pos) * 1.0 / num_pos
+
+ def get_object_position_vocab_sizes(self):
+ return self.tokenizer.get_object_position_vocab_sizes()
+
+ def get_vocab_size(self):
+ return self.tokenizer.get_vocab_size()
+
+ def get_data_index(self, idx):
+ filename = self.arrangement_data[idx]
+ return filename
+
+ def get_raw_data(self, idx, inference_mode=False, shuffle_object_index=False):
+ """
+
+ :param idx:
+ :param inference_mode:
+ :param shuffle_object_index: used to test different orders of objects
+ :return:
+ """
+
+ filename, _ = self.arrangement_data[idx]
+
+ h5 = h5py.File(filename, 'r')
+ ids = self._get_ids(h5)
+ all_objs = sorted([o for o in ids.keys() if "object_" in o])
+ goal_specification = json.loads(str(np.array(h5["goal_specification"])))
+ num_rearrange_objs = len(goal_specification["rearrange"]["objects"])
+ num_other_objs = len(goal_specification["anchor"]["objects"] + goal_specification["distract"]["objects"])
+ assert len(all_objs) == num_rearrange_objs + num_other_objs, "{}, {}".format(len(all_objs), num_rearrange_objs + num_other_objs)
+ assert num_rearrange_objs <= self.max_num_objects
+ assert num_other_objs <= self.max_num_other_objects
+
+ # important: only using the last step
+ step_t = num_rearrange_objs
+
+ target_objs = all_objs[:num_rearrange_objs]
+ other_objs = all_objs[num_rearrange_objs:]
+
+ structure_parameters = goal_specification["shape"]
+
+ # Important: ensure the order is correct
+ if structure_parameters["type"] == "circle" or structure_parameters["type"] == "line":
+ target_objs = target_objs[::-1]
+ elif structure_parameters["type"] == "tower" or structure_parameters["type"] == "dinner":
+ target_objs = target_objs
+ else:
+ raise KeyError("{} structure is not recognized".format(structure_parameters["type"]))
+ all_objs = target_objs + other_objs
+
+ ###################################
+ # getting scene images and point clouds
+ scene = self._get_images(h5, step_t, ee=True)
+ rgb, depth, seg, valid, xyz = scene
+ if inference_mode:
+ initial_scene = scene
+
+ # getting object point clouds
+ obj_pcs = []
+ obj_pad_mask = []
+ current_pc_poses = []
+ other_obj_pcs = []
+ other_obj_pad_mask = []
+ for obj in all_objs:
+ obj_mask = np.logical_and(seg == ids[obj], valid)
+ if np.sum(obj_mask) <= 0:
+ raise Exception
+ ok, obj_xyz, obj_rgb, _ = get_pts(xyz, rgb, obj_mask, num_pts=self.num_pts)
+ if not ok:
+ raise Exception
+
+ if obj in target_objs:
+ if self.ignore_rgb:
+ obj_pcs.append(obj_xyz)
+ else:
+ obj_pcs.append(torch.concat([obj_xyz, obj_rgb], dim=-1))
+ obj_pad_mask.append(0)
+ pc_pose = np.eye(4)
+ pc_pose[:3, 3] = torch.mean(obj_xyz, dim=0).numpy()
+ current_pc_poses.append(pc_pose)
+ elif obj in other_objs:
+ if self.ignore_rgb:
+ other_obj_pcs.append(obj_xyz)
+ else:
+ other_obj_pcs.append(torch.concat([obj_xyz, obj_rgb], dim=-1))
+ other_obj_pad_mask.append(0)
+ else:
+ raise Exception
+
+ ###################################
+ # computes goal positions for objects
+ # Important: because of the noises we added to point clouds, the rearranged point clouds will not be perfect
+ if self.use_virtual_structure_frame:
+ goal_structure_pose = tra.euler_matrix(structure_parameters["rotation"][0], structure_parameters["rotation"][1],
+ structure_parameters["rotation"][2])
+ goal_structure_pose[:3, 3] = [structure_parameters["position"][0], structure_parameters["position"][1],
+ structure_parameters["position"][2]]
+ goal_structure_pose_inv = np.linalg.inv(goal_structure_pose)
+
+ goal_obj_poses = []
+ current_obj_poses = []
+ goal_pc_poses = []
+ for obj, current_pc_pose in zip(target_objs, current_pc_poses):
+ goal_pose = h5[obj][0]
+ current_pose = h5[obj][step_t]
+ if inference_mode:
+ goal_obj_poses.append(goal_pose)
+ current_obj_poses.append(current_pose)
+
+ goal_pc_pose = goal_pose @ np.linalg.inv(current_pose) @ current_pc_pose
+ if self.use_virtual_structure_frame:
+ goal_pc_pose = goal_structure_pose_inv @ goal_pc_pose
+ goal_pc_poses.append(goal_pc_pose)
+
+ # transform current object point cloud to the goal point cloud in the world frame
+ if self.debug:
+ new_obj_pcs = [copy.deepcopy(pc.numpy()) for pc in obj_pcs]
+ for i, obj_pc in enumerate(new_obj_pcs):
+
+ current_pc_pose = current_pc_poses[i]
+ goal_pc_pose = goal_pc_poses[i]
+ if self.use_virtual_structure_frame:
+ goal_pc_pose = goal_structure_pose @ goal_pc_pose
+ print("current pc pose", current_pc_pose)
+ print("goal pc pose", goal_pc_pose)
+
+ goal_pc_transform = goal_pc_pose @ np.linalg.inv(current_pc_pose)
+ print("transform", goal_pc_transform)
+ new_obj_pc = copy.deepcopy(obj_pc)
+ new_obj_pc[:, :3] = trimesh.transform_points(obj_pc[:, :3], goal_pc_transform)
+ print(new_obj_pc.shape)
+
+ # visualize rearrangement sequence (new_obj_xyzs), the current object before moving (obj_xyz), and other objects
+ new_obj_pcs[i] = new_obj_pc
+ new_obj_pcs[i][:, 3:] = np.tile(np.array([1, 0, 0], dtype=np.float), (new_obj_pc.shape[0], 1))
+ new_obj_rgb_current = np.tile(np.array([0, 1, 0], dtype=np.float), (new_obj_pc.shape[0], 1))
+ show_pcs([pc[:, :3] for pc in new_obj_pcs] + [pc[:, :3] for pc in other_obj_pcs] + [obj_pc[:, :3]],
+ [pc[:, 3:] for pc in new_obj_pcs] + [pc[:, 3:] for pc in other_obj_pcs] + [new_obj_rgb_current],
+ add_coordinate_frame=True)
+ show_pcs([pc[:, :3] for pc in new_obj_pcs], [pc[:, 3:] for pc in new_obj_pcs], add_coordinate_frame=True)
+
+ # pad data
+ for i in range(self.max_num_objects - len(target_objs)):
+ obj_pcs.append(torch.zeros_like(obj_pcs[0], dtype=torch.float32))
+ obj_pad_mask.append(1)
+ for i in range(self.max_num_other_objects - len(other_objs)):
+ other_obj_pcs.append(torch.zeros_like(obj_pcs[0], dtype=torch.float32))
+ other_obj_pad_mask.append(1)
+
+ ###################################
+ # preparing sentence
+ sentence = []
+ sentence_pad_mask = []
+
+ # structure parameters
+ # 5 parameters
+ structure_parameters = goal_specification["shape"]
+ if structure_parameters["type"] == "circle" or structure_parameters["type"] == "line":
+ sentence.append((structure_parameters["type"], "shape"))
+ sentence.append((structure_parameters["rotation"][2], "rotation"))
+ sentence.append((structure_parameters["position"][0], "position_x"))
+ sentence.append((structure_parameters["position"][1], "position_y"))
+ if structure_parameters["type"] == "circle":
+ sentence.append((structure_parameters["radius"], "radius"))
+ elif structure_parameters["type"] == "line":
+ sentence.append((structure_parameters["length"] / 2.0, "radius"))
+ for _ in range(5):
+ sentence_pad_mask.append(0)
+ else:
+ sentence.append((structure_parameters["type"], "shape"))
+ sentence.append((structure_parameters["rotation"][2], "rotation"))
+ sentence.append((structure_parameters["position"][0], "position_x"))
+ sentence.append((structure_parameters["position"][1], "position_y"))
+ for _ in range(4):
+ sentence_pad_mask.append(0)
+ sentence.append(("PAD", None))
+ sentence_pad_mask.append(1)
+
+ ###################################
+ # paddings
+ for i in range(self.max_num_objects - len(target_objs)):
+ goal_pc_poses.append(np.eye(4))
+
+ ###################################
+ if self.debug:
+ print("---")
+ print("all objects:", all_objs)
+ print("target objects:", target_objs)
+ print("other objects:", other_objs)
+ print("goal specification:", goal_specification)
+ print("sentence:", sentence)
+ show_pcs([pc[:, :3] for pc in obj_pcs + other_obj_pcs], [pc[:, 3:] for pc in obj_pcs + other_obj_pcs], add_coordinate_frame=True)
+
+ assert len(obj_pcs) == len(goal_pc_poses)
+ ###################################
+
+ # shuffle the position of objects
+ if shuffle_object_index:
+ shuffle_target_object_indices = list(range(len(target_objs)))
+ random.shuffle(shuffle_target_object_indices)
+ shuffle_object_indices = shuffle_target_object_indices + list(range(len(target_objs), self.max_num_objects))
+ obj_pcs = [obj_pcs[i] for i in shuffle_object_indices]
+ goal_pc_poses = [goal_pc_poses[i] for i in shuffle_object_indices]
+ if inference_mode:
+ goal_obj_poses = [goal_obj_poses[i] for i in shuffle_object_indices]
+ current_obj_poses = [current_obj_poses[i] for i in shuffle_object_indices]
+ target_objs = [target_objs[i] for i in shuffle_target_object_indices]
+ current_pc_poses = [current_pc_poses[i] for i in shuffle_object_indices]
+
+ ###################################
+ if self.use_virtual_structure_frame:
+ if self.ignore_distractor_objects:
+ # language, structure virtual frame, target objects
+ pcs = obj_pcs
+ type_index = [0] * self.max_num_shape_parameters + [2] + [3] * self.max_num_objects
+ position_index = list(range(self.max_num_shape_parameters)) + [0] + list(range(self.max_num_objects))
+ pad_mask = sentence_pad_mask + [0] + obj_pad_mask
+ else:
+ # language, distractor objects, structure virtual frame, target objects
+ pcs = other_obj_pcs + obj_pcs
+ type_index = [0] * self.max_num_shape_parameters + [1] * self.max_num_other_objects + [2] + [3] * self.max_num_objects
+ position_index = list(range(self.max_num_shape_parameters)) + list(range(self.max_num_other_objects)) + [0] + list(range(self.max_num_objects))
+ pad_mask = sentence_pad_mask + other_obj_pad_mask + [0] + obj_pad_mask
+ goal_poses = [goal_structure_pose] + goal_pc_poses
+ else:
+ if self.ignore_distractor_objects:
+ # language, target objects
+ pcs = obj_pcs
+ type_index = [0] * self.max_num_shape_parameters + [3] * self.max_num_objects
+ position_index = list(range(self.max_num_shape_parameters)) + list(range(self.max_num_objects))
+ pad_mask = sentence_pad_mask + obj_pad_mask
+ else:
+ # language, distractor objects, target objects
+ pcs = other_obj_pcs + obj_pcs
+ type_index = [0] * self.max_num_shape_parameters + [1] * self.max_num_other_objects + [3] * self.max_num_objects
+ position_index = list(range(self.max_num_shape_parameters)) + list(range(self.max_num_other_objects)) + list(range(self.max_num_objects))
+ pad_mask = sentence_pad_mask + other_obj_pad_mask + obj_pad_mask
+ goal_poses = goal_pc_poses
+
+ datum = {
+ "pcs": pcs,
+ "sentence": sentence,
+ "goal_poses": goal_poses,
+ "type_index": type_index,
+ "position_index": position_index,
+ "pad_mask": pad_mask,
+ "t": step_t,
+ "filename": filename
+ }
+
+ if inference_mode:
+ datum["rgb"] = rgb
+ datum["goal_obj_poses"] = goal_obj_poses
+ datum["current_obj_poses"] = current_obj_poses
+ datum["target_objs"] = target_objs
+ datum["initial_scene"] = initial_scene
+ datum["ids"] = ids
+ datum["goal_specification"] = goal_specification
+ datum["current_pc_poses"] = current_pc_poses
+
+ return datum
+
+ @staticmethod
+ def convert_to_tensors(datum, tokenizer):
+ tensors = {
+ "pcs": torch.stack(datum["pcs"], dim=0),
+ "sentence": torch.LongTensor(np.array([tokenizer.tokenize(*i) for i in datum["sentence"]])),
+ "goal_poses": torch.FloatTensor(np.array(datum["goal_poses"])),
+ "type_index": torch.LongTensor(np.array(datum["type_index"])),
+ "position_index": torch.LongTensor(np.array(datum["position_index"])),
+ "pad_mask": torch.LongTensor(np.array(datum["pad_mask"])),
+ "t": datum["t"],
+ "filename": datum["filename"]
+ }
+ return tensors
+
+ def __getitem__(self, idx):
+
+ datum = self.convert_to_tensors(self.get_raw_data(idx, shuffle_object_index=self.shuffle_object_index),
+ self.tokenizer)
+
+ return datum
+
+ def single_datum_to_batch(self, x, num_samples, device, inference_mode=True):
+ tensor_x = {}
+
+ tensor_x["pcs"] = x["pcs"].to(device)[None, :, :, :].repeat(num_samples, 1, 1, 1)
+ tensor_x["sentence"] = x["sentence"].to(device)[None, :].repeat(num_samples, 1)
+ if not inference_mode:
+ tensor_x["goal_poses"] = x["goal_poses"].to(device)[None, :, :, :].repeat(num_samples, 1, 1, 1)
+
+ tensor_x["type_index"] = x["type_index"].to(device)[None, :].repeat(num_samples, 1)
+ tensor_x["position_index"] = x["position_index"].to(device)[None, :].repeat(num_samples, 1)
+ tensor_x["pad_mask"] = x["pad_mask"].to(device)[None, :].repeat(num_samples, 1)
+
+ return tensor_x
+
+
+def compute_min_max(dataloader):
+
+ # tensor([-0.3557, -0.3847, 0.0000, -1.0000, -1.0000, -0.4759, -1.0000, -1.0000,
+ # -0.9079, -0.8668, -0.9105, -0.4186])
+ # tensor([0.3915, 0.3494, 0.3267, 1.0000, 1.0000, 0.8961, 1.0000, 1.0000, 0.8194,
+ # 0.4787, 0.6421, 1.0000])
+ # tensor([0.0918, -0.3758, 0.0000, -1.0000, -1.0000, 0.0000, -1.0000, -1.0000,
+ # -0.0000, 0.0000, 0.0000, 1.0000])
+ # tensor([0.9199, 0.3710, 0.0000, 1.0000, 1.0000, 0.0000, 1.0000, 1.0000, -0.0000,
+ # 0.0000, 0.0000, 1.0000])
+
+ min_value = torch.ones(16) * 10000
+ max_value = torch.ones(16) * -10000
+ for d in tqdm(dataloader):
+ goal_poses = d["goal_poses"]
+ goal_poses = goal_poses.reshape(-1, 16)
+ current_max, _ = torch.max(goal_poses, dim=0)
+ current_min, _ = torch.min(goal_poses, dim=0)
+ max_value[max_value < current_max] = current_max[max_value < current_max]
+ max_value[max_value > current_min] = current_min[max_value > current_min]
+ print(f"{min_value} - {max_value}")
+
+
+if __name__ == "__main__":
+
+ tokenizer = Tokenizer("/home/weiyu/data_drive/data_new_objects/type_vocabs_coarse.json")
+
+ data_roots = []
+ index_roots = []
+ for shape, index in [("circle", "index_10k"), ("line", "index_10k"), ("stacking", "index_10k"), ("dinner", "index_10k")]:
+ data_roots.append("/home/weiyu/data_drive/data_new_objects/examples_{}_new_objects/result".format(shape))
+ index_roots.append(index)
+
+ dataset = SemanticArrangementDataset(data_roots=data_roots,
+ index_roots=index_roots,
+ split="valid", tokenizer=tokenizer,
+ max_num_target_objects=7,
+ max_num_distractor_objects=5,
+ max_num_shape_parameters=5,
+ max_num_rearrange_features=0,
+ max_num_anchor_features=0,
+ num_pts=1024,
+ use_virtual_structure_frame=True,
+ ignore_distractor_objects=True,
+ ignore_rgb=True,
+ filter_num_moved_objects_range=None, # [5, 5]
+ data_augmentation=False,
+ shuffle_object_index=False,
+ debug=False)
+
+ # print(len(dataset))
+ # for d in dataset:
+ # print("\n\n" + "="*100)
+
+ dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=8)
+ for i, d in enumerate(tqdm(dataloader)):
+ pass
+ # for k in d:
+ # if isinstance(d[k], torch.Tensor):
+ # print("--size", k, d[k].shape)
+ # for k in d:
+ # print(k, d[k])
+ #
+ # input("next?")
\ No newline at end of file
diff --git a/src/StructDiffusion/diffusion/__init__.py b/src/StructDiffusion/diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/StructDiffusion/diffusion/__pycache__/__init__.cpython-37.pyc b/src/StructDiffusion/diffusion/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..15e1d8365cad74bebf2f131d9e8646c529b9974d
Binary files /dev/null and b/src/StructDiffusion/diffusion/__pycache__/__init__.cpython-37.pyc differ
diff --git a/src/StructDiffusion/diffusion/__pycache__/__init__.cpython-38.pyc b/src/StructDiffusion/diffusion/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..27f93c52a01731df348737534c20838286322db6
Binary files /dev/null and b/src/StructDiffusion/diffusion/__pycache__/__init__.cpython-38.pyc differ
diff --git a/src/StructDiffusion/diffusion/__pycache__/noise_schedule.cpython-37.pyc b/src/StructDiffusion/diffusion/__pycache__/noise_schedule.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c50945974f35e2d0b727238215e8adc0400b8644
Binary files /dev/null and b/src/StructDiffusion/diffusion/__pycache__/noise_schedule.cpython-37.pyc differ
diff --git a/src/StructDiffusion/diffusion/__pycache__/noise_schedule.cpython-38.pyc b/src/StructDiffusion/diffusion/__pycache__/noise_schedule.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..494884c7815431f08db9cd82825ee5fabb0b4a36
Binary files /dev/null and b/src/StructDiffusion/diffusion/__pycache__/noise_schedule.cpython-38.pyc differ
diff --git a/src/StructDiffusion/diffusion/__pycache__/pose_conversion.cpython-37.pyc b/src/StructDiffusion/diffusion/__pycache__/pose_conversion.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4549ee4cedac8ab2704f4e3ed6481343fdf8594f
Binary files /dev/null and b/src/StructDiffusion/diffusion/__pycache__/pose_conversion.cpython-37.pyc differ
diff --git a/src/StructDiffusion/diffusion/__pycache__/pose_conversion.cpython-38.pyc b/src/StructDiffusion/diffusion/__pycache__/pose_conversion.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7945175edead17f2bfc5c6e27ed83105e7506532
Binary files /dev/null and b/src/StructDiffusion/diffusion/__pycache__/pose_conversion.cpython-38.pyc differ
diff --git a/src/StructDiffusion/diffusion/__pycache__/sampler.cpython-37.pyc b/src/StructDiffusion/diffusion/__pycache__/sampler.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3caa501f32838b8807f2bb72a9e7ed280fff72dc
Binary files /dev/null and b/src/StructDiffusion/diffusion/__pycache__/sampler.cpython-37.pyc differ
diff --git a/src/StructDiffusion/diffusion/__pycache__/sampler.cpython-38.pyc b/src/StructDiffusion/diffusion/__pycache__/sampler.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ef5a255067f922aa21dc539da6665bdf20f19c04
Binary files /dev/null and b/src/StructDiffusion/diffusion/__pycache__/sampler.cpython-38.pyc differ
diff --git a/src/StructDiffusion/diffusion/noise_schedule.py b/src/StructDiffusion/diffusion/noise_schedule.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c2b7986fdc0c88d09acd16117d3d059cf67f96b
--- /dev/null
+++ b/src/StructDiffusion/diffusion/noise_schedule.py
@@ -0,0 +1,81 @@
+import math
+import torch
+import torch.nn.functional as F
+
+
+def cosine_beta_schedule(timesteps, s=0.008):
+ """
+ cosine schedule as proposed in https://arxiv.org/abs/2102.09672
+ """
+ steps = timesteps + 1
+ x = torch.linspace(0, timesteps, steps)
+ alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
+ return torch.clip(betas, 0.0001, 0.9999)
+
+
+def linear_beta_schedule(timesteps):
+ beta_start = 0.0001
+ beta_end = 0.02
+ return torch.linspace(beta_start, beta_end, timesteps)
+
+
+def quadratic_beta_schedule(timesteps):
+ beta_start = 0.0001
+ beta_end = 0.02
+ return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2
+
+
+def sigmoid_beta_schedule(timesteps):
+ beta_start = 0.0001
+ beta_end = 0.02
+ betas = torch.linspace(-6, 6, timesteps)
+ return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
+
+
+class NoiseSchedule:
+
+ def __init__(self, timesteps=200):
+
+ self.timesteps = timesteps
+
+ # define beta schedule
+ self.betas = linear_beta_schedule(timesteps=timesteps)
+ # self.betas = cosine_beta_schedule(timesteps=timesteps)
+
+ # define alphas
+ self.alphas = 1. - self.betas
+ # alphas_cumprod: alpha bar
+ self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
+ self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
+ self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
+ self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)
+
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
+ self.posterior_variance = self.betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
+
+
+def extract(a, t, x_shape):
+ batch_size = t.shape[0]
+ out = a.gather(-1, t.cpu())
+ return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
+
+
+# forward diffusion (using the nice property)
+def q_sample(x_start, t, noise_schedule, noise=None):
+ if noise is None:
+ noise = torch.randn_like(x_start)
+
+ sqrt_alphas_cumprod_t = extract(noise_schedule.sqrt_alphas_cumprod, t, x_start.shape)
+ # print("sqrt_alphas_cumprod_t", sqrt_alphas_cumprod_t)
+ sqrt_one_minus_alphas_cumprod_t = extract(
+ noise_schedule.sqrt_one_minus_alphas_cumprod, t, x_start.shape
+ )
+ # print("sqrt_one_minus_alphas_cumprod_t", sqrt_one_minus_alphas_cumprod_t)
+ # print("noise", noise)
+
+ return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
\ No newline at end of file
diff --git a/src/StructDiffusion/diffusion/pose_conversion.py b/src/StructDiffusion/diffusion/pose_conversion.py
new file mode 100644
index 0000000000000000000000000000000000000000..abf43bae2bc6af9fed8338c255d0bdcc06447465
--- /dev/null
+++ b/src/StructDiffusion/diffusion/pose_conversion.py
@@ -0,0 +1,103 @@
+import os
+import torch
+import pytorch3d.transforms as tra3d
+
+from StructDiffusion.utils.rotation_continuity import compute_rotation_matrix_from_ortho6d
+
+
+def get_diffusion_variables_from_9D_actions(struct_xyztheta_inputs, obj_xyztheta_inputs):
+
+ # important: we need to get the first two columns, not first two rows
+ # array([[ 3, 4, 5],
+ # [ 6, 7, 8],
+ # [ 9, 10, 11]])
+ xyz_6d_idxs = [0, 1, 2, 3, 6, 9, 4, 7, 10]
+
+ # print(batch_data["obj_xyztheta_inputs"].shape)
+ # print(batch_data["struct_xyztheta_inputs"].shape)
+
+ # only get the first and second columns of rotation
+ obj_xyztheta_inputs = obj_xyztheta_inputs[:, :, xyz_6d_idxs] # B, N, 9
+ struct_xyztheta_inputs = struct_xyztheta_inputs[:, :, xyz_6d_idxs] # B, 1, 9
+
+ x = torch.cat([struct_xyztheta_inputs, obj_xyztheta_inputs], dim=1) # B, 1 + N, 9
+
+ # print(x.shape)
+
+ return x
+
+
+def get_diffusion_variables_from_H(poses):
+ """
+ [[0,1,2,3],
+ [4,5,6,7],
+ [8,9,10,11],
+ [12,13,14,15]
+ :param obj_xyztheta_inputs: B, N, 4, 4
+ :return:
+ """
+
+ xyz_6d_idxs = [3, 7, 11, 0, 4, 8, 1, 5, 9]
+
+ B, N, _, _ = poses.shape
+ x = poses.reshape(B, N, 16)[:, :, xyz_6d_idxs] # B, N, 9
+ return x
+
+
+def get_struct_objs_poses(x):
+
+ on_gpu = x.is_cuda
+ if not on_gpu:
+ x = x.cuda()
+
+ # assert x.is_cuda, "compute_rotation_matrix_from_ortho6d requires input to be on gpu"
+ device = x.device
+
+ # important: the noisy x can go out of bounds
+ x = torch.clamp(x, min=-1, max=1)
+
+ # x: B, 1 + N, 9
+ B = x.shape[0]
+ N = x.shape[1] - 1
+
+ # compute_rotation_matrix_from_ortho6d takes in [B, 6], outputs [B, 3, 3]
+ x_6d = x[:, :, 3:].reshape(-1, 6)
+ x_rot = compute_rotation_matrix_from_ortho6d(x_6d).reshape(B, N+1, 3, 3) # B, 1 + N, 3, 3
+
+ x_trans = x[:, :, :3] # B, 1 + N, 3
+
+ x_full = torch.eye(4).repeat(B, 1 + N, 1, 1).to(device)
+ x_full[:, :, :3, :3] = x_rot
+ x_full[:, :, :3, 3] = x_trans
+
+ struct_pose = x_full[:, 0].unsqueeze(1) # B, 1, 4, 4
+ pc_poses_in_struct = x_full[:, 1:] # B, N, 4, 4
+
+ if not on_gpu:
+ struct_pose = struct_pose.cpu()
+ pc_poses_in_struct = pc_poses_in_struct.cpu()
+
+ return struct_pose, pc_poses_in_struct
+
+
+def compute_current_and_goal_pc_poses(obj_xyzs, struct_pose, pc_poses_in_struct):
+
+ device = obj_xyzs.device
+
+ # obj_xyzs: B, N, P, 3
+ # struct_pose: B, 1, 4, 4
+ # pc_poses_in_struct: B, N, 4, 4
+ B, N, _, _ = pc_poses_in_struct.shape
+ _, _, P, _ = obj_xyzs.shape
+
+ current_pc_poses = torch.eye(4).repeat(B, N, 1, 1).to(device) # B, N, 4, 4
+ # print(torch.mean(obj_xyzs, dim=2).shape)
+ current_pc_poses[:, :, :3, 3] = torch.mean(obj_xyzs, dim=2) # B, N, 4, 4
+
+ struct_pose = struct_pose.repeat(1, N, 1, 1) # B, N, 4, 4
+ struct_pose = struct_pose.reshape(B * N, 4, 4) # B x 1, 4, 4
+ pc_poses_in_struct = pc_poses_in_struct.reshape(B * N, 4, 4) # B x N, 4, 4
+
+ goal_pc_poses = struct_pose @ pc_poses_in_struct # B x N, 4, 4
+ goal_pc_poses = goal_pc_poses.reshape(B, N, 4, 4) # B, N, 4, 4
+ return current_pc_poses, goal_pc_poses
\ No newline at end of file
diff --git a/src/StructDiffusion/diffusion/sampler.py b/src/StructDiffusion/diffusion/sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..06ee7a6258771dd01e35c075661b5a7f0e686035
--- /dev/null
+++ b/src/StructDiffusion/diffusion/sampler.py
@@ -0,0 +1,296 @@
+import torch
+from tqdm import tqdm
+import pytorch3d.transforms as tra3d
+
+from StructDiffusion.diffusion.noise_schedule import extract
+from StructDiffusion.diffusion.pose_conversion import get_struct_objs_poses
+from StructDiffusion.utils.batch_inference import move_pc_and_create_scene_simple, visualize_batch_pcs, move_pc_and_create_scene_new
+
+class Sampler:
+
+ def __init__(self, model_class, checkpoint_path, device, debug=False):
+
+ self.debug = debug
+ self.device = device
+
+ self.model = model_class.load_from_checkpoint(checkpoint_path)
+ self.backbone = self.model.model
+ self.backbone.to(device)
+ self.backbone.eval()
+
+ def sample(self, batch, num_poses):
+
+ noise_schedule = self.model.noise_schedule
+
+ B = batch["pcs"].shape[0]
+
+ x_noisy = torch.randn((B, num_poses, 9), device=self.device)
+
+ xs = []
+ for t_index in tqdm(reversed(range(0, noise_schedule.timesteps)),
+ desc='sampling loop time step', total=noise_schedule.timesteps):
+
+ t = torch.full((B,), t_index, device=self.device, dtype=torch.long)
+
+ # noise schedule
+ betas_t = extract(noise_schedule.betas, t, x_noisy.shape)
+ sqrt_one_minus_alphas_cumprod_t = extract(noise_schedule.sqrt_one_minus_alphas_cumprod, t, x_noisy.shape)
+ sqrt_recip_alphas_t = extract(noise_schedule.sqrt_recip_alphas, t, x_noisy.shape)
+
+ # predict noise
+ pcs = batch["pcs"]
+ sentence = batch["sentence"]
+ type_index = batch["type_index"]
+ position_index = batch["position_index"]
+ pad_mask = batch["pad_mask"]
+ # calling the backbone instead of the pytorch-lightning model
+ with torch.no_grad():
+ predicted_noise = self.backbone.forward(t, pcs, sentence, x_noisy, type_index, position_index, pad_mask)
+
+ # compute noisy x at t
+ model_mean = sqrt_recip_alphas_t * (x_noisy - betas_t * predicted_noise / sqrt_one_minus_alphas_cumprod_t)
+ if t_index == 0:
+ x_noisy = model_mean
+ else:
+ posterior_variance_t = extract(noise_schedule.posterior_variance, t, x_noisy.shape)
+ noise = torch.randn_like(x_noisy)
+ x_noisy = model_mean + torch.sqrt(posterior_variance_t) * noise
+
+ xs.append(x_noisy)
+
+ xs = list(reversed(xs))
+ return xs
+
+class SamplerV2:
+
+ def __init__(self, diffusion_model_class, diffusion_checkpoint_path,
+ collision_model_class, collision_checkpoint_path,
+ device, debug=False):
+
+ self.debug = debug
+ self.device = device
+
+ self.diffusion_model = diffusion_model_class.load_from_checkpoint(diffusion_checkpoint_path)
+ self.diffusion_backbone = self.diffusion_model.model
+ self.diffusion_backbone.to(device)
+ self.diffusion_backbone.eval()
+
+ self.collision_model = collision_model_class.load_from_checkpoint(collision_checkpoint_path)
+ self.collision_backbone = self.collision_model.model
+ self.collision_backbone.to(device)
+ self.collision_backbone.eval()
+
+ def sample(self, batch, num_poses):
+
+ noise_schedule = self.diffusion_model.noise_schedule
+
+ B = batch["pcs"].shape[0]
+
+ x_noisy = torch.randn((B, num_poses, 9), device=self.device)
+
+ xs = []
+ for t_index in tqdm(reversed(range(0, noise_schedule.timesteps)),
+ desc='sampling loop time step', total=noise_schedule.timesteps):
+
+ t = torch.full((B,), t_index, device=self.device, dtype=torch.long)
+
+ # noise schedule
+ betas_t = extract(noise_schedule.betas, t, x_noisy.shape)
+ sqrt_one_minus_alphas_cumprod_t = extract(noise_schedule.sqrt_one_minus_alphas_cumprod, t, x_noisy.shape)
+ sqrt_recip_alphas_t = extract(noise_schedule.sqrt_recip_alphas, t, x_noisy.shape)
+
+ # predict noise
+ pcs = batch["pcs"]
+ sentence = batch["sentence"]
+ type_index = batch["type_index"]
+ position_index = batch["position_index"]
+ pad_mask = batch["pad_mask"]
+ # calling the backbone instead of the pytorch-lightning model
+ with torch.no_grad():
+ predicted_noise = self.diffusion_backbone.forward(t, pcs, sentence, x_noisy, type_index, position_index, pad_mask)
+
+ # compute noisy x at t
+ model_mean = sqrt_recip_alphas_t * (x_noisy - betas_t * predicted_noise / sqrt_one_minus_alphas_cumprod_t)
+ if t_index == 0:
+ x_noisy = model_mean
+ else:
+ posterior_variance_t = extract(noise_schedule.posterior_variance, t, x_noisy.shape)
+ noise = torch.randn_like(x_noisy)
+ x_noisy = model_mean + torch.sqrt(posterior_variance_t) * noise
+
+ xs.append(x_noisy)
+
+ xs = list(reversed(xs))
+
+ visualize = True
+
+ struct_pose, pc_poses_in_struct = get_struct_objs_poses(xs[0])
+ # struct_pose: B, 1, 4, 4
+ # pc_poses_in_struct: B, N, 4, 4
+
+ S = B
+ num_elite = 10
+ ####################################################
+ # only keep one copy
+
+ # N, P, 3
+ obj_xyzs = batch["pcs"][0][:, :, :3]
+ print("obj_xyzs shape", obj_xyzs.shape)
+
+ # 1, N
+ # object_pad_mask: padding location has 1
+ num_target_objs = num_poses
+ if self.diffusion_backbone.use_virtual_structure_frame:
+ num_target_objs -= 1
+ object_pad_mask = batch["pad_mask"][0][-num_target_objs:].unsqueeze(0)
+ target_object_inds = 1 - object_pad_mask
+ print("target_object_inds shape", target_object_inds.shape)
+ print("target_object_inds", target_object_inds)
+
+ N, P, _ = obj_xyzs.shape
+ print("S, N, P: {}, {}, {}".format(S, N, P))
+
+ ####################################################
+ # S, N, ...
+
+ struct_pose = struct_pose.repeat(1, N, 1, 1) # S, N, 4, 4
+ struct_pose = struct_pose.reshape(S * N, 4, 4) # S x N, 4, 4
+
+ new_obj_xyzs = obj_xyzs.repeat(S, 1, 1, 1) # S, N, P, 3
+ current_pc_pose = torch.eye(4).repeat(S, N, 1, 1).to(self.device) # S, N, 4, 4
+ current_pc_pose[:, :, :3, 3] = torch.mean(new_obj_xyzs, dim=2) # S, N, 4, 4
+ current_pc_pose = current_pc_pose.reshape(S * N, 4, 4) # S x N, 4, 4
+
+ # optimize xyzrpy
+ obj_params = torch.zeros((S, N, 6)).to(self.device)
+ obj_params[:, :, :3] = pc_poses_in_struct[:, :, :3, 3]
+ obj_params[:, :, 3:] = tra3d.matrix_to_euler_angles(pc_poses_in_struct[:, :, :3, :3], "XYZ") # S, N, 6
+ #
+ # new_obj_xyzs_before_cem, goal_pc_pose_before_cem = move_pc(obj_xyzs, obj_params, struct_pose, current_pc_pose, device)
+ #
+ # if visualize:
+ # print("visualizing rearrangements predicted by the generator")
+ # visualize_batch_pcs(new_obj_xyzs_before_cem, S, N, P, limit_B=5)
+
+ ####################################################
+ # rank
+
+ # evaluate in batches
+ scores = torch.zeros(S).to(self.device)
+ no_intersection_scores = torch.zeros(S).to(self.device) # the higher the better
+ num_batches = int(S / B)
+ if S % B != 0:
+ num_batches += 1
+ for b in range(num_batches):
+ if b + 1 == num_batches:
+ cur_batch_idxs_start = b * B
+ cur_batch_idxs_end = S
+ else:
+ cur_batch_idxs_start = b * B
+ cur_batch_idxs_end = (b + 1) * B
+ cur_batch_size = cur_batch_idxs_end - cur_batch_idxs_start
+
+ # print("current batch idxs start", cur_batch_idxs_start)
+ # print("current batch idxs end", cur_batch_idxs_end)
+ # print("size of the current batch", cur_batch_size)
+
+ batch_obj_params = obj_params[cur_batch_idxs_start: cur_batch_idxs_end]
+ batch_struct_pose = struct_pose[cur_batch_idxs_start * N: cur_batch_idxs_end * N]
+ batch_current_pc_pose = current_pc_pose[cur_batch_idxs_start * N:cur_batch_idxs_end * N]
+
+ new_obj_xyzs, _, subsampled_scene_xyz, _, obj_pair_xyzs = \
+ move_pc_and_create_scene_new(obj_xyzs, batch_obj_params, batch_struct_pose, batch_current_pc_pose,
+ target_object_inds, self.device,
+ return_scene_pts=False,
+ return_scene_pts_and_pc_idxs=False,
+ num_scene_pts=False,
+ normalize_pc=False,
+ return_pair_pc=True,
+ num_pair_pc_pts=self.collision_model.data_cfg.num_scene_pts,
+ normalize_pair_pc=self.collision_model.data_cfg.normalize_pc)
+
+ #######################################
+ # predict whether there are pairwise collisions
+ # if collision_score_weight > 0:
+ with torch.no_grad():
+ _, num_comb, num_pair_pc_pts, _ = obj_pair_xyzs.shape
+ # obj_pair_xyzs = obj_pair_xyzs.reshape(cur_batch_size * num_comb, num_pair_pc_pts, -1)
+ collision_logits = self.collision_backbone.forward(obj_pair_xyzs.reshape(cur_batch_size * num_comb, num_pair_pc_pts, -1))
+ collision_scores = self.collision_backbone.convert_logits(collision_logits).reshape(cur_batch_size, num_comb) # cur_batch_size, num_comb
+
+ # debug
+ # for bi, this_obj_pair_xyzs in enumerate(obj_pair_xyzs):
+ # print("batch id", bi)
+ # for pi, obj_pair_xyz in enumerate(this_obj_pair_xyzs):
+ # print("pair", pi)
+ # # obj_pair_xyzs: 2 * P, 5
+ # print("collision score", collision_scores[bi, pi])
+ # trimesh.PointCloud(obj_pair_xyz[:, :3].cpu()).show()
+
+ # 1 - mean() since the collision model predicts 1 if there is a collision
+ no_intersection_scores[cur_batch_idxs_start:cur_batch_idxs_end] = 1 - torch.mean(collision_scores, dim=1)
+ if visualize:
+ print("no intersection scores", no_intersection_scores)
+ # #######################################
+ # if discriminator_score_weight > 0:
+ # # # debug:
+ # # print(subsampled_scene_xyz.shape)
+ # # print(subsampled_scene_xyz[0])
+ # # trimesh.PointCloud(subsampled_scene_xyz[0, :, :3].cpu().numpy()).show()
+ # #
+ # with torch.no_grad():
+ #
+ # # Important: since this discriminator only uses local structure param, takes sentence from the first and last position
+ # # local_sentence = sentence[:, [0, 4]]
+ # # local_sentence_pad_mask = sentence_pad_mask[:, [0, 4]]
+ # # sentence_disc, sentence_pad_mask_disc, position_index_dic = discriminator_inference.dataset.tensorfy_sentence(raw_sentence_discriminator, raw_sentence_pad_mask_discriminator, raw_position_index_discriminator)
+ #
+ # sentence_disc = torch.LongTensor(
+ # [discriminator_tokenizer.tokenize(*i) for i in raw_sentence_discriminator])
+ # sentence_pad_mask_disc = torch.LongTensor(raw_sentence_pad_mask_discriminator)
+ # position_index_dic = torch.LongTensor(raw_position_index_discriminator)
+ #
+ # preds = discriminator_model.forward(subsampled_scene_xyz,
+ # sentence_disc.unsqueeze(0).repeat(cur_batch_size, 1).to(device),
+ # sentence_pad_mask_disc.unsqueeze(0).repeat(cur_batch_size,
+ # 1).to(device),
+ # position_index_dic.unsqueeze(0).repeat(cur_batch_size, 1).to(
+ # device))
+ # # preds = discriminator_model.forward(subsampled_scene_xyz)
+ # preds = discriminator_model.convert_logits(preds)
+ # preds = preds["is_circle"] # cur_batch_size,
+ # scores[cur_batch_idxs_start:cur_batch_idxs_end] = preds
+ # if visualize:
+ # print("discriminator scores", scores)
+
+ # scores = scores * discriminator_score_weight + no_intersection_scores * collision_score_weight
+ scores = no_intersection_scores
+ sort_idx = torch.argsort(scores).flip(dims=[0])[:num_elite]
+ elite_obj_params = obj_params[sort_idx] # num_elite, N, 6
+ elite_struct_poses = struct_pose.reshape(S, N, 4, 4)[sort_idx] # num_elite, N, 4, 4
+ elite_struct_poses = elite_struct_poses.reshape(num_elite * N, 4, 4) # num_elite x N, 4, 4
+ elite_scores = scores[sort_idx]
+ print("elite scores:", elite_scores)
+
+ ####################################################
+ # # visualize best samples
+ # num_scene_pts = 4096 # if discriminator_num_scene_pts is None else discriminator_num_scene_pts
+ # batch_current_pc_pose = current_pc_pose[0: num_elite * N]
+ # best_new_obj_xyzs, best_goal_pc_pose, best_subsampled_scene_xyz, _, _ = \
+ # move_pc_and_create_scene_new(obj_xyzs, elite_obj_params, elite_struct_poses, batch_current_pc_pose,
+ # target_object_inds, self.device,
+ # return_scene_pts=True, num_scene_pts=num_scene_pts, normalize_pc=True)
+ # if visualize:
+ # print("visualizing elite rearrangements ranked by collision model/discriminator")
+ # visualize_batch_pcs(best_new_obj_xyzs, num_elite, limit_B=num_elite)
+
+ # num_elite, N, 6
+ elite_obj_params = elite_obj_params.reshape(num_elite * N, -1)
+ pc_poses_in_struct = torch.eye(4).repeat(num_elite * N, 1, 1).to(self.device)
+ pc_poses_in_struct[:, :3, :3] = tra3d.euler_angles_to_matrix(elite_obj_params[:, 3:], "XYZ")
+ pc_poses_in_struct[:, :3, 3] = elite_obj_params[:, :3]
+ pc_poses_in_struct = pc_poses_in_struct.reshape(num_elite, N, 4, 4) # num_elite, N, 4, 4
+
+ struct_pose = elite_struct_poses.reshape(num_elite, N, 4, 4)[:, 0,].unsqueeze(1) # num_elite, 1, 4, 4
+
+ return struct_pose, pc_poses_in_struct
\ No newline at end of file
diff --git a/src/StructDiffusion/language/__init__.py b/src/StructDiffusion/language/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/StructDiffusion/language/__pycache__/__init__.cpython-37.pyc b/src/StructDiffusion/language/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4aed88d9ff03518b98ae5ed16652d45bf0e5e28c
Binary files /dev/null and b/src/StructDiffusion/language/__pycache__/__init__.cpython-37.pyc differ
diff --git a/src/StructDiffusion/language/__pycache__/__init__.cpython-38.pyc b/src/StructDiffusion/language/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..83ba0f5e6c27b15d0de3f53b686d3267c861ddfa
Binary files /dev/null and b/src/StructDiffusion/language/__pycache__/__init__.cpython-38.pyc differ
diff --git a/src/StructDiffusion/language/__pycache__/tokenizer.cpython-37.pyc b/src/StructDiffusion/language/__pycache__/tokenizer.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..09854272d48643a552100b005d109099ba6e11bb
Binary files /dev/null and b/src/StructDiffusion/language/__pycache__/tokenizer.cpython-37.pyc differ
diff --git a/src/StructDiffusion/language/__pycache__/tokenizer.cpython-38.pyc b/src/StructDiffusion/language/__pycache__/tokenizer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3090fb1a1f4014f8ede1eec9758f552adc21a8cd
Binary files /dev/null and b/src/StructDiffusion/language/__pycache__/tokenizer.cpython-38.pyc differ
diff --git a/src/StructDiffusion/language/tokenizer.py b/src/StructDiffusion/language/tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ef678dbfcf20ef2649d3491169b2f9f5517495c
--- /dev/null
+++ b/src/StructDiffusion/language/tokenizer.py
@@ -0,0 +1,541 @@
+import json
+import numpy as np
+import re
+
+# def add_pad_to_vocab(vocab):
+# new_vocab = {"PAD": 0}
+# for k in vocab:
+# new_vocab[k] = vocab[k] + 1
+# return new_vocab
+#
+#
+# def combine_vocabs(vocabs, vocab_types):
+# new_vocab = {}
+# for type, vocab in zip(vocab_types, vocabs):
+# for k in vocab:
+# new_vocab["{}:{}".format(type, k)] = len(new_vocab)
+# return new_vocab
+#
+#
+# def add_token_to_vocab(vocab):
+# new_vocab = {"MASK": 0}
+# for k in vocab:
+# new_vocab[k] = vocab[k] + 1
+# return new_vocab
+#
+#
+# def tokenize_circle_specification(circle_specification):
+# tokenized = {}
+# # min 0, max 0.5, increment 0.05, 10 discrete values
+# tokenized["radius"] = int(circle_specification["radius"] / 0.05)
+#
+# # min 0, max 1, increment 0.10, 10 discrete values
+# tokenized["position_x"] = int(circle_specification["position"][0] / 0.10)
+#
+# # min -0.5, max 0.5, increment 0.10, 10 discrete values
+# tokenized["position_y"] = int(circle_specification["position"][1] / 0.10)
+#
+# # min -3.14, max 3.14, increment 3.14 / 18, 36 discrete values
+# tokenized["rotation"] = int((circle_specification["rotation"][2] + 3.14) / (3.14 / 18))
+#
+# uniform_angle_vocab = {"False": 0, "True": 1}
+# tokenized["uniform_angle"] = uniform_angle_vocab[circle_specification["uniform_angle"]]
+#
+# face_center_vocab = {"False": 0, "True": 1}
+# tokenized["face_center"] = face_center_vocab[circle_specification["face_center"]]
+#
+# angle_ratio_vocab = {0.5: 0, 1.0: 1}
+# tokenized["angle_ratio"] = angle_ratio_vocab[circle_specification["angle_ratio"]]
+#
+# # heights min 0.0, max 0.5
+# # volumn min 0.0, max 0.012
+#
+# return tokenized
+#
+#
+# def build_vocab(old_vocab_file, new_vocab_file):
+# with open(old_vocab_file, "r") as fh:
+# vocab_json = json.load(fh)
+#
+# vocabs = {}
+# vocabs["class"] = vocab_json["class_to_idx"]
+# vocabs["size"] = vocab_json["size_to_idx"]
+# vocabs["color"] = vocab_json["color_to_idx"]
+# vocabs["material"] = vocab_json["material_to_idx"]
+# vocabs["comparator"] = {"less": 1, "greater": 2, "equal": 3}
+#
+# vocabs["radius"] = (0.0, 0.5, 10)
+# vocabs["position_x"] = (0.0, 1.0, 10)
+# vocabs["position_y"] = (-0.5, 0.5, 10)
+# vocabs["rotation"] = (-3.14, 3.14, 36)
+# vocabs["height"] = (0.0, 0.5, 10)
+# vocabs["volumn"] = (0.0, 0.012, 10)
+#
+# vocabs["uniform_angle"] = {"False": 0, "True": 1}
+# vocabs["face_center"] = {"False": 0, "True": 1}
+# vocabs["angle_ratio"] = {0.5: 0, 1.0: 1}
+#
+# with open(new_vocab_file, "w") as fh:
+# json.dump(vocabs, fh)
+
+
+class Tokenizer:
+ """
+ We want to build a tokenizer that tokenize words, features, and numbers.
+
+ This tokenizer should also allow us to sample random values.
+
+ For discrete values, we store mapping from the value to an id
+ For continuous values, we store min, max, and number of bins after discretization
+
+ """
+
+ def __init__(self, vocab_file):
+
+ self.vocab_file = vocab_file
+ with open(self.vocab_file, "r") as fh:
+ self.type_vocabs = json.load(fh)
+
+ self.vocab = {"PAD": 0, "CLS": 1}
+ self.discrete_types = set()
+ self.continuous_types = set()
+ self.build_one_vocab()
+
+ self.object_position_vocabs = {}
+ self.build_object_position_vocabs()
+
+ def build_one_vocab(self):
+ print("\nBuild one vacab for everything...")
+
+ for typ, vocab in self.type_vocabs.items():
+ if typ == "comparator":
+ continue
+
+ if typ in ["obj_x", "obj_y", "obj_z", "obj_rr", "obj_rp", "obj_ry",
+ "struct_x", "struct_y", "struct_z", "struct_rr", "struct_rp", "struct_ry"]:
+ continue
+
+ if type(vocab) == dict:
+ self.vocab["{}:{}".format(typ, "MASK")] = len(self.vocab)
+
+ for v in vocab:
+ assert ":" not in v
+ self.vocab["{}:{}".format(typ, v)] = len(self.vocab)
+ self.discrete_types.add(typ)
+
+ elif type(vocab) == tuple or type(vocab) == list:
+ self.vocab["{}:{}".format(typ, "MASK")] = len(self.vocab)
+
+ for c in self.type_vocabs["comparator"]:
+ self.vocab["{}:{}".format(typ, c)] = len(self.vocab)
+
+ min_value, max_value, num_bins = vocab
+ for i in range(num_bins):
+ self.vocab["{}:{}".format(typ, i)] = len(self.vocab)
+ self.continuous_types.add(typ)
+ else:
+ raise TypeError("The dtype of the vocab cannot be handled: {}".format(vocab))
+
+ print("The vocab has {} tokens: {}".format(len(self.vocab), self.vocab))
+
+ def build_object_position_vocabs(self):
+ print("\nBuild vocabs for object position")
+ for typ in ["obj_x", "obj_y", "obj_z", "obj_rr", "obj_rp", "obj_ry",
+ "struct_x", "struct_y", "struct_z", "struct_rr", "struct_rp", "struct_ry"]:
+ self.object_position_vocabs[typ] = {"PAD": 0, "MASK": 1}
+
+ if typ not in self.type_vocabs:
+ continue
+ min_value, max_value, num_bins = self.type_vocabs[typ]
+ for i in range(num_bins):
+ self.object_position_vocabs[typ]["{}".format(i)] = len(self.object_position_vocabs[typ])
+ print("The {} vocab has {} tokens: {}".format(typ, len(self.object_position_vocabs[typ]), self.object_position_vocabs[typ]))
+
+ def get_object_position_vocab_sizes(self):
+ return len(self.object_position_vocabs["position_x"]), len(self.object_position_vocabs["position_y"]), len(self.object_position_vocabs["rotation"])
+
+ def get_vocab_size(self):
+ return len(self.vocab)
+
+ def tokenize_object_position(self, value, typ):
+ assert typ in ["obj_x", "obj_y", "obj_z", "obj_rr", "obj_rp", "obj_ry",
+ "struct_x", "struct_y", "struct_z", "struct_rr", "struct_rp", "struct_ry"]
+ if value == "MASK" or value == "PAD":
+ return self.object_position_vocabs[typ][value]
+ elif value == "IGNORE":
+ # Important: used to avoid computing loss. -100 is the default ignore_index for NLLLoss
+ return -100
+ else:
+ min_value, max_value, num_bins = self.type_vocabs[typ]
+ assert min_value <= value <= max_value, value
+ dv = min(int((value - min_value) / ((max_value - min_value) / num_bins)), num_bins - 1)
+ return self.object_position_vocabs[typ]["{}".format(dv)]
+
+ def tokenize(self, value, typ=None):
+ if value in ["PAD", "CLS"]:
+ idx = self.vocab[value]
+ else:
+ if typ is None:
+ raise KeyError("Type cannot be None")
+
+ if typ[-2:] == "_c" or typ[-2:] == "_d":
+ typ = typ[:-2]
+
+ if typ in self.discrete_types:
+ idx = self.vocab["{}:{}".format(typ, value)]
+ elif typ in self.continuous_types:
+ if value == "MASK" or value in self.type_vocabs["comparator"]:
+ idx = self.vocab["{}:{}".format(typ, "MASK")]
+ else:
+ min_value, max_value, num_bins = self.type_vocabs[typ]
+ assert min_value <= value <= max_value, "type {} value {} exceeds {} and {}".format(typ, value, min_value, max_value)
+ dv = min(int((value - min_value) / ((max_value - min_value) / num_bins)), num_bins - 1)
+ # print(value, dv, "{}:{}".format(typ, dv))
+ idx = self.vocab["{}:{}".format(typ, dv)]
+ else:
+ raise KeyError("Do not recognize the type {} of the given token: {}".format(typ, value))
+ return idx
+
+ def get_valid_random_value(self, typ):
+ """
+ Get a random value for the given typ
+ :param typ:
+ :return:
+ """
+ if typ[-2:] == "_c" or typ[-2:] == "_d":
+ typ = typ[-2:]
+
+ candidate_values = []
+ for v in self.vocab:
+ if v in ["PAD", "CLS"]:
+ continue
+ ft, fv = v.split(":")
+ if typ == ft and fv != "MASK" and fv not in self.type_vocabs["comparator"]:
+ candidate_values.append(v)
+ assert len(candidate_values) != 0
+ typed_v = np.random.choice(candidate_values)
+ value = typed_v.split(":")[1]
+
+ if typ in self.discrete_types:
+ return value
+ elif typ in self.continuous_types:
+ min_value, max_value, num_bins = self.type_vocabs[typ]
+ return min_value + ((max_value - min_value) / num_bins) * int(value)
+ else:
+ raise KeyError("Do not recognize the type {} of the given token".format(typ))
+
+ def get_all_values_of_type(self, typ):
+ """
+ Get all values for the given typ
+ :param typ:
+ :return:
+ """
+ if typ[-2:] == "_c" or typ[-2:] == "_d":
+ typ = typ[-2:]
+
+ candidate_values = []
+ for v in self.vocab:
+ if v in ["PAD", "CLS"]:
+ continue
+ ft, fv = v.split(":")
+ if typ == ft and fv != "MASK" and fv not in self.type_vocabs["comparator"]:
+ candidate_values.append(v)
+ assert len(candidate_values) != 0
+ values = [typed_v.split(":")[1] for typed_v in candidate_values]
+
+ if typ in self.discrete_types:
+ return values
+ else:
+ raise KeyError("Do not recognize the type {} of the given token".format(typ))
+
+ def convert_to_natural_sentence(self, template_sentence):
+
+ # select objects that are [red, metal]
+ # select objects that are [larger, taller] than the [], [], [] object
+ # select objects that have the same [color, material] of the [], [], [] object
+
+ natural_sentence_templates = ["select objects that are {}.",
+ "select objects that have {} {} {} the {}.",
+ "select objects that have the same {} as the {}."]
+
+ v, t = template_sentence[0]
+ if t[-2:] == "_c" or t[-2:] == "_d":
+ t = t[:-2]
+
+ if v != "MASK" and t in self.discrete_types:
+ natural_sentence_template = natural_sentence_templates[0]
+ if t == "class":
+ natural_sentence = natural_sentence_template.format(re.findall(r'[A-Z](?:[a-z]+|[A-Z]*(?=[A-Z]|$))', v)[0].lower())
+ else:
+ natural_sentence = natural_sentence_template.format(v)
+ else:
+ anchor_obj_properties = []
+ class_reference = None
+ for token in template_sentence[1:]:
+ if token[0] != "PAD":
+ if token[1] == "class":
+ class_reference = token[0]
+ else:
+ anchor_obj_properties.append(token[0])
+ # order the properties
+ anchor_obj_des = ", ".join(anchor_obj_properties)
+ if class_reference is None:
+ anchor_obj_des += " object"
+ else:
+ anchor_obj_des += " {}".format(re.findall(r'[A-Z](?:[a-z]+|[A-Z]*(?=[A-Z]|$))', class_reference)[0].lower())
+
+ if v == "MASK":
+ natural_sentence_template = natural_sentence_templates[2]
+ anchor_type = t
+ natural_sentence = natural_sentence_template.format(anchor_type, anchor_obj_des)
+ elif t in self.continuous_types:
+ natural_sentence_template = natural_sentence_templates[1]
+ if v == "equal":
+ jun = "as"
+ else:
+ jun = "than"
+ natural_sentence = natural_sentence_template.format(v, t, jun, anchor_obj_des)
+ else:
+ raise NotImplementedError
+
+ return natural_sentence
+
+ def prepare_grounding_reference(self):
+ goal = {"rearrange": {"features": []},
+ "anchor": {"features": []}}
+ discrete_type = ["class", "material", "color"]
+ continuous_type = ["volumn", "height"]
+
+ print("#"*50)
+ print("Preparing referring expression")
+
+ refer_type = verify_input("direct (1) or relational reference (2)? ", [1, 2], int)
+ if refer_type == 1:
+
+ # 1. no anchor
+ t = verify_input("desired type: ", discrete_type, None)
+ v = verify_input("desired value: ", self.get_all_values_of_type(t), None)
+
+ goal["rearrange"]["features"].append({"comparator": None, "type": t, "value": v})
+
+ elif refer_type == 2:
+
+ value_type = verify_input("discrete (1) or continuous relational reference (2)? ", [1, 2], int)
+ if value_type == 1:
+ t = verify_input("desired type: ", discrete_type, None)
+ # 2. discrete
+ goal["rearrange"]["features"].append({"comparator": None, "type": t, "value": None})
+ elif value_type == 2:
+ comp = verify_input("desired comparator: ", list(self.type_vocabs["comparator"].keys()), None)
+ t = verify_input("desired type: ", continuous_type, None)
+ # 3. continuous
+ goal["rearrange"]["features"].append({"comparator": comp, "type": t, "value": None})
+
+ num_f = verify_input("desired number of features for the anchor object: ", [1, 2, 3], int)
+ for i in range(num_f):
+ t = verify_input("desired type: ", discrete_type, None)
+ v = verify_input("desired value: ", self.get_all_values_of_type(t), None)
+ goal["anchor"]["features"].append({"comparator": None, "type": t, "value": v})
+
+ return goal
+
+ def convert_structure_params_to_natural_language(self, sentence):
+
+ # ('circle', 'shape'), (-1.3430555575431449, 'rotation'), (0.3272675147405848, 'position_x'), (-0.03104362197706456, 'position_y'), (0.04674859577847633, 'radius')
+
+ shape = None
+ x = None
+ y = None
+ rot = None
+ size = None
+
+ for param in sentence:
+ if param[0] == "PAD":
+ continue
+
+ v, t = param
+ if t == "shape":
+ shape = v
+ elif t == "position_x":
+ dv = self.discretize(v, t)
+ if dv == 0:
+ x = "bottom"
+ elif dv == 1:
+ x = "middle"
+ elif dv == 2:
+ x = "top"
+ else:
+ raise KeyError("key {} not found in {}".format(v, self.type_vocabs[t]))
+ elif t == "position_y":
+ dv = self.discretize(v, t)
+ if dv == 0:
+ y = "right"
+ elif dv == 1:
+ y = "center"
+ elif dv == 2:
+ y = "left"
+ else:
+ raise KeyError("key {} not found in {}".format(v, self.type_vocabs[t]))
+ elif t == "radius":
+ dv = self.discretize(v, t)
+ if dv == 0:
+ size = "small"
+ elif dv == 1:
+ size = "medium"
+ elif dv == 2:
+ size = "large"
+ else:
+ raise KeyError("key {} not found in {}".format(v, self.type_vocabs[t]))
+ elif t == "rotation":
+ dv = self.discretize(v, t)
+ if dv == 0:
+ rot = "north"
+ elif dv == 1:
+ rot = "east"
+ elif dv == 2:
+ rot = "south"
+ elif dv == 3:
+ rot = "west"
+ else:
+ raise KeyError("key {} not found in {}".format(v, self.type_vocabs[t]))
+
+ natural_sentence = "" # "{} {} in the {} {} of the table facing {}".format(size, shape, x, y, rot)
+
+ if size:
+ natural_sentence += "{}".format(size)
+ if shape:
+ natural_sentence += " {}".format(shape)
+ if x:
+ natural_sentence += " in the {}".format(x)
+ if y:
+ natural_sentence += " {} of the table".format(y)
+ if rot:
+ natural_sentence += " facing {}".format(rot)
+
+ natural_sentence = natural_sentence.strip()
+
+ return natural_sentence
+
+ def convert_structure_params_to_type_value_tuple(self, sentence):
+
+ # ('circle', 'shape'), (-1.3430555575431449, 'rotation'), (0.3272675147405848, 'position_x'), (-0.03104362197706456, 'position_y'), (0.04674859577847633, 'radius')
+
+ shape = None
+ x = None
+ y = None
+ rot = None
+ size = None
+
+ for param in sentence:
+ if param[0] == "PAD":
+ continue
+
+ v, t = param
+ if t == "shape":
+ shape = v
+ elif t == "position_x":
+ dv = self.discretize(v, t)
+ if dv == 0:
+ x = "bottom"
+ elif dv == 1:
+ x = "middle"
+ elif dv == 2:
+ x = "top"
+ else:
+ raise KeyError("key {} not found in {}".format(v, self.type_vocabs[t]))
+ elif t == "position_y":
+ dv = self.discretize(v, t)
+ if dv == 0:
+ y = "right"
+ elif dv == 1:
+ y = "center"
+ elif dv == 2:
+ y = "left"
+ else:
+ raise KeyError("key {} not found in {}".format(v, self.type_vocabs[t]))
+ elif t == "radius":
+ dv = self.discretize(v, t)
+ if dv == 0:
+ size = "small"
+ elif dv == 1:
+ size = "medium"
+ elif dv == 2:
+ size = "large"
+ else:
+ raise KeyError("key {} not found in {}".format(v, self.type_vocabs[t]))
+ elif t == "rotation":
+ dv = self.discretize(v, t)
+ if dv == 0:
+ rot = "north"
+ elif dv == 1:
+ rot = "east"
+ elif dv == 2:
+ rot = "south"
+ elif dv == 3:
+ rot = "west"
+ else:
+ raise KeyError("key {} not found in {}".format(v, self.type_vocabs[t]))
+
+ # rotation, shape, size, x, y
+ type_value_tuple_init = [("rotation", rot), ("shape", shape), ("size", size), ("x", x), ("y", y)]
+ type_value_tuple = []
+ for type_value in type_value_tuple_init:
+ if type_value[1] is not None:
+ type_value_tuple.append(type_value)
+
+ type_value_tuple = tuple(sorted(type_value_tuple))
+ return type_value_tuple
+
+ def discretize(self, v, t):
+ min_value, max_value, num_bins = self.type_vocabs[t]
+ assert min_value <= v <= max_value, "type {} value {} exceeds {} and {}".format(t, v, min_value, max_value)
+ dv = min(int((v - min_value) / ((max_value - min_value) / num_bins)), num_bins - 1)
+ return dv
+
+
+class ContinuousTokenizer:
+ """
+ This tokenizer is for testing not discretizing structure parameters
+ """
+
+ def __init__(self):
+
+ print("WARNING: Current continous tokenizer does not support multiple shapes")
+
+ self.continuous_types = ["rotation", "position_x", "position_y", "radius"]
+ self.discrete_types = ["shape"]
+
+ def tokenize(self, value, typ=None):
+ if value == "PAD":
+ idx = 0.0
+ else:
+ if typ is None:
+ raise KeyError("Type cannot be None")
+ elif typ in self.discrete_types:
+ idx = 1.0
+ elif typ in self.continuous_types:
+ idx = value
+ else:
+ raise KeyError("Do not recognize the type {} of the given token: {}".format(typ, value))
+ return idx
+
+
+if __name__ == "__main__":
+ tokenizer = Tokenizer("/home/weiyu/data_drive/data_new_objects/type_vocabs_coarse.json")
+ # print(tokenizer.get_all_values_of_type("class"))
+ # print(tokenizer.get_all_values_of_type("color"))
+ # print(tokenizer.get_all_values_of_type("material"))
+ #
+ # for type in tokenizer.type_vocabs:
+ # print(type, tokenizer.type_vocabs[type])
+
+ tokenizer.prepare_grounding_reference()
+
+ # for i in range(100):
+ # types = list(tokenizer.continuous_types) + list(tokenizer.discrete_types)
+ # for t in types:
+ # v = tokenizer.get_valid_random_value(t)
+ # print(v)
+ # print(tokenizer.tokenize(v, t))
+
+ # build_vocab("/home/weiyu/data_drive/examples_v4/leonardo/vocab.json", "/home/weiyu/data_drive/examples_v4/leonardo/type_vocabs.json")
\ No newline at end of file
diff --git a/src/StructDiffusion/models/__init__.py b/src/StructDiffusion/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/StructDiffusion/models/__pycache__/__init__.cpython-37.pyc b/src/StructDiffusion/models/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fcd68754e149fcfa74b798e7cede414dc3d0ab00
Binary files /dev/null and b/src/StructDiffusion/models/__pycache__/__init__.cpython-37.pyc differ
diff --git a/src/StructDiffusion/models/__pycache__/__init__.cpython-38.pyc b/src/StructDiffusion/models/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b4f6076a0eef068852f173506d0a538d52340a1c
Binary files /dev/null and b/src/StructDiffusion/models/__pycache__/__init__.cpython-38.pyc differ
diff --git a/src/StructDiffusion/models/__pycache__/encoders.cpython-37.pyc b/src/StructDiffusion/models/__pycache__/encoders.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ec884bc615594979b02e0e163c1e5546ea83b429
Binary files /dev/null and b/src/StructDiffusion/models/__pycache__/encoders.cpython-37.pyc differ
diff --git a/src/StructDiffusion/models/__pycache__/encoders.cpython-38.pyc b/src/StructDiffusion/models/__pycache__/encoders.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..59681495d38034504267993477bebcf776d852ab
Binary files /dev/null and b/src/StructDiffusion/models/__pycache__/encoders.cpython-38.pyc differ
diff --git a/src/StructDiffusion/models/__pycache__/models.cpython-37.pyc b/src/StructDiffusion/models/__pycache__/models.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..54ff2c4a74799cc7866b831b395d4927544da19b
Binary files /dev/null and b/src/StructDiffusion/models/__pycache__/models.cpython-37.pyc differ
diff --git a/src/StructDiffusion/models/__pycache__/models.cpython-38.pyc b/src/StructDiffusion/models/__pycache__/models.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bf835600dc9bcbe557378578187fc3b8a050759f
Binary files /dev/null and b/src/StructDiffusion/models/__pycache__/models.cpython-38.pyc differ
diff --git a/src/StructDiffusion/models/__pycache__/pl_models.cpython-37.pyc b/src/StructDiffusion/models/__pycache__/pl_models.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4aac35dc9aa36ed0b2a9dbeaf0e016e5c7ecd779
Binary files /dev/null and b/src/StructDiffusion/models/__pycache__/pl_models.cpython-37.pyc differ
diff --git a/src/StructDiffusion/models/__pycache__/pl_models.cpython-38.pyc b/src/StructDiffusion/models/__pycache__/pl_models.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9dcbc7fb1943c708a3ae4476d1b9d15f78fdb50c
Binary files /dev/null and b/src/StructDiffusion/models/__pycache__/pl_models.cpython-38.pyc differ
diff --git a/src/StructDiffusion/models/__pycache__/point_transformer.cpython-37.pyc b/src/StructDiffusion/models/__pycache__/point_transformer.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..08eed8a045e62263826d431993619fe5231e7703
Binary files /dev/null and b/src/StructDiffusion/models/__pycache__/point_transformer.cpython-37.pyc differ
diff --git a/src/StructDiffusion/models/__pycache__/point_transformer.cpython-38.pyc b/src/StructDiffusion/models/__pycache__/point_transformer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d4c2e685a196894089b6bdc038444f2e34f39507
Binary files /dev/null and b/src/StructDiffusion/models/__pycache__/point_transformer.cpython-38.pyc differ
diff --git a/src/StructDiffusion/models/__pycache__/point_transformer_large.cpython-37.pyc b/src/StructDiffusion/models/__pycache__/point_transformer_large.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5de3117d4b31d989f6aa0ed3e9775e072ccf1619
Binary files /dev/null and b/src/StructDiffusion/models/__pycache__/point_transformer_large.cpython-37.pyc differ
diff --git a/src/StructDiffusion/models/__pycache__/point_transformer_large.cpython-38.pyc b/src/StructDiffusion/models/__pycache__/point_transformer_large.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4422f79019f73c9c9b64a3c0f30675ee5cfea6c2
Binary files /dev/null and b/src/StructDiffusion/models/__pycache__/point_transformer_large.cpython-38.pyc differ
diff --git a/src/StructDiffusion/models/encoders.py b/src/StructDiffusion/models/encoders.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c46cfc0dde57e8d1ef916494ffbf4177047b878
--- /dev/null
+++ b/src/StructDiffusion/models/encoders.py
@@ -0,0 +1,97 @@
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class SinusoidalPositionEmbeddings(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.dim = dim
+
+ def forward(self, time):
+ device = time.device
+ half_dim = self.dim // 2
+ embeddings = math.log(10000) / (half_dim - 1)
+ embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
+ embeddings = time[:, None] * embeddings[None, :]
+ embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
+ return embeddings
+
+
+class DropoutSampler(torch.nn.Module):
+ def __init__(self, num_features, num_outputs, dropout_rate = 0.5):
+ super(DropoutSampler, self).__init__()
+ self.linear = nn.Linear(num_features, num_features)
+ self.linear2 = nn.Linear(num_features, num_features)
+ self.predict = nn.Linear(num_features, num_outputs)
+ self.num_features = num_features
+ self.num_outputs = num_outputs
+ self.dropout_rate = dropout_rate
+
+ def forward(self, x):
+ x = F.relu(self.linear(x))
+ if self.dropout_rate > 0:
+ x = F.dropout(x, self.dropout_rate)
+ x = F.relu(self.linear2(x))
+ return self.predict(x)
+
+
+class EncoderMLP(torch.nn.Module):
+ def __init__(self, in_dim, out_dim, pt_dim=3, uses_pt=True):
+ super(EncoderMLP, self).__init__()
+ self.uses_pt = uses_pt
+ self.output = out_dim
+ d5 = int(in_dim)
+ d6 = int(2 * self.output)
+ d7 = self.output
+ self.encode_position = nn.Sequential(
+ nn.Linear(pt_dim, in_dim),
+ nn.LayerNorm(in_dim),
+ nn.ReLU(),
+ nn.Linear(in_dim, in_dim),
+ nn.LayerNorm(in_dim),
+ nn.ReLU(),
+ )
+ d5 = 2 * in_dim if self.uses_pt else in_dim
+ self.fc_block = nn.Sequential(
+ nn.Linear(int(d5), d6),
+ nn.LayerNorm(int(d6)),
+ nn.ReLU(),
+ nn.Linear(int(d6), d6),
+ nn.LayerNorm(int(d6)),
+ nn.ReLU(),
+ nn.Linear(d6, d7))
+
+ def forward(self, x, pt=None):
+ if self.uses_pt:
+ if pt is None: raise RuntimeError('did not provide pt')
+ y = self.encode_position(pt)
+ x = torch.cat([x, y], dim=-1)
+ return self.fc_block(x)
+
+
+class MeanEncoder(torch.nn.Module):
+ def __init__(self, input_channels=3, use_xyz=True, output=512, scale=0.04, factor=1):
+ super(MeanEncoder, self).__init__()
+ self.uses_rgb = False
+ self.dim = 3
+
+ def forward(self, xyz, f=None):
+
+ # Fix shape
+ if f is not None:
+ if len(f.shape) < 3:
+ f = f.transpose(0,1).contiguous()
+ f = f[None]
+ else:
+ f = f.transpose(1,2).contiguous()
+ if len(xyz.shape) == 3:
+ center = torch.mean(xyz, dim=1)
+ elif len(xyz.shape) == 2:
+ center = torch.mean(xyz, dim=0)
+ else:
+ raise RuntimeError('not sure what to do with points of shape ' + str(xyz.shape))
+ assert(xyz.shape[-1]) == 3
+ assert(center.shape[-1]) == 3
+ return center, center
\ No newline at end of file
diff --git a/src/StructDiffusion/models/models.py b/src/StructDiffusion/models/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c8db381307eb8dc64e55308e021b73510ed52ad
--- /dev/null
+++ b/src/StructDiffusion/models/models.py
@@ -0,0 +1,184 @@
+import torch
+import torch.nn as nn
+import math
+import torch.nn.functional as F
+
+from StructDiffusion.models.encoders import EncoderMLP, DropoutSampler, SinusoidalPositionEmbeddings
+from torch.nn import TransformerEncoder, TransformerEncoderLayer
+from StructDiffusion.models.point_transformer import PointTransformerEncoderSmall
+from StructDiffusion.models.point_transformer_large import PointTransformerCls
+
+
+class TransformerDiffusionModel(torch.nn.Module):
+
+ def __init__(self, vocab_size,
+ # transformer params
+ encoder_input_dim=256,
+ num_attention_heads=8, encoder_hidden_dim=16, encoder_dropout=0.1, encoder_activation="relu",
+ encoder_num_layers=8,
+ # output head params
+ structure_dropout=0.0, object_dropout=0.0,
+ # pc encoder params
+ ignore_rgb=False, pc_emb_dim=256, posed_pc_emb_dim=80,
+ pose_emb_dim=80,
+ max_seq_size=7, max_token_type_size=4,
+ seq_pos_emb_dim=8, seq_type_emb_dim=8,
+ word_emb_dim=160,
+ time_emb_dim=80,
+ use_virtual_structure_frame=True,
+ ):
+ super(TransformerDiffusionModel, self).__init__()
+
+ assert posed_pc_emb_dim + pose_emb_dim == word_emb_dim
+ assert encoder_input_dim == word_emb_dim + time_emb_dim + seq_pos_emb_dim + seq_type_emb_dim
+
+ # 3D translation + 6D rotation
+ action_dim = 3 + 6
+
+ # default:
+ # 256 = 80 (point cloud) + 80 (position) + 80 (time) + 8 (position idx) + 8 (token idx)
+ # 256 = 160 (word embedding) + 80 (time) + 8 (position idx) + 8 (token idx)
+
+ # PC
+ self.ignore_rgb = ignore_rgb
+ if ignore_rgb:
+ self.pc_encoder = PointTransformerEncoderSmall(output_dim=pc_emb_dim, input_dim=3, mean_center=True)
+ else:
+ self.pc_encoder = PointTransformerEncoderSmall(output_dim=pc_emb_dim, input_dim=6, mean_center=True)
+ self.posed_pc_encoder = EncoderMLP(pc_emb_dim, posed_pc_emb_dim, uses_pt=True)
+
+ # for virtual structure frame
+ self.use_virtual_structure_frame = use_virtual_structure_frame
+ if use_virtual_structure_frame:
+ self.virtual_frame_embed = nn.Parameter(torch.randn(1, 1, posed_pc_emb_dim)) # B, 1, posed_pc_emb_dim
+
+ # for language
+ self.word_embeddings = torch.nn.Embedding(vocab_size, word_emb_dim, padding_idx=0)
+
+ # for diffusion
+ self.pose_encoder = nn.Sequential(nn.Linear(action_dim, pose_emb_dim))
+ self.time_embeddings = nn.Sequential(
+ SinusoidalPositionEmbeddings(time_emb_dim),
+ nn.Linear(time_emb_dim, time_emb_dim),
+ nn.GELU(),
+ nn.Linear(time_emb_dim, time_emb_dim),
+ )
+
+ # for transformer
+ self.position_embeddings = torch.nn.Embedding(max_seq_size, seq_pos_emb_dim)
+ self.type_embeddings = torch.nn.Embedding(max_token_type_size, seq_type_emb_dim)
+
+ encoder_layers = TransformerEncoderLayer(encoder_input_dim, num_attention_heads,
+ encoder_hidden_dim, encoder_dropout, encoder_activation)
+ self.encoder = TransformerEncoder(encoder_layers, encoder_num_layers)
+
+ self.struct_head = DropoutSampler(encoder_input_dim, action_dim, dropout_rate=structure_dropout)
+ self.obj_head = DropoutSampler(encoder_input_dim, action_dim, dropout_rate=object_dropout)
+
+ def encode_posed_pc(self, pcs, batch_size, num_objects):
+ if self.ignore_rgb:
+ center_xyz, x = self.pc_encoder(pcs[:, :, :3], None)
+ else:
+ center_xyz, x = self.pc_encoder(pcs[:, :, :3], pcs[:, :, 3:])
+ posed_pc_embed = self.posed_pc_encoder(x, center_xyz)
+ posed_pc_embed = posed_pc_embed.reshape(batch_size, num_objects, -1)
+ return posed_pc_embed
+
+ def forward(self, t, pcs, sentence, poses, type_index, position_index, pad_mask):
+
+ batch_size, num_objects, num_pts, _ = pcs.shape
+ _, num_poses, _ = poses.shape
+ _, sentence_len = sentence.shape
+ _, total_len = type_index.shape
+
+ pcs = pcs.reshape(batch_size * num_objects, num_pts, -1)
+ posed_pc_embed = self.encode_posed_pc(pcs, batch_size, num_objects)
+
+ pose_embed = self.pose_encoder(poses)
+
+ if self.use_virtual_structure_frame:
+ virtual_frame_embed = self.virtual_frame_embed.repeat(batch_size, 1, 1)
+ posed_pc_embed = torch.cat([virtual_frame_embed, posed_pc_embed], dim=1)
+ tgt_obj_embed = torch.cat([pose_embed, posed_pc_embed], dim=-1)
+
+ #########################
+ sentence_embed = self.word_embeddings(sentence)
+
+ #########################
+
+ # transformer time dim: sentence, struct, obj
+ # transformer feat dim: obj pc + pose / word, time, token type, position
+
+ time_embed = self.time_embeddings(t) # B, dim
+ time_embed = time_embed.unsqueeze(1).repeat(1, total_len, 1) # B, L, dim
+
+ position_embed = self.position_embeddings(position_index)
+ type_embed = self.type_embeddings(type_index)
+
+ tgt_sequence_encode = torch.cat([sentence_embed, tgt_obj_embed], dim=1)
+ tgt_sequence_encode = torch.cat([tgt_sequence_encode, time_embed, position_embed, type_embed], dim=-1)
+
+ tgt_pad_mask = pad_mask
+
+ #########################
+ # sequence_encode: [batch size, sequence_length, encoder input dimension]
+ # input to transformer needs to have dimenion [sequence_length, batch size, encoder input dimension]
+ tgt_sequence_encode = tgt_sequence_encode.transpose(1, 0)
+ # convert to bool
+ tgt_pad_mask = (tgt_pad_mask == 1)
+ # encode: [sequence_length, batch_size, embedding_size]
+ encode = self.encoder(tgt_sequence_encode, src_key_padding_mask=tgt_pad_mask)
+ encode = encode.transpose(1, 0)
+ #########################
+
+ target_encodes = encode[:, -num_poses:, :]
+ if self.use_virtual_structure_frame:
+ obj_encodes = target_encodes[:, 1:, :]
+ pred_obj_poses = self.obj_head(obj_encodes) # B, N, 3 + 6
+ struct_encode = encode[:, 0, :].unsqueeze(1)
+ # use a different sampler for struct prediction since it should have larger variance than object predictions
+ pred_struct_pose = self.struct_head(struct_encode) # B, 1, 3 + 6
+ pred_poses = torch.cat([pred_struct_pose, pred_obj_poses], dim=1)
+ else:
+ pred_poses = self.obj_head(target_encodes) # B, N, 3 + 6
+
+ assert pred_poses.shape == poses.shape
+
+ return pred_poses
+
+
+class FocalLoss(nn.Module):
+ def __init__(self, gamma=2, alpha=.25):
+ super(FocalLoss, self).__init__()
+ # self.alpha = torch.tensor([alpha, 1-alpha])
+ self.gamma = gamma
+
+ def forward(self, inputs, targets):
+ BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
+ pt = torch.exp(-BCE_loss)
+ # targets = targets.type(torch.long)
+ # at = self.alpha.gather(0, targets.data.view(-1))
+ # F_loss = at*(1-pt)**self.gamma * BCE_loss
+ F_loss = (1 - pt)**self.gamma * BCE_loss
+ return F_loss.mean()
+
+
+class PCTDiscriminator(torch.nn.Module):
+
+ def __init__(self, max_num_objects, include_env_pc=False, pct_random_sampling=False):
+
+ super(PCTDiscriminator, self).__init__()
+
+ # input_dim: xyz + one hot for each object
+ if include_env_pc:
+ self.classifier = PointTransformerCls(input_dim=max_num_objects + 1 + 3, output_dim=1, use_random_sampling=pct_random_sampling)
+ else:
+ self.classifier = PointTransformerCls(input_dim=max_num_objects + 3, output_dim=1, use_random_sampling=pct_random_sampling)
+
+ def forward(self, scene_xyz):
+ label = self.classifier(scene_xyz)
+ return label
+
+ def convert_logits(self, logits):
+ return torch.sigmoid(logits)
+
diff --git a/src/StructDiffusion/models/pl_models.py b/src/StructDiffusion/models/pl_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..237ef12741758dbabe84bae424fd33d5eb50f5c2
--- /dev/null
+++ b/src/StructDiffusion/models/pl_models.py
@@ -0,0 +1,136 @@
+import torch
+import torch.nn.functional as F
+import pytorch_lightning as pl
+from StructDiffusion.models.models import TransformerDiffusionModel, PCTDiscriminator, FocalLoss
+
+from StructDiffusion.diffusion.noise_schedule import NoiseSchedule, q_sample
+from StructDiffusion.diffusion.pose_conversion import get_diffusion_variables_from_H, get_diffusion_variables_from_9D_actions
+
+
+class ConditionalPoseDiffusionModel(pl.LightningModule):
+
+ def __init__(self, vocab_size, model_cfg, loss_cfg, noise_scheduler_cfg, optimizer_cfg):
+ super().__init__()
+ self.save_hyperparameters()
+
+ self.model = TransformerDiffusionModel(vocab_size, **model_cfg)
+
+ self.noise_schedule = NoiseSchedule(**noise_scheduler_cfg)
+
+ self.loss_type = loss_cfg.type
+
+ self.optimizer_cfg = optimizer_cfg
+ self.configure_optimizers()
+
+ self.batch_size = None
+
+ def forward(self, batch):
+
+ # input
+ pcs = batch["pcs"]
+ B = pcs.shape[0]
+ self.batch_size = B
+ sentence = batch["sentence"]
+ goal_poses = batch["goal_poses"]
+ type_index = batch["type_index"]
+ position_index = batch["position_index"]
+ pad_mask = batch["pad_mask"]
+
+ t = torch.randint(0, self.noise_schedule.timesteps, (B,), dtype=torch.long).to(self.device)
+
+ # --------------
+ x_start = get_diffusion_variables_from_H(goal_poses)
+ noise = torch.randn_like(x_start, device=self.device)
+ x_noisy = q_sample(x_start=x_start, t=t, noise_schedule=self.noise_schedule, noise=noise)
+
+ predicted_noise = self.model.forward(t, pcs, sentence, x_noisy, type_index, position_index, pad_mask)
+
+ # important: skip computing loss for masked positions
+ num_poses = goal_poses.shape[1] # B, N, 4, 4
+ pose_pad_mask = pad_mask[:, -num_poses:]
+ keep_mask = (pose_pad_mask == 0)
+ noise = noise[keep_mask] # dim: number of positions that need loss calculation
+ predicted_noise = predicted_noise[keep_mask]
+
+ return noise, predicted_noise
+
+ def compute_loss(self, noise, predicted_noise, prefix="train/"):
+ if self.loss_type == 'l1':
+ loss = F.l1_loss(noise, predicted_noise)
+ elif self.loss_type == 'l2':
+ loss = F.mse_loss(noise, predicted_noise)
+ elif self.loss_type == "huber":
+ loss = F.smooth_l1_loss(noise, predicted_noise)
+ else:
+ raise NotImplementedError()
+
+ self.log(prefix + "loss", loss, prog_bar=True, batch_size=self.batch_size)
+ return loss
+
+ def training_step(self, batch, batch_idx):
+ noise, pred_noise = self.forward(batch)
+ loss = self.compute_loss(noise, pred_noise, prefix="train/")
+ return loss
+
+ def validation_step(self, batch, batch_idx):
+ noise, pred_noise = self.forward(batch)
+ loss = self.compute_loss(noise, pred_noise, prefix="val/")
+
+ def configure_optimizers(self):
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.optimizer_cfg.lr, weight_decay=self.optimizer_cfg.weight_decay) # 1e-5
+ return optimizer
+
+
+class PairwiseCollisionModel(pl.LightningModule):
+
+ def __init__(self, model_cfg, loss_cfg, optimizer_cfg, data_cfg):
+ super().__init__()
+ self.save_hyperparameters()
+
+ self.model = PCTDiscriminator(**model_cfg)
+
+ self.loss_cfg = loss_cfg
+ self.loss = None
+ self.configure_loss()
+
+ self.optimizer_cfg = optimizer_cfg
+ self.configure_optimizers()
+
+ # this is stored, because some of the data parameters affect the model behavior
+ self.data_cfg = data_cfg
+
+ def forward(self, batch):
+ label = batch["label"]
+ predicted_label = self.model.forward(batch["scene_xyz"])
+ return label, predicted_label
+
+ def compute_loss(self, label, predicted_label, prefix="train/"):
+ if self.loss_cfg.type == "MSE":
+ predicted_label = torch.sigmoid(predicted_label)
+ loss = self.loss(predicted_label, label)
+ self.log(prefix + "loss", loss, prog_bar=True)
+ return loss
+
+ def training_step(self, batch, batch_idx):
+ label, predicted_label = self.forward(batch)
+ loss = self.compute_loss(label, predicted_label, prefix="train/")
+ return loss
+
+ def validation_step(self, batch, batch_idx):
+ label, predicted_label = self.forward(batch)
+ loss = self.compute_loss(label, predicted_label, prefix="val/")
+
+ def configure_optimizers(self):
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.optimizer_cfg.lr, weight_decay=self.optimizer_cfg.weight_decay) # 1e-5
+ return optimizer
+
+ def configure_loss(self):
+ if self.loss_cfg.type == "Focal":
+ print("use focal loss with gamma {}".format(self.loss_cfg.focal_gamma))
+ self.loss = FocalLoss(gamma=self.loss_cfg.focal_gamma)
+ elif self.loss_cfg.type == "MSE":
+ print("use regression L2 loss")
+ self.loss = torch.nn.MSELoss()
+ elif self.loss_cfg.type == "BCE":
+ print("use standard BCE logit loss")
+ self.loss = torch.nn.BCEWithLogitsLoss(reduction="mean")
\ No newline at end of file
diff --git a/src/StructDiffusion/models/point_transformer.py b/src/StructDiffusion/models/point_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..08d8f3ac790b8811dc7352842e05b3caa5fd7756
--- /dev/null
+++ b/src/StructDiffusion/models/point_transformer.py
@@ -0,0 +1,208 @@
+import torch
+import torch.nn as nn
+from StructDiffusion.utils.pointnet import farthest_point_sample, index_points, square_distance
+
+# adapted from https://github.com/qq456cvb/Point-Transformers
+
+
+def sample_and_group(npoint, nsample, xyz, points):
+ B, N, C = xyz.shape
+ S = npoint
+
+ fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint]
+
+ new_xyz = index_points(xyz, fps_idx)
+ new_points = index_points(points, fps_idx)
+
+ dists = square_distance(new_xyz, xyz) # B x npoint x N
+ idx = dists.argsort()[:, :, :nsample] # B x npoint x K
+
+ grouped_points = index_points(points, idx)
+ grouped_points_norm = grouped_points - new_points.view(B, S, 1, -1)
+ new_points = torch.cat([grouped_points_norm, new_points.view(B, S, 1, -1).repeat(1, 1, nsample, 1)], dim=-1)
+ return new_xyz, new_points
+
+
+class Local_op(nn.Module):
+ def __init__(self, in_channels, out_channels):
+ super().__init__()
+ self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False)
+ self.bn1 = nn.BatchNorm1d(out_channels)
+ self.relu = nn.ReLU()
+
+ def forward(self, x):
+ b, n, s, d = x.size() # torch.Size([32, 512, 32, 6])
+ x = x.permute(0, 1, 3, 2)
+ x = x.reshape(-1, d, s)
+ batch_size, _, N = x.size()
+ x = self.relu(self.bn1(self.conv1(x))) # B, D, N
+ x = torch.max(x, 2)[0]
+ x = x.view(batch_size, -1)
+ x = x.reshape(b, n, -1).permute(0, 2, 1)
+ return x
+
+
+class SA_Layer(nn.Module):
+ def __init__(self, channels):
+ super().__init__()
+ self.q_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)
+ self.k_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)
+ self.q_conv.weight = self.k_conv.weight
+ self.v_conv = nn.Conv1d(channels, channels, 1)
+ self.trans_conv = nn.Conv1d(channels, channels, 1)
+ self.after_norm = nn.BatchNorm1d(channels)
+ self.act = nn.ReLU()
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x):
+ x_q = self.q_conv(x).permute(0, 2, 1) # b, n, c
+ x_k = self.k_conv(x)# b, c, n
+ x_v = self.v_conv(x)
+ energy = x_q @ x_k # b, n, n
+ attention = self.softmax(energy)
+ attention = attention / (1e-9 + attention.sum(dim=1, keepdims=True))
+ x_r = x_v @ attention # b, c, n
+ x_r = self.act(self.after_norm(self.trans_conv(x - x_r)))
+ x = x + x_r
+ return x
+
+
+class StackedAttention(nn.Module):
+ def __init__(self, channels=64):
+ super().__init__()
+ self.conv1 = nn.Conv1d(channels, channels, kernel_size=1, bias=False)
+ self.conv2 = nn.Conv1d(channels, channels, kernel_size=1, bias=False)
+
+ self.bn1 = nn.BatchNorm1d(channels)
+ self.bn2 = nn.BatchNorm1d(channels)
+
+ self.sa1 = SA_Layer(channels)
+ self.sa2 = SA_Layer(channels)
+
+ self.relu = nn.ReLU()
+
+ def forward(self, x):
+ #
+ # b, 3, npoint, nsample
+ # conv2d 3 -> 128 channels 1, 1
+ # b * npoint, c, nsample
+ # permute reshape
+ batch_size, _, N = x.size()
+
+ x = self.relu(self.bn1(self.conv1(x))) # B, D, N
+ x = self.relu(self.bn2(self.conv2(x)))
+
+ x1 = self.sa1(x)
+ x2 = self.sa2(x1)
+
+ x = torch.cat((x1, x2), dim=1)
+
+ return x
+
+
+class PointTransformerEncoderSmall(nn.Module):
+
+ def __init__(self, output_dim=256, input_dim=6, mean_center=True):
+ super(PointTransformerEncoderSmall, self).__init__()
+
+ self.mean_center = mean_center
+
+ # map the second dim of the input from input_dim to 64
+ self.conv1 = nn.Conv1d(input_dim, 64, kernel_size=1, bias=False)
+ self.bn1 = nn.BatchNorm1d(64)
+ self.gather_local_0 = Local_op(in_channels=128, out_channels=64)
+ self.gather_local_1 = Local_op(in_channels=128, out_channels=64)
+ self.pt_last = StackedAttention(channels=64)
+
+ self.relu = nn.ReLU()
+ self.conv_fuse = nn.Sequential(nn.Conv1d(192, 256, kernel_size=1, bias=False),
+ nn.BatchNorm1d(256),
+ nn.LeakyReLU(negative_slope=0.2))
+
+ self.linear1 = nn.Linear(256, 256, bias=False)
+ self.bn6 = nn.BatchNorm1d(256)
+ self.dp1 = nn.Dropout(p=0.5)
+ self.linear2 = nn.Linear(256, 256)
+
+ def forward(self, xyz, f=None):
+ # xyz: B, N, 3
+ # f: B, N, D
+ center = torch.mean(xyz, dim=1)
+ if self.mean_center:
+ xyz = xyz - center.view(-1, 1, 3).repeat(1, xyz.shape[1], 1)
+ if f is None:
+ x = self.pct(xyz)
+ else:
+ x = self.pct(torch.cat([xyz, f], dim=2)) # B, output_dim
+
+ return center, x
+
+ def pct(self, x):
+
+ # x: B, N, D
+ xyz = x[..., :3]
+ x = x.permute(0, 2, 1)
+ batch_size, _, _ = x.size()
+ x = self.relu(self.bn1(self.conv1(x))) # B, D, N
+ x = x.permute(0, 2, 1)
+ new_xyz, new_feature = sample_and_group(npoint=128, nsample=32, xyz=xyz, points=x)
+ feature_0 = self.gather_local_0(new_feature)
+ feature = feature_0.permute(0, 2, 1) # B, nsamples, D
+ new_xyz, new_feature = sample_and_group(npoint=32, nsample=16, xyz=new_xyz, points=feature)
+ feature_1 = self.gather_local_1(new_feature) # B, D, nsamples
+
+ x = self.pt_last(feature_1) # B, D * 2, nsamples
+ x = torch.cat([x, feature_1], dim=1) # B, D * 3, nsamples
+ x = self.conv_fuse(x)
+ x = torch.max(x, 2)[0]
+ x = x.view(batch_size, -1)
+
+ x = self.relu(self.bn6(self.linear1(x)))
+ x = self.dp1(x)
+ x = self.linear2(x)
+
+ return x
+
+
+class SampleAndGroup(nn.Module):
+
+ def __init__(self, output_dim=64, input_dim=6, mean_center=True, npoints=(128, 32), nsamples=(32, 16)):
+ super(SampleAndGroup, self).__init__()
+
+ self.mean_center = mean_center
+ self.npoints = npoints
+ self.nsamples = nsamples
+
+ # map the second dim of the input from input_dim to 64
+ self.conv1 = nn.Conv1d(input_dim, output_dim, kernel_size=1, bias=False)
+ self.bn1 = nn.BatchNorm1d(output_dim)
+ self.gather_local_0 = Local_op(in_channels=output_dim * 2, out_channels=output_dim)
+ self.gather_local_1 = Local_op(in_channels=output_dim * 2, out_channels=output_dim)
+ self.relu = nn.ReLU()
+
+ def forward(self, xyz, f):
+ # xyz: B, N, 3
+ # f: B, N, D
+ center = torch.mean(xyz, dim=1)
+ if self.mean_center:
+ xyz = xyz - center.view(-1, 1, 3).repeat(1, xyz.shape[1], 1)
+ x = self.sg(torch.cat([xyz, f], dim=2)) # B, nsamples, output_dim
+
+ return center, x
+
+ def sg(self, x):
+
+ # x: B, N, D
+ xyz = x[..., :3]
+ x = x.permute(0, 2, 1)
+ batch_size, _, _ = x.size()
+ x = self.relu(self.bn1(self.conv1(x))) # B, D, N
+ x = x.permute(0, 2, 1)
+ new_xyz, new_feature = sample_and_group(npoint=self.npoints[0], nsample=self.nsamples[0], xyz=xyz, points=x)
+ feature_0 = self.gather_local_0(new_feature)
+ feature = feature_0.permute(0, 2, 1) # B, nsamples, D
+ new_xyz, new_feature = sample_and_group(npoint=self.npoints[1], nsample=self.nsamples[1], xyz=new_xyz, points=feature)
+ feature_1 = self.gather_local_1(new_feature) # B, D, nsamples
+ x = feature_1.permute(0, 2, 1) # B, nsamples, D
+
+ return x
\ No newline at end of file
diff --git a/src/StructDiffusion/models/point_transformer_large.py b/src/StructDiffusion/models/point_transformer_large.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e2a3d9f65c922b18641064e001c6fcc428d30cf
--- /dev/null
+++ b/src/StructDiffusion/models/point_transformer_large.py
@@ -0,0 +1,306 @@
+import torch
+import torch.nn as nn
+from StructDiffusion.utils.pointnet import farthest_point_sample, index_points, square_distance, random_point_sample
+
+
+def sample_and_group(npoint, nsample, xyz, points, use_random_sampling=False):
+ B, N, C = xyz.shape
+ S = npoint
+
+ if not use_random_sampling:
+ fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint]
+ else:
+ fps_idx = random_point_sample(xyz, npoint) # [B, npoint]
+
+ new_xyz = index_points(xyz, fps_idx)
+ new_points = index_points(points, fps_idx)
+
+ dists = square_distance(new_xyz, xyz) # B x npoint x N
+ idx = dists.argsort()[:, :, :nsample] # B x npoint x K
+
+ grouped_points = index_points(points, idx)
+ grouped_points_norm = grouped_points - new_points.view(B, S, 1, -1)
+ new_points = torch.cat([grouped_points_norm, new_points.view(B, S, 1, -1).repeat(1, 1, nsample, 1)], dim=-1)
+ return new_xyz, new_points
+
+
+class Local_op(nn.Module):
+ def __init__(self, in_channels, out_channels):
+ super().__init__()
+ self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False)
+ self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=1, bias=False)
+ self.bn1 = nn.BatchNorm1d(out_channels)
+ self.bn2 = nn.BatchNorm1d(out_channels)
+ self.relu = nn.ReLU()
+
+ def forward(self, x):
+ b, n, s, d = x.size() # torch.Size([32, 512, 32, 6])
+ x = x.permute(0, 1, 3, 2)
+ x = x.reshape(-1, d, s)
+ batch_size, _, N = x.size()
+ x = self.relu(self.bn1(self.conv1(x))) # B, D, N
+ x = self.relu(self.bn2(self.conv2(x))) # B, D, N
+ x = torch.max(x, 2)[0]
+ x = x.view(batch_size, -1)
+ x = x.reshape(b, n, -1).permute(0, 2, 1)
+ return x
+
+
+class SA_Layer(nn.Module):
+ def __init__(self, channels):
+ super().__init__()
+ self.q_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)
+ self.k_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)
+ self.q_conv.weight = self.k_conv.weight
+ self.v_conv = nn.Conv1d(channels, channels, 1)
+ self.trans_conv = nn.Conv1d(channels, channels, 1)
+ self.after_norm = nn.BatchNorm1d(channels)
+ self.act = nn.ReLU()
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x):
+ x_q = self.q_conv(x).permute(0, 2, 1) # b, n, c
+ x_k = self.k_conv(x)# b, c, n
+ x_v = self.v_conv(x)
+ energy = x_q @ x_k # b, n, n
+ attention = self.softmax(energy)
+ attention = attention / (1e-9 + attention.sum(dim=1, keepdims=True))
+ x_r = x_v @ attention # b, c, n
+ x_r = self.act(self.after_norm(self.trans_conv(x - x_r)))
+ x = x + x_r
+ return x
+
+
+class StackedAttention(nn.Module):
+ def __init__(self, channels=256):
+ super().__init__()
+ self.conv1 = nn.Conv1d(channels, channels, kernel_size=1, bias=False)
+ self.conv2 = nn.Conv1d(channels, channels, kernel_size=1, bias=False)
+
+ self.bn1 = nn.BatchNorm1d(channels)
+ self.bn2 = nn.BatchNorm1d(channels)
+
+ self.sa1 = SA_Layer(channels)
+ self.sa2 = SA_Layer(channels)
+ self.sa3 = SA_Layer(channels)
+ self.sa4 = SA_Layer(channels)
+
+ self.relu = nn.ReLU()
+
+ def forward(self, x):
+ #
+ # b, 3, npoint, nsample
+ # conv2d 3 -> 128 channels 1, 1
+ # b * npoint, c, nsample
+ # permute reshape
+ batch_size, _, N = x.size()
+
+ x = self.relu(self.bn1(self.conv1(x))) # B, D, N
+ x = self.relu(self.bn2(self.conv2(x)))
+
+ x1 = self.sa1(x)
+ x2 = self.sa2(x1)
+ x3 = self.sa3(x2)
+ x4 = self.sa4(x3)
+
+ x = torch.cat((x1, x2, x3, x4), dim=1)
+
+ return x
+
+
+class PointTransformerCls(nn.Module):
+ def __init__(self, input_dim, output_dim, use_random_sampling=False):
+ super().__init__()
+
+ self.use_random_sampling = use_random_sampling
+
+ self.conv1 = nn.Conv1d(input_dim, 64, kernel_size=1, bias=False)
+ self.conv2 = nn.Conv1d(64, 64, kernel_size=1, bias=False)
+ self.bn1 = nn.BatchNorm1d(64)
+ self.bn2 = nn.BatchNorm1d(64)
+ self.gather_local_0 = Local_op(in_channels=128, out_channels=128)
+ self.gather_local_1 = Local_op(in_channels=256, out_channels=256)
+ self.pt_last = StackedAttention()
+
+ self.relu = nn.ReLU()
+ self.conv_fuse = nn.Sequential(nn.Conv1d(1280, 1024, kernel_size=1, bias=False),
+ nn.BatchNorm1d(1024),
+ nn.LeakyReLU(negative_slope=0.2))
+
+ self.linear1 = nn.Linear(1024, 512, bias=False)
+ self.bn6 = nn.BatchNorm1d(512)
+ self.dp1 = nn.Dropout(p=0.5)
+ self.linear2 = nn.Linear(512, 256)
+ self.bn7 = nn.BatchNorm1d(256)
+ self.dp2 = nn.Dropout(p=0.5)
+ self.linear3 = nn.Linear(256, output_dim)
+
+ def forward(self, x):
+ xyz = x[..., :3]
+ x = x.permute(0, 2, 1)
+ batch_size, _, _ = x.size()
+ x = self.relu(self.bn1(self.conv1(x))) # B, D, N
+ x = self.relu(self.bn2(self.conv2(x))) # B, D, N
+ x = x.permute(0, 2, 1)
+ new_xyz, new_feature = sample_and_group(npoint=512, nsample=32, xyz=xyz, points=x,
+ use_random_sampling=self.use_random_sampling)
+ feature_0 = self.gather_local_0(new_feature)
+ feature = feature_0.permute(0, 2, 1)
+ new_xyz, new_feature = sample_and_group(npoint=256, nsample=32, xyz=new_xyz, points=feature,
+ use_random_sampling=self.use_random_sampling)
+
+ # debug: visualize
+ # # new_xyz: B, N, 3
+ # from rearrangement_utils import show_pcs
+ # import numpy as np
+ #
+ # new_xyz_copy = new_xyz.detach().cpu().numpy()
+ # for i in range(new_xyz_copy.shape[0]):
+ # print(new_xyz_copy[i].shape)
+ # show_pcs([new_xyz_copy[i]], [np.tile(np.array([0, 1, 0], dtype=np.float), (new_xyz_copy[i].shape[0], 1))])
+
+ feature_1 = self.gather_local_1(new_feature)
+
+ x = self.pt_last(feature_1)
+ x = torch.cat([x, feature_1], dim=1)
+ x = self.conv_fuse(x)
+ x = torch.max(x, 2)[0]
+ x = x.view(batch_size, -1)
+
+ x = self.relu(self.bn6(self.linear1(x)))
+ x = self.dp1(x)
+ x = self.relu(self.bn7(self.linear2(x)))
+ x = self.dp2(x)
+ x = self.linear3(x)
+
+ return x
+
+
+class PointTransformerClsLarge(nn.Module):
+ def __init__(self, input_dim, output_dim, use_random_sampling=False):
+ super().__init__()
+
+ self.use_random_sampling = use_random_sampling
+
+ self.conv1 = nn.Conv1d(input_dim, 64, kernel_size=1, bias=False)
+ self.conv2 = nn.Conv1d(64, 64, kernel_size=1, bias=False)
+ self.bn1 = nn.BatchNorm1d(64)
+ self.bn2 = nn.BatchNorm1d(64)
+ self.gather_local_0 = Local_op(in_channels=128, out_channels=128)
+ self.gather_local_1 = Local_op(in_channels=256, out_channels=256)
+ self.pt_last = StackedAttention()
+
+ self.relu = nn.ReLU()
+ self.conv_fuse = nn.Sequential(nn.Conv1d(1280, 1024, kernel_size=1, bias=False),
+ nn.BatchNorm1d(1024),
+ nn.LeakyReLU(negative_slope=0.2))
+
+ self.linear1 = nn.Linear(1024, 1024, bias=False)
+ self.bn6 = nn.BatchNorm1d(1024)
+ self.dp1 = nn.Dropout(p=0.5)
+ self.linear2 = nn.Linear(1024, 512)
+ self.bn7 = nn.BatchNorm1d(512)
+ self.dp2 = nn.Dropout(p=0.5)
+ self.linear3 = nn.Linear(512, output_dim)
+
+ def forward(self, x):
+ xyz = x[..., :3]
+ x = x.permute(0, 2, 1)
+ batch_size, _, _ = x.size()
+ x = self.relu(self.bn1(self.conv1(x))) # B, D, N
+ x = self.relu(self.bn2(self.conv2(x))) # B, D, N
+ x = x.permute(0, 2, 1)
+ new_xyz, new_feature = sample_and_group(npoint=512, nsample=32, xyz=xyz, points=x,
+ use_random_sampling=self.use_random_sampling)
+ feature_0 = self.gather_local_0(new_feature)
+ feature = feature_0.permute(0, 2, 1)
+ new_xyz, new_feature = sample_and_group(npoint=256, nsample=32, xyz=new_xyz, points=feature,
+ use_random_sampling=self.use_random_sampling)
+
+ # debug: visualize
+ # # new_xyz: B, N, 3
+ # from rearrangement_utils import show_pcs
+ # import numpy as np
+ #
+ # new_xyz_copy = new_xyz.detach().cpu().numpy()
+ # for i in range(new_xyz_copy.shape[0]):
+ # print(new_xyz_copy[i].shape)
+ # show_pcs([new_xyz_copy[i]], [np.tile(np.array([0, 1, 0], dtype=np.float), (new_xyz_copy[i].shape[0], 1))])
+
+ feature_1 = self.gather_local_1(new_feature)
+
+ x = self.pt_last(feature_1)
+ x = torch.cat([x, feature_1], dim=1)
+ x = self.conv_fuse(x)
+ x = torch.max(x, 2)[0]
+ x = x.view(batch_size, -1)
+
+ x = self.relu(self.bn6(self.linear1(x)))
+ x = self.dp1(x)
+ x = self.relu(self.bn7(self.linear2(x)))
+ x = self.dp2(x)
+ x = self.linear3(x)
+
+ return x
+
+
+class PointTransformerEncoderLarge(nn.Module):
+ def __init__(self, output_dim=256, input_dim=6, mean_center=True):
+ super(PointTransformerEncoderLarge, self).__init__()
+
+ self.mean_center = mean_center
+
+ self.conv1 = nn.Conv1d(input_dim, 64, kernel_size=1, bias=False)
+ self.conv2 = nn.Conv1d(64, 64, kernel_size=1, bias=False)
+ self.bn1 = nn.BatchNorm1d(64)
+ self.bn2 = nn.BatchNorm1d(64)
+ self.gather_local_0 = Local_op(in_channels=128, out_channels=128)
+ self.gather_local_1 = Local_op(in_channels=256, out_channels=256)
+ self.pt_last = StackedAttention()
+
+ self.relu = nn.ReLU()
+ self.conv_fuse = nn.Sequential(nn.Conv1d(1280, 1024, kernel_size=1, bias=False),
+ nn.BatchNorm1d(1024),
+ nn.LeakyReLU(negative_slope=0.2))
+
+ self.linear1 = nn.Linear(1024, 512, bias=False)
+ self.bn6 = nn.BatchNorm1d(512)
+ self.dp1 = nn.Dropout(p=0.5)
+ self.linear2 = nn.Linear(512, 256)
+
+ def forward(self, xyz, f):
+ # xyz: B, N, 3
+ # f: B, N, D
+ center = torch.mean(xyz, dim=1)
+ if self.mean_center:
+ xyz = xyz - center.view(-1, 1, 3).repeat(1, xyz.shape[1], 1)
+ x = self.pct(torch.cat([xyz, f], dim=2)) # B, output_dim
+
+ return center, x
+
+ def pct(self, x):
+
+ xyz = x[..., :3]
+ x = x.permute(0, 2, 1)
+ batch_size, _, _ = x.size()
+ x = self.relu(self.bn1(self.conv1(x))) # B, D, N
+ x = self.relu(self.bn2(self.conv2(x))) # B, D, N
+ x = x.permute(0, 2, 1)
+ new_xyz, new_feature = sample_and_group(npoint=512, nsample=32, xyz=xyz, points=x)
+ feature_0 = self.gather_local_0(new_feature)
+ feature = feature_0.permute(0, 2, 1)
+ new_xyz, new_feature = sample_and_group(npoint=256, nsample=32, xyz=new_xyz, points=feature)
+ feature_1 = self.gather_local_1(new_feature)
+
+ x = self.pt_last(feature_1)
+ x = torch.cat([x, feature_1], dim=1)
+ x = self.conv_fuse(x)
+ x = torch.max(x, 2)[0]
+ x = x.view(batch_size, -1)
+
+ x = self.relu(self.bn6(self.linear1(x)))
+ x = self.dp1(x)
+ x = self.linear2(x)
+
+ return x
+
diff --git a/src/StructDiffusion/utils/__init__.py b/src/StructDiffusion/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/StructDiffusion/utils/__pycache__/__init__.cpython-37.pyc b/src/StructDiffusion/utils/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cae37618e6d32d718366dffd61e9b763f6c9ce5f
Binary files /dev/null and b/src/StructDiffusion/utils/__pycache__/__init__.cpython-37.pyc differ
diff --git a/src/StructDiffusion/utils/__pycache__/__init__.cpython-38.pyc b/src/StructDiffusion/utils/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4727c42b4a7ac2a745e70e0ea443820e1e2a4f94
Binary files /dev/null and b/src/StructDiffusion/utils/__pycache__/__init__.cpython-38.pyc differ
diff --git a/src/StructDiffusion/utils/__pycache__/batch_inference.cpython-37.pyc b/src/StructDiffusion/utils/__pycache__/batch_inference.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..df26d00a967aeeb79278f647634af346c43a0743
Binary files /dev/null and b/src/StructDiffusion/utils/__pycache__/batch_inference.cpython-37.pyc differ
diff --git a/src/StructDiffusion/utils/__pycache__/batch_inference.cpython-38.pyc b/src/StructDiffusion/utils/__pycache__/batch_inference.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3ecc3e99b8d0692f36305d31b7cc84a9e275d799
Binary files /dev/null and b/src/StructDiffusion/utils/__pycache__/batch_inference.cpython-38.pyc differ
diff --git a/src/StructDiffusion/utils/__pycache__/files.cpython-37.pyc b/src/StructDiffusion/utils/__pycache__/files.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..159b7bdeaae322d391f3f7ba279e732a7f857a19
Binary files /dev/null and b/src/StructDiffusion/utils/__pycache__/files.cpython-37.pyc differ
diff --git a/src/StructDiffusion/utils/__pycache__/files.cpython-38.pyc b/src/StructDiffusion/utils/__pycache__/files.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a5a63e4fc58eefe3762552512a001ca8f9cab83e
Binary files /dev/null and b/src/StructDiffusion/utils/__pycache__/files.cpython-38.pyc differ
diff --git a/src/StructDiffusion/utils/__pycache__/pointnet.cpython-37.pyc b/src/StructDiffusion/utils/__pycache__/pointnet.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..951dc53d8b891232b7812c5fb5041acc0e1cf92b
Binary files /dev/null and b/src/StructDiffusion/utils/__pycache__/pointnet.cpython-37.pyc differ
diff --git a/src/StructDiffusion/utils/__pycache__/pointnet.cpython-38.pyc b/src/StructDiffusion/utils/__pycache__/pointnet.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8e5e1f5b5d5ce557c861482a10811e48342fca05
Binary files /dev/null and b/src/StructDiffusion/utils/__pycache__/pointnet.cpython-38.pyc differ
diff --git a/src/StructDiffusion/utils/__pycache__/rearrangement.cpython-37.pyc b/src/StructDiffusion/utils/__pycache__/rearrangement.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..60c8e1e1607fc3b2ef51626f3833350afdbc73a5
Binary files /dev/null and b/src/StructDiffusion/utils/__pycache__/rearrangement.cpython-37.pyc differ
diff --git a/src/StructDiffusion/utils/__pycache__/rearrangement.cpython-38.pyc b/src/StructDiffusion/utils/__pycache__/rearrangement.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c4ee8032d620ef9ee48c65bb0708352c4fd71bf6
Binary files /dev/null and b/src/StructDiffusion/utils/__pycache__/rearrangement.cpython-38.pyc differ
diff --git a/src/StructDiffusion/utils/__pycache__/rotation_continuity.cpython-37.pyc b/src/StructDiffusion/utils/__pycache__/rotation_continuity.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c256ed0456e4a1b552ddbb4595b4366bf90fb637
Binary files /dev/null and b/src/StructDiffusion/utils/__pycache__/rotation_continuity.cpython-37.pyc differ
diff --git a/src/StructDiffusion/utils/__pycache__/rotation_continuity.cpython-38.pyc b/src/StructDiffusion/utils/__pycache__/rotation_continuity.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1754983c6f0a3a2d7c5160573123f64637422360
Binary files /dev/null and b/src/StructDiffusion/utils/__pycache__/rotation_continuity.cpython-38.pyc differ
diff --git a/src/StructDiffusion/utils/__pycache__/transformations.cpython-37.pyc b/src/StructDiffusion/utils/__pycache__/transformations.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1644b1ba1c64215da2ced9538e5d98f81a15613e
Binary files /dev/null and b/src/StructDiffusion/utils/__pycache__/transformations.cpython-37.pyc differ
diff --git a/src/StructDiffusion/utils/__pycache__/transformations.cpython-38.pyc b/src/StructDiffusion/utils/__pycache__/transformations.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fac52a646e04870c1698bfb9fd092c58e802e39e
Binary files /dev/null and b/src/StructDiffusion/utils/__pycache__/transformations.cpython-38.pyc differ
diff --git a/src/StructDiffusion/utils/batch_inference.py b/src/StructDiffusion/utils/batch_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..91ae39a6b8220383fe2765ab8455c0dffb38f24d
--- /dev/null
+++ b/src/StructDiffusion/utils/batch_inference.py
@@ -0,0 +1,490 @@
+import os
+import torch
+import numpy as np
+import pytorch3d.transforms as tra3d
+
+from StructDiffusion.utils.rearrangement import show_pcs_color_order, show_pcs_with_trimesh
+from StructDiffusion.utils.pointnet import random_point_sample, index_points
+
+
+def move_pc_and_create_scene(obj_xyzs, obj_params, struct_pose, current_pc_pose, target_object_inds,
+ num_scene_pts, device, normalize_pc=False,
+ return_pair_pc=False, normalize_pair_pc=False, num_pair_pc_pts=None,
+ return_scene_pts=True, return_scene_pts_and_pc_idxs=False):
+
+ # obj_xyzs: N, P, 3
+ # obj_params: B, N, 6
+ # struct_pose: B x N, 4, 4
+ # current_pc_pose: B x N, 4, 4
+ # target_object_inds: 1, N
+
+ B, N, _ = obj_params.shape
+ _, P, _ = obj_xyzs.shape
+
+ # B, N, 6
+ flat_obj_params = obj_params.reshape(B * N, -1)
+ goal_pc_pose_in_struct = torch.eye(4).repeat(B * N, 1, 1).to(device)
+ goal_pc_pose_in_struct[:, :3, :3] = tra3d.euler_angles_to_matrix(flat_obj_params[:, 3:], "XYZ")
+ goal_pc_pose_in_struct[:, :3, 3] = flat_obj_params[:, :3] # B x N, 4, 4
+
+ goal_pc_pose = struct_pose @ goal_pc_pose_in_struct
+ goal_pc_transform = goal_pc_pose @ torch.inverse(current_pc_pose) # cur_batch_size x N, 4, 4
+
+ # important: pytorch3d uses row-major ordering, need to transpose each transformation matrix
+ transpose = tra3d.Transform3d(matrix=goal_pc_transform.transpose(1, 2))
+
+ # obj_xyzs: N, P, 3
+ new_obj_xyzs = obj_xyzs.repeat(B, 1, 1)
+ new_obj_xyzs = transpose.transform_points(new_obj_xyzs)
+
+ # put it back to B, N, P, 3
+ new_obj_xyzs = new_obj_xyzs.reshape(B, N, P, -1)
+ # visualize_batch_pcs(new_obj_xyzs, S, N, P)
+
+ # ===================================
+ # Pass to discriminator
+ subsampled_scene_xyz = None
+ if return_scene_pts:
+
+ num_indicator = N
+
+ # add one hot
+ indicator_variables = torch.eye(num_indicator).repeat(B, 1, 1, P).reshape(B, num_indicator, P, num_indicator).to(device) # B, N, P, N
+ # print(indicator_variables.shape)
+ # print(new_obj_xyzs.shape)
+ new_obj_xyzs = torch.cat([new_obj_xyzs, indicator_variables], dim=-1) # B, N, P, 3 + N
+
+ # combine pcs in each scene
+ scene_xyzs = new_obj_xyzs.reshape(B, N * P, 3 + N)
+
+ # ToDo: maybe convert this to a batch operation
+ subsampled_scene_xyz = torch.FloatTensor(B, num_scene_pts, 3 + N).to(device)
+ for si, scene_xyz in enumerate(scene_xyzs):
+ # scene_xyz: N*P, 3+N
+ # target_object_inds: 1, N
+ subsample_idx = torch.randint(0, torch.sum(target_object_inds[0]) * P, (num_scene_pts,)).to(device)
+ subsampled_scene_xyz[si] = scene_xyz[subsample_idx]
+
+ # # debug:
+ # print("-"*50)
+ # if si < 10:
+ # trimesh.PointCloud(scene_xyz[:, :3].cpu().numpy(), colors=[255, 0, 0, 255]).show()
+ # trimesh.PointCloud(subsampled_scene_xyz[si, :, :3].cpu().numpy(), colors=[0, 255, 0, 255]).show()
+
+ # subsampled_scene_xyz: B, num_scene_pts, 3+N
+ # new_obj_xyzs: B, N, P, 3
+ # goal_pc_pose: B, N, 4, 4
+
+ # important:
+ if normalize_pc:
+ subsampled_scene_xyz[:, :, 0:3] = pc_normalize_batch(subsampled_scene_xyz[:, :, 0:3])
+
+ # # debug:
+ # for si in range(10):
+ # trimesh.PointCloud(subsampled_scene_xyz[si, :, :3].cpu().numpy(), colors=[0, 0, 255, 255]).show()
+
+ if return_scene_pts_and_pc_idxs:
+ num_indicator = N
+ pc_idxs = torch.arange(0, num_indicator)[:, None].repeat(B, 1, P).reshape(B, num_indicator, P).to(device) # B, N, P
+ # new_obj_xyzs: B, N, P, 3 + 1
+
+ # combine pcs in each scene
+ scene_xyzs = new_obj_xyzs.reshape(B, N * P, 3)
+ pc_idxs = pc_idxs.reshape(B, N*P)
+
+ subsampled_scene_xyz = torch.FloatTensor(B, num_scene_pts, 3).to(device)
+ subsampled_pc_idxs = torch.LongTensor(B, num_scene_pts).to(device)
+ for si, (scene_xyz, pc_idx) in enumerate(zip(scene_xyzs, pc_idxs)):
+ # scene_xyz: N*P, 3+1
+ # target_object_inds: 1, N
+ subsample_idx = torch.randint(0, torch.sum(target_object_inds[0]) * P, (num_scene_pts,)).to(device)
+ subsampled_scene_xyz[si] = scene_xyz[subsample_idx]
+ subsampled_pc_idxs[si] = pc_idx[subsample_idx]
+
+ # subsampled_scene_xyz: B, num_scene_pts, 3
+ # subsampled_pc_idxs: B, num_scene_pts
+ # new_obj_xyzs: B, N, P, 3
+ # goal_pc_pose: B, N, 4, 4
+
+ # important:
+ if normalize_pc:
+ subsampled_scene_xyz[:, :, 0:3] = pc_normalize_batch(subsampled_scene_xyz[:, :, 0:3])
+
+ # TODO: visualize each individual object
+ # debug
+ # print(subsampled_scene_xyz.shape)
+ # print(subsampled_pc_idxs.shape)
+ # print("visualize subsampled scene")
+ # for si in range(5):
+ # trimesh.PointCloud(subsampled_scene_xyz[si, :, :3].cpu().numpy(), colors=[0, 0, 255, 255]).show()
+
+ ###############################################
+ # Create input for pairwise collision detector
+ if return_pair_pc:
+
+ assert num_pair_pc_pts is not None
+
+ # new_obj_xyzs: B, N, P, 3 + N
+ # target_object_inds: 1, N
+ # ignore paddings
+ num_objs = torch.sum(target_object_inds[0])
+ obj_pair_idxs = torch.combinations(torch.arange(num_objs), r=2) # num_comb, 2
+
+ # use [:, :, :, :3] to get obj_xyzs without object-wise indicator
+ obj_pair_xyzs = new_obj_xyzs[:, :, :, :3][:, obj_pair_idxs] # B, num_comb, 2 (obj 1 and obj 2), P, 3
+ num_comb = obj_pair_xyzs.shape[1]
+ pair_indicator_variables = torch.eye(2).repeat(B, num_comb, 1, 1, P).reshape(B, num_comb, 2, P, 2).to(device) # B, num_comb, 2, P, 2
+ obj_pair_xyzs = torch.cat([obj_pair_xyzs, pair_indicator_variables], dim=-1) # B, num_comb, 2, P, 3 (pc channels) + 2 (indicator for obj 1 and obj 2)
+ obj_pair_xyzs = obj_pair_xyzs.reshape(B, num_comb, P * 2, 5)
+
+ # random sample: idx = np.random.randint(0, scene_xyz.shape[0], self.num_scene_pts)
+ obj_pair_xyzs = obj_pair_xyzs.reshape(B * num_comb, P * 2, 5)
+ # random_point_sample() input dim: B, N, C
+ rand_idxs = random_point_sample(obj_pair_xyzs, num_pair_pc_pts) # B * num_comb, num_pair_pc_pts
+ obj_pair_xyzs = index_points(obj_pair_xyzs, rand_idxs) # B * num_comb, num_pair_pc_pts, 5
+
+ if normalize_pair_pc:
+ # pc_normalize_batch() input dim: pc: B, num_scene_pts, 3
+ # obj_pair_xyzs = obj_pair_xyzs.reshape(B * num_comb, num_pair_pc_pts, 5)
+ obj_pair_xyzs[:, :, 0:3] = pc_normalize_batch(obj_pair_xyzs[:, :, 0:3])
+ obj_pair_xyzs = obj_pair_xyzs.reshape(B, num_comb, num_pair_pc_pts, 5)
+
+ # # debug
+ # for bi, this_obj_pair_xyzs in enumerate(obj_pair_xyzs):
+ # print("batch id", bi)
+ # for pi, obj_pair_xyz in enumerate(this_obj_pair_xyzs):
+ # print("pair", pi)
+ # # obj_pair_xyzs: 2 * P, 5
+ # print(obj_pair_xyz[:, :3].shape)
+ # trimesh.PointCloud(obj_pair_xyz[:, :3].cpu()).show()
+
+ # obj_pair_xyzs: B, num_comb, num_pair_pc_pts, 3 + 2
+ goal_pc_pose = goal_pc_pose.reshape(B, N, 4, 4)
+
+ # TODO: update the return logic, a mess right now
+ if return_scene_pts_and_pc_idxs:
+ return subsampled_scene_xyz, subsampled_pc_idxs, new_obj_xyzs, goal_pc_pose
+
+ if return_pair_pc:
+ return subsampled_scene_xyz, new_obj_xyzs, goal_pc_pose, obj_pair_xyzs
+ else:
+ return subsampled_scene_xyz, new_obj_xyzs, goal_pc_pose
+
+
+def move_pc_and_create_scene_new(obj_xyzs, obj_params, struct_pose, current_pc_pose, target_object_inds, device,
+ return_scene_pts=False, return_scene_pts_and_pc_idxs=False, num_scene_pts=None, normalize_pc=False,
+ return_pair_pc=False, num_pair_pc_pts=None, normalize_pair_pc=False):
+
+ # obj_xyzs: N, P, 3
+ # obj_params: B, N, 6
+ # struct_pose: B x N, 4, 4
+ # current_pc_pose: B x N, 4, 4
+ # target_object_inds: 1, N
+
+ B, N, _ = obj_params.shape
+ _, P, _ = obj_xyzs.shape
+
+ # B, N, 6
+ flat_obj_params = obj_params.reshape(B * N, -1)
+ goal_pc_pose_in_struct = torch.eye(4).repeat(B * N, 1, 1).to(device)
+ goal_pc_pose_in_struct[:, :3, :3] = tra3d.euler_angles_to_matrix(flat_obj_params[:, 3:], "XYZ")
+ goal_pc_pose_in_struct[:, :3, 3] = flat_obj_params[:, :3] # B x N, 4, 4
+
+ goal_pc_pose = struct_pose @ goal_pc_pose_in_struct
+ goal_pc_transform = goal_pc_pose @ torch.inverse(current_pc_pose) # cur_batch_size x N, 4, 4
+
+ # important: pytorch3d uses row-major ordering, need to transpose each transformation matrix
+ transpose = tra3d.Transform3d(matrix=goal_pc_transform.transpose(1, 2))
+
+ # obj_xyzs: N, P, 3
+ new_obj_xyzs = obj_xyzs.repeat(B, 1, 1)
+ new_obj_xyzs = transpose.transform_points(new_obj_xyzs)
+
+ # put it back to B, N, P, 3
+ new_obj_xyzs = new_obj_xyzs.reshape(B, N, P, -1)
+ # visualize_batch_pcs(new_obj_xyzs, S, N, P)
+
+
+ # initialize the additional outputs
+ subsampled_scene_xyz = None
+ subsampled_pc_idxs = None
+ obj_pair_xyzs = None
+
+ # ===================================
+ # Pass to discriminator
+ if return_scene_pts:
+
+ num_indicator = N
+
+ # add one hot
+ indicator_variables = torch.eye(num_indicator).repeat(B, 1, 1, P).reshape(B, num_indicator, P, num_indicator).to(device) # B, N, P, N
+ # print(indicator_variables.shape)
+ # print(new_obj_xyzs.shape)
+ new_obj_xyzs = torch.cat([new_obj_xyzs, indicator_variables], dim=-1) # B, N, P, 3 + N
+
+ # combine pcs in each scene
+ scene_xyzs = new_obj_xyzs.reshape(B, N * P, 3 + N)
+
+ # ToDo: maybe convert this to a batch operation
+ subsampled_scene_xyz = torch.FloatTensor(B, num_scene_pts, 3 + N).to(device)
+ for si, scene_xyz in enumerate(scene_xyzs):
+ # scene_xyz: N*P, 3+N
+ # target_object_inds: 1, N
+ subsample_idx = torch.randint(0, torch.sum(target_object_inds[0]) * P, (num_scene_pts,)).to(device)
+ subsampled_scene_xyz[si] = scene_xyz[subsample_idx]
+
+ # # debug:
+ # print("-"*50)
+ # if si < 10:
+ # trimesh.PointCloud(scene_xyz[:, :3].cpu().numpy(), colors=[255, 0, 0, 255]).show()
+ # trimesh.PointCloud(subsampled_scene_xyz[si, :, :3].cpu().numpy(), colors=[0, 255, 0, 255]).show()
+
+ # subsampled_scene_xyz: B, num_scene_pts, 3+N
+ # new_obj_xyzs: B, N, P, 3
+ # goal_pc_pose: B, N, 4, 4
+
+ # important:
+ if normalize_pc:
+ subsampled_scene_xyz[:, :, 0:3] = pc_normalize_batch(subsampled_scene_xyz[:, :, 0:3])
+
+ # # debug:
+ # for si in range(10):
+ # trimesh.PointCloud(subsampled_scene_xyz[si, :, :3].cpu().numpy(), colors=[0, 0, 255, 255]).show()
+
+ if return_scene_pts_and_pc_idxs:
+ num_indicator = N
+ pc_idxs = torch.arange(0, num_indicator)[:, None].repeat(B, 1, P).reshape(B, num_indicator, P).to(device) # B, N, P
+ # new_obj_xyzs: B, N, P, 3 + 1
+
+ # combine pcs in each scene
+ scene_xyzs = new_obj_xyzs.reshape(B, N * P, 3)
+ pc_idxs = pc_idxs.reshape(B, N*P)
+
+ subsampled_scene_xyz = torch.FloatTensor(B, num_scene_pts, 3).to(device)
+ subsampled_pc_idxs = torch.LongTensor(B, num_scene_pts).to(device)
+ for si, (scene_xyz, pc_idx) in enumerate(zip(scene_xyzs, pc_idxs)):
+ # scene_xyz: N*P, 3+1
+ # target_object_inds: 1, N
+ subsample_idx = torch.randint(0, torch.sum(target_object_inds[0]) * P, (num_scene_pts,)).to(device)
+ subsampled_scene_xyz[si] = scene_xyz[subsample_idx]
+ subsampled_pc_idxs[si] = pc_idx[subsample_idx]
+
+ # subsampled_scene_xyz: B, num_scene_pts, 3
+ # subsampled_pc_idxs: B, num_scene_pts
+ # new_obj_xyzs: B, N, P, 3
+ # goal_pc_pose: B, N, 4, 4
+
+ # important:
+ if normalize_pc:
+ subsampled_scene_xyz[:, :, 0:3] = pc_normalize_batch(subsampled_scene_xyz[:, :, 0:3])
+
+ # TODO: visualize each individual object
+ # debug
+ # print(subsampled_scene_xyz.shape)
+ # print(subsampled_pc_idxs.shape)
+ # print("visualize subsampled scene")
+ # for si in range(5):
+ # trimesh.PointCloud(subsampled_scene_xyz[si, :, :3].cpu().numpy(), colors=[0, 0, 255, 255]).show()
+
+ ###############################################
+ # Create input for pairwise collision detector
+ if return_pair_pc:
+
+ assert num_pair_pc_pts is not None
+
+ # new_obj_xyzs: B, N, P, 3 + N
+ # target_object_inds: 1, N
+ # ignore paddings
+ num_objs = torch.sum(target_object_inds[0])
+ obj_pair_idxs = torch.combinations(torch.arange(num_objs), r=2) # num_comb, 2
+
+ # use [:, :, :, :3] to get obj_xyzs without object-wise indicator
+ obj_pair_xyzs = new_obj_xyzs[:, :, :, :3][:, obj_pair_idxs] # B, num_comb, 2 (obj 1 and obj 2), P, 3
+ num_comb = obj_pair_xyzs.shape[1]
+ pair_indicator_variables = torch.eye(2).repeat(B, num_comb, 1, 1, P).reshape(B, num_comb, 2, P, 2).to(device) # B, num_comb, 2, P, 2
+ obj_pair_xyzs = torch.cat([obj_pair_xyzs, pair_indicator_variables], dim=-1) # B, num_comb, 2, P, 3 (pc channels) + 2 (indicator for obj 1 and obj 2)
+ obj_pair_xyzs = obj_pair_xyzs.reshape(B, num_comb, P * 2, 5)
+
+ # random sample: idx = np.random.randint(0, scene_xyz.shape[0], self.num_scene_pts)
+ obj_pair_xyzs = obj_pair_xyzs.reshape(B * num_comb, P * 2, 5)
+ # random_point_sample() input dim: B, N, C
+ rand_idxs = random_point_sample(obj_pair_xyzs, num_pair_pc_pts) # B * num_comb, num_pair_pc_pts
+ obj_pair_xyzs = index_points(obj_pair_xyzs, rand_idxs) # B * num_comb, num_pair_pc_pts, 5
+
+ if normalize_pair_pc:
+ # pc_normalize_batch() input dim: pc: B, num_scene_pts, 3
+ # obj_pair_xyzs = obj_pair_xyzs.reshape(B * num_comb, num_pair_pc_pts, 5)
+ obj_pair_xyzs[:, :, 0:3] = pc_normalize_batch(obj_pair_xyzs[:, :, 0:3])
+ obj_pair_xyzs = obj_pair_xyzs.reshape(B, num_comb, num_pair_pc_pts, 5)
+
+ # # debug
+ # for bi, this_obj_pair_xyzs in enumerate(obj_pair_xyzs):
+ # print("batch id", bi)
+ # for pi, obj_pair_xyz in enumerate(this_obj_pair_xyzs):
+ # print("pair", pi)
+ # # obj_pair_xyzs: 2 * P, 5
+ # print(obj_pair_xyz[:, :3].shape)
+ # trimesh.PointCloud(obj_pair_xyz[:, :3].cpu()).show()
+
+ # obj_pair_xyzs: B, num_comb, num_pair_pc_pts, 3 + 2
+ goal_pc_pose = goal_pc_pose.reshape(B, N, 4, 4)
+
+ return new_obj_xyzs, goal_pc_pose, subsampled_scene_xyz, subsampled_pc_idxs, obj_pair_xyzs
+
+
+def move_pc(obj_xyzs, obj_params, struct_pose, current_pc_pose, device):
+
+ # obj_xyzs: N, P, 3
+ # obj_params: B, N, 6
+ # struct_pose: B x N, 4, 4
+ # current_pc_pose: B x N, 4, 4
+ # target_object_inds: 1, N
+
+ B, N, _ = obj_params.shape
+ _, P, _ = obj_xyzs.shape
+
+ # B, N, 6
+ flat_obj_params = obj_params.reshape(B * N, -1)
+ goal_pc_pose_in_struct = torch.eye(4).repeat(B * N, 1, 1).to(device)
+ goal_pc_pose_in_struct[:, :3, :3] = tra3d.euler_angles_to_matrix(flat_obj_params[:, 3:], "XYZ")
+ goal_pc_pose_in_struct[:, :3, 3] = flat_obj_params[:, :3] # B x N, 4, 4
+
+ goal_pc_pose = struct_pose @ goal_pc_pose_in_struct
+ goal_pc_transform = goal_pc_pose @ torch.inverse(current_pc_pose) # cur_batch_size x N, 4, 4
+
+ # important: pytorch3d uses row-major ordering, need to transpose each transformation matrix
+ transpose = tra3d.Transform3d(matrix=goal_pc_transform.transpose(1, 2))
+
+ # obj_xyzs: N, P, 3
+ new_obj_xyzs = obj_xyzs.repeat(B, 1, 1)
+ new_obj_xyzs = transpose.transform_points(new_obj_xyzs)
+
+ # put it back to B, N, P, 3
+ new_obj_xyzs = new_obj_xyzs.reshape(B, N, P, -1)
+ # visualize_batch_pcs(new_obj_xyzs, S, N, P)
+
+ # subsampled_scene_xyz: B, num_scene_pts, 3+N
+ # new_obj_xyzs: B, N, P, 3
+ # goal_pc_pose: B, N, 4, 4
+
+ goal_pc_pose = goal_pc_pose.reshape(B, N, 4, 4)
+ return new_obj_xyzs, goal_pc_pose
+
+
+def move_pc_and_create_scene_simple(obj_xyzs, struct_pose, pc_poses_in_struct):
+
+ device = obj_xyzs.device
+
+ # obj_xyzs: B, N, P, 3 or 6
+ # struct_pose: B, 1, 4, 4
+ # pc_poses_in_struct: B, N, 4, 4
+
+ B, N, _, _ = pc_poses_in_struct.shape
+ _, _, P, _ = obj_xyzs.shape
+
+ current_pc_poses = torch.eye(4).repeat(B, N, 1, 1).to(device) # B, N, 4, 4
+ # print(torch.mean(obj_xyzs, dim=2).shape)
+ current_pc_poses[:, :, :3, 3] = torch.mean(obj_xyzs[:, :, :, :3], dim=2) # B, N, 4, 4
+ current_pc_poses = current_pc_poses.reshape(B * N, 4, 4) # B x N, 4, 4
+
+ struct_pose = struct_pose.repeat(1, N, 1, 1) # B, N, 4, 4
+ struct_pose = struct_pose.reshape(B * N, 4, 4) # B x 1, 4, 4
+ pc_poses_in_struct = pc_poses_in_struct.reshape(B * N, 4, 4) # B x N, 4, 4
+
+ goal_pc_pose = struct_pose @ pc_poses_in_struct # B x N, 4, 4
+ # print("goal pc poses")
+ # print(goal_pc_pose)
+ goal_pc_transform = goal_pc_pose @ torch.inverse(current_pc_poses) # B x N, 4, 4
+
+ # important: pytorch3d uses row-major ordering, need to transpose each transformation matrix
+ transpose = tra3d.Transform3d(matrix=goal_pc_transform.transpose(1, 2))
+
+ new_obj_xyzs = obj_xyzs.reshape(B * N, P, -1) # B x N, P, 3
+ new_obj_xyzs[:, :, :3] = transpose.transform_points(new_obj_xyzs[:, :, :3])
+
+ # put it back to B, N, P, 3
+ new_obj_xyzs = new_obj_xyzs.reshape(B, N, P, -1)
+
+ return new_obj_xyzs
+
+
+def compute_current_and_goal_pc_poses(obj_xyzs, struct_pose, pc_poses_in_struct):
+
+ device = obj_xyzs.device
+
+ # obj_xyzs: B, N, P, 3
+ # struct_pose: B, 1, 4, 4
+ # pc_poses_in_struct: B, N, 4, 4
+ B, N, _, _ = pc_poses_in_struct.shape
+ _, _, P, _ = obj_xyzs.shape
+
+ current_pc_poses = torch.eye(4).repeat(B, N, 1, 1).to(device) # B, N, 4, 4
+ # print(torch.mean(obj_xyzs, dim=2).shape)
+ current_pc_poses[:, :, :3, 3] = torch.mean(obj_xyzs, dim=2) # B, N, 4, 4
+
+ struct_pose = struct_pose.repeat(1, N, 1, 1) # B, N, 4, 4
+ struct_pose = struct_pose.reshape(B * N, 4, 4) # B x 1, 4, 4
+ pc_poses_in_struct = pc_poses_in_struct.reshape(B * N, 4, 4) # B x N, 4, 4
+
+ goal_pc_poses = struct_pose @ pc_poses_in_struct # B x N, 4, 4
+ goal_pc_poses = goal_pc_poses.reshape(B, N, 4, 4) # B, N, 4, 4
+ return current_pc_poses, goal_pc_poses
+
+
+def sample_gaussians(mus, sigmas, sample_size):
+ # mus: [number of individual gaussians]
+ # sigmas: [number of individual gaussians]
+ normal = torch.distributions.Normal(mus, sigmas)
+ samples = normal.sample((sample_size,))
+ # samples: [sample_size, number of individual gaussians]
+ return samples
+
+def fit_gaussians(samples, sigma_eps=0.01):
+ device = samples.device
+
+ # samples: [sample_size, number of individual gaussians]
+ num_gs = samples.shape[1]
+ mus = torch.mean(samples, dim=0).to(device)
+ sigmas = torch.std(samples, dim=0).to(device) + sigma_eps * torch.ones(num_gs).to(device)
+ # mus: [number of individual gaussians]
+ # sigmas: [number of individual gaussians]
+ return mus, sigmas
+
+
+def visualize_batch_pcs(obj_xyzs, B, verbose=False, limit_B=None, save_dir=None, trimesh=False):
+ if limit_B is None:
+ limit_B = B
+
+ vis_obj_xyzs = obj_xyzs[:limit_B]
+
+ if torch.is_tensor(vis_obj_xyzs):
+ if vis_obj_xyzs.is_cuda:
+ vis_obj_xyzs = vis_obj_xyzs.detach().cpu()
+ vis_obj_xyzs = vis_obj_xyzs.numpy()
+
+ for bi, vis_obj_xyz in enumerate(vis_obj_xyzs):
+ if verbose:
+ print("example {}".format(bi))
+ print(vis_obj_xyz.shape)
+
+ if trimesh:
+ show_pcs_with_trimesh([xyz[:, :3] for xyz in vis_obj_xyz], [xyz[:, 3:] for xyz in vis_obj_xyz])
+ else:
+ if save_dir:
+ if not os.path.exists(save_dir):
+ os.makedirs(save_dir)
+ save_path = os.path.join(save_dir, "b{}.jpg".format(bi))
+ show_pcs_color_order([xyz[:, :3] for xyz in vis_obj_xyz], None, visualize=False, add_coordinate_frame=False,
+ side_view=True, save_path=save_path)
+ else:
+ show_pcs_color_order([xyz[:, :3] for xyz in vis_obj_xyz], None, visualize=True, add_coordinate_frame=False,
+ side_view=True)
+
+
+def pc_normalize_batch(pc):
+ # pc: B, num_scene_pts, 3
+ centroid = torch.mean(pc, dim=1) # B, 3
+ pc = pc - centroid[:, None, :]
+ m = torch.max(torch.sqrt(torch.sum(pc ** 2, dim=2)), dim=1)[0]
+ pc = pc / m[:, None, None]
+ return pc
\ No newline at end of file
diff --git a/src/StructDiffusion/utils/brain2/__init__.py b/src/StructDiffusion/utils/brain2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/StructDiffusion/utils/brain2/__pycache__/__init__.cpython-37.pyc b/src/StructDiffusion/utils/brain2/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a865e957b0416b8531ff96968d43a42e02d0c5c8
Binary files /dev/null and b/src/StructDiffusion/utils/brain2/__pycache__/__init__.cpython-37.pyc differ
diff --git a/src/StructDiffusion/utils/brain2/__pycache__/__init__.cpython-38.pyc b/src/StructDiffusion/utils/brain2/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b61b09a1b77329c30f59eaffd4d878dc0d0f02bf
Binary files /dev/null and b/src/StructDiffusion/utils/brain2/__pycache__/__init__.cpython-38.pyc differ
diff --git a/src/StructDiffusion/utils/brain2/__pycache__/camera.cpython-37.pyc b/src/StructDiffusion/utils/brain2/__pycache__/camera.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..85dd8dfbe93b525792a1dda7ffc56fc69ac43830
Binary files /dev/null and b/src/StructDiffusion/utils/brain2/__pycache__/camera.cpython-37.pyc differ
diff --git a/src/StructDiffusion/utils/brain2/__pycache__/camera.cpython-38.pyc b/src/StructDiffusion/utils/brain2/__pycache__/camera.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9af852d208bc53275e669633ebb0c883123427de
Binary files /dev/null and b/src/StructDiffusion/utils/brain2/__pycache__/camera.cpython-38.pyc differ
diff --git a/src/StructDiffusion/utils/brain2/__pycache__/image.cpython-37.pyc b/src/StructDiffusion/utils/brain2/__pycache__/image.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b5e36a52f7c675b572fa3e1ceeedfb06f0009fcb
Binary files /dev/null and b/src/StructDiffusion/utils/brain2/__pycache__/image.cpython-37.pyc differ
diff --git a/src/StructDiffusion/utils/brain2/__pycache__/image.cpython-38.pyc b/src/StructDiffusion/utils/brain2/__pycache__/image.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9ef499255c77b1d8fa3cc65a32a9e4a6b9d9f11a
Binary files /dev/null and b/src/StructDiffusion/utils/brain2/__pycache__/image.cpython-38.pyc differ
diff --git a/src/StructDiffusion/utils/brain2/__pycache__/pose.cpython-37.pyc b/src/StructDiffusion/utils/brain2/__pycache__/pose.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0b910883f1c4dd2fbc1a5475965faa9cc9ed50d3
Binary files /dev/null and b/src/StructDiffusion/utils/brain2/__pycache__/pose.cpython-37.pyc differ
diff --git a/src/StructDiffusion/utils/brain2/__pycache__/pose.cpython-38.pyc b/src/StructDiffusion/utils/brain2/__pycache__/pose.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1325a12669cdb6cfbe1c0b12fed60e2b9a79536e
Binary files /dev/null and b/src/StructDiffusion/utils/brain2/__pycache__/pose.cpython-38.pyc differ
diff --git a/src/StructDiffusion/utils/brain2/camera.py b/src/StructDiffusion/utils/brain2/camera.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ac75323f0468815b9f25066d66a458f3e0dc771
--- /dev/null
+++ b/src/StructDiffusion/utils/brain2/camera.py
@@ -0,0 +1,175 @@
+from __future__ import print_function
+
+import numpy as np
+import open3d
+import trimesh
+
+import StructDiffusion.utils.transformations as tra
+from StructDiffusion.utils.brain2.pose import make_pose
+
+
+def get_camera_from_h5(h5):
+ """ Simple reference to help make these """
+ proj_near = h5['cam_near'][()]
+ proj_far = h5['cam_far'][()]
+ proj_fov = h5['cam_fov'][()]
+ width = h5['cam_width'][()]
+ height = h5['cam_height'][()]
+ return GenericCameraReference(proj_near, proj_far, proj_fov, width, height)
+
+
+class GenericCameraReference(object):
+ """ Class storing camera information and providing easy image capture """
+
+ def __init__(self, proj_near=0.01, proj_far=5., proj_fov=60., img_width=640,
+ img_height=480):
+
+ self.proj_near = proj_near
+ self.proj_far = proj_far
+ self.proj_fov = proj_fov
+ self.img_width = img_width
+ self.img_height = img_height
+ self.x_offset = self.img_width / 2.
+ self.y_offset = self.img_height / 2.
+
+ # Compute focal params
+ aspect_ratio = self.img_width / self.img_height
+ e = 1 / (np.tan(np.radians(self.proj_fov/2.)))
+ t = self.proj_near / e
+ b = -t
+ r = t * aspect_ratio
+ l = -r
+ # pixels per meter
+ alpha = self.img_width / (r-l)
+ self.focal_length = self.proj_near * alpha
+ self.fx = self.focal_length
+ self.fy = self.focal_length
+ self.pose = None
+ self.inv_pose = None
+
+ def set_pose(self, trans, rot):
+ self.pose = make_pose(trans, rot)
+ self.inv_pose = tra.inverse_matrix(self.pose)
+
+ def set_pose_matrix(self, matrix):
+ self.pose = matrix
+ self.inv_pose = tra.inverse_matrix(matrix)
+
+ def transform_to_world_coords(self, xyz):
+ """ transform xyz into world coordinates """
+ #cam_pose = tra.inverse_matrix(self.pose).dot(tra.euler_matrix(np.pi, 0, 0))
+ #xyz = trimesh.transform_points(xyz, self.inv_pose)
+ #xyz = trimesh.transform_points(xyz, cam_pose)
+ #pose = tra.euler_matrix(np.pi, 0, 0) @ self.pose
+ pose = self.pose
+ xyz = trimesh.transform_points(xyz, pose)
+ return xyz
+
+def get_camera_presets():
+ return [
+ "n/a",
+ "azure_depth_nfov",
+ "realsense",
+ "azure_720p",
+ "simple256",
+ "simple512",
+ ]
+
+
+def get_camera_preset(name):
+
+ if name == "azure_depth_nfov":
+ # Setting for depth camera is pretty different from RGB
+ height, width, fov = 576, 640, 75
+ if name == "azure_720p":
+ # This is actually the 720p RGB setting
+ # Used for our color camera most of the time
+ #height, width, fov = 720, 1280, 90
+ height, width, fov = 720, 1280, 60
+ elif name == "realsense":
+ height, width, fov = 480, 640, 60
+ elif name == "simple256":
+ height, width, fov = 256, 256, 60
+ elif name == "simple512":
+ height, width, fov = 512, 512, 60
+ else:
+ raise RuntimeError(('camera "%s" not supported, choose from: ' +
+ str(get_camera_presets())) % str(name))
+ return height, width, fov
+
+
+def get_generic_camera(name):
+ h, w, fov = get_camera_preset(name)
+ return GenericCameraReference(img_height=h, img_width=w, proj_fov=fov)
+
+
+def get_matrix_of_indices(height, width):
+ """ Get indices """
+ return np.indices((height, width), dtype=np.float32).transpose(1,2,0)
+
+# --------------------------------------------------------
+# NOTE: this code taken from Arsalan and modified
+def compute_xyz(depth_img, camera, visualize_xyz=False,
+ xmap=None, ymap=None, max_clip_depth=5):
+ """ Compute xyz image from depth for a camera """
+
+ # We need thes eparameters
+ height = camera.img_height
+ width = camera.img_width
+ assert depth_img.shape[0] == camera.img_height
+ assert depth_img.shape[1] == camera.img_width
+ fx = camera.fx
+ fy = camera.fy
+ cx = camera.x_offset
+ cy = camera.y_offset
+
+ """
+ # Create the matrix of parameters
+ indices = np.indices((height, width), dtype=np.float32).transpose(1,2,0)
+ # pixel indices start at top-left corner. for these equations, it starts at bottom-left
+ # indices[..., 0] = np.flipud(indices[..., 0])
+ z_e = depth_img
+ x_e = (indices[..., 1] - x_offset) * z_e / fx
+ y_e = (indices[..., 0] - y_offset) * z_e / fy
+ xyz_img = np.stack([x_e, y_e, z_e], axis=-1) # Shape: [H x W x 3]
+ """
+
+ height = depth_img.shape[0]
+ width = depth_img.shape[1]
+ input_x = np.arange(width)
+ input_y = np.arange(height)
+ input_x, input_y = np.meshgrid(input_x, input_y)
+ input_x = input_x.flatten()
+ input_y = input_y.flatten()
+ input_z = depth_img.flatten()
+ # clip points that are farther than max distance
+ input_z[input_z > max_clip_depth] = 0
+ output_x = (input_x * input_z - cx * input_z) / fx
+ output_y = (input_y * input_z - cy * input_z) / fy
+ raw_pc = np.stack([output_x, output_y, input_z], -1).reshape(
+ height, width, 3
+ )
+ return raw_pc
+
+ if visualize_xyz:
+ unordered_pc = xyz_img.reshape(-1, 3)
+ pcd = open3d.geometry.PointCloud()
+ pcd.points = open3d.utility.Vector3dVector(unordered_pc)
+ pcd.transform([[1,0,0,0], [0,1,0,0], [0,0,-1,0], [0,0,0,1]]) # Transform it so it's not upside down
+ open3d.visualization.draw_geometries([pcd])
+
+ return xyz_img
+
+def show_pcs(xyz, rgb):
+ """ Display point clouds """
+ if len(xyz.shape) > 2:
+ unordered_pc = xyz.reshape(-1, 3)
+ unordered_rgb = rgb.reshape(-1, 3) / 255.
+ assert(unordered_rgb.shape[0] == unordered_pc.shape[0])
+ assert(unordered_pc.shape[1] == 3)
+ assert(unordered_rgb.shape[1] == 3)
+ pcd = open3d.geometry.PointCloud()
+ pcd.points = open3d.utility.Vector3dVector(unordered_pc)
+ pcd.colors = open3d.utility.Vector3dVector(unordered_rgb)
+ pcd.transform([[1,0,0,0],[0,1,0,0],[0,0,-1,0],[0,0,0,1]]) # Transform it so it's not upside down
+ open3d.visualization.draw_geometries([pcd])
diff --git a/src/StructDiffusion/utils/brain2/image.py b/src/StructDiffusion/utils/brain2/image.py
new file mode 100644
index 0000000000000000000000000000000000000000..668a551760f2548e1dec1e318ef54073244a7383
--- /dev/null
+++ b/src/StructDiffusion/utils/brain2/image.py
@@ -0,0 +1,154 @@
+"""
+By Chris Paxton.
+
+Copyright (c) 2018, Johns Hopkins University
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+ * Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+ * Neither the name of the Johns Hopkins University nor the
+ names of its contributors may be used to endorse or promote products
+ derived from this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL JOHNS HOPKINS UNIVERSITY BE LIABLE FOR ANY
+DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+"""
+
+import numpy as np
+import io
+from PIL import Image
+
+def GetJpeg(img):
+ '''
+ Save a numpy array as a Jpeg, then get it out as a binary blob
+ '''
+ im = Image.fromarray(np.uint8(img))
+ output = io.BytesIO()
+ im.save(output, format="JPEG", quality=80)
+ return output.getvalue()
+
+def JpegToNumpy(jpeg):
+ stream = io.BytesIO(jpeg)
+ im = Image.open(stream)
+ return np.asarray(im, dtype=np.uint8)
+
+def ConvertJpegListToNumpy(data):
+ length = len(data)
+ imgs = []
+ for raw in data:
+ imgs.append(JpegToNumpy(raw))
+ arr = np.array(imgs)
+ return arr
+
+def DepthToZBuffer(img, z_near, z_far):
+ real_depth = z_near * z_far / (z_far - img * (z_far - z_near))
+ return real_depth
+
+def ZBufferToRGB(img, z_near, z_far):
+ real_depth = z_near * z_far / (z_far - img * (z_far - z_near))
+ depth_m = np.uint8(real_depth)
+ depth_cm = np.uint8((real_depth-depth_m)*100)
+ depth_tmm = np.uint8((real_depth-depth_m-0.01*depth_cm)*10000)
+ return np.dstack([depth_m, depth_cm, depth_tmm])
+
+def RGBToDepth(img, min_dist=0., max_dist=2.,):
+ return (img[:,:,0]+.01*img[:,:,1]+.0001*img[:,:,2]).clip(min_dist, max_dist)
+ #return img[:,:,0]+.01*img[:,:,1]+.0001*img[:,:,2]
+
+def MaskToRGBA(img):
+ buf = img.astype(np.int32)
+ A = buf.astype(np.uint8)
+ buf = np.right_shift(buf, 8)
+ B = buf.astype(np.uint8)
+ buf = np.right_shift(buf, 8)
+ G = buf.astype(np.uint8)
+ buf = np.right_shift(buf, 8)
+ R = buf.astype(np.uint8)
+
+ dims = [np.expand_dims(d, -1) for d in [R,G,B,A]]
+ return np.concatenate(dims, axis=-1)
+
+def RGBAToMask(img):
+ mask = np.zeros(img.shape[:-1], dtype=np.int32)
+ buf = img.astype(np.int32)
+ for i, dim in enumerate([3,2,1,0]):
+ shift = 8*i
+ #print(i, dim, shift, buf[0,0,dim], np.left_shift(buf[0,0,dim], shift))
+ mask += np.left_shift(buf[:,:, dim], shift)
+ return mask
+
+def RGBAArrayToMasks(img):
+ mask = np.zeros(img.shape[:-1], dtype=np.int32)
+ buf = img.astype(np.int32)
+ for i, dim in enumerate([3,2,1,0]):
+ shift = 8*i
+ mask += np.left_shift(buf[:,:,:, dim], shift)
+ return mask
+
+def GetPNG(img):
+ '''
+ Save a numpy array as a PNG, then get it out as a binary blob
+ '''
+ im = Image.fromarray(np.uint8(img))
+ output = io.BytesIO()
+ im.save(output, format="PNG")#, quality=80)
+ return output.getvalue()
+
+def PNGToNumpy(png):
+ stream = io.BytesIO(png)
+ im = Image.open(stream)
+ return np.array(im, dtype=np.uint8)
+
+def ConvertPNGListToNumpy(data):
+ length = len(data)
+ imgs = []
+ for raw in data:
+ imgs.append(PNGToNumpy(raw))
+ arr = np.array(imgs)
+ return arr
+
+def ConvertDepthPNGListToNumpy(data):
+ length = len(data)
+ imgs = []
+ for raw in data:
+ imgs.append(RGBToDepth(PNGToNumpy(raw)))
+ arr = np.array(imgs)
+ return arr
+
+import cv2
+def Shrink(img, nw=64):
+ h,w = img.shape[:2]
+ ratio = float(nw) / w
+ nh = int(ratio * h)
+ img2 = cv2.resize(img, dsize=(nw, nh),
+ interpolation=cv2.INTER_NEAREST)
+ return img2
+
+def ShrinkSmooth(img, nw=64):
+ h,w = img.shape[:2]
+ ratio = float(nw) / w
+ nh = int(ratio * h)
+ img2 = cv2.resize(img, dsize=(nw, nh),
+ interpolation=cv2.INTER_LINEAR)
+ return img2
+
+def CropCenter(img, cropx, cropy):
+ y = img.shape[0]
+ x = img.shape[1]
+ startx = (x // 2) - (cropx // 2)
+ starty = (y // 2) - (cropy // 2)
+ return img[starty: starty + cropy, startx : startx + cropx]
+
diff --git a/src/StructDiffusion/utils/brain2/pose.py b/src/StructDiffusion/utils/brain2/pose.py
new file mode 100644
index 0000000000000000000000000000000000000000..e314f58f6c1942b1f62e827130a681b829d4e5f1
--- /dev/null
+++ b/src/StructDiffusion/utils/brain2/pose.py
@@ -0,0 +1,19 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+
+from __future__ import print_function
+
+import StructDiffusion.utils.transformations as tra
+
+
+def make_pose(trans, rot):
+ """Make 4x4 matrix from (trans, rot)"""
+ pose = tra.quaternion_matrix(rot)
+ pose[:3, 3] = trans
+ return pose
diff --git a/src/StructDiffusion/utils/files.py b/src/StructDiffusion/utils/files.py
new file mode 100644
index 0000000000000000000000000000000000000000..ffd77d30a73ade07f627ebe87dc0b4ce2e7230ed
--- /dev/null
+++ b/src/StructDiffusion/utils/files.py
@@ -0,0 +1,9 @@
+import os
+
+def get_checkpoint_path_from_dir(checkpoint_dir):
+ checkpoint_path = None
+ for file in os.listdir(checkpoint_dir):
+ if "ckpt" in file:
+ checkpoint_path = os.path.join(checkpoint_dir, file)
+ assert checkpoint_path is not None
+ return checkpoint_path
\ No newline at end of file
diff --git a/src/StructDiffusion/utils/physics_eval.py b/src/StructDiffusion/utils/physics_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa14d5f92c6239bf7b86039d5020fa54eb8eae58
--- /dev/null
+++ b/src/StructDiffusion/utils/physics_eval.py
@@ -0,0 +1,295 @@
+import sys
+import os
+import h5py
+import torch
+import pytorch3d.transforms as tra3d
+
+from StructDiffusion.utils.rearrangement import show_pcs_color_order
+from StructDiffusion.utils.pointnet import random_point_sample, index_points
+
+
+def switch_stdout(stdout_filename=None):
+ if stdout_filename:
+ print("setting stdout to {}".format(stdout_filename))
+ if os.path.exists(stdout_filename):
+ sys.stdout = open(stdout_filename, 'a')
+ else:
+ sys.stdout = open(stdout_filename, 'w')
+ else:
+ sys.stdout = sys.__stdout__
+
+
+def visualize_batch_pcs(obj_xyzs, B, N, P, verbose=True, limit_B=None):
+ if limit_B is None:
+ limit_B = B
+
+ vis_obj_xyzs = obj_xyzs.reshape(B, N, P, -1)
+ vis_obj_xyzs = vis_obj_xyzs[:limit_B]
+
+ if type(vis_obj_xyzs).__module__ == torch.__name__:
+ if vis_obj_xyzs.is_cuda:
+ vis_obj_xyzs = vis_obj_xyzs.detach().cpu()
+ vis_obj_xyzs = vis_obj_xyzs.numpy()
+
+ for bi, vis_obj_xyz in enumerate(vis_obj_xyzs):
+ if verbose:
+ print("example {}".format(bi))
+ print(vis_obj_xyz.shape)
+ show_pcs_color_order([xyz[:, :3] for xyz in vis_obj_xyz], None, visualize=True, add_coordinate_frame=True, add_table=False)
+
+
+def convert_bool(d):
+ for k in d:
+ if type(d[k]) == list:
+ d[k] = [bool(i) for i in d[k]]
+ else:
+ d[k] = bool(d[k])
+ return d
+
+
+def save_dict_to_h5(dict_data, filename):
+ fh = h5py.File(filename, 'w')
+ for k in dict_data:
+ key_data = dict_data[k]
+ if key_data is None:
+ raise RuntimeError('data was not properly populated')
+ # if type(key_data) is dict:
+ # key_data = json.dumps(key_data, sort_keys=True)
+ try:
+ fh.create_dataset(k, data=key_data)
+ except TypeError as e:
+ print("Failure on key", k)
+ print(key_data)
+ print(e)
+ raise e
+ fh.close()
+
+
+def move_pc_and_create_scene_new(obj_xyzs, obj_params, struct_pose, current_pc_pose, target_object_inds, device,
+ return_scene_pts=False, return_scene_pts_and_pc_idxs=False, num_scene_pts=None, normalize_pc=False,
+ return_pair_pc=False, num_pair_pc_pts=None, normalize_pair_pc=False):
+
+ # obj_xyzs: N, P, 3
+ # obj_params: B, N, 6
+ # struct_pose: B x N, 4, 4
+ # current_pc_pose: B x N, 4, 4
+ # target_object_inds: 1, N
+
+ B, N, _ = obj_params.shape
+ _, P, _ = obj_xyzs.shape
+
+ # B, N, 6
+ flat_obj_params = obj_params.reshape(B * N, -1)
+ goal_pc_pose_in_struct = torch.eye(4).repeat(B * N, 1, 1).to(device)
+ goal_pc_pose_in_struct[:, :3, :3] = tra3d.euler_angles_to_matrix(flat_obj_params[:, 3:], "XYZ")
+ goal_pc_pose_in_struct[:, :3, 3] = flat_obj_params[:, :3] # B x N, 4, 4
+
+ goal_pc_pose = struct_pose @ goal_pc_pose_in_struct
+ goal_pc_transform = goal_pc_pose @ torch.inverse(current_pc_pose) # cur_batch_size x N, 4, 4
+
+ # important: pytorch3d uses row-major ordering, need to transpose each transformation matrix
+ transpose = tra3d.Transform3d(matrix=goal_pc_transform.transpose(1, 2))
+
+ # obj_xyzs: N, P, 3
+ new_obj_xyzs = obj_xyzs.repeat(B, 1, 1)
+ new_obj_xyzs = transpose.transform_points(new_obj_xyzs)
+
+ # put it back to B, N, P, 3
+ new_obj_xyzs = new_obj_xyzs.reshape(B, N, P, -1)
+ # visualize_batch_pcs(new_obj_xyzs, S, N, P)
+
+
+ # initialize the additional outputs
+ subsampled_scene_xyz = None
+ subsampled_pc_idxs = None
+ obj_pair_xyzs = None
+
+ # ===================================
+ # Pass to discriminator
+ if return_scene_pts:
+
+ num_indicator = N
+
+ # add one hot
+ indicator_variables = torch.eye(num_indicator).repeat(B, 1, 1, P).reshape(B, num_indicator, P, num_indicator).to(device) # B, N, P, N
+ # print(indicator_variables.shape)
+ # print(new_obj_xyzs.shape)
+ new_obj_xyzs = torch.cat([new_obj_xyzs, indicator_variables], dim=-1) # B, N, P, 3 + N
+
+ # combine pcs in each scene
+ scene_xyzs = new_obj_xyzs.reshape(B, N * P, 3 + N)
+
+ # ToDo: maybe convert this to a batch operation
+ subsampled_scene_xyz = torch.FloatTensor(B, num_scene_pts, 3 + N).to(device)
+ for si, scene_xyz in enumerate(scene_xyzs):
+ # scene_xyz: N*P, 3+N
+ # target_object_inds: 1, N
+ subsample_idx = torch.randint(0, torch.sum(target_object_inds[0]) * P, (num_scene_pts,)).to(device)
+ subsampled_scene_xyz[si] = scene_xyz[subsample_idx]
+
+ # # debug:
+ # print("-"*50)
+ # if si < 10:
+ # trimesh.PointCloud(scene_xyz[:, :3].cpu().numpy(), colors=[255, 0, 0, 255]).show()
+ # trimesh.PointCloud(subsampled_scene_xyz[si, :, :3].cpu().numpy(), colors=[0, 255, 0, 255]).show()
+
+ # subsampled_scene_xyz: B, num_scene_pts, 3+N
+ # new_obj_xyzs: B, N, P, 3
+ # goal_pc_pose: B, N, 4, 4
+
+ # important:
+ if normalize_pc:
+ subsampled_scene_xyz[:, :, 0:3] = pc_normalize_batch(subsampled_scene_xyz[:, :, 0:3])
+
+ # # debug:
+ # for si in range(10):
+ # trimesh.PointCloud(subsampled_scene_xyz[si, :, :3].cpu().numpy(), colors=[0, 0, 255, 255]).show()
+
+ if return_scene_pts_and_pc_idxs:
+ num_indicator = N
+ pc_idxs = torch.arange(0, num_indicator)[:, None].repeat(B, 1, P).reshape(B, num_indicator, P).to(device) # B, N, P
+ # new_obj_xyzs: B, N, P, 3 + 1
+
+ # combine pcs in each scene
+ scene_xyzs = new_obj_xyzs.reshape(B, N * P, 3)
+ pc_idxs = pc_idxs.reshape(B, N*P)
+
+ subsampled_scene_xyz = torch.FloatTensor(B, num_scene_pts, 3).to(device)
+ subsampled_pc_idxs = torch.LongTensor(B, num_scene_pts).to(device)
+ for si, (scene_xyz, pc_idx) in enumerate(zip(scene_xyzs, pc_idxs)):
+ # scene_xyz: N*P, 3+1
+ # target_object_inds: 1, N
+ subsample_idx = torch.randint(0, torch.sum(target_object_inds[0]) * P, (num_scene_pts,)).to(device)
+ subsampled_scene_xyz[si] = scene_xyz[subsample_idx]
+ subsampled_pc_idxs[si] = pc_idx[subsample_idx]
+
+ # subsampled_scene_xyz: B, num_scene_pts, 3
+ # subsampled_pc_idxs: B, num_scene_pts
+ # new_obj_xyzs: B, N, P, 3
+ # goal_pc_pose: B, N, 4, 4
+
+ # important:
+ if normalize_pc:
+ subsampled_scene_xyz[:, :, 0:3] = pc_normalize_batch(subsampled_scene_xyz[:, :, 0:3])
+
+ # TODO: visualize each individual object
+ # debug
+ # print(subsampled_scene_xyz.shape)
+ # print(subsampled_pc_idxs.shape)
+ # print("visualize subsampled scene")
+ # for si in range(5):
+ # trimesh.PointCloud(subsampled_scene_xyz[si, :, :3].cpu().numpy(), colors=[0, 0, 255, 255]).show()
+
+ ###############################################
+ # Create input for pairwise collision detector
+ if return_pair_pc:
+
+ assert num_pair_pc_pts is not None
+
+ # new_obj_xyzs: B, N, P, 3 + N
+ # target_object_inds: 1, N
+ # ignore paddings
+ num_objs = torch.sum(target_object_inds[0])
+ obj_pair_idxs = torch.combinations(torch.arange(num_objs), r=2) # num_comb, 2
+
+ # use [:, :, :, :3] to get obj_xyzs without object-wise indicator
+ obj_pair_xyzs = new_obj_xyzs[:, :, :, :3][:, obj_pair_idxs] # B, num_comb, 2 (obj 1 and obj 2), P, 3
+ num_comb = obj_pair_xyzs.shape[1]
+ pair_indicator_variables = torch.eye(2).repeat(B, num_comb, 1, 1, P).reshape(B, num_comb, 2, P, 2).to(device) # B, num_comb, 2, P, 2
+ obj_pair_xyzs = torch.cat([obj_pair_xyzs, pair_indicator_variables], dim=-1) # B, num_comb, 2, P, 3 (pc channels) + 2 (indicator for obj 1 and obj 2)
+ obj_pair_xyzs = obj_pair_xyzs.reshape(B, num_comb, P * 2, 5)
+
+ # random sample: idx = np.random.randint(0, scene_xyz.shape[0], self.num_scene_pts)
+ obj_pair_xyzs = obj_pair_xyzs.reshape(B * num_comb, P * 2, 5)
+ # random_point_sample() input dim: B, N, C
+ rand_idxs = random_point_sample(obj_pair_xyzs, num_pair_pc_pts) # B * num_comb, num_pair_pc_pts
+ obj_pair_xyzs = index_points(obj_pair_xyzs, rand_idxs) # B * num_comb, num_pair_pc_pts, 5
+
+ if normalize_pair_pc:
+ # pc_normalize_batch() input dim: pc: B, num_scene_pts, 3
+ # obj_pair_xyzs = obj_pair_xyzs.reshape(B * num_comb, num_pair_pc_pts, 5)
+ obj_pair_xyzs[:, :, 0:3] = pc_normalize_batch(obj_pair_xyzs[:, :, 0:3])
+ obj_pair_xyzs = obj_pair_xyzs.reshape(B, num_comb, num_pair_pc_pts, 5)
+
+ # # debug
+ # for bi, this_obj_pair_xyzs in enumerate(obj_pair_xyzs):
+ # print("batch id", bi)
+ # for pi, obj_pair_xyz in enumerate(this_obj_pair_xyzs):
+ # print("pair", pi)
+ # # obj_pair_xyzs: 2 * P, 5
+ # print(obj_pair_xyz[:, :3].shape)
+ # trimesh.PointCloud(obj_pair_xyz[:, :3].cpu()).show()
+
+ # obj_pair_xyzs: B, num_comb, num_pair_pc_pts, 3 + 2
+ goal_pc_pose = goal_pc_pose.reshape(B, N, 4, 4)
+
+ return new_obj_xyzs, goal_pc_pose, subsampled_scene_xyz, subsampled_pc_idxs, obj_pair_xyzs
+
+
+
+def move_pc(obj_xyzs, obj_params, struct_pose, current_pc_pose, device):
+
+ # obj_xyzs: N, P, 3
+ # obj_params: B, N, 6
+ # struct_pose: B x N, 4, 4
+ # current_pc_pose: B x N, 4, 4
+ # target_object_inds: 1, N
+
+ B, N, _ = obj_params.shape
+ _, P, _ = obj_xyzs.shape
+
+ # B, N, 6
+ flat_obj_params = obj_params.reshape(B * N, -1)
+ goal_pc_pose_in_struct = torch.eye(4).repeat(B * N, 1, 1).to(device)
+ goal_pc_pose_in_struct[:, :3, :3] = tra3d.euler_angles_to_matrix(flat_obj_params[:, 3:], "XYZ")
+ goal_pc_pose_in_struct[:, :3, 3] = flat_obj_params[:, :3] # B x N, 4, 4
+
+ goal_pc_pose = struct_pose @ goal_pc_pose_in_struct
+ goal_pc_transform = goal_pc_pose @ torch.inverse(current_pc_pose) # cur_batch_size x N, 4, 4
+
+ # important: pytorch3d uses row-major ordering, need to transpose each transformation matrix
+ transpose = tra3d.Transform3d(matrix=goal_pc_transform.transpose(1, 2))
+
+ # obj_xyzs: N, P, 3
+ new_obj_xyzs = obj_xyzs.repeat(B, 1, 1)
+ new_obj_xyzs = transpose.transform_points(new_obj_xyzs)
+
+ # put it back to B, N, P, 3
+ new_obj_xyzs = new_obj_xyzs.reshape(B, N, P, -1)
+ # visualize_batch_pcs(new_obj_xyzs, S, N, P)
+
+ # subsampled_scene_xyz: B, num_scene_pts, 3+N
+ # new_obj_xyzs: B, N, P, 3
+ # goal_pc_pose: B, N, 4, 4
+
+ goal_pc_pose = goal_pc_pose.reshape(B, N, 4, 4)
+ return new_obj_xyzs, goal_pc_pose
+
+
+def sample_gaussians(mus, sigmas, sample_size):
+ # mus: [number of individual gaussians]
+ # sigmas: [number of individual gaussians]
+ normal = torch.distributions.Normal(mus, sigmas)
+ samples = normal.sample((sample_size,))
+ # samples: [sample_size, number of individual gaussians]
+ return samples
+
+def fit_gaussians(samples, sigma_eps=0.01):
+ device = samples.device
+
+ # samples: [sample_size, number of individual gaussians]
+ num_gs = samples.shape[1]
+ mus = torch.mean(samples, dim=0).to(device)
+ sigmas = torch.std(samples, dim=0).to(device) + sigma_eps * torch.ones(num_gs).to(device)
+ # mus: [number of individual gaussians]
+ # sigmas: [number of individual gaussians]
+ return mus, sigmas
+
+
+def pc_normalize_batch(pc):
+ # pc: B, num_scene_pts, 3
+ centroid = torch.mean(pc, dim=1) # B, 3
+ pc = pc - centroid[:, None, :]
+ m = torch.max(torch.sqrt(torch.sum(pc ** 2, dim=2)), dim=1)[0]
+ pc = pc / m[:, None, None]
+ return pc
diff --git a/src/StructDiffusion/utils/pointnet.py b/src/StructDiffusion/utils/pointnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f7371b6d4250d86668a16a396572286a3f91bbc
--- /dev/null
+++ b/src/StructDiffusion/utils/pointnet.py
@@ -0,0 +1,583 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from time import time
+import numpy as np
+
+
+# reference https://github.com/yanx27/Pointnet_Pointnet2_pytorch, modified by Yang You
+
+
+def timeit(tag, t):
+ print("{}: {}s".format(tag, time() - t))
+ return time()
+
+def pc_normalize(pc):
+ if type(pc).__module__ == np.__name__:
+ centroid = np.mean(pc, axis=0)
+ pc = pc - centroid
+ m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
+ pc = pc / m
+ else:
+ centroid = torch.mean(pc, dim=0)
+ pc = pc - centroid
+ m = torch.max(torch.sqrt(torch.sum(pc ** 2, dim=1)))
+ pc = pc / m
+ return pc
+
+def square_distance(src, dst):
+ """
+ Calculate Euclid distance between each two points.
+ src^T * dst = xn * xm + yn * ym + zn * zmï¼›
+ sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
+ sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
+ dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
+ = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
+ Input:
+ src: source points, [B, N, C]
+ dst: target points, [B, M, C]
+ Output:
+ dist: per-point square distance, [B, N, M]
+ """
+ return torch.sum((src[:, :, None] - dst[:, None]) ** 2, dim=-1)
+
+
+def index_points(points, idx):
+ """
+ Input:
+ points: input points data, [B, N, C]
+ idx: sample index data, [B, S, [K]]
+ Return:
+ new_points:, indexed points data, [B, S, [K], C]
+ """
+ raw_size = idx.size()
+ idx = idx.reshape(raw_size[0], -1)
+ res = torch.gather(points, 1, idx[..., None].expand(-1, -1, points.size(-1)))
+ return res.reshape(*raw_size, -1)
+
+
+def farthest_point_sample(xyz, npoint):
+ """
+ Input:
+ xyz: pointcloud data, [B, N, 3]
+ npoint: number of samples
+ Return:
+ centroids: sampled pointcloud index, [B, npoint]
+ """
+ device = xyz.device
+ B, N, C = xyz.shape
+ centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
+ distance = torch.ones(B, N).to(device) * 1e10
+ farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
+ batch_indices = torch.arange(B, dtype=torch.long).to(device)
+ for i in range(npoint):
+ centroids[:, i] = farthest
+ centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
+ dist = torch.sum((xyz - centroid) ** 2, -1)
+ distance = torch.min(distance, dist)
+ farthest = torch.max(distance, -1)[1]
+ return centroids
+
+
+def random_point_sample(xyz, npoint):
+ """
+ Input:
+ xyz: pointcloud data, [B, N, 3]
+ npoint: number of samples
+ Return:
+ idxs: sampled pointcloud index, [B, npoint]
+ """
+ device = xyz.device
+ B, N, C = xyz.shape
+ idxs = torch.randint(0, N, (B, npoint), dtype=torch.long).to(device)
+ return idxs
+
+
+def query_ball_point(radius, nsample, xyz, new_xyz):
+ """
+ Input:
+ radius: local region radius
+ nsample: max sample number in local region
+ xyz: all points, [B, N, 3]
+ new_xyz: query points, [B, S, 3]
+ Return:
+ group_idx: grouped points index, [B, S, nsample]
+ """
+ device = xyz.device
+ B, N, C = xyz.shape
+ _, S, _ = new_xyz.shape
+ group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
+ sqrdists = square_distance(new_xyz, xyz)
+ group_idx[sqrdists > radius ** 2] = N
+ group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
+ group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
+ mask = group_idx == N
+ group_idx[mask] = group_first[mask]
+ return group_idx
+
+
+def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False, knn=False):
+ """
+ Input:
+ npoint:
+ radius:
+ nsample:
+ xyz: input points position data, [B, N, 3]
+ points: input points data, [B, N, D]
+ Return:
+ new_xyz: sampled points position data, [B, npoint, nsample, 3]
+ new_points: sampled points data, [B, npoint, nsample, 3+D]
+ """
+ B, N, C = xyz.shape
+ S = npoint
+ fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint]
+ torch.cuda.empty_cache()
+ new_xyz = index_points(xyz, fps_idx)
+ torch.cuda.empty_cache()
+ if knn:
+ dists = square_distance(new_xyz, xyz) # B x npoint x N
+ idx = dists.argsort()[:, :, :nsample] # B x npoint x K
+ else:
+ idx = query_ball_point(radius, nsample, xyz, new_xyz)
+ torch.cuda.empty_cache()
+ grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
+ torch.cuda.empty_cache()
+ grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
+ torch.cuda.empty_cache()
+
+ if points is not None:
+ grouped_points = index_points(points, idx)
+ new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
+ else:
+ new_points = grouped_xyz_norm
+ if returnfps:
+ return new_xyz, new_points, grouped_xyz, fps_idx
+ else:
+ return new_xyz, new_points
+
+
+def sample_and_group_all(xyz, points):
+ """
+ Input:
+ xyz: input points position data, [B, N, 3]
+ points: input points data, [B, N, D]
+ Return:
+ new_xyz: sampled points position data, [B, 1, 3]
+ new_points: sampled points data, [B, 1, N, 3+D]
+ """
+ device = xyz.device
+ B, N, C = xyz.shape
+ new_xyz = torch.zeros(B, 1, C).to(device)
+ grouped_xyz = xyz.view(B, 1, N, C)
+ if points is not None:
+ new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
+ else:
+ new_points = grouped_xyz
+ return new_xyz, new_points
+
+
+class PointNetSetAbstraction(nn.Module):
+ def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all, knn=False):
+ super(PointNetSetAbstraction, self).__init__()
+ self.npoint = npoint
+ self.radius = radius
+ self.nsample = nsample
+ self.knn = knn
+ self.mlp_convs = nn.ModuleList()
+ self.mlp_bns = nn.ModuleList()
+ last_channel = in_channel
+ for out_channel in mlp:
+ self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
+ self.mlp_bns.append(nn.BatchNorm2d(out_channel))
+ last_channel = out_channel
+ self.group_all = group_all
+
+ def forward(self, xyz, points):
+ """
+ Input:
+ xyz: input points position data, [B, N, C]
+ points: input points data, [B, N, C]
+ Return:
+ new_xyz: sampled points position data, [B, S, C]
+ new_points_concat: sample points feature data, [B, S, D']
+ """
+ if self.group_all:
+ new_xyz, new_points = sample_and_group_all(xyz, points)
+ else:
+ new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points, knn=self.knn)
+ # new_xyz: sampled points position data, [B, npoint, C]
+ # new_points: sampled points data, [B, npoint, nsample, C+D]
+ new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
+ for i, conv in enumerate(self.mlp_convs):
+ bn = self.mlp_bns[i]
+ new_points = F.relu(bn(conv(new_points)))
+
+ new_points = torch.max(new_points, 2)[0].transpose(1, 2)
+ return new_xyz, new_points
+
+
+class PointNetSetAbstractionMsg(nn.Module):
+ def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list, knn=False):
+ super(PointNetSetAbstractionMsg, self).__init__()
+ self.npoint = npoint
+ self.radius_list = radius_list
+ self.nsample_list = nsample_list
+ self.knn = knn
+ self.conv_blocks = nn.ModuleList()
+ self.bn_blocks = nn.ModuleList()
+ for i in range(len(mlp_list)):
+ convs = nn.ModuleList()
+ bns = nn.ModuleList()
+ last_channel = in_channel + 3
+ for out_channel in mlp_list[i]:
+ convs.append(nn.Conv2d(last_channel, out_channel, 1))
+ bns.append(nn.BatchNorm2d(out_channel))
+ last_channel = out_channel
+ self.conv_blocks.append(convs)
+ self.bn_blocks.append(bns)
+
+ def forward(self, xyz, points, seed_idx=None):
+ """
+ Input:
+ xyz: input points position data, [B, C, N]
+ points: input points data, [B, D, N]
+ Return:
+ new_xyz: sampled points position data, [B, C, S]
+ new_points_concat: sample points feature data, [B, D', S]
+ """
+
+ B, N, C = xyz.shape
+ S = self.npoint
+ new_xyz = index_points(xyz, farthest_point_sample(xyz, S) if seed_idx is None else seed_idx)
+ new_points_list = []
+ for i, radius in enumerate(self.radius_list):
+ K = self.nsample_list[i]
+ if self.knn:
+ dists = square_distance(new_xyz, xyz) # B x npoint x N
+ group_idx = dists.argsort()[:, :, :K] # B x npoint x K
+ else:
+ group_idx = query_ball_point(radius, K, xyz, new_xyz)
+ grouped_xyz = index_points(xyz, group_idx)
+ grouped_xyz -= new_xyz.view(B, S, 1, C)
+ if points is not None:
+ grouped_points = index_points(points, group_idx)
+ grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)
+ else:
+ grouped_points = grouped_xyz
+
+ grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S]
+ for j in range(len(self.conv_blocks[i])):
+ conv = self.conv_blocks[i][j]
+ bn = self.bn_blocks[i][j]
+ grouped_points = F.relu(bn(conv(grouped_points)))
+ new_points = torch.max(grouped_points, 2)[0] # [B, D', S]
+ new_points_list.append(new_points)
+
+ new_points_concat = torch.cat(new_points_list, dim=1).transpose(1, 2)
+ return new_xyz, new_points_concat
+
+
+# NoteL this function swaps N and C
+class PointNetFeaturePropagation(nn.Module):
+ def __init__(self, in_channel, mlp):
+ super(PointNetFeaturePropagation, self).__init__()
+ self.mlp_convs = nn.ModuleList()
+ self.mlp_bns = nn.ModuleList()
+ last_channel = in_channel
+ for out_channel in mlp:
+ self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
+ self.mlp_bns.append(nn.BatchNorm1d(out_channel))
+ last_channel = out_channel
+
+ def forward(self, xyz1, xyz2, points1, points2):
+ """
+ Input:
+ xyz1: input points position data, [B, C, N]
+ xyz2: sampled input points position data, [B, C, S]
+ points1: input points data, [B, D, N]
+ points2: input points data, [B, D, S]
+ Return:
+ new_points: upsampled points data, [B, D', N]
+ """
+ xyz1 = xyz1.permute(0, 2, 1)
+ xyz2 = xyz2.permute(0, 2, 1)
+
+ points2 = points2.permute(0, 2, 1)
+ B, N, C = xyz1.shape
+ _, S, _ = xyz2.shape
+
+ if S == 1:
+ interpolated_points = points2.repeat(1, N, 1)
+ else:
+ dists = square_distance(xyz1, xyz2)
+ dists, idx = dists.sort(dim=-1)
+ dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]
+
+ dist_recip = 1.0 / (dists + 1e-8)
+ norm = torch.sum(dist_recip, dim=2, keepdim=True)
+ weight = dist_recip / norm
+ interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)
+
+ if points1 is not None:
+ points1 = points1.permute(0, 2, 1)
+ new_points = torch.cat([points1, interpolated_points], dim=-1)
+ else:
+ new_points = interpolated_points
+
+ new_points = new_points.permute(0, 2, 1)
+ for i, conv in enumerate(self.mlp_convs):
+ bn = self.mlp_bns[i]
+ new_points = F.relu(bn(conv(new_points)))
+ return new_points
+
+
+# reference https://github.com/qq456cvb/Point-Transformers
+
+def normalize_data(batch_data):
+ """ Normalize the batch data, use coordinates of the block centered at origin,
+ Input:
+ BxNxC array
+ Output:
+ BxNxC array
+ """
+ B, N, C = batch_data.shape
+ normal_data = np.zeros((B, N, C))
+ for b in range(B):
+ pc = batch_data[b]
+ centroid = np.mean(pc, axis=0)
+ pc = pc - centroid
+ m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
+ pc = pc / m
+ normal_data[b] = pc
+ return normal_data
+
+
+def shuffle_data(data, labels):
+ """ Shuffle data and labels.
+ Input:
+ data: B,N,... numpy array
+ label: B,... numpy array
+ Return:
+ shuffled data, label and shuffle indices
+ """
+ idx = np.arange(len(labels))
+ np.random.shuffle(idx)
+ return data[idx, ...], labels[idx], idx
+
+def shuffle_points(batch_data):
+ """ Shuffle orders of points in each point cloud -- changes FPS behavior.
+ Use the same shuffling idx for the entire batch.
+ Input:
+ BxNxC array
+ Output:
+ BxNxC array
+ """
+ idx = np.arange(batch_data.shape[1])
+ np.random.shuffle(idx)
+ return batch_data[:,idx,:]
+
+def rotate_point_cloud(batch_data):
+ """ Randomly rotate the point clouds to augument the dataset
+ rotation is per shape based along up direction
+ Input:
+ BxNx3 array, original batch of point clouds
+ Return:
+ BxNx3 array, rotated batch of point clouds
+ """
+ rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
+ for k in range(batch_data.shape[0]):
+ rotation_angle = np.random.uniform() * 2 * np.pi
+ cosval = np.cos(rotation_angle)
+ sinval = np.sin(rotation_angle)
+ rotation_matrix = np.array([[cosval, 0, sinval],
+ [0, 1, 0],
+ [-sinval, 0, cosval]])
+ shape_pc = batch_data[k, ...]
+ rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
+ return rotated_data
+
+def rotate_point_cloud_z(batch_data):
+ """ Randomly rotate the point clouds to augument the dataset
+ rotation is per shape based along up direction
+ Input:
+ BxNx3 array, original batch of point clouds
+ Return:
+ BxNx3 array, rotated batch of point clouds
+ """
+ rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
+ for k in range(batch_data.shape[0]):
+ rotation_angle = np.random.uniform() * 2 * np.pi
+ cosval = np.cos(rotation_angle)
+ sinval = np.sin(rotation_angle)
+ rotation_matrix = np.array([[cosval, sinval, 0],
+ [-sinval, cosval, 0],
+ [0, 0, 1]])
+ shape_pc = batch_data[k, ...]
+ rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
+ return rotated_data
+
+def rotate_point_cloud_with_normal(batch_xyz_normal):
+ ''' Randomly rotate XYZ, normal point cloud.
+ Input:
+ batch_xyz_normal: B,N,6, first three channels are XYZ, last 3 all normal
+ Output:
+ B,N,6, rotated XYZ, normal point cloud
+ '''
+ for k in range(batch_xyz_normal.shape[0]):
+ rotation_angle = np.random.uniform() * 2 * np.pi
+ cosval = np.cos(rotation_angle)
+ sinval = np.sin(rotation_angle)
+ rotation_matrix = np.array([[cosval, 0, sinval],
+ [0, 1, 0],
+ [-sinval, 0, cosval]])
+ shape_pc = batch_xyz_normal[k,:,0:3]
+ shape_normal = batch_xyz_normal[k,:,3:6]
+ batch_xyz_normal[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
+ batch_xyz_normal[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), rotation_matrix)
+ return batch_xyz_normal
+
+def rotate_perturbation_point_cloud_with_normal(batch_data, angle_sigma=0.06, angle_clip=0.18):
+ """ Randomly perturb the point clouds by small rotations
+ Input:
+ BxNx6 array, original batch of point clouds and point normals
+ Return:
+ BxNx3 array, rotated batch of point clouds
+ """
+ rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
+ for k in range(batch_data.shape[0]):
+ angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip)
+ Rx = np.array([[1,0,0],
+ [0,np.cos(angles[0]),-np.sin(angles[0])],
+ [0,np.sin(angles[0]),np.cos(angles[0])]])
+ Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])],
+ [0,1,0],
+ [-np.sin(angles[1]),0,np.cos(angles[1])]])
+ Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0],
+ [np.sin(angles[2]),np.cos(angles[2]),0],
+ [0,0,1]])
+ R = np.dot(Rz, np.dot(Ry,Rx))
+ shape_pc = batch_data[k,:,0:3]
+ shape_normal = batch_data[k,:,3:6]
+ rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), R)
+ rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), R)
+ return rotated_data
+
+
+def rotate_point_cloud_by_angle(batch_data, rotation_angle):
+ """ Rotate the point cloud along up direction with certain angle.
+ Input:
+ BxNx3 array, original batch of point clouds
+ Return:
+ BxNx3 array, rotated batch of point clouds
+ """
+ rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
+ for k in range(batch_data.shape[0]):
+ #rotation_angle = np.random.uniform() * 2 * np.pi
+ cosval = np.cos(rotation_angle)
+ sinval = np.sin(rotation_angle)
+ rotation_matrix = np.array([[cosval, 0, sinval],
+ [0, 1, 0],
+ [-sinval, 0, cosval]])
+ shape_pc = batch_data[k,:,0:3]
+ rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
+ return rotated_data
+
+def rotate_point_cloud_by_angle_with_normal(batch_data, rotation_angle):
+ """ Rotate the point cloud along up direction with certain angle.
+ Input:
+ BxNx6 array, original batch of point clouds with normal
+ scalar, angle of rotation
+ Return:
+ BxNx6 array, rotated batch of point clouds iwth normal
+ """
+ rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
+ for k in range(batch_data.shape[0]):
+ #rotation_angle = np.random.uniform() * 2 * np.pi
+ cosval = np.cos(rotation_angle)
+ sinval = np.sin(rotation_angle)
+ rotation_matrix = np.array([[cosval, 0, sinval],
+ [0, 1, 0],
+ [-sinval, 0, cosval]])
+ shape_pc = batch_data[k,:,0:3]
+ shape_normal = batch_data[k,:,3:6]
+ rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
+ rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1,3)), rotation_matrix)
+ return rotated_data
+
+
+
+def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18):
+ """ Randomly perturb the point clouds by small rotations
+ Input:
+ BxNx3 array, original batch of point clouds
+ Return:
+ BxNx3 array, rotated batch of point clouds
+ """
+ rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
+ for k in range(batch_data.shape[0]):
+ angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip)
+ Rx = np.array([[1,0,0],
+ [0,np.cos(angles[0]),-np.sin(angles[0])],
+ [0,np.sin(angles[0]),np.cos(angles[0])]])
+ Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])],
+ [0,1,0],
+ [-np.sin(angles[1]),0,np.cos(angles[1])]])
+ Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0],
+ [np.sin(angles[2]),np.cos(angles[2]),0],
+ [0,0,1]])
+ R = np.dot(Rz, np.dot(Ry,Rx))
+ shape_pc = batch_data[k, ...]
+ rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R)
+ return rotated_data
+
+
+def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05):
+ """ Randomly jitter points. jittering is per point.
+ Input:
+ BxNx3 array, original batch of point clouds
+ Return:
+ BxNx3 array, jittered batch of point clouds
+ """
+ B, N, C = batch_data.shape
+ assert(clip > 0)
+ jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip)
+ jittered_data += batch_data
+ return jittered_data
+
+def shift_point_cloud(batch_data, shift_range=0.1):
+ """ Randomly shift point cloud. Shift is per point cloud.
+ Input:
+ BxNx3 array, original batch of point clouds
+ Return:
+ BxNx3 array, shifted batch of point clouds
+ """
+ B, N, C = batch_data.shape
+ shifts = np.random.uniform(-shift_range, shift_range, (B,3))
+ for batch_index in range(B):
+ batch_data[batch_index,:,:] += shifts[batch_index,:]
+ return batch_data
+
+
+def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25):
+ """ Randomly scale the point cloud. Scale is per point cloud.
+ Input:
+ BxNx3 array, original batch of point clouds
+ Return:
+ BxNx3 array, scaled batch of point clouds
+ """
+ B, N, C = batch_data.shape
+ scales = np.random.uniform(scale_low, scale_high, B)
+ for batch_index in range(B):
+ batch_data[batch_index,:,:] *= scales[batch_index]
+ return batch_data
+
+def random_point_dropout(batch_pc, max_dropout_ratio=0.875):
+ ''' batch_pc: BxNx3 '''
+ for b in range(batch_pc.shape[0]):
+ dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875
+ drop_idx = np.where(np.random.random((batch_pc.shape[1]))<=dropout_ratio)[0]
+ if len(drop_idx)>0:
+ batch_pc[b,drop_idx,:] = batch_pc[b,0,:] # set to the first point
+ return batch_pc
+
+
diff --git a/src/StructDiffusion/utils/random.py b/src/StructDiffusion/utils/random.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/StructDiffusion/utils/rearrangement.py b/src/StructDiffusion/utils/rearrangement.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3957666076f0ce471745a9e8bf37f45a0f60382
--- /dev/null
+++ b/src/StructDiffusion/utils/rearrangement.py
@@ -0,0 +1,1201 @@
+import copy
+import os
+import torch
+import trimesh
+import numpy as np
+import open3d
+from PIL import Image, ImageDraw, ImageFont
+from sklearn.metrics import classification_report
+from collections import defaultdict
+import matplotlib.pyplot as plt
+import itertools
+import matplotlib
+import h5py
+import json
+
+import StructDiffusion.utils.transformations as tra
+from StructDiffusion.utils.rotation_continuity import compute_geodesic_distance_from_two_matrices
+
+# from pointnet_utils import farthest_point_sample, index_points
+
+
+def flatten1d(img):
+ return img.reshape(-1)
+
+
+def flatten3d(img):
+ hw = img.shape[0] * img.shape[1]
+ return img.reshape(hw, -1)
+
+
+def array_to_tensor(array):
+ """ Assume arrays are in numpy (channels-last) format and put them into the right one """
+ if array.ndim == 4: # NHWC
+ tensor = torch.from_numpy(array).permute(0,3,1,2).float()
+ elif array.ndim == 3: # HWC
+ tensor = torch.from_numpy(array).permute(2,0,1).float()
+ else: # everything else - just keep it as-is
+ tensor = torch.from_numpy(array).float()
+ return tensor
+
+
+def get_pts(xyz_in, rgb_in, mask, bg_mask=None, num_pts=1024, center=None,
+ radius=0.5, filename=None, to_tensor=True):
+
+ # Get the XYZ and RGB
+ mask = flatten1d(mask)
+ assert(np.sum(mask) > 0)
+ xyz = flatten3d(xyz_in)[mask > 0]
+ if rgb_in is not None:
+ rgb = flatten3d(rgb_in)[mask > 0]
+
+ if xyz.shape[0] == 0:
+ raise RuntimeError('this should not happen')
+ ok = False
+ xyz = flatten3d(xyz_in)
+ if rgb_in is not None:
+ rgb = flatten3d(rgb_in)
+ else:
+ ok = True
+
+ # prune to this region
+ if center is not None:
+ # numpy matrix
+ # use the full xyz point cloud to determine what is close enough
+ # now that we have the closest background point we can place the object on it
+ # Just center on the point
+ center = center.numpy()
+ center = center[None].repeat(xyz.shape[0], axis=0)
+ dists = np.linalg.norm(xyz - center, axis=-1)
+ idx = dists < radius
+ xyz = xyz[idx]
+ if rgb_in is not None:
+ rgb = rgb[idx]
+ center = center[0]
+ else:
+ center = None
+
+ # Compute number of points we are using
+ if num_pts is not None:
+ if xyz.shape[0] < 1:
+ print("!!!! bad shape:", xyz.shape, filename, "!!!!")
+ return (None, None, None, None)
+ idx = np.random.randint(0, xyz.shape[0], num_pts)
+ xyz = xyz[idx]
+ if rgb_in is not None:
+ rgb = rgb[idx]
+
+ # Shuffle the points
+ if rgb_in is not None:
+ rgb = array_to_tensor(rgb) if to_tensor else rgb
+ else:
+ rgb = None
+ xyz = array_to_tensor(xyz) if to_tensor else xyz
+ return (ok, xyz, rgb, center)
+
+
+def align(y_true, y_pred):
+ """ Add or remove 2*pi to predicted angle to minimize difference from GT"""
+ y_pred = y_pred.copy()
+ y_pred[y_true - y_pred > np.pi] += np.pi * 2
+ y_pred[y_true - y_pred < -np.pi] -= np.pi * 2
+ return y_pred
+
+
+def random_move_obj_xyz(obj_xyz,
+ min_translation, max_translation,
+ min_rotation, max_rotation, mode,
+ visualize=False, return_perturbed_obj_xyzs=True):
+
+ assert mode in ["planar", "6d", "3d_planar"]
+
+ if mode == "planar":
+ random_translation = np.random.uniform(low=min_translation, high=max_translation, size=2) * np.random.choice(
+ [-1, 1], size=2)
+ random_rotation = np.random.uniform(low=min_rotation, high=max_rotation) * np.random.choice([-1, 1])
+ random_rotation = tra.euler_matrix(0, 0, random_rotation)
+ elif mode == "6d":
+ random_rotation = np.random.uniform(low=min_rotation, high=max_rotation, size=3) * np.random.choice([-1, 1], size=3)
+ random_rotation = tra.euler_matrix(*random_rotation)
+ random_translation = np.random.uniform(low=min_translation, high=max_translation, size=3) * np.random.choice([-1, 1], size=3)
+ elif mode == "3d_planar":
+ random_translation = np.random.uniform(low=min_translation, high=max_translation, size=3) * np.random.choice(
+ [-1, 1], size=3)
+ random_rotation = np.random.uniform(low=min_rotation, high=max_rotation) * np.random.choice([-1, 1])
+ random_rotation = tra.euler_matrix(0, 0, random_rotation)
+
+ if return_perturbed_obj_xyzs:
+ raise Exception("return_perturbed_obj_xyzs=True is no longer supported")
+ # xyz_mean = np.mean(obj_xyz, axis=0)
+ # new_obj_xyz = obj_xyz - xyz_mean
+ # new_obj_xyz = trimesh.transform_points(new_obj_xyz, random_rotation, translate=False)
+ # new_obj_xyz = new_obj_xyz + xyz_mean + random_translation
+ else:
+ new_obj_xyz = obj_xyz
+
+ # test moving the perturbed obj pc back
+ # new_xyz_mean = np.mean(new_obj_xyz, axis=0)
+ # old_obj_xyz = new_obj_xyz - new_xyz_mean
+ # old_obj_xyz = trimesh.transform_points(old_obj_xyz, np.linalg.inv(random_rotation), translate=False)
+ # old_obj_xyz = old_obj_xyz + new_xyz_mean - random_translation
+
+ # even though we are putting perturbation rotation and translation in the same matrix, they should be applied
+ # independently. More specifically, rotate the object pc in place and then translate it.
+ perturbation_matrix = random_rotation
+ perturbation_matrix[:3, 3] = random_translation
+
+ if visualize:
+ show_pcs([new_obj_xyz, obj_xyz],
+ [np.tile(np.array([1, 0, 0], dtype=np.float), (obj_xyz.shape[0], 1)),
+ np.tile(np.array([0, 1, 0], dtype=np.float), (obj_xyz.shape[0], 1))], add_coordinate_frame=True)
+
+ return new_obj_xyz, perturbation_matrix
+
+
+def random_move_obj_xyzs(obj_xyzs,
+ min_translation, max_translation,
+ min_rotation, max_rotation, mode, move_obj_idxs=None, visualize=False, return_moved_obj_idxs=False,
+ return_perturbation=False, return_perturbed_obj_xyzs=True):
+ """
+
+ :param obj_xyzs:
+ :param min_translation:
+ :param max_translation:
+ :param min_rotation:
+ :param max_rotation:
+ :param mode:
+ :param move_obj_idxs:
+ :param visualize:
+ :param return_moved_obj_idxs:
+ :param return_perturbation:
+ :param return_perturbed_obj_xyzs:
+ :return:
+ """
+
+ new_obj_xyzs = []
+ new_obj_rgbs = []
+ old_obj_rgbs = []
+ perturbation_matrices = []
+
+ if move_obj_idxs is None:
+ move_obj_idxs = list(range(len(obj_xyzs)))
+
+ # this many objects will not be randomly moved
+ stationary_obj_idxs = np.random.choice(move_obj_idxs, np.random.randint(0, len(move_obj_idxs)), replace=False).tolist()
+
+ moved_obj_idxs = []
+ for obj_idx, obj_xyz in enumerate(obj_xyzs):
+
+ if obj_idx in stationary_obj_idxs:
+ new_obj_xyzs.append(obj_xyz)
+ perturbation_matrices.append(np.eye(4))
+ if visualize:
+ new_obj_rgbs.append(np.tile(np.array([1, 0, 0], dtype=np.float), (obj_xyz.shape[0], 1)))
+ old_obj_rgbs.append(np.tile(np.array([0, 0, 1], dtype=np.float), (obj_xyz.shape[0], 1)))
+ else:
+ new_obj_xyz, perturbation_matrix = random_move_obj_xyz(obj_xyz,
+ min_translation=min_translation, max_translation=max_translation,
+ min_rotation=min_rotation, max_rotation=max_rotation, mode=mode,
+ return_perturbed_obj_xyzs=return_perturbed_obj_xyzs)
+ new_obj_xyzs.append(new_obj_xyz)
+ moved_obj_idxs.append(obj_idx)
+ perturbation_matrices.append(perturbation_matrix)
+ if visualize:
+ new_obj_rgbs.append(np.tile(np.array([1, 0, 0], dtype=np.float), (obj_xyz.shape[0], 1)))
+ old_obj_rgbs.append(np.tile(np.array([0, 1, 0], dtype=np.float), (obj_xyz.shape[0], 1)))
+ if visualize:
+ show_pcs(new_obj_xyzs + obj_xyzs,
+ new_obj_rgbs + old_obj_rgbs, add_coordinate_frame=True)
+
+ if return_moved_obj_idxs:
+ if return_perturbation:
+ return new_obj_xyzs, moved_obj_idxs, perturbation_matrices
+ else:
+ return new_obj_xyzs, moved_obj_idxs
+ else:
+ if return_perturbation:
+ return new_obj_xyzs, perturbation_matrices
+ else:
+ return new_obj_xyzs
+
+
+def check_pairwise_collision(pcs, visualize=False):
+
+ voxel_extents = [0.005] * 3
+
+ collision_managers = []
+ collision_objects = []
+
+ for pc in pcs:
+
+ # farthest point sample
+ pc = pc.unsqueeze(0)
+ fps_idx = farthest_point_sample(pc, 100) # [B, npoint]
+ pc = index_points(pc, fps_idx).squeeze(0)
+
+ pc = np.asanyarray(pc)
+ # ignore empty pc
+ if np.all(pc == 0):
+ continue
+
+ n_points = pc.shape[0]
+ collision_object = []
+ collision_manager = trimesh.collision.CollisionManager()
+
+ # Construct collision objects
+ for i in range(n_points):
+ extents = voxel_extents
+ transform = np.eye(4)
+ transform[:3, 3] = pc[i, :3]
+ voxel = trimesh.primitives.Box(extents=extents, transform=transform)
+ collision_object.append((voxel, extents, transform))
+
+ # Add to collision manager
+ for i, (voxel, _, _) in enumerate(collision_object):
+ collision_manager.add_object("voxel_{}".format(i), voxel)
+
+ collision_managers.append(collision_manager)
+ collision_objects.append(collision_object)
+
+ in_collision = False
+ for i, cm_i in enumerate(collision_managers):
+ for j, cm_j in enumerate(collision_managers):
+ if i == j:
+ continue
+ if cm_i.in_collision_other(cm_j):
+ in_collision = True
+
+ if visualize:
+ visualize_collision_objects(collision_objects[i] + collision_objects[j])
+
+ break
+
+ if in_collision:
+ break
+
+ return in_collision
+
+
+def check_collision_with(this_pc, other_pcs, visualize=False):
+
+ voxel_extents = [0.005] * 3
+
+ this_collision_manager = None
+ this_collision_object = None
+ other_collision_managers = []
+ other_collision_objects = []
+
+ for oi, pc in enumerate([this_pc] + other_pcs):
+
+ # farthest point sample
+ pc = pc.unsqueeze(0)
+ fps_idx = farthest_point_sample(pc, 100) # [B, npoint]
+ pc = index_points(pc, fps_idx).squeeze(0)
+
+ pc = np.asanyarray(pc)
+ # ignore empty pc
+ if np.all(pc == 0):
+ continue
+
+ n_points = pc.shape[0]
+ collision_object = []
+ collision_manager = trimesh.collision.CollisionManager()
+
+ # Construct collision objects
+ for i in range(n_points):
+ extents = voxel_extents
+ transform = np.eye(4)
+ transform[:3, 3] = pc[i, :3]
+ voxel = trimesh.primitives.Box(extents=extents, transform=transform)
+ collision_object.append((voxel, extents, transform))
+
+ # Add to collision manager
+ for i, (voxel, _, _) in enumerate(collision_object):
+ collision_manager.add_object("voxel_{}".format(i), voxel)
+
+ if oi == 0:
+ this_collision_manager = collision_manager
+ this_collision_object = collision_object
+ else:
+ other_collision_managers.append(collision_manager)
+ other_collision_objects.append(collision_object)
+
+ collisions = []
+ for i, cm_i in enumerate(other_collision_managers):
+ if this_collision_manager.in_collision_other(cm_i):
+ collisions.append(i)
+
+ if visualize:
+ visualize_collision_objects(this_collision_object + other_collision_objects[i])
+
+ return collisions
+
+
+def visualize_collision_objects(collision_objects):
+
+ # Convert from trimesh to open3d
+ meshes_o3d = []
+ for elem in collision_objects:
+ (voxel, extents, transform) = elem
+ voxel_o3d = open3d.geometry.TriangleMesh.create_box(width=extents[0], height=extents[1],
+ depth=extents[2])
+ voxel_o3d.compute_vertex_normals()
+ voxel_o3d.paint_uniform_color([0.8, 0.2, 0])
+ voxel_o3d.transform(transform)
+ meshes_o3d.append(voxel_o3d)
+ meshes = meshes_o3d
+
+ vis = open3d.visualization.Visualizer()
+ vis.create_window()
+
+ for mesh in meshes:
+ vis.add_geometry(mesh)
+
+ vis.run()
+ vis.destroy_window()
+
+
+# def test_collision(pc):
+# n_points = pc.shape[0]
+# voxel_extents = [0.005] * 3
+# collision_objects = []
+# collision_manager = trimesh.collision.CollisionManager()
+#
+# # Construct collision objects
+# for i in range(n_points):
+# extents = voxel_extents
+# transform = np.eye(4)
+# transform[:3, 3] = pc[i, :3]
+# voxel = trimesh.primitives.Box(extents=extents, transform=transform)
+# collision_objects.append((voxel, extents, transform))
+#
+# # Add to collision manager
+# for i, (voxel, _, _) in enumerate(collision_objects):
+# collision_manager.add_object("voxel_{}".format(i), voxel)
+#
+# for i, (voxel, _, _) in enumerate(collision_objects):
+# c, names = collision_manager.in_collision_single(voxel, return_names=True)
+# if c:
+# print(i, names)
+#
+# # Convert from trimesh to open3d
+# meshes_o3d = []
+# for elem in collision_objects:
+# (voxel, extents, transform) = elem
+# voxel_o3d = open3d.geometry.TriangleMesh.create_box(width=extents[0], height=extents[1],
+# depth=extents[2])
+# voxel_o3d.compute_vertex_normals()
+# voxel_o3d.paint_uniform_color([0.8, 0.2, 0])
+# voxel_o3d.transform(transform)
+# meshes_o3d.append(voxel_o3d)
+# meshes = meshes_o3d
+#
+# vis = open3d.visualization.Visualizer()
+# vis.create_window()
+#
+# for mesh in meshes:
+# vis.add_geometry(mesh)
+#
+# vis.run()
+# vis.destroy_window()
+#
+#
+# def test_collision2(pc):
+# pcd = open3d.geometry.PointCloud()
+# pcd.points = open3d.utility.Vector3dVector(pc)
+# pcd.estimate_normals()
+# open3d.visualization.draw_geometries([pcd])
+#
+# # poisson_mesh = open3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=8, width=0, scale=1.1, linear_fit=False)[0]
+# # bbox = pcd.get_axis_aligned_bounding_box()
+# # p_mesh_crop = poisson_mesh.crop(bbox)
+# # open3d.visualization.draw_geometries([p_mesh_crop, pcd])
+#
+# distances = pcd.compute_nearest_neighbor_distance()
+# avg_dist = np.mean(distances)
+# radius = 3 * avg_dist
+# bpa_mesh = open3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(pcd, open3d.utility.DoubleVector(
+# [radius, radius * 2]))
+# dec_mesh = bpa_mesh.simplify_quadric_decimation(100000)
+# dec_mesh.remove_degenerate_triangles()
+# dec_mesh.remove_duplicated_triangles()
+# dec_mesh.remove_duplicated_vertices()
+# dec_mesh.remove_non_manifold_edges()
+# open3d.visualization.draw_geometries([dec_mesh, pcd])
+# open3d.visualization.draw_geometries([dec_mesh])
+
+
+def make_gifs(imgs, save_path, texts=None, numpy_img=True, duration=10):
+ gif_filename = os.path.join(save_path)
+ pil_imgs = []
+ for i, img in enumerate(imgs):
+ if numpy_img:
+ img = Image.fromarray(img)
+ if texts:
+ text = texts[i]
+ draw = ImageDraw.Draw(img)
+ font = ImageFont.truetype("FreeMono.ttf", 40)
+ draw.text((0, 0), text, (120, 120, 120), font=font)
+ pil_imgs.append(img)
+
+ pil_imgs[0].save(gif_filename, save_all=True,
+ append_images=pil_imgs[1:], optimize=True,
+ duration=duration*len(pil_imgs), loop=0)
+
+
+def save_img(img, save_path, text=None, numpy_img=True):
+ if numpy_img:
+ img = Image.fromarray(img)
+ if text:
+ draw = ImageDraw.Draw(img)
+ font = ImageFont.truetype("FreeMono.ttf", 40)
+ draw.text((0, 0), text, (120, 120, 120), font=font)
+ img.save(save_path)
+
+
+def move_one_object_pc(obj_xyz, obj_rgb, struct_params, object_params, euler_angles=False):
+ struct_params = np.asanyarray(struct_params)
+ object_params = np.asanyarray(object_params)
+
+ R_struct = np.eye(4)
+ if not euler_angles:
+ R_struct[:3, :3] = struct_params[3:].reshape(3, 3)
+ else:
+ R_struct[:3, :3] = tra.euler_matrix(*struct_params[3:])[:3, :3]
+ R_obj = np.eye(4)
+ if not euler_angles:
+ R_obj[:3, :3] = object_params[3:].reshape(3, 3)
+ else:
+ R_obj[:3, :3] = tra.euler_matrix(*object_params[3:])[:3, :3]
+
+ T_struct = R_struct
+ T_struct[:3, 3] = [struct_params[0], struct_params[1], struct_params[2]]
+
+ # translate to structure frame
+ t = np.eye(4)
+ obj_center = torch.mean(obj_xyz, dim=0)
+ t[:3, 3] = [object_params[0] - obj_center[0], object_params[1] - obj_center[1], object_params[2] - obj_center[2]]
+ new_obj_xyz = trimesh.transform_points(obj_xyz, t)
+
+ # rotate in place
+ R = R_obj
+ obj_center = np.mean(new_obj_xyz, axis=0)
+ centered_obj_xyz = new_obj_xyz - obj_center
+ new_centered_obj_xyz = trimesh.transform_points(centered_obj_xyz, R, translate=True)
+ new_obj_xyz = new_centered_obj_xyz + obj_center
+
+ # transform to the global frame from the structure frame
+ new_obj_xyz = trimesh.transform_points(new_obj_xyz, T_struct)
+
+ # convert back to torch
+ new_obj_xyz = torch.tensor(new_obj_xyz, dtype=obj_xyz.dtype)
+
+ return new_obj_xyz, obj_rgb
+
+
+def move_one_object_pc_no_struct(obj_xyz, obj_rgb, object_params, euler_angles=False):
+ object_params = np.asanyarray(object_params)
+
+ R_obj = np.eye(4)
+ if not euler_angles:
+ R_obj[:3, :3] = object_params[3:].reshape(3, 3)
+ else:
+ R_obj[:3, :3] = tra.euler_matrix(*object_params[3:])[:3, :3]
+
+ t = np.eye(4)
+ obj_center = torch.mean(obj_xyz, dim=0)
+ t[:3, 3] = [object_params[0] - obj_center[0], object_params[1] - obj_center[1], object_params[2] - obj_center[2]]
+ new_obj_xyz = trimesh.transform_points(obj_xyz, t)
+
+ # rotate in place
+ R = R_obj
+ obj_center = np.mean(new_obj_xyz, axis=0)
+ centered_obj_xyz = new_obj_xyz - obj_center
+ new_centered_obj_xyz = trimesh.transform_points(centered_obj_xyz, R, translate=True)
+ new_obj_xyz = new_centered_obj_xyz + obj_center
+
+ # convert back to torch
+ new_obj_xyz = torch.tensor(new_obj_xyz, dtype=obj_xyz.dtype)
+
+ return new_obj_xyz, obj_rgb
+
+
+def modify_language(sentence, radius=None, position_x=None, position_y=None, rotation=None, shape=None):
+ # "radius": [0.0, 0.5, 3], "position_x": [-0.1, 1.0, 3], "position_y": [-0.5, 0.5, 3], "rotation": [-3.15, 3.15, 4]
+
+ sentence = copy.deepcopy(sentence)
+ for pi, pair in enumerate(sentence):
+ if radius is not None and len(pair) == 2 and pair[1] == "radius":
+ sentence[pi] = (radius, 'radius')
+ if position_y is not None and len(pair) == 2 and pair[1] == "position_y":
+ sentence[pi] = (position_y, 'position_y')
+ if position_x is not None and len(pair) == 2 and pair[1] == "position_x":
+ sentence[pi] = (position_x, 'position_x')
+ if rotation is not None and len(pair) == 2 and pair[1] == "rotation":
+ sentence[pi] = (rotation, 'rotation')
+ if shape is not None and len(pair) == 2 and pair[1] == "shape":
+ sentence[pi] = (shape, 'shape')
+
+ return sentence
+
+
+def sample_gaussians(mus, sigmas, sample_size):
+ # mus: [number of individual gaussians]
+ # sigmas: [number of individual gaussians]
+ normal = torch.distributions.Normal(mus, sigmas)
+ samples = normal.sample((sample_size,))
+ # samples: [sample_size, number of individual gaussians]
+ return samples
+
+
+def fit_gaussians(samples, sigma_eps=0.01):
+ # samples: [sample_size, number of individual gaussians]
+ num_gs = samples.shape[1]
+ mus = torch.mean(samples, dim=0)
+ sigmas = torch.std(samples, dim=0) + sigma_eps * torch.ones(num_gs)
+ # mus: [number of individual gaussians]
+ # sigmas: [number of individual gaussians]
+ return mus, sigmas
+
+
+def show_pcs_with_trimesh(obj_xyzs, obj_rgbs, return_scene=False):
+ vis_pcs = [trimesh.PointCloud(obj_xyz, colors=np.concatenate([obj_rgb * 255, np.ones([obj_rgb.shape[0], 1]) * 255], axis=-1)) for
+ obj_xyz, obj_rgb in zip(obj_xyzs, obj_rgbs)]
+ scene = trimesh.Scene()
+ # add the coordinate frame first
+ geom = trimesh.creation.axis(0.01)
+ # scene.add_geometry(geom)
+ table = trimesh.creation.box(extents=[1.0, 1.0, 0.02])
+ table.apply_translation([0.5, 0, -0.01])
+ table.visual.vertex_colors = [150, 111, 87, 125]
+ scene.add_geometry(table)
+ # bounds = trimesh.creation.box(extents=[4.0, 4.0, 4.0])
+ bounds = trimesh.creation.icosphere(subdivisions=3, radius=3.1)
+ bounds.apply_translation([0, 0, 0])
+ bounds.visual.vertex_colors = [30, 30, 30, 30]
+ # scene.add_geometry(bounds)
+ scene.add_geometry(vis_pcs)
+ RT_4x4 = np.array([[-0.39560353822208355, -0.9183993826406329, 0.006357240869497738, 0.2651463080169481],
+ [-0.797630370081598, 0.3401340617616391, -0.4980909683511864, 0.2225696480721997],
+ [0.45528412367406523, -0.2021172778236285, -0.8671014777611122, 0.9449050652025951],
+ [0.0, 0.0, 0.0, 1.0]])
+ RT_4x4 = np.linalg.inv(RT_4x4)
+ RT_4x4 = RT_4x4 @ np.diag([1, -1, -1, 1])
+ scene.camera_transform = RT_4x4
+ if return_scene:
+ return scene
+ else:
+ scene.show()
+
+
+def show_pcs_with_predictions(xyz, rgb, gts, predictions, add_coordinate_frame=False, return_buffer=False, add_table=True, side_view=True):
+ """ Display point clouds """
+
+ assert len(gts) == len(predictions) == len(xyz) == len(rgb)
+
+ unordered_pc = np.concatenate(xyz, axis=0)
+ unordered_rgb = np.concatenate(rgb, axis=0)
+ pcd = open3d.geometry.PointCloud()
+ pcd.points = open3d.utility.Vector3dVector(unordered_pc)
+ pcd.colors = open3d.utility.Vector3dVector(unordered_rgb)
+
+ vis = open3d.visualization.Visualizer()
+ vis.create_window()
+ vis.add_geometry(pcd)
+
+ if add_table:
+ table_color = [0.7, 0.7, 0.7]
+ origin = [0, -0.5, -0.05]
+ table = open3d.geometry.TriangleMesh.create_box(width=1.0, height=1.0, depth=0.02)
+ table.paint_uniform_color(table_color)
+ table.translate(origin)
+ vis.add_geometry(table)
+
+ if add_coordinate_frame:
+ mesh_frame = open3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=[0, 0, 0])
+ vis.add_geometry(mesh_frame)
+
+ for i in range(len(xyz)):
+ pred_color = [0.0, 1.0, 0] if predictions[i] else [1.0, 0.0, 0]
+ gt_color = [0.0, 1.0, 0] if gts[i] else [1.0, 0.0, 0]
+ origin = torch.mean(xyz[i], dim=0)
+ origin[2] += 0.02
+ pred_vis = open3d.geometry.TriangleMesh.create_torus(torus_radius=0.02, tube_radius=0.01)
+ pred_vis.paint_uniform_color(pred_color)
+ pred_vis.translate(origin)
+ gt_vis = open3d.geometry.TriangleMesh.create_sphere(radius=0.01)
+ gt_vis.paint_uniform_color(gt_color)
+ gt_vis.translate(origin)
+ vis.add_geometry(pred_vis)
+ vis.add_geometry(gt_vis)
+
+ if side_view:
+ open3d_set_side_view(vis)
+
+ if return_buffer:
+ vis.poll_events()
+ vis.update_renderer()
+ buffer = vis.capture_screen_float_buffer(False)
+ vis.destroy_window()
+ return buffer
+ else:
+ vis.run()
+ vis.destroy_window()
+
+
+def show_pcs_with_only_predictions(xyz, rgb, gts, predictions, add_coordinate_frame=False, return_buffer=False, add_table=True, side_view=True):
+ """ Display point clouds """
+
+ assert len(gts) == len(predictions) == len(xyz) == len(rgb)
+
+ unordered_pc = np.concatenate(xyz, axis=0)
+ unordered_rgb = np.concatenate(rgb, axis=0)
+ pcd = open3d.geometry.PointCloud()
+ pcd.points = open3d.utility.Vector3dVector(unordered_pc)
+ pcd.colors = open3d.utility.Vector3dVector(unordered_rgb)
+
+ vis = open3d.visualization.Visualizer()
+ vis.create_window()
+ vis.add_geometry(pcd)
+
+ if add_table:
+ table_color = [0.7, 0.7, 0.7]
+ origin = [0, -0.5, -0.05]
+ table = open3d.geometry.TriangleMesh.create_box(width=1.0, height=1.0, depth=0.02)
+ table.paint_uniform_color(table_color)
+ table.translate(origin)
+ vis.add_geometry(table)
+
+ if add_coordinate_frame:
+ mesh_frame = open3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=[0, 0, 0])
+ vis.add_geometry(mesh_frame)
+
+ for i in range(len(xyz)):
+ pred_color = [0.0, 1.0, 0] if predictions[i] else [1.0, 0.0, 0]
+ pcd = open3d.geometry.PointCloud()
+ pcd.points = open3d.utility.Vector3dVector(xyz[i])
+ pcd.colors = open3d.utility.Vector3dVector(np.tile(np.array(pred_color, dtype=np.float), (xyz[i].shape[0], 1)))
+ # pcd = pcd.uniform_down_sample(10)
+ # vis.add_geometry(pcd)
+
+ obb = pcd.get_axis_aligned_bounding_box()
+ obb.color = pred_color
+ vis.add_geometry(obb)
+
+
+ # origin = torch.mean(xyz[i], dim=0)
+ # origin[2] += 0.02
+ # pred_vis = open3d.geometry.TriangleMesh.create_torus(torus_radius=0.02, tube_radius=0.01)
+ # pred_vis.paint_uniform_color(pred_color)
+ # pred_vis.translate(origin)
+ # gt_vis = open3d.geometry.TriangleMesh.create_sphere(radius=0.01)
+ # gt_vis.paint_uniform_color(gt_color)
+ # gt_vis.translate(origin)
+ # vis.add_geometry(pred_vis)
+ # vis.add_geometry(gt_vis)
+
+ if side_view:
+ open3d_set_side_view(vis)
+
+ if return_buffer:
+ vis.poll_events()
+ vis.update_renderer()
+ buffer = vis.capture_screen_float_buffer(False)
+ vis.destroy_window()
+ return buffer
+ else:
+ vis.run()
+ vis.destroy_window()
+
+
+def test_new_vis(xyz, rgb):
+ pass
+# unordered_pc = np.concatenate(xyz, axis=0)
+# unordered_rgb = np.concatenate(rgb, axis=0)
+# pcd = open3d.geometry.PointCloud()
+# pcd.points = open3d.utility.Vector3dVector(unordered_pc)
+# pcd.colors = open3d.utility.Vector3dVector(unordered_rgb)
+#
+# # Some platforms do not require OpenGL implementations to support wide lines,
+# # so the renderer requires a custom shader to implement this: "unlitLine".
+# # The line_width field is only used by this shader; all other shaders ignore
+# # it.
+# # mat = o3d.visualization.rendering.Material()
+# # mat.shader = "unlitLine"
+# # mat.line_width = 10 # note that this is scaled with respect to pixels,
+# # # so will give different results depending on the
+# # # scaling values of your system
+# # mat.transmission = 0.5
+# open3d.visualization.draw({
+# "name": "pcd",
+# "geometry": pcd,
+# # "material": mat
+# })
+#
+# for i in range(len(xyz)):
+# pred_color = [0.0, 1.0, 0] if predictions[i] else [1.0, 0.0, 0]
+# pcd = open3d.geometry.PointCloud()
+# pcd.points = open3d.utility.Vector3dVector(xyz[i])
+# pcd.colors = open3d.utility.Vector3dVector(np.tile(np.array(pred_color, dtype=np.float), (xyz[i].shape[0], 1)))
+# # pcd = pcd.uniform_down_sample(10)
+# # vis.add_geometry(pcd)
+#
+# obb = pcd.get_axis_aligned_bounding_box()
+# obb.color = pred_color
+# vis.add_geometry(obb)
+
+
+def show_pcs(xyz, rgb, add_coordinate_frame=False, side_view=False, add_table=True):
+ """ Display point clouds """
+
+ unordered_pc = np.concatenate(xyz, axis=0)
+ unordered_rgb = np.concatenate(rgb, axis=0)
+ pcd = open3d.geometry.PointCloud()
+ pcd.points = open3d.utility.Vector3dVector(unordered_pc)
+ pcd.colors = open3d.utility.Vector3dVector(unordered_rgb)
+
+ if add_table:
+ table_color = [0.78, 0.64, 0.44]
+ origin = [0, -0.5, -0.02]
+ table = open3d.geometry.TriangleMesh.create_box(width=1.0, height=1.0, depth=0.001)
+ table.paint_uniform_color(table_color)
+ table.translate(origin)
+
+ if not add_coordinate_frame:
+ vis = open3d.visualization.Visualizer()
+ vis.create_window()
+ vis.add_geometry(pcd)
+ if add_table:
+ vis.add_geometry(table)
+ if side_view:
+ open3d_set_side_view(vis)
+ vis.run()
+ vis.destroy_window()
+ else:
+ mesh_frame = open3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=[0, 0, 0])
+ # open3d.visualization.draw_geometries([pcd, mesh_frame])
+ vis = open3d.visualization.Visualizer()
+ vis.create_window()
+ vis.add_geometry(pcd)
+ vis.add_geometry(mesh_frame)
+ if add_table:
+ vis.add_geometry(table)
+ if side_view:
+ open3d_set_side_view(vis)
+ vis.run()
+ vis.destroy_window()
+
+
+def show_pcs_color_order(xyzs, rgbs, add_coordinate_frame=False, side_view=False, add_table=True, save_path=None, texts=None, visualize=False):
+
+ rgb_colors = get_rgb_colors()
+
+ order_rgbs = []
+ for i, xyz in enumerate(xyzs):
+ order_rgbs.append(np.tile(np.array(rgb_colors[i][1], dtype=np.float), (xyz.shape[0], 1)))
+
+ if visualize:
+ show_pcs(xyzs, order_rgbs, add_coordinate_frame=add_coordinate_frame, side_view=side_view, add_table=add_table)
+ if save_path:
+ if not texts:
+ save_pcs(xyzs, order_rgbs, save_path=save_path, add_coordinate_frame=add_coordinate_frame, side_view=side_view, add_table=add_table)
+ if texts:
+ buffer = save_pcs(xyzs, order_rgbs, add_coordinate_frame=add_coordinate_frame,
+ side_view=side_view, add_table=add_table, return_buffer=True)
+ img = np.uint8(np.asarray(buffer) * 255)
+ img = Image.fromarray(img)
+ draw = ImageDraw.Draw(img)
+ font = ImageFont.truetype("FreeMono.ttf", 20)
+ for it, text in enumerate(texts):
+ draw.text((0, it*20), text, (120, 120, 120), font=font)
+ img.save(save_path)
+
+
+def get_rgb_colors():
+ rgb_colors = []
+ # each color is a tuple of (name, (r,g,b))
+ for name, hex in matplotlib.colors.cnames.items():
+ rgb_colors.append((name, matplotlib.colors.to_rgb(hex)))
+
+ rgb_colors = sorted(rgb_colors, key=lambda x: x[0])
+
+ priority_colors = [('red', (1.0, 0.0, 0.0)), ('green', (0.0, 1.0, 0.0)), ('blue', (0.0, 0.0, 1.0)), ('orange', (1.0, 0.6470588235294118, 0.0)), ('purple', (0.5019607843137255, 0.0, 0.5019607843137255)), ('magenta', (1.0, 0.0, 1.0)),]
+ rgb_colors = priority_colors + rgb_colors
+
+ return rgb_colors
+
+
+def open3d_set_side_view(vis):
+ ctr = vis.get_view_control()
+ # ctr.set_front([-0.61959040621518757, 0.46765094085676973, 0.63040489055992976])
+ # ctr.set_lookat([0.28810001969337462, 0.10746435821056366, 0.23499999999999999])
+ # ctr.set_up([0.64188154672853504, -0.16037991603449936, 0.74984422549096852])
+ # ctr.set_zoom(0.7)
+ # ctr.rotate(10.0, 0.0)
+
+ # ctr.set_front([ -0.51720189814974493, 0.55636089622063711, 0.65035740151617438 ])
+ # ctr.set_lookat([ 0.23103321183824999, 0.26154772406860449, 0.15131956132592411 ])
+ # ctr.set_up([ 0.47073865286968591, -0.44969907810742304, 0.75906248744340343 ])
+ # ctr.set_zoom(3)
+
+ # ctr.set_front([-0.86019269757539152, 0.40355968763418076, 0.31178213796587784])
+ # ctr.set_lookat([0.28810001969337462, 0.10746435821056366, 0.23499999999999999])
+ # ctr.set_up([0.30587875107201218, -0.080905438599338214, 0.94862663869811026])
+ # ctr.set_zoom(0.69999999999999996)
+
+ # ctr.set_front([0.40466417238365116, 0.019007526352692254, 0.91426780624224468])
+ # ctr.set_lookat([0.61287602731590907, 0.010181152776318789, -0.073166629933366326])
+ # ctr.set_up([-0.91444954965885639, 0.0025306059632757057, 0.40469200283941076])
+ # ctr.set_zoom(0.84000000000000008)
+
+ ctr.set_front([-0.45528412367406523, 0.20211727782362851, 0.86710147776111224])
+ ctr.set_lookat([0.48308104105920047, 0.078726411326627957, -0.27298814087096795])
+ ctr.set_up([0.79763037008159798, -0.34013406176163907, 0.49809096835118638])
+ ctr.set_zoom(0.80000000000000004)
+
+ init_param = ctr.convert_to_pinhole_camera_parameters()
+ print("camera extrinsic", init_param.extrinsic.tolist())
+
+
+def save_pcs(xyz, rgb, save_path=None, return_buffer=False, add_coordinate_frame=False, side_view=False, add_table=True):
+
+ assert save_path or return_buffer, "provide path to save or set return_buffer to true"
+
+ unordered_pc = np.concatenate(xyz, axis=0)
+ unordered_rgb = np.concatenate(rgb, axis=0)
+ pcd = open3d.geometry.PointCloud()
+ pcd.points = open3d.utility.Vector3dVector(unordered_pc)
+ pcd.colors = open3d.utility.Vector3dVector(unordered_rgb)
+
+ vis = open3d.visualization.Visualizer()
+ vis.create_window()
+
+ vis.add_geometry(pcd)
+ vis.update_geometry(pcd)
+
+ if add_table:
+ table_color = [0.7, 0.7, 0.7]
+ origin = [0, -0.5, -0.03]
+ table = open3d.geometry.TriangleMesh.create_box(width=1.0, height=1.0, depth=0.02)
+ table.paint_uniform_color(table_color)
+ table.translate(origin)
+ vis.add_geometry(table)
+
+ if add_coordinate_frame:
+ mesh_frame = open3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=[0, 0, 0])
+ vis.add_geometry(mesh_frame)
+ vis.update_geometry(mesh_frame)
+
+ if side_view:
+ open3d_set_side_view(vis)
+
+ vis.poll_events()
+ vis.update_renderer()
+ if save_path:
+ vis.capture_screen_image(save_path)
+ elif return_buffer:
+ buffer = vis.capture_screen_float_buffer(False)
+
+ vis.destroy_window()
+
+ if return_buffer:
+ return buffer
+ else:
+ return None
+
+
+def get_initial_scene_idxs(dataset):
+ """
+ This function finds initial scenes from the dataset
+ :param dataset:
+ :return:
+ """
+
+ initial_scene2idx_t = {}
+ for idx in range(len(dataset)):
+ filename, t = dataset.get_data_index(idx)
+ if filename not in initial_scene2idx_t:
+ initial_scene2idx_t[filename] = (idx, t)
+ else:
+ if t > initial_scene2idx_t[filename][1]:
+ initial_scene2idx_t[filename] = (idx, t)
+ initial_scene_idxs = [initial_scene2idx_t[f][0] for f in initial_scene2idx_t]
+ return initial_scene_idxs
+
+
+def get_initial_scene_idxs_raw_data(data):
+ """
+ This function finds initial scenes from the dataset
+ :param dataset:
+ :return:
+ """
+
+ initial_scene2idx_t = {}
+ for idx in range(len(data)):
+ filename, t = data[idx]
+ if filename not in initial_scene2idx_t:
+ initial_scene2idx_t[filename] = (idx, t)
+ else:
+ if t > initial_scene2idx_t[filename][1]:
+ initial_scene2idx_t[filename] = (idx, t)
+ initial_scene_idxs = [initial_scene2idx_t[f][0] for f in initial_scene2idx_t]
+ return initial_scene_idxs
+
+
+def evaluate_target_object_predictions(all_gts, all_predictions, all_sentences, initial_scene_idxs, tokenizer):
+ """
+ This function evaluates target object predictions
+
+ :param all_gts: a list of predictions for scenes. Each element is a list of booleans for objects in the scene
+ :param all_predictions:
+ :param all_sentences: a list of descriptions for scenes
+ :param initial_scene_idxs:
+ :param tokenizer:
+ :return:
+ """
+
+ # overall accuracy
+ print("\noverall accuracy")
+ report = classification_report(list(itertools.chain(*all_gts)), list(itertools.chain(*all_predictions)),
+ output_dict=True)
+ print(report)
+
+ # scene average
+ print("\naccuracy per scene")
+ acc_per_scene = []
+ for gts, preds in zip(all_gts, all_predictions):
+ acc_per_scene.append(sum(np.array(gts) == np.array(preds)) * 1.0 / len(gts))
+ print(np.mean(acc_per_scene))
+ plt.hist(acc_per_scene, 10, range=(0, 1), facecolor='g', alpha=0.75)
+ plt.xlabel('Accuracy')
+ plt.ylabel('# Scene')
+ plt.title('Predicting objects to be rearranged')
+ plt.xticks(np.linspace(0, 1, 11), np.linspace(0, 1, 11).round(1))
+ plt.grid(True)
+ plt.show()
+
+ # initial scene accuracy
+ print("\noverall accuracy for initial scenes")
+ tested_initial_scene_idxs = [i for i in initial_scene_idxs if i < len(all_gts)]
+ initial_gts = [all_gts[i] for i in tested_initial_scene_idxs]
+ initial_predictions = [all_predictions[i] for i in tested_initial_scene_idxs]
+ report = classification_report(list(itertools.chain(*initial_gts)), list(itertools.chain(*initial_predictions)),
+ output_dict=True)
+ print(report)
+
+ # break down by the number of objects
+ print("\naccuracy for # objects in scene")
+ num_objects_in_scenes = np.array([len(gts) for gts in all_gts])
+ unique_num_objects = np.unique(num_objects_in_scenes)
+ acc_per_scene = np.array(acc_per_scene)
+ assert len(acc_per_scene) == len(num_objects_in_scenes)
+ for num_objects in unique_num_objects:
+ this_scene_idxs = [i for i in range(len(all_gts)) if len(all_gts[i]) == num_objects]
+ this_num_obj_gts = [all_gts[i] for i in this_scene_idxs]
+ this_num_obj_predictions = [all_predictions[i] for i in this_scene_idxs]
+ report = classification_report(list(itertools.chain(*this_num_obj_gts)), list(itertools.chain(*this_num_obj_predictions)),
+ output_dict=True)
+ print("{} objects".format(num_objects))
+ print(report)
+
+ # reference
+ print("\noverall accuracy break down")
+ direct_gts_by_type = defaultdict(list)
+ direct_preds_by_type = defaultdict(list)
+ d_anchor_gts_by_type = defaultdict(list)
+ d_anchor_preds_by_type = defaultdict(list)
+ c_anchor_gts_by_type = defaultdict(list)
+ c_anchor_preds_by_type = defaultdict(list)
+
+ for i, s in enumerate(all_sentences):
+ v, t = s[0]
+ if t[-2:] == "_c" or t[-2:] == "_d":
+ t = t[:-2]
+ if v != "MASK" and t in tokenizer.discrete_types:
+ # direct reference
+ direct_gts_by_type[t].extend(all_gts[i])
+ direct_preds_by_type[t].extend(all_predictions[i])
+ else:
+ if v == "MASK":
+ # discrete anchor
+ d_anchor_gts_by_type[t].extend(all_gts[i])
+ d_anchor_preds_by_type[t].extend(all_predictions[i])
+ else:
+ c_anchor_gts_by_type[t].extend(all_gts[i])
+ c_anchor_preds_by_type[t].extend(all_predictions[i])
+
+ print("direct")
+ for t in direct_gts_by_type:
+ report = classification_report(direct_gts_by_type[t], direct_preds_by_type[t], output_dict=True)
+ print(t, report)
+
+ print("discrete anchor")
+ for t in d_anchor_gts_by_type:
+ report = classification_report(d_anchor_gts_by_type[t], d_anchor_preds_by_type[t], output_dict=True)
+ print(t, report)
+
+ print("continuous anchor")
+ for t in c_anchor_gts_by_type:
+ report = classification_report(c_anchor_gts_by_type[t], c_anchor_preds_by_type[t], output_dict=True)
+ print(t, report)
+
+ # break down by object class
+
+
+def combine_and_sample_xyzs(xyzs, rgbs, center=None, radius=0.5, num_pts=1024):
+ xyz = torch.cat(xyzs, dim=0)
+ rgb = torch.cat(rgbs, dim=0)
+
+ if center is not None:
+ center = center.repeat(xyz.shape[0], 1)
+ dists = torch.linalg.norm(xyz - center, dim=-1)
+ idx = dists < radius
+ xyz = xyz[idx]
+ rgb = rgb[idx]
+
+ idx = np.random.randint(0, xyz.shape[0], num_pts)
+ xyz = xyz[idx]
+ rgb = rgb[idx]
+
+ return xyz, rgb
+
+
+def evaluate_prior_prediction(gts, predictions, keys, debug=False):
+ """
+ :param gts: expect a list of tensors
+ :param predictions: expect a list of tensor
+ :return:
+ """
+
+ total_mses = 0
+ obj_dists = []
+ struct_dists = []
+ for key in keys:
+ # predictions[key][0]: [batch_size * number_of_objects, dim]
+ predictions_for_key = torch.cat(predictions[key], dim=0)
+ # gts[key][0]: [batch_size * number_of_objects, dim]
+ gts_for_key = torch.cat(gts[key], dim=0)
+
+ assert gts_for_key.shape == predictions_for_key.shape
+
+ target_indices = gts_for_key != -100
+ gts_for_key = gts_for_key[target_indices]
+ predictions_for_key = predictions_for_key[target_indices]
+ num_objects = len(predictions_for_key)
+
+ distances = predictions_for_key - gts_for_key
+
+ me = torch.mean(torch.abs(distances))
+ mse = torch.mean(distances ** 2)
+ med = torch.median(torch.abs(distances))
+
+ if "obj_x" in key or "obj_y" in key or "obj_z" in key:
+ obj_dists.append(distances)
+ if "struct_x" in key or "struct_y" in key or "struct_z" in key:
+ struct_dists.append(distances)
+
+ if debug:
+ print("Groundtruths:")
+ print(gts_for_key[:100])
+ print("Predictions")
+ print(predictions_for_key[:100])
+
+ print("{} ME for {} objects: {}".format(key, num_objects, me))
+ print("{} MSE for {} objects: {}".format(key, num_objects, mse))
+ print("{} MEDIAN for {} objects: {}".format(key, num_objects, med))
+ total_mses += mse
+
+ if "theta" in key:
+ predictions_for_key = predictions_for_key.reshape(-1, 3, 3)
+ gts_for_key = gts_for_key.reshape(-1, 3, 3)
+ geodesic_distance = compute_geodesic_distance_from_two_matrices(predictions_for_key, gts_for_key)
+ geodesic_distance = torch.rad2deg(geodesic_distance)
+ mgd = torch.mean(geodesic_distance)
+ stdgd = torch.std(geodesic_distance)
+ megd = torch.median(geodesic_distance)
+ print("{} Mean and std Geodesic Distance for {} objects: {} +- {}".format(key, num_objects, mgd, stdgd))
+ print("{} Median Geodesic Distance for {} objects: {}".format(key, num_objects, megd))
+
+ if obj_dists:
+ euclidean_dists = torch.sqrt(obj_dists[0]**2 + obj_dists[1]**2 + obj_dists[2]**2)
+ me = torch.mean(euclidean_dists)
+ stde = torch.std(euclidean_dists)
+ med = torch.median(euclidean_dists)
+ print("Mean and std euclidean dist for {} objects: {} +- {}".format(len(euclidean_dists), me, stde))
+ print("Median euclidean dist for {} objects: {}".format(len(euclidean_dists), med))
+ if struct_dists:
+ euclidean_dists = torch.sqrt(struct_dists[0] ** 2 + struct_dists[1] ** 2 + struct_dists[2] ** 2)
+ me = torch.mean(euclidean_dists)
+ stde = torch.std(euclidean_dists)
+ med = torch.median(euclidean_dists)
+ print("Mean euclidean dist for {} structures: {} +- {}".format(len(euclidean_dists), me, stde))
+ print("Median euclidean dist for {} structures: {}".format(len(euclidean_dists), med))
+
+ return -total_mses
+
+
+def generate_square_subsequent_mask(sz):
+ mask = (torch.triu(torch.ones((sz, sz))) == 1).transpose(0, 1)
+ mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
+ return mask
+
+
+def visualize_occ(points, occupancies, in_num_pts=1000, out_num_pts=1000, visualize=False, threshold=0.5):
+
+ rix = np.random.permutation(points.shape[0])
+ vis_points = points[rix]
+ vis_occupancies = occupancies[rix]
+ in_pc = vis_points[vis_occupancies.squeeze() > threshold, :][:in_num_pts]
+ out_pc = vis_points[vis_occupancies.squeeze() < threshold, :][:out_num_pts]
+
+ if len(in_pc) == 0:
+ print("no in points")
+ if len(out_pc) == 0:
+ print("no out points")
+
+ in_pc = trimesh.PointCloud(in_pc)
+ out_pc = trimesh.PointCloud(out_pc)
+ in_pc.colors = np.tile((255, 0, 0, 255), (in_pc.vertices.shape[0], 1))
+ out_pc.colors = np.tile((255, 255, 0, 120), (out_pc.vertices.shape[0], 1))
+
+ if visualize:
+ scene = trimesh.Scene([in_pc, out_pc])
+ scene.show()
+
+ return in_pc, out_pc
+
+
+def save_dict_to_h5(dict_data, filename):
+ fh = h5py.File(filename, 'w')
+ for k in dict_data:
+ key_data = dict_data[k]
+ if key_data is None:
+ raise RuntimeError('data was not properly populated')
+ # if type(key_data) is dict:
+ # key_data = json.dumps(key_data, sort_keys=True)
+ try:
+ fh.create_dataset(k, data=key_data)
+ except TypeError as e:
+ print("Failure on key", k)
+ print(key_data)
+ print(e)
+ raise e
+ fh.close()
+
+
+def load_h5_key(h5, key):
+ if key in h5:
+ return h5[key][()]
+ elif "json_" + key in h5:
+ return json.loads(h5["json_" + key][()])
+ else:
+ return None
+
+
+def load_dict_from_h5(filename):
+ h5 = h5py.File(filename, "r")
+ data_dict = {}
+ for k in h5:
+ data_dict[k] = h5[k][()]
+ return data_dict
+
diff --git a/src/StructDiffusion/utils/rotation_continuity.py b/src/StructDiffusion/utils/rotation_continuity.py
new file mode 100755
index 0000000000000000000000000000000000000000..7e0e92ea44c590d95e2dc39b9f3641963ad4eab7
--- /dev/null
+++ b/src/StructDiffusion/utils/rotation_continuity.py
@@ -0,0 +1,395 @@
+import torch
+import torch.nn as nn
+from torch.autograd import Variable
+import numpy as np
+
+# Code adapted from the rotation continuity repo (https://github.com/papagina/RotationContinuity)
+
+#T_poses num*3
+#r_matrix batch*3*3
+def compute_pose_from_rotation_matrix(T_pose, r_matrix):
+ batch=r_matrix.shape[0]
+ joint_num = T_pose.shape[0]
+ r_matrices = r_matrix.view(batch,1, 3,3).expand(batch,joint_num, 3,3).contiguous().view(batch*joint_num,3,3)
+ src_poses = T_pose.view(1,joint_num,3,1).expand(batch,joint_num,3,1).contiguous().view(batch*joint_num,3,1)
+
+ out_poses = torch.matmul(r_matrices, src_poses) #(batch*joint_num)*3*1
+
+ return out_poses.view(batch, joint_num, 3)
+
+# batch*n
+def normalize_vector( v, return_mag =False):
+ batch=v.shape[0]
+ v_mag = torch.sqrt(v.pow(2).sum(1))# batch
+ v_mag = torch.max(v_mag, torch.autograd.Variable(torch.FloatTensor([1e-8]).cuda()))
+ v_mag = v_mag.view(batch,1).expand(batch,v.shape[1])
+ v = v/v_mag
+ if(return_mag==True):
+ return v, v_mag[:,0]
+ else:
+ return v
+
+# u, v batch*n
+def cross_product( u, v):
+ batch = u.shape[0]
+ #print (u.shape)
+ #print (v.shape)
+ i = u[:,1]*v[:,2] - u[:,2]*v[:,1]
+ j = u[:,2]*v[:,0] - u[:,0]*v[:,2]
+ k = u[:,0]*v[:,1] - u[:,1]*v[:,0]
+
+ out = torch.cat((i.view(batch,1), j.view(batch,1), k.view(batch,1)),1)#batch*3
+
+ return out
+
+
+#poses batch*6
+#poses
+def compute_rotation_matrix_from_ortho6d(ortho6d):
+ x_raw = ortho6d[:,0:3]#batch*3
+ y_raw = ortho6d[:,3:6]#batch*3
+
+ x = normalize_vector(x_raw) #batch*3
+ z = cross_product(x,y_raw) #batch*3
+ z = normalize_vector(z)#batch*3
+ y = cross_product(z,x)#batch*3
+
+ x = x.view(-1,3,1)
+ y = y.view(-1,3,1)
+ z = z.view(-1,3,1)
+ matrix = torch.cat((x,y,z), 2) #batch*3*3
+ return matrix
+
+
+#in batch*6
+#out batch*5
+def stereographic_project(a):
+ dim = a.shape[1]
+ a = normalize_vector(a)
+ out = a[:,0:dim-1]/(1-a[:,dim-1])
+ return out
+
+
+
+#in a batch*5, axis int
+def stereographic_unproject(a, axis=None):
+ """
+ Inverse of stereographic projection: increases dimension by one.
+ """
+ batch=a.shape[0]
+ if axis is None:
+ axis = a.shape[1]
+ s2 = torch.pow(a,2).sum(1) #batch
+ ans = torch.autograd.Variable(torch.zeros(batch, a.shape[1]+1).cuda()) #batch*6
+ unproj = 2*a/(s2+1).view(batch,1).repeat(1,a.shape[1]) #batch*5
+ if(axis>0):
+ ans[:,:axis] = unproj[:,:axis] #batch*(axis-0)
+ ans[:,axis] = (s2-1)/(s2+1) #batch
+ ans[:,axis+1:] = unproj[:,axis:] #batch*(5-axis) # Note that this is a no-op if the default option (last axis) is used
+ return ans
+
+
+#a batch*5
+#out batch*3*3
+def compute_rotation_matrix_from_ortho5d(a):
+ batch = a.shape[0]
+ proj_scale_np = np.array([np.sqrt(2)+1, np.sqrt(2)+1, np.sqrt(2)]) #3
+ proj_scale = torch.autograd.Variable(torch.FloatTensor(proj_scale_np).cuda()).view(1,3).repeat(batch,1) #batch,3
+
+ u = stereographic_unproject(a[:, 2:5] * proj_scale, axis=0)#batch*4
+ norm = torch.sqrt(torch.pow(u[:,1:],2).sum(1)) #batch
+ u = u/ norm.view(batch,1).repeat(1,u.shape[1]) #batch*4
+ b = torch.cat((a[:,0:2], u),1)#batch*6
+ matrix = compute_rotation_matrix_from_ortho6d(b)
+ return matrix
+
+
+#quaternion batch*4
+def compute_rotation_matrix_from_quaternion( quaternion):
+ batch=quaternion.shape[0]
+
+
+ quat = normalize_vector(quaternion).contiguous()
+
+ qw = quat[...,0].contiguous().view(batch, 1)
+ qx = quat[...,1].contiguous().view(batch, 1)
+ qy = quat[...,2].contiguous().view(batch, 1)
+ qz = quat[...,3].contiguous().view(batch, 1)
+
+ # Unit quaternion rotation matrices computatation
+ xx = qx*qx
+ yy = qy*qy
+ zz = qz*qz
+ xy = qx*qy
+ xz = qx*qz
+ yz = qy*qz
+ xw = qx*qw
+ yw = qy*qw
+ zw = qz*qw
+
+ row0 = torch.cat((1-2*yy-2*zz, 2*xy - 2*zw, 2*xz + 2*yw), 1) #batch*3
+ row1 = torch.cat((2*xy+ 2*zw, 1-2*xx-2*zz, 2*yz-2*xw ), 1) #batch*3
+ row2 = torch.cat((2*xz-2*yw, 2*yz+2*xw, 1-2*xx-2*yy), 1) #batch*3
+
+ matrix = torch.cat((row0.view(batch, 1, 3), row1.view(batch,1,3), row2.view(batch,1,3)),1) #batch*3*3
+
+ return matrix
+
+#axisAngle batch*4 angle, x,y,z
+def compute_rotation_matrix_from_axisAngle( axisAngle):
+ batch = axisAngle.shape[0]
+
+ theta = torch.tanh(axisAngle[:,0])*np.pi #[-180, 180]
+ sin = torch.sin(theta*0.5)
+ axis = normalize_vector(axisAngle[:,1:4]) #batch*3
+ qw = torch.cos(theta*0.5)
+ qx = axis[:,0]*sin
+ qy = axis[:,1]*sin
+ qz = axis[:,2]*sin
+
+ # Unit quaternion rotation matrices computatation
+ xx = (qx*qx).view(batch,1)
+ yy = (qy*qy).view(batch,1)
+ zz = (qz*qz).view(batch,1)
+ xy = (qx*qy).view(batch,1)
+ xz = (qx*qz).view(batch,1)
+ yz = (qy*qz).view(batch,1)
+ xw = (qx*qw).view(batch,1)
+ yw = (qy*qw).view(batch,1)
+ zw = (qz*qw).view(batch,1)
+
+ row0 = torch.cat((1-2*yy-2*zz, 2*xy - 2*zw, 2*xz + 2*yw), 1) #batch*3
+ row1 = torch.cat((2*xy+ 2*zw, 1-2*xx-2*zz, 2*yz-2*xw ), 1) #batch*3
+ row2 = torch.cat((2*xz-2*yw, 2*yz+2*xw, 1-2*xx-2*yy), 1) #batch*3
+
+ matrix = torch.cat((row0.view(batch, 1, 3), row1.view(batch,1,3), row2.view(batch,1,3)),1) #batch*3*3
+
+ return matrix
+
+#axisAngle batch*3 (x,y,z)*theta
+def compute_rotation_matrix_from_Rodriguez( rod):
+ batch = rod.shape[0]
+
+ axis, theta = normalize_vector(rod, return_mag=True)
+
+ sin = torch.sin(theta)
+
+
+ qw = torch.cos(theta)
+ qx = axis[:,0]*sin
+ qy = axis[:,1]*sin
+ qz = axis[:,2]*sin
+
+ # Unit quaternion rotation matrices computatation
+ xx = (qx*qx).view(batch,1)
+ yy = (qy*qy).view(batch,1)
+ zz = (qz*qz).view(batch,1)
+ xy = (qx*qy).view(batch,1)
+ xz = (qx*qz).view(batch,1)
+ yz = (qy*qz).view(batch,1)
+ xw = (qx*qw).view(batch,1)
+ yw = (qy*qw).view(batch,1)
+ zw = (qz*qw).view(batch,1)
+
+ row0 = torch.cat((1-2*yy-2*zz, 2*xy - 2*zw, 2*xz + 2*yw), 1) #batch*3
+ row1 = torch.cat((2*xy+ 2*zw, 1-2*xx-2*zz, 2*yz-2*xw ), 1) #batch*3
+ row2 = torch.cat((2*xz-2*yw, 2*yz+2*xw, 1-2*xx-2*yy), 1) #batch*3
+
+ matrix = torch.cat((row0.view(batch, 1, 3), row1.view(batch,1,3), row2.view(batch,1,3)),1) #batch*3*3
+
+ return matrix
+
+#axisAngle batch*3 a,b,c
+def compute_rotation_matrix_from_hopf( hopf):
+ batch = hopf.shape[0]
+
+ theta = (torch.tanh(hopf[:,0])+1.0)*np.pi/2.0 #[0, pi]
+ phi = (torch.tanh(hopf[:,1])+1.0)*np.pi #[0,2pi)
+ tao = (torch.tanh(hopf[:,2])+1.0)*np.pi #[0,2pi)
+
+ qw = torch.cos(theta/2)*torch.cos(tao/2)
+ qx = torch.cos(theta/2)*torch.sin(tao/2)
+ qy = torch.sin(theta/2)*torch.cos(phi+tao/2)
+ qz = torch.sin(theta/2)*torch.sin(phi+tao/2)
+
+ # Unit quaternion rotation matrices computatation
+ xx = (qx*qx).view(batch,1)
+ yy = (qy*qy).view(batch,1)
+ zz = (qz*qz).view(batch,1)
+ xy = (qx*qy).view(batch,1)
+ xz = (qx*qz).view(batch,1)
+ yz = (qy*qz).view(batch,1)
+ xw = (qx*qw).view(batch,1)
+ yw = (qy*qw).view(batch,1)
+ zw = (qz*qw).view(batch,1)
+
+ row0 = torch.cat((1-2*yy-2*zz, 2*xy - 2*zw, 2*xz + 2*yw), 1) #batch*3
+ row1 = torch.cat((2*xy+ 2*zw, 1-2*xx-2*zz, 2*yz-2*xw ), 1) #batch*3
+ row2 = torch.cat((2*xz-2*yw, 2*yz+2*xw, 1-2*xx-2*yy), 1) #batch*3
+
+ matrix = torch.cat((row0.view(batch, 1, 3), row1.view(batch,1,3), row2.view(batch,1,3)),1) #batch*3*3
+
+ return matrix
+
+
+#euler batch*4
+#output cuda batch*3*3 matrices in the rotation order of XZ'Y'' (intrinsic) or YZX (extrinsic)
+def compute_rotation_matrix_from_euler(euler):
+ batch=euler.shape[0]
+
+ c1=torch.cos(euler[:,0]).view(batch,1)#batch*1
+ s1=torch.sin(euler[:,0]).view(batch,1)#batch*1
+ c2=torch.cos(euler[:,2]).view(batch,1)#batch*1
+ s2=torch.sin(euler[:,2]).view(batch,1)#batch*1
+ c3=torch.cos(euler[:,1]).view(batch,1)#batch*1
+ s3=torch.sin(euler[:,1]).view(batch,1)#batch*1
+
+ row1=torch.cat((c2*c3, -s2, c2*s3 ), 1).view(-1,1,3) #batch*1*3
+ row2=torch.cat((c1*s2*c3+s1*s3, c1*c2, c1*s2*s3-s1*c3), 1).view(-1,1,3) #batch*1*3
+ row3=torch.cat((s1*s2*c3-c1*s3, s1*c2, s1*s2*s3+c1*c3), 1).view(-1,1,3) #batch*1*3
+
+ matrix = torch.cat((row1, row2, row3), 1) #batch*3*3
+
+
+ return matrix
+
+
+#euler_sin_cos batch*6
+#output cuda batch*3*3 matrices in the rotation order of XZ'Y'' (intrinsic) or YZX (extrinsic)
+def compute_rotation_matrix_from_euler_sin_cos(euler_sin_cos):
+ batch=euler_sin_cos.shape[0]
+
+ s1 = euler_sin_cos[:,0].view(batch,1)
+ c1 = euler_sin_cos[:,1].view(batch,1)
+ s2 = euler_sin_cos[:,2].view(batch,1)
+ c2 = euler_sin_cos[:,3].view(batch,1)
+ s3 = euler_sin_cos[:,4].view(batch,1)
+ c3 = euler_sin_cos[:,5].view(batch,1)
+
+
+ row1=torch.cat((c2*c3, -s2, c2*s3 ), 1).view(-1,1,3) #batch*1*3
+ row2=torch.cat((c1*s2*c3+s1*s3, c1*c2, c1*s2*s3-s1*c3), 1).view(-1,1,3) #batch*1*3
+ row3=torch.cat((s1*s2*c3-c1*s3, s1*c2, s1*s2*s3+c1*c3), 1).view(-1,1,3) #batch*1*3
+
+ matrix = torch.cat((row1, row2, row3), 1) #batch*3*3
+
+
+ return matrix
+
+
+#matrices batch*3*3
+#both matrix are orthogonal rotation matrices
+#out theta between 0 to 180 degree batch
+def compute_geodesic_distance_from_two_matrices(m1, m2):
+ batch=m1.shape[0]
+ m = torch.bmm(m1, m2.transpose(1,2)) #batch*3*3
+
+ cos = ( m[:,0,0] + m[:,1,1] + m[:,2,2] - 1 )/2
+ cos = torch.min(cos, torch.autograd.Variable(torch.ones(batch).cuda()) )
+ cos = torch.max(cos, torch.autograd.Variable(torch.ones(batch).cuda())*-1 )
+
+
+ theta = torch.acos(cos)
+
+ #theta = torch.min(theta, 2*np.pi - theta)
+
+
+ return theta
+
+
+#matrices batch*3*3
+#both matrix are orthogonal rotation matrices
+#out theta between 0 to 180 degree batch
+def compute_angle_from_r_matrices(m):
+
+ batch=m.shape[0]
+
+ cos = ( m[:,0,0] + m[:,1,1] + m[:,2,2] - 1 )/2
+ cos = torch.min(cos, torch.autograd.Variable(torch.ones(batch).cuda()) )
+ cos = torch.max(cos, torch.autograd.Variable(torch.ones(batch).cuda())*-1 )
+
+ theta = torch.acos(cos)
+
+ return theta
+
+def get_sampled_rotation_matrices_by_quat(batch):
+ #quat = torch.autograd.Variable(torch.rand(batch,4).cuda())
+ quat = torch.autograd.Variable(torch.randn(batch, 4).cuda())
+ matrix = compute_rotation_matrix_from_quaternion(quat)
+ return matrix
+
+def get_sampled_rotation_matrices_by_hpof(batch):
+
+ theta = torch.autograd.Variable(torch.FloatTensor(np.random.uniform(0,1, batch)*np.pi).cuda()) #[0, pi]
+ phi = torch.autograd.Variable(torch.FloatTensor(np.random.uniform(0,2,batch)*np.pi).cuda()) #[0,2pi)
+ tao = torch.autograd.Variable(torch.FloatTensor(np.random.uniform(0,2,batch)*np.pi).cuda()) #[0,2pi)
+
+
+ qw = torch.cos(theta/2)*torch.cos(tao/2)
+ qx = torch.cos(theta/2)*torch.sin(tao/2)
+ qy = torch.sin(theta/2)*torch.cos(phi+tao/2)
+ qz = torch.sin(theta/2)*torch.sin(phi+tao/2)
+
+ # Unit quaternion rotation matrices computatation
+ xx = (qx*qx).view(batch,1)
+ yy = (qy*qy).view(batch,1)
+ zz = (qz*qz).view(batch,1)
+ xy = (qx*qy).view(batch,1)
+ xz = (qx*qz).view(batch,1)
+ yz = (qy*qz).view(batch,1)
+ xw = (qx*qw).view(batch,1)
+ yw = (qy*qw).view(batch,1)
+ zw = (qz*qw).view(batch,1)
+
+ row0 = torch.cat((1-2*yy-2*zz, 2*xy - 2*zw, 2*xz + 2*yw), 1) #batch*3
+ row1 = torch.cat((2*xy+ 2*zw, 1-2*xx-2*zz, 2*yz-2*xw ), 1) #batch*3
+ row2 = torch.cat((2*xz-2*yw, 2*yz+2*xw, 1-2*xx-2*yy), 1) #batch*3
+
+ matrix = torch.cat((row0.view(batch, 1, 3), row1.view(batch,1,3), row2.view(batch,1,3)),1) #batch*3*3
+
+ return matrix
+
+#axisAngle batch*4 angle, x,y,z
+def get_sampled_rotation_matrices_by_axisAngle( batch, return_quaternion=False):
+
+ theta = torch.autograd.Variable(torch.FloatTensor(np.random.uniform(-1,1, batch)*np.pi).cuda()) #[0, pi] #[-180, 180]
+ sin = torch.sin(theta)
+ axis = torch.autograd.Variable(torch.randn(batch, 3).cuda())
+ axis = normalize_vector(axis) #batch*3
+ qw = torch.cos(theta)
+ qx = axis[:,0]*sin
+ qy = axis[:,1]*sin
+ qz = axis[:,2]*sin
+
+ quaternion = torch.cat((qw.view(batch,1), qx.view(batch,1), qy.view(batch,1), qz.view(batch,1)), 1 )
+
+ # Unit quaternion rotation matrices computatation
+ xx = (qx*qx).view(batch,1)
+ yy = (qy*qy).view(batch,1)
+ zz = (qz*qz).view(batch,1)
+ xy = (qx*qy).view(batch,1)
+ xz = (qx*qz).view(batch,1)
+ yz = (qy*qz).view(batch,1)
+ xw = (qx*qw).view(batch,1)
+ yw = (qy*qw).view(batch,1)
+ zw = (qz*qw).view(batch,1)
+
+ row0 = torch.cat((1-2*yy-2*zz, 2*xy - 2*zw, 2*xz + 2*yw), 1) #batch*3
+ row1 = torch.cat((2*xy+ 2*zw, 1-2*xx-2*zz, 2*yz-2*xw ), 1) #batch*3
+ row2 = torch.cat((2*xz-2*yw, 2*yz+2*xw, 1-2*xx-2*yy), 1) #batch*3
+
+ matrix = torch.cat((row0.view(batch, 1, 3), row1.view(batch,1,3), row2.view(batch,1,3)),1) #batch*3*3
+
+ if(return_quaternion==True):
+ return matrix, quaternion
+ else:
+ return matrix
+
+
+
+
+
+
+
+
+
diff --git a/src/StructDiffusion/utils/torch_data.py b/src/StructDiffusion/utils/torch_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..eae8f23fd7eec4f6e2a49f52f25fad0f8bfc0bbf
--- /dev/null
+++ b/src/StructDiffusion/utils/torch_data.py
@@ -0,0 +1,185 @@
+r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers to
+collate samples fetched from dataset into Tensor(s).
+
+These **needs** to be in global scope since Py2 doesn't support serializing
+static methods.
+
+`default_collate` and `default_convert` are exposed to users via 'dataloader.py'.
+"""
+
+import torch
+import re
+import collections
+from torch._six import string_classes
+
+np_str_obj_array_pattern = re.compile(r'[SaUO]')
+
+
+def default_convert(data):
+ r"""
+ Function that converts each NumPy array element into a :class:`torch.Tensor`. If the input is a `Sequence`,
+ `Collection`, or `Mapping`, it tries to convert each element inside to a :class:`torch.Tensor`.
+ If the input is not an NumPy array, it is left unchanged.
+ This is used as the default function for collation when both `batch_sampler` and
+ `batch_size` are NOT defined in :class:`~torch.utils.data.DataLoader`.
+
+ The general input type to output type mapping is similar to that
+ of :func:`~torch.utils.data.default_collate`. See the description there for more details.
+
+ Args:
+ data: a single data point to be converted
+
+ Examples:
+ >>> # Example with `int`
+ >>> default_convert(0)
+ 0
+ >>> # Example with NumPy array
+ >>> # xdoctest: +SKIP
+ >>> default_convert(np.array([0, 1]))
+ tensor([0, 1])
+ >>> # Example with NamedTuple
+ >>> Point = namedtuple('Point', ['x', 'y'])
+ >>> default_convert(Point(0, 0))
+ Point(x=0, y=0)
+ >>> default_convert(Point(np.array(0), np.array(0)))
+ Point(x=tensor(0), y=tensor(0))
+ >>> # Example with List
+ >>> default_convert([np.array([0, 1]), np.array([2, 3])])
+ [tensor([0, 1]), tensor([2, 3])]
+ """
+ elem_type = type(data)
+ if isinstance(data, torch.Tensor):
+ return data
+ elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
+ and elem_type.__name__ != 'string_':
+ # array of string classes and object
+ if elem_type.__name__ == 'ndarray' \
+ and np_str_obj_array_pattern.search(data.dtype.str) is not None:
+ return data
+ return torch.as_tensor(data)
+ elif isinstance(data, collections.abc.Mapping):
+ try:
+ return elem_type({key: default_convert(data[key]) for key in data})
+ except TypeError:
+ # The mapping type may not support `__init__(iterable)`.
+ return {key: default_convert(data[key]) for key in data}
+ elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple
+ return elem_type(*(default_convert(d) for d in data))
+ elif isinstance(data, tuple):
+ return [default_convert(d) for d in data] # Backwards compatibility.
+ elif isinstance(data, collections.abc.Sequence) and not isinstance(data, string_classes):
+ try:
+ return elem_type([default_convert(d) for d in data])
+ except TypeError:
+ # The sequence type may not support `__init__(iterable)` (e.g., `range`).
+ return [default_convert(d) for d in data]
+ else:
+ return data
+
+
+default_collate_err_msg_format = (
+ "default_collate: batch must contain tensors, numpy arrays, numbers, "
+ "dicts or lists; found {}")
+
+
+def default_collate(batch):
+ r"""
+ Function that takes in a batch of data and puts the elements within the batch
+ into a tensor with an additional outer dimension - batch size. The exact output type can be
+ a :class:`torch.Tensor`, a `Sequence` of :class:`torch.Tensor`, a
+ Collection of :class:`torch.Tensor`, or left unchanged, depending on the input type.
+ This is used as the default function for collation when
+ `batch_size` or `batch_sampler` is defined in :class:`~torch.utils.data.DataLoader`.
+
+ Here is the general input type (based on the type of the element within the batch) to output type mapping:
+
+ * :class:`torch.Tensor` -> :class:`torch.Tensor` (with an added outer dimension batch size)
+ * NumPy Arrays -> :class:`torch.Tensor`
+ * `float` -> :class:`torch.Tensor`
+ * `int` -> :class:`torch.Tensor`
+ * `str` -> `str` (unchanged)
+ * `bytes` -> `bytes` (unchanged)
+ * `Mapping[K, V_i]` -> `Mapping[K, default_collate([V_1, V_2, ...])]`
+ * `NamedTuple[V1_i, V2_i, ...]` -> `NamedTuple[default_collate([V1_1, V1_2, ...]),
+ default_collate([V2_1, V2_2, ...]), ...]`
+ * `Sequence[V1_i, V2_i, ...]` -> `Sequence[default_collate([V1_1, V1_2, ...]),
+ default_collate([V2_1, V2_2, ...]), ...]`
+
+ Args:
+ batch: a single batch to be collated
+
+ Examples:
+ >>> # Example with a batch of `int`s:
+ >>> default_collate([0, 1, 2, 3])
+ tensor([0, 1, 2, 3])
+ >>> # Example with a batch of `str`s:
+ >>> default_collate(['a', 'b', 'c'])
+ ['a', 'b', 'c']
+ >>> # Example with `Map` inside the batch:
+ >>> default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}])
+ {'A': tensor([ 0, 100]), 'B': tensor([ 1, 100])}
+ >>> # Example with `NamedTuple` inside the batch:
+ >>> # xdoctest: +SKIP
+ >>> Point = namedtuple('Point', ['x', 'y'])
+ >>> default_collate([Point(0, 0), Point(1, 1)])
+ Point(x=tensor([0, 1]), y=tensor([0, 1]))
+ >>> # Example with `Tuple` inside the batch:
+ >>> default_collate([(0, 1), (2, 3)])
+ [tensor([0, 2]), tensor([1, 3])]
+ >>> # Example with `List` inside the batch:
+ >>> default_collate([[0, 1], [2, 3]])
+ [tensor([0, 2]), tensor([1, 3])]
+ """
+ elem = batch[0]
+ elem_type = type(elem)
+ if isinstance(elem, torch.Tensor):
+ out = None
+ if torch.utils.data.get_worker_info() is not None:
+ # If we're in a background process, concatenate directly into a
+ # shared memory tensor to avoid an extra copy
+ numel = sum(x.numel() for x in batch)
+ storage = elem.storage()._new_shared(numel, device=elem.device)
+ out = elem.new(storage).resize_(len(batch), *list(elem.size()))
+ return torch.stack(batch, 0, out=out)
+ elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
+ and elem_type.__name__ != 'string_':
+ if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
+ # array of string classes and object
+ if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
+ raise TypeError(default_collate_err_msg_format.format(elem.dtype))
+
+ return default_collate([torch.as_tensor(b) for b in batch])
+ elif elem.shape == (): # scalars
+ return torch.as_tensor(batch)
+ elif isinstance(elem, float):
+ return torch.tensor(batch, dtype=torch.float64)
+ elif isinstance(elem, int):
+ return torch.tensor(batch)
+ elif isinstance(elem, string_classes):
+ return batch
+ elif isinstance(elem, collections.abc.Mapping):
+ try:
+ return elem_type({key: default_collate([d[key] for d in batch]) for key in elem})
+ except TypeError:
+ # The mapping type may not support `__init__(iterable)`.
+ return {key: default_collate([d[key] for d in batch]) for key in elem}
+ elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
+ return elem_type(*(default_collate(samples) for samples in zip(*batch)))
+ elif isinstance(elem, collections.abc.Sequence):
+ # check to make sure that the elements in batch have consistent size
+ it = iter(batch)
+ elem_size = len(next(it))
+ if not all(len(elem) == elem_size for elem in it):
+ raise RuntimeError('each element in list of batch should be of equal size')
+ transposed = list(zip(*batch)) # It may be accessed twice, so we use a list.
+
+ if isinstance(elem, tuple):
+ return [default_collate(samples) for samples in transposed] # Backwards compatibility.
+ else:
+ try:
+ return elem_type([default_collate(samples) for samples in transposed])
+ except TypeError:
+ # The sequence type may not support `__init__(iterable)` (e.g., `range`).
+ return [default_collate(samples) for samples in transposed]
+
+ raise TypeError(default_collate_err_msg_format.format(elem_type))
\ No newline at end of file
diff --git a/src/StructDiffusion/utils/transformations.py b/src/StructDiffusion/utils/transformations.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d6f19e5b909a4aa4a024b3356aaabf960071792
--- /dev/null
+++ b/src/StructDiffusion/utils/transformations.py
@@ -0,0 +1,1705 @@
+# -*- coding: utf-8 -*-
+# transformations.py
+
+# Copyright (c) 2006, Christoph Gohlke
+# Copyright (c) 2006-2009, The Regents of the University of California
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in the
+# documentation and/or other materials provided with the distribution.
+# * Neither the name of the copyright holders nor the names of any
+# contributors may be used to endorse or promote products derived
+# from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+
+"""Homogeneous Transformation Matrices and Quaternions.
+
+A library for calculating 4x4 matrices for translating, rotating, reflecting,
+scaling, shearing, projecting, orthogonalizing, and superimposing arrays of
+3D homogeneous coordinates as well as for converting between rotation matrices,
+Euler angles, and quaternions. Also includes an Arcball control object and
+functions to decompose transformation matrices.
+
+:Authors:
+ `Christoph Gohlke `__,
+ Laboratory for Fluorescence Dynamics, University of California, Irvine
+
+:Version: 20090418
+
+Requirements
+------------
+
+* `Python 2.6 `__
+* `Numpy 1.3 `__
+* `transformations.c 20090418 `__
+ (optional implementation of some functions in C)
+
+Notes
+-----
+
+Matrices (M) can be inverted using numpy.linalg.inv(M), concatenated using
+numpy.dot(M0, M1), or used to transform homogeneous coordinates (v) using
+numpy.dot(M, v) for shape (4, \*) "point of arrays", respectively
+numpy.dot(v, M.T) for shape (\*, 4) "array of points".
+
+Calculations are carried out with numpy.float64 precision.
+
+This Python implementation is not optimized for speed.
+
+Vector, point, quaternion, and matrix function arguments are expected to be
+"array like", i.e. tuple, list, or numpy arrays.
+
+Return types are numpy arrays unless specified otherwise.
+
+Angles are in radians unless specified otherwise.
+
+Quaternions ix+jy+kz+w are represented as [x, y, z, w].
+
+Use the transpose of transformation matrices for OpenGL glMultMatrixd().
+
+A triple of Euler angles can be applied/interpreted in 24 ways, which can
+be specified using a 4 character string or encoded 4-tuple:
+
+ *Axes 4-string*: e.g. 'sxyz' or 'ryxy'
+
+ - first character : rotations are applied to 's'tatic or 'r'otating frame
+ - remaining characters : successive rotation axis 'x', 'y', or 'z'
+
+ *Axes 4-tuple*: e.g. (0, 0, 0, 0) or (1, 1, 1, 1)
+
+ - inner axis: code of axis ('x':0, 'y':1, 'z':2) of rightmost matrix.
+ - parity : even (0) if inner axis 'x' is followed by 'y', 'y' is followed
+ by 'z', or 'z' is followed by 'x'. Otherwise odd (1).
+ - repetition : first and last axis are same (1) or different (0).
+ - frame : rotations are applied to static (0) or rotating (1) frame.
+
+References
+----------
+
+(1) Matrices and transformations. Ronald Goldman.
+ In "Graphics Gems I", pp 472-475. Morgan Kaufmann, 1990.
+(2) More matrices and transformations: shear and pseudo-perspective.
+ Ronald Goldman. In "Graphics Gems II", pp 320-323. Morgan Kaufmann, 1991.
+(3) Decomposing a matrix into simple transformations. Spencer Thomas.
+ In "Graphics Gems II", pp 320-323. Morgan Kaufmann, 1991.
+(4) Recovering the data from the transformation matrix. Ronald Goldman.
+ In "Graphics Gems II", pp 324-331. Morgan Kaufmann, 1991.
+(5) Euler angle conversion. Ken Shoemake.
+ In "Graphics Gems IV", pp 222-229. Morgan Kaufmann, 1994.
+(6) Arcball rotation control. Ken Shoemake.
+ In "Graphics Gems IV", pp 175-192. Morgan Kaufmann, 1994.
+(7) Representing attitude: Euler angles, unit quaternions, and rotation
+ vectors. James Diebel. 2006.
+(8) A discussion of the solution for the best rotation to relate two sets
+ of vectors. W Kabsch. Acta Cryst. 1978. A34, 827-828.
+(9) Closed-form solution of absolute orientation using unit quaternions.
+ BKP Horn. J Opt Soc Am A. 1987. 4(4), 629-642.
+(10) Quaternions. Ken Shoemake.
+ http://www.sfu.ca/~jwa3/cmpt461/files/quatut.pdf
+(11) From quaternion to matrix and back. JMP van Waveren. 2005.
+ http://www.intel.com/cd/ids/developer/asmo-na/eng/293748.htm
+(12) Uniform random rotations. Ken Shoemake.
+ In "Graphics Gems III", pp 124-132. Morgan Kaufmann, 1992.
+
+
+Examples
+--------
+
+>>> alpha, beta, gamma = 0.123, -1.234, 2.345
+>>> origin, xaxis, yaxis, zaxis = (0, 0, 0), (1, 0, 0), (0, 1, 0), (0, 0, 1)
+>>> I = identity_matrix()
+>>> Rx = rotation_matrix(alpha, xaxis)
+>>> Ry = rotation_matrix(beta, yaxis)
+>>> Rz = rotation_matrix(gamma, zaxis)
+>>> R = concatenate_matrices(Rx, Ry, Rz)
+>>> euler = euler_from_matrix(R, 'rxyz')
+>>> numpy.allclose([alpha, beta, gamma], euler)
+True
+>>> Re = euler_matrix(alpha, beta, gamma, 'rxyz')
+>>> is_same_transform(R, Re)
+True
+>>> al, be, ga = euler_from_matrix(Re, 'rxyz')
+>>> is_same_transform(Re, euler_matrix(al, be, ga, 'rxyz'))
+True
+>>> qx = quaternion_about_axis(alpha, xaxis)
+>>> qy = quaternion_about_axis(beta, yaxis)
+>>> qz = quaternion_about_axis(gamma, zaxis)
+>>> q = quaternion_multiply(qx, qy)
+>>> q = quaternion_multiply(q, qz)
+>>> Rq = quaternion_matrix(q)
+>>> is_same_transform(R, Rq)
+True
+>>> S = scale_matrix(1.23, origin)
+>>> T = translation_matrix((1, 2, 3))
+>>> Z = shear_matrix(beta, xaxis, origin, zaxis)
+>>> R = random_rotation_matrix(numpy.random.rand(3))
+>>> M = concatenate_matrices(T, R, Z, S)
+>>> scale, shear, angles, trans, persp = decompose_matrix(M)
+>>> numpy.allclose(scale, 1.23)
+True
+>>> numpy.allclose(trans, (1, 2, 3))
+True
+>>> numpy.allclose(shear, (0, math.tan(beta), 0))
+True
+>>> is_same_transform(R, euler_matrix(axes='sxyz', *angles))
+True
+>>> M1 = compose_matrix(scale, shear, angles, trans, persp)
+>>> is_same_transform(M, M1)
+True
+
+"""
+
+from __future__ import division
+
+import warnings
+import math
+
+import numpy
+
+# Documentation in HTML format can be generated with Epydoc
+__docformat__ = "restructuredtext en"
+
+
+def identity_matrix():
+ """Return 4x4 identity/unit matrix.
+
+ >>> I = identity_matrix()
+ >>> numpy.allclose(I, numpy.dot(I, I))
+ True
+ >>> numpy.sum(I), numpy.trace(I)
+ (4.0, 4.0)
+ >>> numpy.allclose(I, numpy.identity(4, dtype=numpy.float64))
+ True
+
+ """
+ return numpy.identity(4, dtype=numpy.float64)
+
+
+def translation_matrix(direction):
+ """Return matrix to translate by direction vector.
+
+ >>> v = numpy.random.random(3) - 0.5
+ >>> numpy.allclose(v, translation_matrix(v)[:3, 3])
+ True
+
+ """
+ M = numpy.identity(4)
+ M[:3, 3] = direction[:3]
+ return M
+
+
+def translation_from_matrix(matrix):
+ """Return translation vector from translation matrix.
+
+ >>> v0 = numpy.random.random(3) - 0.5
+ >>> v1 = translation_from_matrix(translation_matrix(v0))
+ >>> numpy.allclose(v0, v1)
+ True
+
+ """
+ return numpy.array(matrix, copy=False)[:3, 3].copy()
+
+
+def reflection_matrix(point, normal):
+ """Return matrix to mirror at plane defined by point and normal vector.
+
+ >>> v0 = numpy.random.random(4) - 0.5
+ >>> v0[3] = 1.0
+ >>> v1 = numpy.random.random(3) - 0.5
+ >>> R = reflection_matrix(v0, v1)
+ >>> numpy.allclose(2., numpy.trace(R))
+ True
+ >>> numpy.allclose(v0, numpy.dot(R, v0))
+ True
+ >>> v2 = v0.copy()
+ >>> v2[:3] += v1
+ >>> v3 = v0.copy()
+ >>> v2[:3] -= v1
+ >>> numpy.allclose(v2, numpy.dot(R, v3))
+ True
+
+ """
+ normal = unit_vector(normal[:3])
+ M = numpy.identity(4)
+ M[:3, :3] -= 2.0 * numpy.outer(normal, normal)
+ M[:3, 3] = (2.0 * numpy.dot(point[:3], normal)) * normal
+ return M
+
+
+def reflection_from_matrix(matrix):
+ """Return mirror plane point and normal vector from reflection matrix.
+
+ >>> v0 = numpy.random.random(3) - 0.5
+ >>> v1 = numpy.random.random(3) - 0.5
+ >>> M0 = reflection_matrix(v0, v1)
+ >>> point, normal = reflection_from_matrix(M0)
+ >>> M1 = reflection_matrix(point, normal)
+ >>> is_same_transform(M0, M1)
+ True
+
+ """
+ M = numpy.array(matrix, dtype=numpy.float64, copy=False)
+ # normal: unit eigenvector corresponding to eigenvalue -1
+ l, V = numpy.linalg.eig(M[:3, :3])
+ i = numpy.where(abs(numpy.real(l) + 1.0) < 1e-8)[0]
+ if not len(i):
+ raise ValueError("no unit eigenvector corresponding to eigenvalue -1")
+ normal = numpy.real(V[:, i[0]]).squeeze()
+ # point: any unit eigenvector corresponding to eigenvalue 1
+ l, V = numpy.linalg.eig(M)
+ i = numpy.where(abs(numpy.real(l) - 1.0) < 1e-8)[0]
+ if not len(i):
+ raise ValueError("no unit eigenvector corresponding to eigenvalue 1")
+ point = numpy.real(V[:, i[-1]]).squeeze()
+ point /= point[3]
+ return point, normal
+
+
+def rotation_matrix(angle, direction, point=None):
+ """Return matrix to rotate about axis defined by point and direction.
+
+ >>> angle = (random.random() - 0.5) * (2*math.pi)
+ >>> direc = numpy.random.random(3) - 0.5
+ >>> point = numpy.random.random(3) - 0.5
+ >>> R0 = rotation_matrix(angle, direc, point)
+ >>> R1 = rotation_matrix(angle-2*math.pi, direc, point)
+ >>> is_same_transform(R0, R1)
+ True
+ >>> R0 = rotation_matrix(angle, direc, point)
+ >>> R1 = rotation_matrix(-angle, -direc, point)
+ >>> is_same_transform(R0, R1)
+ True
+ >>> I = numpy.identity(4, numpy.float64)
+ >>> numpy.allclose(I, rotation_matrix(math.pi*2, direc))
+ True
+ >>> numpy.allclose(2., numpy.trace(rotation_matrix(math.pi/2,
+ ... direc, point)))
+ True
+
+ """
+ sina = math.sin(angle)
+ cosa = math.cos(angle)
+ direction = unit_vector(direction[:3])
+ # rotation matrix around unit vector
+ R = numpy.array(((cosa, 0.0, 0.0),
+ (0.0, cosa, 0.0),
+ (0.0, 0.0, cosa)), dtype=numpy.float64)
+ R += numpy.outer(direction, direction) * (1.0 - cosa)
+ direction *= sina
+ R += numpy.array((( 0.0, -direction[2], direction[1]),
+ ( direction[2], 0.0, -direction[0]),
+ (-direction[1], direction[0], 0.0)),
+ dtype=numpy.float64)
+ M = numpy.identity(4)
+ M[:3, :3] = R
+ if point is not None:
+ # rotation not around origin
+ point = numpy.array(point[:3], dtype=numpy.float64, copy=False)
+ M[:3, 3] = point - numpy.dot(R, point)
+ return M
+
+
+def rotation_from_matrix(matrix):
+ """Return rotation angle and axis from rotation matrix.
+
+ >>> angle = (random.random() - 0.5) * (2*math.pi)
+ >>> direc = numpy.random.random(3) - 0.5
+ >>> point = numpy.random.random(3) - 0.5
+ >>> R0 = rotation_matrix(angle, direc, point)
+ >>> angle, direc, point = rotation_from_matrix(R0)
+ >>> R1 = rotation_matrix(angle, direc, point)
+ >>> is_same_transform(R0, R1)
+ True
+
+ """
+ R = numpy.array(matrix, dtype=numpy.float64, copy=False)
+ R33 = R[:3, :3]
+ # direction: unit eigenvector of R33 corresponding to eigenvalue of 1
+ l, W = numpy.linalg.eig(R33.T)
+ i = numpy.where(abs(numpy.real(l) - 1.0) < 1e-8)[0]
+ if not len(i):
+ raise ValueError("no unit eigenvector corresponding to eigenvalue 1")
+ direction = numpy.real(W[:, i[-1]]).squeeze()
+ # point: unit eigenvector of R33 corresponding to eigenvalue of 1
+ l, Q = numpy.linalg.eig(R)
+ i = numpy.where(abs(numpy.real(l) - 1.0) < 1e-8)[0]
+ if not len(i):
+ raise ValueError("no unit eigenvector corresponding to eigenvalue 1")
+ point = numpy.real(Q[:, i[-1]]).squeeze()
+ point /= point[3]
+ # rotation angle depending on direction
+ cosa = (numpy.trace(R33) - 1.0) / 2.0
+ if abs(direction[2]) > 1e-8:
+ sina = (R[1, 0] + (cosa-1.0)*direction[0]*direction[1]) / direction[2]
+ elif abs(direction[1]) > 1e-8:
+ sina = (R[0, 2] + (cosa-1.0)*direction[0]*direction[2]) / direction[1]
+ else:
+ sina = (R[2, 1] + (cosa-1.0)*direction[1]*direction[2]) / direction[0]
+ angle = math.atan2(sina, cosa)
+ return angle, direction, point
+
+
+def scale_matrix(factor, origin=None, direction=None):
+ """Return matrix to scale by factor around origin in direction.
+
+ Use factor -1 for point symmetry.
+
+ >>> v = (numpy.random.rand(4, 5) - 0.5) * 20.0
+ >>> v[3] = 1.0
+ >>> S = scale_matrix(-1.234)
+ >>> numpy.allclose(numpy.dot(S, v)[:3], -1.234*v[:3])
+ True
+ >>> factor = random.random() * 10 - 5
+ >>> origin = numpy.random.random(3) - 0.5
+ >>> direct = numpy.random.random(3) - 0.5
+ >>> S = scale_matrix(factor, origin)
+ >>> S = scale_matrix(factor, origin, direct)
+
+ """
+ if direction is None:
+ # uniform scaling
+ M = numpy.array(((factor, 0.0, 0.0, 0.0),
+ (0.0, factor, 0.0, 0.0),
+ (0.0, 0.0, factor, 0.0),
+ (0.0, 0.0, 0.0, 1.0)), dtype=numpy.float64)
+ if origin is not None:
+ M[:3, 3] = origin[:3]
+ M[:3, 3] *= 1.0 - factor
+ else:
+ # nonuniform scaling
+ direction = unit_vector(direction[:3])
+ factor = 1.0 - factor
+ M = numpy.identity(4)
+ M[:3, :3] -= factor * numpy.outer(direction, direction)
+ if origin is not None:
+ M[:3, 3] = (factor * numpy.dot(origin[:3], direction)) * direction
+ return M
+
+
+def scale_from_matrix(matrix):
+ """Return scaling factor, origin and direction from scaling matrix.
+
+ >>> factor = random.random() * 10 - 5
+ >>> origin = numpy.random.random(3) - 0.5
+ >>> direct = numpy.random.random(3) - 0.5
+ >>> S0 = scale_matrix(factor, origin)
+ >>> factor, origin, direction = scale_from_matrix(S0)
+ >>> S1 = scale_matrix(factor, origin, direction)
+ >>> is_same_transform(S0, S1)
+ True
+ >>> S0 = scale_matrix(factor, origin, direct)
+ >>> factor, origin, direction = scale_from_matrix(S0)
+ >>> S1 = scale_matrix(factor, origin, direction)
+ >>> is_same_transform(S0, S1)
+ True
+
+ """
+ M = numpy.array(matrix, dtype=numpy.float64, copy=False)
+ M33 = M[:3, :3]
+ factor = numpy.trace(M33) - 2.0
+ try:
+ # direction: unit eigenvector corresponding to eigenvalue factor
+ l, V = numpy.linalg.eig(M33)
+ i = numpy.where(abs(numpy.real(l) - factor) < 1e-8)[0][0]
+ direction = numpy.real(V[:, i]).squeeze()
+ direction /= vector_norm(direction)
+ except IndexError:
+ # uniform scaling
+ factor = (factor + 2.0) / 3.0
+ direction = None
+ # origin: any eigenvector corresponding to eigenvalue 1
+ l, V = numpy.linalg.eig(M)
+ i = numpy.where(abs(numpy.real(l) - 1.0) < 1e-8)[0]
+ if not len(i):
+ raise ValueError("no eigenvector corresponding to eigenvalue 1")
+ origin = numpy.real(V[:, i[-1]]).squeeze()
+ origin /= origin[3]
+ return factor, origin, direction
+
+
+def projection_matrix(point, normal, direction=None,
+ perspective=None, pseudo=False):
+ """Return matrix to project onto plane defined by point and normal.
+
+ Using either perspective point, projection direction, or none of both.
+
+ If pseudo is True, perspective projections will preserve relative depth
+ such that Perspective = dot(Orthogonal, PseudoPerspective).
+
+ >>> P = projection_matrix((0, 0, 0), (1, 0, 0))
+ >>> numpy.allclose(P[1:, 1:], numpy.identity(4)[1:, 1:])
+ True
+ >>> point = numpy.random.random(3) - 0.5
+ >>> normal = numpy.random.random(3) - 0.5
+ >>> direct = numpy.random.random(3) - 0.5
+ >>> persp = numpy.random.random(3) - 0.5
+ >>> P0 = projection_matrix(point, normal)
+ >>> P1 = projection_matrix(point, normal, direction=direct)
+ >>> P2 = projection_matrix(point, normal, perspective=persp)
+ >>> P3 = projection_matrix(point, normal, perspective=persp, pseudo=True)
+ >>> is_same_transform(P2, numpy.dot(P0, P3))
+ True
+ >>> P = projection_matrix((3, 0, 0), (1, 1, 0), (1, 0, 0))
+ >>> v0 = (numpy.random.rand(4, 5) - 0.5) * 20.0
+ >>> v0[3] = 1.0
+ >>> v1 = numpy.dot(P, v0)
+ >>> numpy.allclose(v1[1], v0[1])
+ True
+ >>> numpy.allclose(v1[0], 3.0-v1[1])
+ True
+
+ """
+ M = numpy.identity(4)
+ point = numpy.array(point[:3], dtype=numpy.float64, copy=False)
+ normal = unit_vector(normal[:3])
+ if perspective is not None:
+ # perspective projection
+ perspective = numpy.array(perspective[:3], dtype=numpy.float64,
+ copy=False)
+ M[0, 0] = M[1, 1] = M[2, 2] = numpy.dot(perspective-point, normal)
+ M[:3, :3] -= numpy.outer(perspective, normal)
+ if pseudo:
+ # preserve relative depth
+ M[:3, :3] -= numpy.outer(normal, normal)
+ M[:3, 3] = numpy.dot(point, normal) * (perspective+normal)
+ else:
+ M[:3, 3] = numpy.dot(point, normal) * perspective
+ M[3, :3] = -normal
+ M[3, 3] = numpy.dot(perspective, normal)
+ elif direction is not None:
+ # parallel projection
+ direction = numpy.array(direction[:3], dtype=numpy.float64, copy=False)
+ scale = numpy.dot(direction, normal)
+ M[:3, :3] -= numpy.outer(direction, normal) / scale
+ M[:3, 3] = direction * (numpy.dot(point, normal) / scale)
+ else:
+ # orthogonal projection
+ M[:3, :3] -= numpy.outer(normal, normal)
+ M[:3, 3] = numpy.dot(point, normal) * normal
+ return M
+
+
+def projection_from_matrix(matrix, pseudo=False):
+ """Return projection plane and perspective point from projection matrix.
+
+ Return values are same as arguments for projection_matrix function:
+ point, normal, direction, perspective, and pseudo.
+
+ >>> point = numpy.random.random(3) - 0.5
+ >>> normal = numpy.random.random(3) - 0.5
+ >>> direct = numpy.random.random(3) - 0.5
+ >>> persp = numpy.random.random(3) - 0.5
+ >>> P0 = projection_matrix(point, normal)
+ >>> result = projection_from_matrix(P0)
+ >>> P1 = projection_matrix(*result)
+ >>> is_same_transform(P0, P1)
+ True
+ >>> P0 = projection_matrix(point, normal, direct)
+ >>> result = projection_from_matrix(P0)
+ >>> P1 = projection_matrix(*result)
+ >>> is_same_transform(P0, P1)
+ True
+ >>> P0 = projection_matrix(point, normal, perspective=persp, pseudo=False)
+ >>> result = projection_from_matrix(P0, pseudo=False)
+ >>> P1 = projection_matrix(*result)
+ >>> is_same_transform(P0, P1)
+ True
+ >>> P0 = projection_matrix(point, normal, perspective=persp, pseudo=True)
+ >>> result = projection_from_matrix(P0, pseudo=True)
+ >>> P1 = projection_matrix(*result)
+ >>> is_same_transform(P0, P1)
+ True
+
+ """
+ M = numpy.array(matrix, dtype=numpy.float64, copy=False)
+ M33 = M[:3, :3]
+ l, V = numpy.linalg.eig(M)
+ i = numpy.where(abs(numpy.real(l) - 1.0) < 1e-8)[0]
+ if not pseudo and len(i):
+ # point: any eigenvector corresponding to eigenvalue 1
+ point = numpy.real(V[:, i[-1]]).squeeze()
+ point /= point[3]
+ # direction: unit eigenvector corresponding to eigenvalue 0
+ l, V = numpy.linalg.eig(M33)
+ i = numpy.where(abs(numpy.real(l)) < 1e-8)[0]
+ if not len(i):
+ raise ValueError("no eigenvector corresponding to eigenvalue 0")
+ direction = numpy.real(V[:, i[0]]).squeeze()
+ direction /= vector_norm(direction)
+ # normal: unit eigenvector of M33.T corresponding to eigenvalue 0
+ l, V = numpy.linalg.eig(M33.T)
+ i = numpy.where(abs(numpy.real(l)) < 1e-8)[0]
+ if len(i):
+ # parallel projection
+ normal = numpy.real(V[:, i[0]]).squeeze()
+ normal /= vector_norm(normal)
+ return point, normal, direction, None, False
+ else:
+ # orthogonal projection, where normal equals direction vector
+ return point, direction, None, None, False
+ else:
+ # perspective projection
+ i = numpy.where(abs(numpy.real(l)) > 1e-8)[0]
+ if not len(i):
+ raise ValueError(
+ "no eigenvector not corresponding to eigenvalue 0")
+ point = numpy.real(V[:, i[-1]]).squeeze()
+ point /= point[3]
+ normal = - M[3, :3]
+ perspective = M[:3, 3] / numpy.dot(point[:3], normal)
+ if pseudo:
+ perspective -= normal
+ return point, normal, None, perspective, pseudo
+
+
+def clip_matrix(left, right, bottom, top, near, far, perspective=False):
+ """Return matrix to obtain normalized device coordinates from frustrum.
+
+ The frustrum bounds are axis-aligned along x (left, right),
+ y (bottom, top) and z (near, far).
+
+ Normalized device coordinates are in range [-1, 1] if coordinates are
+ inside the frustrum.
+
+ If perspective is True the frustrum is a truncated pyramid with the
+ perspective point at origin and direction along z axis, otherwise an
+ orthographic canonical view volume (a box).
+
+ Homogeneous coordinates transformed by the perspective clip matrix
+ need to be dehomogenized (devided by w coordinate).
+
+ >>> frustrum = numpy.random.rand(6)
+ >>> frustrum[1] += frustrum[0]
+ >>> frustrum[3] += frustrum[2]
+ >>> frustrum[5] += frustrum[4]
+ >>> M = clip_matrix(*frustrum, perspective=False)
+ >>> numpy.dot(M, [frustrum[0], frustrum[2], frustrum[4], 1.0])
+ array([-1., -1., -1., 1.])
+ >>> numpy.dot(M, [frustrum[1], frustrum[3], frustrum[5], 1.0])
+ array([ 1., 1., 1., 1.])
+ >>> M = clip_matrix(*frustrum, perspective=True)
+ >>> v = numpy.dot(M, [frustrum[0], frustrum[2], frustrum[4], 1.0])
+ >>> v / v[3]
+ array([-1., -1., -1., 1.])
+ >>> v = numpy.dot(M, [frustrum[1], frustrum[3], frustrum[4], 1.0])
+ >>> v / v[3]
+ array([ 1., 1., -1., 1.])
+
+ """
+ if left >= right or bottom >= top or near >= far:
+ raise ValueError("invalid frustrum")
+ if perspective:
+ if near <= _EPS:
+ raise ValueError("invalid frustrum: near <= 0")
+ t = 2.0 * near
+ M = ((-t/(right-left), 0.0, (right+left)/(right-left), 0.0),
+ (0.0, -t/(top-bottom), (top+bottom)/(top-bottom), 0.0),
+ (0.0, 0.0, -(far+near)/(far-near), t*far/(far-near)),
+ (0.0, 0.0, -1.0, 0.0))
+ else:
+ M = ((2.0/(right-left), 0.0, 0.0, (right+left)/(left-right)),
+ (0.0, 2.0/(top-bottom), 0.0, (top+bottom)/(bottom-top)),
+ (0.0, 0.0, 2.0/(far-near), (far+near)/(near-far)),
+ (0.0, 0.0, 0.0, 1.0))
+ return numpy.array(M, dtype=numpy.float64)
+
+
+def shear_matrix(angle, direction, point, normal):
+ """Return matrix to shear by angle along direction vector on shear plane.
+
+ The shear plane is defined by a point and normal vector. The direction
+ vector must be orthogonal to the plane's normal vector.
+
+ A point P is transformed by the shear matrix into P" such that
+ the vector P-P" is parallel to the direction vector and its extent is
+ given by the angle of P-P'-P", where P' is the orthogonal projection
+ of P onto the shear plane.
+
+ >>> angle = (random.random() - 0.5) * 4*math.pi
+ >>> direct = numpy.random.random(3) - 0.5
+ >>> point = numpy.random.random(3) - 0.5
+ >>> normal = numpy.cross(direct, numpy.random.random(3))
+ >>> S = shear_matrix(angle, direct, point, normal)
+ >>> numpy.allclose(1.0, numpy.linalg.det(S))
+ True
+
+ """
+ normal = unit_vector(normal[:3])
+ direction = unit_vector(direction[:3])
+ if abs(numpy.dot(normal, direction)) > 1e-6:
+ raise ValueError("direction and normal vectors are not orthogonal")
+ angle = math.tan(angle)
+ M = numpy.identity(4)
+ M[:3, :3] += angle * numpy.outer(direction, normal)
+ M[:3, 3] = -angle * numpy.dot(point[:3], normal) * direction
+ return M
+
+
+def shear_from_matrix(matrix):
+ """Return shear angle, direction and plane from shear matrix.
+
+ >>> angle = (random.random() - 0.5) * 4*math.pi
+ >>> direct = numpy.random.random(3) - 0.5
+ >>> point = numpy.random.random(3) - 0.5
+ >>> normal = numpy.cross(direct, numpy.random.random(3))
+ >>> S0 = shear_matrix(angle, direct, point, normal)
+ >>> angle, direct, point, normal = shear_from_matrix(S0)
+ >>> S1 = shear_matrix(angle, direct, point, normal)
+ >>> is_same_transform(S0, S1)
+ True
+
+ """
+ M = numpy.array(matrix, dtype=numpy.float64, copy=False)
+ M33 = M[:3, :3]
+ # normal: cross independent eigenvectors corresponding to the eigenvalue 1
+ l, V = numpy.linalg.eig(M33)
+ i = numpy.where(abs(numpy.real(l) - 1.0) < 1e-4)[0]
+ if len(i) < 2:
+ raise ValueError("No two linear independent eigenvectors found %s" % l)
+ V = numpy.real(V[:, i]).squeeze().T
+ lenorm = -1.0
+ for i0, i1 in ((0, 1), (0, 2), (1, 2)):
+ n = numpy.cross(V[i0], V[i1])
+ l = vector_norm(n)
+ if l > lenorm:
+ lenorm = l
+ normal = n
+ normal /= lenorm
+ # direction and angle
+ direction = numpy.dot(M33 - numpy.identity(3), normal)
+ angle = vector_norm(direction)
+ direction /= angle
+ angle = math.atan(angle)
+ # point: eigenvector corresponding to eigenvalue 1
+ l, V = numpy.linalg.eig(M)
+ i = numpy.where(abs(numpy.real(l) - 1.0) < 1e-8)[0]
+ if not len(i):
+ raise ValueError("no eigenvector corresponding to eigenvalue 1")
+ point = numpy.real(V[:, i[-1]]).squeeze()
+ point /= point[3]
+ return angle, direction, point, normal
+
+
+def decompose_matrix(matrix):
+ """Return sequence of transformations from transformation matrix.
+
+ matrix : array_like
+ Non-degenerative homogeneous transformation matrix
+
+ Return tuple of:
+ scale : vector of 3 scaling factors
+ shear : list of shear factors for x-y, x-z, y-z axes
+ angles : list of Euler angles about static x, y, z axes
+ translate : translation vector along x, y, z axes
+ perspective : perspective partition of matrix
+
+ Raise ValueError if matrix is of wrong type or degenerative.
+
+ >>> T0 = translation_matrix((1, 2, 3))
+ >>> scale, shear, angles, trans, persp = decompose_matrix(T0)
+ >>> T1 = translation_matrix(trans)
+ >>> numpy.allclose(T0, T1)
+ True
+ >>> S = scale_matrix(0.123)
+ >>> scale, shear, angles, trans, persp = decompose_matrix(S)
+ >>> scale[0]
+ 0.123
+ >>> R0 = euler_matrix(1, 2, 3)
+ >>> scale, shear, angles, trans, persp = decompose_matrix(R0)
+ >>> R1 = euler_matrix(*angles)
+ >>> numpy.allclose(R0, R1)
+ True
+
+ """
+ M = numpy.array(matrix, dtype=numpy.float64, copy=True).T
+ if abs(M[3, 3]) < _EPS:
+ raise ValueError("M[3, 3] is zero")
+ M /= M[3, 3]
+ P = M.copy()
+ P[:, 3] = 0, 0, 0, 1
+ if not numpy.linalg.det(P):
+ raise ValueError("Matrix is singular")
+
+ scale = numpy.zeros((3, ), dtype=numpy.float64)
+ shear = [0, 0, 0]
+ angles = [0, 0, 0]
+
+ if any(abs(M[:3, 3]) > _EPS):
+ perspective = numpy.dot(M[:, 3], numpy.linalg.inv(P.T))
+ M[:, 3] = 0, 0, 0, 1
+ else:
+ perspective = numpy.array((0, 0, 0, 1), dtype=numpy.float64)
+
+ translate = M[3, :3].copy()
+ M[3, :3] = 0
+
+ row = M[:3, :3].copy()
+ scale[0] = vector_norm(row[0])
+ row[0] /= scale[0]
+ shear[0] = numpy.dot(row[0], row[1])
+ row[1] -= row[0] * shear[0]
+ scale[1] = vector_norm(row[1])
+ row[1] /= scale[1]
+ shear[0] /= scale[1]
+ shear[1] = numpy.dot(row[0], row[2])
+ row[2] -= row[0] * shear[1]
+ shear[2] = numpy.dot(row[1], row[2])
+ row[2] -= row[1] * shear[2]
+ scale[2] = vector_norm(row[2])
+ row[2] /= scale[2]
+ shear[1:] /= scale[2]
+
+ if numpy.dot(row[0], numpy.cross(row[1], row[2])) < 0:
+ scale *= -1
+ row *= -1
+
+ angles[1] = math.asin(-row[0, 2])
+ if math.cos(angles[1]):
+ angles[0] = math.atan2(row[1, 2], row[2, 2])
+ angles[2] = math.atan2(row[0, 1], row[0, 0])
+ else:
+ #angles[0] = math.atan2(row[1, 0], row[1, 1])
+ angles[0] = math.atan2(-row[2, 1], row[1, 1])
+ angles[2] = 0.0
+
+ return scale, shear, angles, translate, perspective
+
+
+def compose_matrix(scale=None, shear=None, angles=None, translate=None,
+ perspective=None):
+ """Return transformation matrix from sequence of transformations.
+
+ This is the inverse of the decompose_matrix function.
+
+ Sequence of transformations:
+ scale : vector of 3 scaling factors
+ shear : list of shear factors for x-y, x-z, y-z axes
+ angles : list of Euler angles about static x, y, z axes
+ translate : translation vector along x, y, z axes
+ perspective : perspective partition of matrix
+
+ >>> scale = numpy.random.random(3) - 0.5
+ >>> shear = numpy.random.random(3) - 0.5
+ >>> angles = (numpy.random.random(3) - 0.5) * (2*math.pi)
+ >>> trans = numpy.random.random(3) - 0.5
+ >>> persp = numpy.random.random(4) - 0.5
+ >>> M0 = compose_matrix(scale, shear, angles, trans, persp)
+ >>> result = decompose_matrix(M0)
+ >>> M1 = compose_matrix(*result)
+ >>> is_same_transform(M0, M1)
+ True
+
+ """
+ M = numpy.identity(4)
+ if perspective is not None:
+ P = numpy.identity(4)
+ P[3, :] = perspective[:4]
+ M = numpy.dot(M, P)
+ if translate is not None:
+ T = numpy.identity(4)
+ T[:3, 3] = translate[:3]
+ M = numpy.dot(M, T)
+ if angles is not None:
+ R = euler_matrix(angles[0], angles[1], angles[2], 'sxyz')
+ M = numpy.dot(M, R)
+ if shear is not None:
+ Z = numpy.identity(4)
+ Z[1, 2] = shear[2]
+ Z[0, 2] = shear[1]
+ Z[0, 1] = shear[0]
+ M = numpy.dot(M, Z)
+ if scale is not None:
+ S = numpy.identity(4)
+ S[0, 0] = scale[0]
+ S[1, 1] = scale[1]
+ S[2, 2] = scale[2]
+ M = numpy.dot(M, S)
+ M /= M[3, 3]
+ return M
+
+
+def orthogonalization_matrix(lengths, angles):
+ """Return orthogonalization matrix for crystallographic cell coordinates.
+
+ Angles are expected in degrees.
+
+ The de-orthogonalization matrix is the inverse.
+
+ >>> O = orthogonalization_matrix((10., 10., 10.), (90., 90., 90.))
+ >>> numpy.allclose(O[:3, :3], numpy.identity(3, float) * 10)
+ True
+ >>> O = orthogonalization_matrix([9.8, 12.0, 15.5], [87.2, 80.7, 69.7])
+ >>> numpy.allclose(numpy.sum(O), 43.063229)
+ True
+
+ """
+ a, b, c = lengths
+ angles = numpy.radians(angles)
+ sina, sinb, _ = numpy.sin(angles)
+ cosa, cosb, cosg = numpy.cos(angles)
+ co = (cosa * cosb - cosg) / (sina * sinb)
+ return numpy.array((
+ ( a*sinb*math.sqrt(1.0-co*co), 0.0, 0.0, 0.0),
+ (-a*sinb*co, b*sina, 0.0, 0.0),
+ ( a*cosb, b*cosa, c, 0.0),
+ ( 0.0, 0.0, 0.0, 1.0)),
+ dtype=numpy.float64)
+
+
+def superimposition_matrix(v0, v1, scaling=False, usesvd=True):
+ """Return matrix to transform given vector set into second vector set.
+
+ v0 and v1 are shape (3, \*) or (4, \*) arrays of at least 3 vectors.
+
+ If usesvd is True, the weighted sum of squared deviations (RMSD) is
+ minimized according to the algorithm by W. Kabsch [8]. Otherwise the
+ quaternion based algorithm by B. Horn [9] is used (slower when using
+ this Python implementation).
+
+ The returned matrix performs rotation, translation and uniform scaling
+ (if specified).
+
+ >>> v0 = numpy.random.rand(3, 10)
+ >>> M = superimposition_matrix(v0, v0)
+ >>> numpy.allclose(M, numpy.identity(4))
+ True
+ >>> R = random_rotation_matrix(numpy.random.random(3))
+ >>> v0 = ((1,0,0), (0,1,0), (0,0,1), (1,1,1))
+ >>> v1 = numpy.dot(R, v0)
+ >>> M = superimposition_matrix(v0, v1)
+ >>> numpy.allclose(v1, numpy.dot(M, v0))
+ True
+ >>> v0 = (numpy.random.rand(4, 100) - 0.5) * 20.0
+ >>> v0[3] = 1.0
+ >>> v1 = numpy.dot(R, v0)
+ >>> M = superimposition_matrix(v0, v1)
+ >>> numpy.allclose(v1, numpy.dot(M, v0))
+ True
+ >>> S = scale_matrix(random.random())
+ >>> T = translation_matrix(numpy.random.random(3)-0.5)
+ >>> M = concatenate_matrices(T, R, S)
+ >>> v1 = numpy.dot(M, v0)
+ >>> v0[:3] += numpy.random.normal(0.0, 1e-9, 300).reshape(3, -1)
+ >>> M = superimposition_matrix(v0, v1, scaling=True)
+ >>> numpy.allclose(v1, numpy.dot(M, v0))
+ True
+ >>> M = superimposition_matrix(v0, v1, scaling=True, usesvd=False)
+ >>> numpy.allclose(v1, numpy.dot(M, v0))
+ True
+ >>> v = numpy.empty((4, 100, 3), dtype=numpy.float64)
+ >>> v[:, :, 0] = v0
+ >>> M = superimposition_matrix(v0, v1, scaling=True, usesvd=False)
+ >>> numpy.allclose(v1, numpy.dot(M, v[:, :, 0]))
+ True
+
+ """
+ v0 = numpy.array(v0, dtype=numpy.float64, copy=False)[:3]
+ v1 = numpy.array(v1, dtype=numpy.float64, copy=False)[:3]
+
+ if v0.shape != v1.shape or v0.shape[1] < 3:
+ raise ValueError("Vector sets are of wrong shape or type.")
+
+ # move centroids to origin
+ t0 = numpy.mean(v0, axis=1)
+ t1 = numpy.mean(v1, axis=1)
+ v0 = v0 - t0.reshape(3, 1)
+ v1 = v1 - t1.reshape(3, 1)
+
+ if usesvd:
+ # Singular Value Decomposition of covariance matrix
+ u, s, vh = numpy.linalg.svd(numpy.dot(v1, v0.T))
+ # rotation matrix from SVD orthonormal bases
+ R = numpy.dot(u, vh)
+ if numpy.linalg.det(R) < 0.0:
+ # R does not constitute right handed system
+ R -= numpy.outer(u[:, 2], vh[2, :]*2.0)
+ s[-1] *= -1.0
+ # homogeneous transformation matrix
+ M = numpy.identity(4)
+ M[:3, :3] = R
+ else:
+ # compute symmetric matrix N
+ xx, yy, zz = numpy.sum(v0 * v1, axis=1)
+ xy, yz, zx = numpy.sum(v0 * numpy.roll(v1, -1, axis=0), axis=1)
+ xz, yx, zy = numpy.sum(v0 * numpy.roll(v1, -2, axis=0), axis=1)
+ N = ((xx+yy+zz, yz-zy, zx-xz, xy-yx),
+ (yz-zy, xx-yy-zz, xy+yx, zx+xz),
+ (zx-xz, xy+yx, -xx+yy-zz, yz+zy),
+ (xy-yx, zx+xz, yz+zy, -xx-yy+zz))
+ # quaternion: eigenvector corresponding to most positive eigenvalue
+ l, V = numpy.linalg.eig(N)
+ q = V[:, numpy.argmax(l)]
+ q /= vector_norm(q) # unit quaternion
+ q = numpy.roll(q, -1) # move w component to end
+ # homogeneous transformation matrix
+ M = quaternion_matrix(q)
+
+ # scale: ratio of rms deviations from centroid
+ if scaling:
+ v0 *= v0
+ v1 *= v1
+ M[:3, :3] *= math.sqrt(numpy.sum(v1) / numpy.sum(v0))
+
+ # translation
+ M[:3, 3] = t1
+ T = numpy.identity(4)
+ T[:3, 3] = -t0
+ M = numpy.dot(M, T)
+ return M
+
+
+def euler_matrix(ai, aj, ak, axes='sxyz'):
+ """Return homogeneous rotation matrix from Euler angles and axis sequence.
+
+ ai, aj, ak : Euler's roll, pitch and yaw angles
+ axes : One of 24 axis sequences as string or encoded tuple
+
+ >>> R = euler_matrix(1, 2, 3, 'syxz')
+ >>> numpy.allclose(numpy.sum(R[0]), -1.34786452)
+ True
+ >>> R = euler_matrix(1, 2, 3, (0, 1, 0, 1))
+ >>> numpy.allclose(numpy.sum(R[0]), -0.383436184)
+ True
+ >>> ai, aj, ak = (4.0*math.pi) * (numpy.random.random(3) - 0.5)
+ >>> for axes in _AXES2TUPLE.keys():
+ ... R = euler_matrix(ai, aj, ak, axes)
+ >>> for axes in _TUPLE2AXES.keys():
+ ... R = euler_matrix(ai, aj, ak, axes)
+
+ """
+ try:
+ firstaxis, parity, repetition, frame = _AXES2TUPLE[axes]
+ except (AttributeError, KeyError):
+ _ = _TUPLE2AXES[axes]
+ firstaxis, parity, repetition, frame = axes
+
+ i = firstaxis
+ j = _NEXT_AXIS[i+parity]
+ k = _NEXT_AXIS[i-parity+1]
+
+ if frame:
+ ai, ak = ak, ai
+ if parity:
+ ai, aj, ak = -ai, -aj, -ak
+
+ si, sj, sk = math.sin(ai), math.sin(aj), math.sin(ak)
+ ci, cj, ck = math.cos(ai), math.cos(aj), math.cos(ak)
+ cc, cs = ci*ck, ci*sk
+ sc, ss = si*ck, si*sk
+
+ M = numpy.identity(4)
+ if repetition:
+ M[i, i] = cj
+ M[i, j] = sj*si
+ M[i, k] = sj*ci
+ M[j, i] = sj*sk
+ M[j, j] = -cj*ss+cc
+ M[j, k] = -cj*cs-sc
+ M[k, i] = -sj*ck
+ M[k, j] = cj*sc+cs
+ M[k, k] = cj*cc-ss
+ else:
+ M[i, i] = cj*ck
+ M[i, j] = sj*sc-cs
+ M[i, k] = sj*cc+ss
+ M[j, i] = cj*sk
+ M[j, j] = sj*ss+cc
+ M[j, k] = sj*cs-sc
+ M[k, i] = -sj
+ M[k, j] = cj*si
+ M[k, k] = cj*ci
+ return M
+
+
+def euler_from_matrix(matrix, axes='sxyz'):
+ """Return Euler angles from rotation matrix for specified axis sequence.
+
+ axes : One of 24 axis sequences as string or encoded tuple
+
+ Note that many Euler angle triplets can describe one matrix.
+
+ >>> R0 = euler_matrix(1, 2, 3, 'syxz')
+ >>> al, be, ga = euler_from_matrix(R0, 'syxz')
+ >>> R1 = euler_matrix(al, be, ga, 'syxz')
+ >>> numpy.allclose(R0, R1)
+ True
+ >>> angles = (4.0*math.pi) * (numpy.random.random(3) - 0.5)
+ >>> for axes in _AXES2TUPLE.keys():
+ ... R0 = euler_matrix(axes=axes, *angles)
+ ... R1 = euler_matrix(axes=axes, *euler_from_matrix(R0, axes))
+ ... if not numpy.allclose(R0, R1): print axes, "failed"
+
+ """
+ try:
+ firstaxis, parity, repetition, frame = _AXES2TUPLE[axes.lower()]
+ except (AttributeError, KeyError):
+ _ = _TUPLE2AXES[axes]
+ firstaxis, parity, repetition, frame = axes
+
+ i = firstaxis
+ j = _NEXT_AXIS[i+parity]
+ k = _NEXT_AXIS[i-parity+1]
+
+ M = numpy.array(matrix, dtype=numpy.float64, copy=False)[:3, :3]
+ if repetition:
+ sy = math.sqrt(M[i, j]*M[i, j] + M[i, k]*M[i, k])
+ if sy > _EPS:
+ ax = math.atan2( M[i, j], M[i, k])
+ ay = math.atan2( sy, M[i, i])
+ az = math.atan2( M[j, i], -M[k, i])
+ else:
+ ax = math.atan2(-M[j, k], M[j, j])
+ ay = math.atan2( sy, M[i, i])
+ az = 0.0
+ else:
+ cy = math.sqrt(M[i, i]*M[i, i] + M[j, i]*M[j, i])
+ if cy > _EPS:
+ ax = math.atan2( M[k, j], M[k, k])
+ ay = math.atan2(-M[k, i], cy)
+ az = math.atan2( M[j, i], M[i, i])
+ else:
+ ax = math.atan2(-M[j, k], M[j, j])
+ ay = math.atan2(-M[k, i], cy)
+ az = 0.0
+
+ if parity:
+ ax, ay, az = -ax, -ay, -az
+ if frame:
+ ax, az = az, ax
+ return ax, ay, az
+
+
+def euler_from_quaternion(quaternion, axes='sxyz'):
+ """Return Euler angles from quaternion for specified axis sequence.
+
+ >>> angles = euler_from_quaternion([0.06146124, 0, 0, 0.99810947])
+ >>> numpy.allclose(angles, [0.123, 0, 0])
+ True
+
+ """
+ return euler_from_matrix(quaternion_matrix(quaternion), axes)
+
+
+def quaternion_from_euler(ai, aj, ak, axes='sxyz'):
+ """Return quaternion from Euler angles and axis sequence.
+
+ ai, aj, ak : Euler's roll, pitch and yaw angles
+ axes : One of 24 axis sequences as string or encoded tuple
+
+ >>> q = quaternion_from_euler(1, 2, 3, 'ryxz')
+ >>> numpy.allclose(q, [0.310622, -0.718287, 0.444435, 0.435953])
+ True
+
+ """
+ try:
+ firstaxis, parity, repetition, frame = _AXES2TUPLE[axes.lower()]
+ except (AttributeError, KeyError):
+ _ = _TUPLE2AXES[axes]
+ firstaxis, parity, repetition, frame = axes
+
+ i = firstaxis
+ j = _NEXT_AXIS[i+parity]
+ k = _NEXT_AXIS[i-parity+1]
+
+ if frame:
+ ai, ak = ak, ai
+ if parity:
+ aj = -aj
+
+ ai /= 2.0
+ aj /= 2.0
+ ak /= 2.0
+ ci = math.cos(ai)
+ si = math.sin(ai)
+ cj = math.cos(aj)
+ sj = math.sin(aj)
+ ck = math.cos(ak)
+ sk = math.sin(ak)
+ cc = ci*ck
+ cs = ci*sk
+ sc = si*ck
+ ss = si*sk
+
+ quaternion = numpy.empty((4, ), dtype=numpy.float64)
+ if repetition:
+ quaternion[i] = cj*(cs + sc)
+ quaternion[j] = sj*(cc + ss)
+ quaternion[k] = sj*(cs - sc)
+ quaternion[3] = cj*(cc - ss)
+ else:
+ quaternion[i] = cj*sc - sj*cs
+ quaternion[j] = cj*ss + sj*cc
+ quaternion[k] = cj*cs - sj*sc
+ quaternion[3] = cj*cc + sj*ss
+ if parity:
+ quaternion[j] *= -1
+
+ return quaternion
+
+
+def quaternion_about_axis(angle, axis):
+ """Return quaternion for rotation about axis.
+
+ >>> q = quaternion_about_axis(0.123, (1, 0, 0))
+ >>> numpy.allclose(q, [0.06146124, 0, 0, 0.99810947])
+ True
+
+ """
+ quaternion = numpy.zeros((4, ), dtype=numpy.float64)
+ quaternion[:3] = axis[:3]
+ qlen = vector_norm(quaternion)
+ if qlen > _EPS:
+ quaternion *= math.sin(angle/2.0) / qlen
+ quaternion[3] = math.cos(angle/2.0)
+ return quaternion
+
+
+def quaternion_matrix(quaternion):
+ """Return homogeneous rotation matrix from quaternion.
+
+ >>> R = quaternion_matrix([0.06146124, 0, 0, 0.99810947])
+ >>> numpy.allclose(R, rotation_matrix(0.123, (1, 0, 0)))
+ True
+
+ """
+ q = numpy.array(quaternion[:4], dtype=numpy.float64, copy=True)
+ nq = numpy.dot(q, q)
+ if nq < _EPS:
+ return numpy.identity(4)
+ q *= math.sqrt(2.0 / nq)
+ q = numpy.outer(q, q)
+ return numpy.array((
+ (1.0-q[1, 1]-q[2, 2], q[0, 1]-q[2, 3], q[0, 2]+q[1, 3], 0.0),
+ ( q[0, 1]+q[2, 3], 1.0-q[0, 0]-q[2, 2], q[1, 2]-q[0, 3], 0.0),
+ ( q[0, 2]-q[1, 3], q[1, 2]+q[0, 3], 1.0-q[0, 0]-q[1, 1], 0.0),
+ ( 0.0, 0.0, 0.0, 1.0)
+ ), dtype=numpy.float64)
+
+
+def quaternion_from_matrix(matrix):
+ """Return quaternion from rotation matrix.
+
+ >>> R = rotation_matrix(0.123, (1, 2, 3))
+ >>> q = quaternion_from_matrix(R)
+ >>> numpy.allclose(q, [0.0164262, 0.0328524, 0.0492786, 0.9981095])
+ True
+
+ """
+ q = numpy.empty((4, ), dtype=numpy.float64)
+ M = numpy.array(matrix, dtype=numpy.float64, copy=False)[:4, :4]
+ t = numpy.trace(M)
+ if t > M[3, 3]:
+ q[3] = t
+ q[2] = M[1, 0] - M[0, 1]
+ q[1] = M[0, 2] - M[2, 0]
+ q[0] = M[2, 1] - M[1, 2]
+ else:
+ i, j, k = 0, 1, 2
+ if M[1, 1] > M[0, 0]:
+ i, j, k = 1, 2, 0
+ if M[2, 2] > M[i, i]:
+ i, j, k = 2, 0, 1
+ t = M[i, i] - (M[j, j] + M[k, k]) + M[3, 3]
+ q[i] = t
+ q[j] = M[i, j] + M[j, i]
+ q[k] = M[k, i] + M[i, k]
+ q[3] = M[k, j] - M[j, k]
+ q *= 0.5 / math.sqrt(t * M[3, 3])
+ return q
+
+
+def quaternion_multiply(quaternion1, quaternion0):
+ """Return multiplication of two quaternions.
+
+ >>> q = quaternion_multiply([1, -2, 3, 4], [-5, 6, 7, 8])
+ >>> numpy.allclose(q, [-44, -14, 48, 28])
+ True
+
+ """
+ x0, y0, z0, w0 = quaternion0
+ x1, y1, z1, w1 = quaternion1
+ return numpy.array((
+ x1*w0 + y1*z0 - z1*y0 + w1*x0,
+ -x1*z0 + y1*w0 + z1*x0 + w1*y0,
+ x1*y0 - y1*x0 + z1*w0 + w1*z0,
+ -x1*x0 - y1*y0 - z1*z0 + w1*w0), dtype=numpy.float64)
+
+
+def quaternion_conjugate(quaternion):
+ """Return conjugate of quaternion.
+
+ >>> q0 = random_quaternion()
+ >>> q1 = quaternion_conjugate(q0)
+ >>> q1[3] == q0[3] and all(q1[:3] == -q0[:3])
+ True
+
+ """
+ return numpy.array((-quaternion[0], -quaternion[1],
+ -quaternion[2], quaternion[3]), dtype=numpy.float64)
+
+
+def quaternion_inverse(quaternion):
+ """Return inverse of quaternion.
+
+ >>> q0 = random_quaternion()
+ >>> q1 = quaternion_inverse(q0)
+ >>> numpy.allclose(quaternion_multiply(q0, q1), [0, 0, 0, 1])
+ True
+
+ """
+ return quaternion_conjugate(quaternion) / numpy.dot(quaternion, quaternion)
+
+
+def quaternion_slerp(quat0, quat1, fraction, spin=0, shortestpath=True):
+ """Return spherical linear interpolation between two quaternions.
+
+ >>> q0 = random_quaternion()
+ >>> q1 = random_quaternion()
+ >>> q = quaternion_slerp(q0, q1, 0.0)
+ >>> numpy.allclose(q, q0)
+ True
+ >>> q = quaternion_slerp(q0, q1, 1.0, 1)
+ >>> numpy.allclose(q, q1)
+ True
+ >>> q = quaternion_slerp(q0, q1, 0.5)
+ >>> angle = math.acos(numpy.dot(q0, q))
+ >>> numpy.allclose(2.0, math.acos(numpy.dot(q0, q1)) / angle) or \
+ numpy.allclose(2.0, math.acos(-numpy.dot(q0, q1)) / angle)
+ True
+
+ """
+ q0 = unit_vector(quat0[:4])
+ q1 = unit_vector(quat1[:4])
+ if fraction == 0.0:
+ return q0
+ elif fraction == 1.0:
+ return q1
+ d = numpy.dot(q0, q1)
+ if abs(abs(d) - 1.0) < _EPS:
+ return q0
+ if shortestpath and d < 0.0:
+ # invert rotation
+ d = -d
+ q1 *= -1.0
+ angle = math.acos(d) + spin * math.pi
+ if abs(angle) < _EPS:
+ return q0
+ isin = 1.0 / math.sin(angle)
+ q0 *= math.sin((1.0 - fraction) * angle) * isin
+ q1 *= math.sin(fraction * angle) * isin
+ q0 += q1
+ return q0
+
+
+def random_quaternion(rand=None):
+ """Return uniform random unit quaternion.
+
+ rand: array like or None
+ Three independent random variables that are uniformly distributed
+ between 0 and 1.
+
+ >>> q = random_quaternion()
+ >>> numpy.allclose(1.0, vector_norm(q))
+ True
+ >>> q = random_quaternion(numpy.random.random(3))
+ >>> q.shape
+ (4,)
+
+ """
+ if rand is None:
+ rand = numpy.random.rand(3)
+ else:
+ assert len(rand) == 3
+ r1 = numpy.sqrt(1.0 - rand[0])
+ r2 = numpy.sqrt(rand[0])
+ pi2 = math.pi * 2.0
+ t1 = pi2 * rand[1]
+ t2 = pi2 * rand[2]
+ return numpy.array((numpy.sin(t1)*r1,
+ numpy.cos(t1)*r1,
+ numpy.sin(t2)*r2,
+ numpy.cos(t2)*r2), dtype=numpy.float64)
+
+
+def random_rotation_matrix(rand=None):
+ """Return uniform random rotation matrix.
+
+ rnd: array like
+ Three independent random variables that are uniformly distributed
+ between 0 and 1 for each returned quaternion.
+
+ >>> R = random_rotation_matrix()
+ >>> numpy.allclose(numpy.dot(R.T, R), numpy.identity(4))
+ True
+
+ """
+ return quaternion_matrix(random_quaternion(rand))
+
+
+class Arcball(object):
+ """Virtual Trackball Control.
+
+ >>> ball = Arcball()
+ >>> ball = Arcball(initial=numpy.identity(4))
+ >>> ball.place([320, 320], 320)
+ >>> ball.down([500, 250])
+ >>> ball.drag([475, 275])
+ >>> R = ball.matrix()
+ >>> numpy.allclose(numpy.sum(R), 3.90583455)
+ True
+ >>> ball = Arcball(initial=[0, 0, 0, 1])
+ >>> ball.place([320, 320], 320)
+ >>> ball.setaxes([1,1,0], [-1, 1, 0])
+ >>> ball.setconstrain(True)
+ >>> ball.down([400, 200])
+ >>> ball.drag([200, 400])
+ >>> R = ball.matrix()
+ >>> numpy.allclose(numpy.sum(R), 0.2055924)
+ True
+ >>> ball.next()
+
+ """
+
+ def __init__(self, initial=None):
+ """Initialize virtual trackball control.
+
+ initial : quaternion or rotation matrix
+
+ """
+ self._axis = None
+ self._axes = None
+ self._radius = 1.0
+ self._center = [0.0, 0.0]
+ self._vdown = numpy.array([0, 0, 1], dtype=numpy.float64)
+ self._constrain = False
+
+ if initial is None:
+ self._qdown = numpy.array([0, 0, 0, 1], dtype=numpy.float64)
+ else:
+ initial = numpy.array(initial, dtype=numpy.float64)
+ if initial.shape == (4, 4):
+ self._qdown = quaternion_from_matrix(initial)
+ elif initial.shape == (4, ):
+ initial /= vector_norm(initial)
+ self._qdown = initial
+ else:
+ raise ValueError("initial not a quaternion or matrix.")
+
+ self._qnow = self._qpre = self._qdown
+
+ def place(self, center, radius):
+ """Place Arcball, e.g. when window size changes.
+
+ center : sequence[2]
+ Window coordinates of trackball center.
+ radius : float
+ Radius of trackball in window coordinates.
+
+ """
+ self._radius = float(radius)
+ self._center[0] = center[0]
+ self._center[1] = center[1]
+
+ def setaxes(self, *axes):
+ """Set axes to constrain rotations."""
+ if axes is None:
+ self._axes = None
+ else:
+ self._axes = [unit_vector(axis) for axis in axes]
+
+ def setconstrain(self, constrain):
+ """Set state of constrain to axis mode."""
+ self._constrain = constrain == True
+
+ def getconstrain(self):
+ """Return state of constrain to axis mode."""
+ return self._constrain
+
+ def down(self, point):
+ """Set initial cursor window coordinates and pick constrain-axis."""
+ self._vdown = arcball_map_to_sphere(point, self._center, self._radius)
+ self._qdown = self._qpre = self._qnow
+
+ if self._constrain and self._axes is not None:
+ self._axis = arcball_nearest_axis(self._vdown, self._axes)
+ self._vdown = arcball_constrain_to_axis(self._vdown, self._axis)
+ else:
+ self._axis = None
+
+ def drag(self, point):
+ """Update current cursor window coordinates."""
+ vnow = arcball_map_to_sphere(point, self._center, self._radius)
+
+ if self._axis is not None:
+ vnow = arcball_constrain_to_axis(vnow, self._axis)
+
+ self._qpre = self._qnow
+
+ t = numpy.cross(self._vdown, vnow)
+ if numpy.dot(t, t) < _EPS:
+ self._qnow = self._qdown
+ else:
+ q = [t[0], t[1], t[2], numpy.dot(self._vdown, vnow)]
+ self._qnow = quaternion_multiply(q, self._qdown)
+
+ def next(self, acceleration=0.0):
+ """Continue rotation in direction of last drag."""
+ q = quaternion_slerp(self._qpre, self._qnow, 2.0+acceleration, False)
+ self._qpre, self._qnow = self._qnow, q
+
+ def matrix(self):
+ """Return homogeneous rotation matrix."""
+ return quaternion_matrix(self._qnow)
+
+
+def arcball_map_to_sphere(point, center, radius):
+ """Return unit sphere coordinates from window coordinates."""
+ v = numpy.array(((point[0] - center[0]) / radius,
+ (center[1] - point[1]) / radius,
+ 0.0), dtype=numpy.float64)
+ n = v[0]*v[0] + v[1]*v[1]
+ if n > 1.0:
+ v /= math.sqrt(n) # position outside of sphere
+ else:
+ v[2] = math.sqrt(1.0 - n)
+ return v
+
+
+def arcball_constrain_to_axis(point, axis):
+ """Return sphere point perpendicular to axis."""
+ v = numpy.array(point, dtype=numpy.float64, copy=True)
+ a = numpy.array(axis, dtype=numpy.float64, copy=True)
+ v -= a * numpy.dot(a, v) # on plane
+ n = vector_norm(v)
+ if n > _EPS:
+ if v[2] < 0.0:
+ v *= -1.0
+ v /= n
+ return v
+ if a[2] == 1.0:
+ return numpy.array([1, 0, 0], dtype=numpy.float64)
+ return unit_vector([-a[1], a[0], 0])
+
+
+def arcball_nearest_axis(point, axes):
+ """Return axis, which arc is nearest to point."""
+ point = numpy.array(point, dtype=numpy.float64, copy=False)
+ nearest = None
+ mx = -1.0
+ for axis in axes:
+ t = numpy.dot(arcball_constrain_to_axis(point, axis), point)
+ if t > mx:
+ nearest = axis
+ mx = t
+ return nearest
+
+
+# epsilon for testing whether a number is close to zero
+_EPS = numpy.finfo(float).eps * 4.0
+
+# axis sequences for Euler angles
+_NEXT_AXIS = [1, 2, 0, 1]
+
+# map axes strings to/from tuples of inner axis, parity, repetition, frame
+_AXES2TUPLE = {
+ 'sxyz': (0, 0, 0, 0), 'sxyx': (0, 0, 1, 0), 'sxzy': (0, 1, 0, 0),
+ 'sxzx': (0, 1, 1, 0), 'syzx': (1, 0, 0, 0), 'syzy': (1, 0, 1, 0),
+ 'syxz': (1, 1, 0, 0), 'syxy': (1, 1, 1, 0), 'szxy': (2, 0, 0, 0),
+ 'szxz': (2, 0, 1, 0), 'szyx': (2, 1, 0, 0), 'szyz': (2, 1, 1, 0),
+ 'rzyx': (0, 0, 0, 1), 'rxyx': (0, 0, 1, 1), 'ryzx': (0, 1, 0, 1),
+ 'rxzx': (0, 1, 1, 1), 'rxzy': (1, 0, 0, 1), 'ryzy': (1, 0, 1, 1),
+ 'rzxy': (1, 1, 0, 1), 'ryxy': (1, 1, 1, 1), 'ryxz': (2, 0, 0, 1),
+ 'rzxz': (2, 0, 1, 1), 'rxyz': (2, 1, 0, 1), 'rzyz': (2, 1, 1, 1)}
+
+_TUPLE2AXES = dict((v, k) for k, v in _AXES2TUPLE.items())
+
+# helper functions
+
+def vector_norm(data, axis=None, out=None):
+ """Return length, i.e. eucledian norm, of ndarray along axis.
+
+ >>> v = numpy.random.random(3)
+ >>> n = vector_norm(v)
+ >>> numpy.allclose(n, numpy.linalg.norm(v))
+ True
+ >>> v = numpy.random.rand(6, 5, 3)
+ >>> n = vector_norm(v, axis=-1)
+ >>> numpy.allclose(n, numpy.sqrt(numpy.sum(v*v, axis=2)))
+ True
+ >>> n = vector_norm(v, axis=1)
+ >>> numpy.allclose(n, numpy.sqrt(numpy.sum(v*v, axis=1)))
+ True
+ >>> v = numpy.random.rand(5, 4, 3)
+ >>> n = numpy.empty((5, 3), dtype=numpy.float64)
+ >>> vector_norm(v, axis=1, out=n)
+ >>> numpy.allclose(n, numpy.sqrt(numpy.sum(v*v, axis=1)))
+ True
+ >>> vector_norm([])
+ 0.0
+ >>> vector_norm([1.0])
+ 1.0
+
+ """
+ data = numpy.array(data, dtype=numpy.float64, copy=True)
+ if out is None:
+ if data.ndim == 1:
+ return math.sqrt(numpy.dot(data, data))
+ data *= data
+ out = numpy.atleast_1d(numpy.sum(data, axis=axis))
+ numpy.sqrt(out, out)
+ return out
+ else:
+ data *= data
+ numpy.sum(data, axis=axis, out=out)
+ numpy.sqrt(out, out)
+
+
+def unit_vector(data, axis=None, out=None):
+ """Return ndarray normalized by length, i.e. eucledian norm, along axis.
+
+ >>> v0 = numpy.random.random(3)
+ >>> v1 = unit_vector(v0)
+ >>> numpy.allclose(v1, v0 / numpy.linalg.norm(v0))
+ True
+ >>> v0 = numpy.random.rand(5, 4, 3)
+ >>> v1 = unit_vector(v0, axis=-1)
+ >>> v2 = v0 / numpy.expand_dims(numpy.sqrt(numpy.sum(v0*v0, axis=2)), 2)
+ >>> numpy.allclose(v1, v2)
+ True
+ >>> v1 = unit_vector(v0, axis=1)
+ >>> v2 = v0 / numpy.expand_dims(numpy.sqrt(numpy.sum(v0*v0, axis=1)), 1)
+ >>> numpy.allclose(v1, v2)
+ True
+ >>> v1 = numpy.empty((5, 4, 3), dtype=numpy.float64)
+ >>> unit_vector(v0, axis=1, out=v1)
+ >>> numpy.allclose(v1, v2)
+ True
+ >>> list(unit_vector([]))
+ []
+ >>> list(unit_vector([1.0]))
+ [1.0]
+
+ """
+ if out is None:
+ data = numpy.array(data, dtype=numpy.float64, copy=True)
+ if data.ndim == 1:
+ data /= math.sqrt(numpy.dot(data, data))
+ return data
+ else:
+ if out is not data:
+ out[:] = numpy.array(data, copy=False)
+ data = out
+ length = numpy.atleast_1d(numpy.sum(data*data, axis))
+ numpy.sqrt(length, length)
+ if axis is not None:
+ length = numpy.expand_dims(length, axis)
+ data /= length
+ if out is None:
+ return data
+
+
+def random_vector(size):
+ """Return array of random doubles in the half-open interval [0.0, 1.0).
+
+ >>> v = random_vector(10000)
+ >>> numpy.all(v >= 0.0) and numpy.all(v < 1.0)
+ True
+ >>> v0 = random_vector(10)
+ >>> v1 = random_vector(10)
+ >>> numpy.any(v0 == v1)
+ False
+
+ """
+ return numpy.random.random(size)
+
+
+def inverse_matrix(matrix):
+ """Return inverse of square transformation matrix.
+
+ >>> M0 = random_rotation_matrix()
+ >>> M1 = inverse_matrix(M0.T)
+ >>> numpy.allclose(M1, numpy.linalg.inv(M0.T))
+ True
+ >>> for size in range(1, 7):
+ ... M0 = numpy.random.rand(size, size)
+ ... M1 = inverse_matrix(M0)
+ ... if not numpy.allclose(M1, numpy.linalg.inv(M0)): print size
+
+ """
+ return numpy.linalg.inv(matrix)
+
+
+def concatenate_matrices(*matrices):
+ """Return concatenation of series of transformation matrices.
+
+ >>> M = numpy.random.rand(16).reshape((4, 4)) - 0.5
+ >>> numpy.allclose(M, concatenate_matrices(M))
+ True
+ >>> numpy.allclose(numpy.dot(M, M.T), concatenate_matrices(M, M.T))
+ True
+
+ """
+ M = numpy.identity(4)
+ for i in matrices:
+ M = numpy.dot(M, i)
+ return M
+
+
+def is_same_transform(matrix0, matrix1):
+ """Return True if two matrices perform same transformation.
+
+ >>> is_same_transform(numpy.identity(4), numpy.identity(4))
+ True
+ >>> is_same_transform(numpy.identity(4), random_rotation_matrix())
+ False
+
+ """
+ matrix0 = numpy.array(matrix0, dtype=numpy.float64, copy=True)
+ matrix0 /= matrix0[3, 3]
+ matrix1 = numpy.array(matrix1, dtype=numpy.float64, copy=True)
+ matrix1 /= matrix1[3, 3]
+ return numpy.allclose(matrix0, matrix1)
+
+
+def _import_module(module_name, warn=True, prefix='_py_', ignore='_'):
+ """Try import all public attributes from module into global namespace.
+
+ Existing attributes with name clashes are renamed with prefix.
+ Attributes starting with underscore are ignored by default.
+
+ Return True on successful import.
+
+ """
+ try:
+ module = __import__(module_name)
+ except ImportError:
+ if warn:
+ warnings.warn("Failed to import module " + module_name)
+ else:
+ for attr in dir(module):
+ if ignore and attr.startswith(ignore):
+ continue
+ if prefix:
+ if attr in globals():
+ globals()[prefix + attr] = globals()[attr]
+ elif warn:
+ warnings.warn("No Python implementation of " + attr)
+ globals()[attr] = getattr(module, attr)
+ return True
diff --git a/tmp_data/scene.glb b/tmp_data/scene.glb
new file mode 100644
index 0000000000000000000000000000000000000000..80c11dfc2e335f39ea2e3520eaac2ffa9fa84332
Binary files /dev/null and b/tmp_data/scene.glb differ
diff --git a/wandb_logs/StructDiffusion/ConditionalPoseDiffusion/checkpoints/epoch=191-step=96000.ckpt b/wandb_logs/StructDiffusion/ConditionalPoseDiffusion/checkpoints/epoch=191-step=96000.ckpt
new file mode 100644
index 0000000000000000000000000000000000000000..0776091759fcafc1a2c92801aad2b7656a9486c3
--- /dev/null
+++ b/wandb_logs/StructDiffusion/ConditionalPoseDiffusion/checkpoints/epoch=191-step=96000.ckpt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b05ed7f605c30c240051c22db501ede6ea2fea79b3b66a23bce28fb01defa463
+size 59444321