diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..25625008b6a8963639be071f8cbc067f692d51ce --- /dev/null +++ b/app.py @@ -0,0 +1,131 @@ +import os +import argparse +import torch +import trimesh +import numpy as np +import pytorch_lightning as pl +import gradio as gr +from omegaconf import OmegaConf + +import sys +sys.path.append('./src') + +from StructDiffusion.data.semantic_arrangement_demo import SemanticArrangementDataset +from StructDiffusion.language.tokenizer import Tokenizer +from StructDiffusion.models.pl_models import ConditionalPoseDiffusionModel +from StructDiffusion.diffusion.sampler import Sampler +from StructDiffusion.diffusion.pose_conversion import get_struct_objs_poses +from StructDiffusion.utils.files import get_checkpoint_path_from_dir +from StructDiffusion.utils.batch_inference import move_pc_and_create_scene_simple, visualize_batch_pcs +from StructDiffusion.utils.rearrangement import show_pcs_with_trimesh + + +class Infer_Wrapper: + + def __init__(self, args, cfg): + + # load + pl.seed_everything(args.eval_random_seed) + self.device = (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")) + + checkpoint_dir = os.path.join(cfg.WANDB.save_dir, cfg.WANDB.project, args.checkpoint_id, "checkpoints") + checkpoint_path = get_checkpoint_path_from_dir(checkpoint_dir) + + self.tokenizer = Tokenizer(cfg.DATASET.vocab_dir) + # override ignore_rgb for visualization + cfg.DATASET.ignore_rgb = False + self.dataset = SemanticArrangementDataset(tokenizer=self.tokenizer, **cfg.DATASET) + + self.sampler = Sampler(ConditionalPoseDiffusionModel, checkpoint_path, self.device) + + def run(self, di): + + # di = np.random.choice(len(self.dataset)) + + raw_datum = self.dataset.get_raw_data(di) + print(self.tokenizer.convert_structure_params_to_natural_language(raw_datum["sentence"])) + datum = self.dataset.convert_to_tensors(raw_datum, self.tokenizer) + batch = self.dataset.single_datum_to_batch(datum, args.num_samples, self.device, inference_mode=True) + + num_poses = datum["goal_poses"].shape[0] + xs = self.sampler.sample(batch, num_poses) + + struct_pose, pc_poses_in_struct = get_struct_objs_poses(xs[0]) + new_obj_xyzs = move_pc_and_create_scene_simple(batch["pcs"], struct_pose, pc_poses_in_struct) + + # vis + vis_obj_xyzs = new_obj_xyzs[:3] + if torch.is_tensor(vis_obj_xyzs): + if vis_obj_xyzs.is_cuda: + vis_obj_xyzs = vis_obj_xyzs.detach().cpu() + vis_obj_xyzs = vis_obj_xyzs.numpy() + + # for bi, vis_obj_xyz in enumerate(vis_obj_xyzs): + # if verbose: + # print("example {}".format(bi)) + # print(vis_obj_xyz.shape) + # + # if trimesh: + # show_pcs_with_trimesh([xyz[:, :3] for xyz in vis_obj_xyz], [xyz[:, 3:] for xyz in vis_obj_xyz]) + vis_obj_xyz = vis_obj_xyzs[0] + scene = show_pcs_with_trimesh([xyz[:, :3] for xyz in vis_obj_xyz], [xyz[:, 3:] for xyz in vis_obj_xyz], return_scene=True) + + scene_filename = "./tmp_data/scene.glb" + scene.export(scene_filename) + + # pc_filename = "/home/weiyu/Research/StructDiffusion/StructDiffusion/interactive_demo/tmp_data/pc.glb" + # scene_filename = "/home/weiyu/Research/StructDiffusion/StructDiffusion/interactive_demo/tmp_data/scene.glb" + # + # vis_obj_xyz = vis_obj_xyz.reshape(-1, 6) + # vis_pc = trimesh.PointCloud(vis_obj_xyz[:, :3], colors=np.concatenate([vis_obj_xyz[:, 3:] * 255, np.ones([vis_obj_xyz.shape[0], 1]) * 255], axis=-1)) + # vis_pc.export(pc_filename) + # + # scene = trimesh.Scene() + # # add the coordinate frame first + # # geom = trimesh.creation.axis(0.01) + # # scene.add_geometry(geom) + # table = trimesh.creation.box(extents=[1.0, 1.0, 0.02]) + # table.apply_translation([0.5, 0, -0.01]) + # table.visual.vertex_colors = [150, 111, 87, 125] + # scene.add_geometry(table) + # # bounds = trimesh.creation.box(extents=[4.0, 4.0, 4.0]) + # # bounds = trimesh.creation.icosphere(subdivisions=3, radius=3.1) + # # bounds.apply_translation([0, 0, 0]) + # # bounds.visual.vertex_colors = [30, 30, 30, 30] + # # scene.add_geometry(bounds) + # # RT_4x4 = np.array([[-0.39560353822208355, -0.9183993826406329, 0.006357240869497738, 0.2651463080169481], + # # [-0.797630370081598, 0.3401340617616391, -0.4980909683511864, 0.2225696480721997], + # # [0.45528412367406523, -0.2021172778236285, -0.8671014777611122, 0.9449050652025951], + # # [0.0, 0.0, 0.0, 1.0]]) + # # RT_4x4 = np.linalg.inv(RT_4x4) + # # RT_4x4 = RT_4x4 @ np.diag([1, -1, -1, 1]) + # # scene.camera_transform = RT_4x4 + # + # mesh_list = trimesh.util.concatenate(scene.dump()) + # print(mesh_list) + # trimesh.io.export.export_mesh(mesh_list, scene_filename, file_type='obj') + + return scene_filename + + +args = OmegaConf.create() +args.base_config_file = "./configs/base.yaml" +args.config_file = "./configs/conditional_pose_diffusion.yaml" +args.checkpoint_id = "ConditionalPoseDiffusion" +args.eval_random_seed = 42 +args.num_samples = 1 + +base_cfg = OmegaConf.load(args.base_config_file) +cfg = OmegaConf.load(args.config_file) +cfg = OmegaConf.merge(base_cfg, cfg) + +infer_wrapper = Infer_Wrapper(args, cfg) + +demo = gr.Interface( + fn=infer_wrapper.run, + inputs=gr.Slider(0, len(infer_wrapper.dataset)), + # clear color range [0-1.0] + outputs=gr.Model3D(clear_color=[0, 0, 0, 0], label="3D Model") +) + +demo.launch() \ No newline at end of file diff --git a/configs/base.yaml b/configs/base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0f57f0e51988fbb67ca874b8520ee214e12f19bb --- /dev/null +++ b/configs/base.yaml @@ -0,0 +1,3 @@ +base_dirs: + data: data + wandb_dir: wandb_logs \ No newline at end of file diff --git a/configs/conditional_pose_diffusion.yaml b/configs/conditional_pose_diffusion.yaml new file mode 100644 index 0000000000000000000000000000000000000000..71086185256cb26421c452e015047aabcaa960ce --- /dev/null +++ b/configs/conditional_pose_diffusion.yaml @@ -0,0 +1,81 @@ +random_seed: 1 + +WANDB: + project: StructDiffusion + save_dir: ${base_dirs.wandb_dir} + name: conditional_pose_diffusion + +DATASET: + data_root: ${base_dirs.data} + vocab_dir: ${base_dirs.data}/type_vocabs_coarse.json + + # important + use_virtual_structure_frame: True + ignore_distractor_objects: True + ignore_rgb: True + + # the following are determined by the dataset + max_num_target_objects: 7 + max_num_distractor_objects: 5 + max_num_shape_parameters: 5 + # set to zeros because they are not used for now + max_num_rearrange_features: 0 + max_num_anchor_features: 0 + + num_pts: 1024 + filter_num_moved_objects_range: + data_augmentation: False + +DATALOADER: + batch_size: 64 + num_workers: 8 + pin_memory: True + +MODEL: + # transformer encoder + encoder_input_dim: 256 + num_attention_heads: 8 + encoder_hidden_dim: 512 + encoder_dropout: 0.0 + encoder_activation: relu + encoder_num_layers: 8 + # output head + structure_dropout: 0 + object_dropout: 0 + # pc encoder + ignore_rgb: ${DATASET.ignore_rgb} + pc_emb_dim: 256 + posed_pc_emb_dim: 80 + # pose encoder + pose_emb_dim: 80 + # language + word_emb_dim: 160 + # diffusion step + time_emb_dim: 80 + # sequence embeddings + # max_num_target_objects (+ max_num_distractor_objects if not ignore_distractor_objects) + max_seq_size: 7 + max_token_type_size: 4 + seq_pos_emb_dim: 8 + seq_type_emb_dim: 8 + # virtual frame + use_virtual_structure_frame: ${DATASET.use_virtual_structure_frame} + +NOISE_SCHEDULE: + timesteps: 200 + +LOSS: + type: huber + +OPTIMIZER: + lr: 0.0001 + weight_decay: 0 #0.0001 + # lr_restart: 3000 + # warmup: 10 + +TRAINER: + max_epochs: 200 + gradient_clip_val: 1.0 + gpus: 1 + deterministic: False + # enable_progress_bar: False \ No newline at end of file diff --git a/configs/pairwise_collision.yaml b/configs/pairwise_collision.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d53be3bca4468d0e054ebf2f6edad233a4fb19dd --- /dev/null +++ b/configs/pairwise_collision.yaml @@ -0,0 +1,42 @@ +random_seed: 1 + +WANDB: + project: StructDiffusion + save_dir: ${base_dirs.wandb_dir} + name: pairwise_collision + +DATASET: + urdf_pc_idx_file: ${base_dirs.pairwise_collision_data}/urdf_pc_idx.pkl + collision_data_dir: ${base_dirs.pairwise_collision_data} + + # important + num_pts: 1024 + num_scene_pts: 2048 + normalize_pc: True + random_rotation: True + data_augmentation: False + +DATALOADER: + batch_size: 32 + num_workers: 8 + pin_memory: True + +MODEL: + max_num_objects: 2 + include_env_pc: False + pct_random_sampling: True + +LOSS: + type: Focal + focal_gamma: 2 + +OPTIMIZER: + lr: 0.0001 + weight_decay: 0 + +TRAINER: + max_epochs: 200 + gradient_clip_val: 1.0 + gpus: 1 + deterministic: False + # enable_progress_bar: False \ No newline at end of file diff --git a/data/data00000000.h5 b/data/data00000000.h5 new file mode 100755 index 0000000000000000000000000000000000000000..ef5df9df8fc9935abde7e9888abad26f7f72e10c --- /dev/null +++ b/data/data00000000.h5 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:947574252625d338b9f37217eacf61f520136e27b458b6d3e65330339e8b299c +size 1271489 diff --git a/data/data00000002.h5 b/data/data00000002.h5 new file mode 100755 index 0000000000000000000000000000000000000000..17e1ce30b8ae08f51059f46d2c1350e750713627 --- /dev/null +++ b/data/data00000002.h5 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3302432de555fed767c5b0d99c35ca01d5e4ac38cf4a0760b8ccb456b432e0e0 +size 3235242 diff --git a/data/data00000003.h5 b/data/data00000003.h5 new file mode 100755 index 0000000000000000000000000000000000000000..b0875002cae581f22752bb4da8981eb2455bf407 --- /dev/null +++ b/data/data00000003.h5 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0b907ba7c3a17f98a438617b462b2a4d3d3f8593c2dc47feb5a6cc3da8c034fc +size 2059708 diff --git a/data/data00000004.h5 b/data/data00000004.h5 new file mode 100755 index 0000000000000000000000000000000000000000..61ba41edc4070d40a6628aa154154ac513dd376b --- /dev/null +++ b/data/data00000004.h5 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d8ec0136dd4055d304e9b7f5697b79613099b8f8f1e5eec94281f22d8d47cca1 +size 2591656 diff --git a/data/data00000006.h5 b/data/data00000006.h5 new file mode 100755 index 0000000000000000000000000000000000000000..78f26a2a27967b099024348c8955bb3b92c64997 --- /dev/null +++ b/data/data00000006.h5 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0e74ebf185b0af58df0fa2483d5fd58a12b3b62ccac27ff665f35c5c7a13b8d8 +size 1572332 diff --git a/data/data00000008.h5 b/data/data00000008.h5 new file mode 100755 index 0000000000000000000000000000000000000000..0cefd1068e399ee5423bfe4e5821d5d903fbc24f --- /dev/null +++ b/data/data00000008.h5 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:db015354a9d53e6fbaf0b040ce226484150b0af226a5c13a0b9f5cb9961db73c +size 2167265 diff --git a/data/data00000009.h5 b/data/data00000009.h5 new file mode 100755 index 0000000000000000000000000000000000000000..944f7f52fe5a8fbe75f05b178bd1639501d8a757 --- /dev/null +++ b/data/data00000009.h5 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:990ad13f423d9089b30de81d002d23d9d00cf3e007fd7073793cbec03c456ebb +size 3607752 diff --git a/data/data00000012.h5 b/data/data00000012.h5 new file mode 100755 index 0000000000000000000000000000000000000000..4d03a37a210f47f01463e9567c603fb9f9451d16 --- /dev/null +++ b/data/data00000012.h5 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:93161f5666c54dbc259c9efa516b67340613b592a1ed42e6c63d4cc8a495002a +size 2525622 diff --git a/data/data00000013.h5 b/data/data00000013.h5 new file mode 100755 index 0000000000000000000000000000000000000000..51389ce34ced6cab838f8898f6c9344f138eec82 --- /dev/null +++ b/data/data00000013.h5 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:94c9cfe6d9f0df176eb0a3baccdf53c7e6e5fc807e5e7ea9e138ad7159f500d9 +size 1715352 diff --git a/data/data00000015.h5 b/data/data00000015.h5 new file mode 100755 index 0000000000000000000000000000000000000000..2112d04990e9d6facec890933d59b47e2f2c8053 --- /dev/null +++ b/data/data00000015.h5 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9aab522f9ace1a03b1705fe3bd693b589971d133d45c232ef9ec53842a540bfa +size 2647026 diff --git a/data/type_vocabs_coarse.json b/data/type_vocabs_coarse.json new file mode 100644 index 0000000000000000000000000000000000000000..91544d8abba1f3aae15fb12f5a64559bd83cab29 --- /dev/null +++ b/data/type_vocabs_coarse.json @@ -0,0 +1 @@ +{"class": {"Basket": 0, "BeerBottle": 1, "Book": 2, "Bottle": 3, "Bowl": 4, "Calculator": 5, "Candle": 6, "CellPhone": 7, "ComputerMouse": 8, "Controller": 9, "Cup": 10, "Donut": 11, "Fork": 12, "Hammer": 13, "Knife": 14, "Marker": 15, "MilkCarton": 16, "Mug": 17, "Pan": 18, "Pen": 19, "PillBottle": 20, "Plate": 21, "PowerStrip": 22, "Scissors": 23, "SoapBottle": 24, "SodaCan": 25, "Spoon": 26, "Stapler": 27, "Teapot": 28, "VideoGameController": 29, "WineBottle": 30, "CanOpener":31, "Fruit": 32}, "scene": {"dinner": 0}, "size": {"L": 0, "M": 1, "S": 2}, "color": {"blue": 0, "cyan": 1, "green": 2, "magenta": 3, "red": 4, "yellow": 5}, "material": {"glass": 0, "metal": 1, "plastic": 2}, "comparator": {"less": 1, "greater": 2, "equal": 3}, "radius": [0.0, 0.5, 3], "position_x": [-0.1, 1.0, 3], "position_y": [-0.5, 0.5, 3], "rotation": [-3.15, 3.15, 4], "height": [0.0, 0.5, 10], "volumn": [0.0, 0.015, 10], "uniform_angle": {"False": 0, "True": 1}, "face_center": {"False": 0, "True": 1}, "angle_ratio": {"0.5": 0, "1.0": 1}, "shape": {"circle": 0, "line": 1, "tower": 2, "dinner": 3}, "obj_x": [-1.0, 1.0, 200], "obj_y": [-1.0, 1.0, 200], "obj_z": [-1.0, 1.0, 200], "obj_rr": [-3.15, 3.15, 360], "obj_rp": [-3.15, 3.15, 360], "obj_ry": [-3.15, 3.15, 360],"struct_x": [-1.0, 1.0, 200], "struct_y": [-1.0, 1.0, 200], "struct_z": [-1.0, 1.0, 200], "struct_rr": [-3.15, 3.15, 360], "struct_rp": [-3.15, 3.15, 360], "struct_ry": [-3.15, 3.15, 360]} \ No newline at end of file diff --git a/packages.txt b/packages.txt new file mode 100644 index 0000000000000000000000000000000000000000..d4458ef18f1d7c83facf3d8f9653b23a58947437 --- /dev/null +++ b/packages.txt @@ -0,0 +1 @@ +python3-opencv \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..3d14e06166201ca5f642f6b1b7c626aef6ea86aa --- /dev/null +++ b/requirements.txt @@ -0,0 +1,13 @@ +numpy==1.21 +h5py==2.10.0 +opencv-python +open3d +trimesh==3.10.2 +pyglet==1.5.0 +pybullet==3.1.7 +nvisii==1.1.70 +openpyxl +pytorch_lightning==1.6.1 +wandb===0.13.10 +pytorch3d==0.3.0 +omegaconf==2.2.2 \ No newline at end of file diff --git a/scripts/infer.py b/scripts/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..7a332032715d1fe655876d9a36d588a1380f75d9 --- /dev/null +++ b/scripts/infer.py @@ -0,0 +1,78 @@ +import os +import argparse +import torch +import numpy as np +import pytorch_lightning as pl +from omegaconf import OmegaConf + +from StructDiffusion.data.semantic_arrangement import SemanticArrangementDataset +from StructDiffusion.language.tokenizer import Tokenizer +from StructDiffusion.models.pl_models import ConditionalPoseDiffusionModel +from StructDiffusion.diffusion.sampler import Sampler +from StructDiffusion.diffusion.pose_conversion import get_struct_objs_poses +from StructDiffusion.utils.files import get_checkpoint_path_from_dir +from StructDiffusion.utils.batch_inference import move_pc_and_create_scene_simple, visualize_batch_pcs + + +def main(args, cfg): + + pl.seed_everything(args.eval_random_seed) + + device = (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")) + + checkpoint_dir = os.path.join(cfg.WANDB.save_dir, cfg.WANDB.project, args.checkpoint_id, "checkpoints") + checkpoint_path = get_checkpoint_path_from_dir(checkpoint_dir) + + if args.eval_mode == "infer": + + tokenizer = Tokenizer(cfg.DATASET.vocab_dir) + # override ignore_rgb for visualization + cfg.DATASET.ignore_rgb = False + dataset = SemanticArrangementDataset(split="test", tokenizer=tokenizer, **cfg.DATASET) + + sampler = Sampler(ConditionalPoseDiffusionModel, checkpoint_path, device) + + data_idxs = np.random.permutation(len(dataset)) + for di in data_idxs: + raw_datum = dataset.get_raw_data(di) + print(tokenizer.convert_structure_params_to_natural_language(raw_datum["sentence"])) + datum = dataset.convert_to_tensors(raw_datum, tokenizer) + batch = dataset.single_datum_to_batch(datum, args.num_samples, device, inference_mode=True) + + num_poses = datum["goal_poses"].shape[0] + xs = sampler.sample(batch, num_poses) + + struct_pose, pc_poses_in_struct = get_struct_objs_poses(xs[0]) + new_obj_xyzs = move_pc_and_create_scene_simple(batch["pcs"], struct_pose, pc_poses_in_struct) + visualize_batch_pcs(new_obj_xyzs, args.num_samples, limit_B=10, trimesh=True) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="infer") + parser.add_argument("--base_config_file", help='base config yaml file', + default='../configs/base.yaml', + type=str) + parser.add_argument("--config_file", help='config yaml file', + default='../configs/conditional_pose_diffusion.yaml', + type=str) + parser.add_argument("--checkpoint_id", + default="ConditionalPoseDiffusion", + type=str) + parser.add_argument("--eval_mode", + default="infer", + type=str) + parser.add_argument("--eval_random_seed", + default=42, + type=int) + parser.add_argument("--num_samples", + default=10, + type=int) + args = parser.parse_args() + + base_cfg = OmegaConf.load(args.base_config_file) + cfg = OmegaConf.load(args.config_file) + cfg = OmegaConf.merge(base_cfg, cfg) + + main(args, cfg) + + diff --git a/scripts/infer_with_discriminator.py b/scripts/infer_with_discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..3452eb7fd2acd90d4482dd388f926bbd318b3515 --- /dev/null +++ b/scripts/infer_with_discriminator.py @@ -0,0 +1,81 @@ +import os +import argparse +import torch +import numpy as np +import pytorch_lightning as pl +from omegaconf import OmegaConf + +from StructDiffusion.data.semantic_arrangement import SemanticArrangementDataset +from StructDiffusion.language.tokenizer import Tokenizer +from StructDiffusion.models.pl_models import ConditionalPoseDiffusionModel, PairwiseCollisionModel +from StructDiffusion.diffusion.sampler import SamplerV2 +from StructDiffusion.diffusion.pose_conversion import get_struct_objs_poses +from StructDiffusion.utils.files import get_checkpoint_path_from_dir +from StructDiffusion.utils.batch_inference import move_pc_and_create_scene_simple, visualize_batch_pcs + + +def main(args, cfg): + + pl.seed_everything(args.eval_random_seed) + + device = (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")) + + diffusion_checkpoint_path = get_checkpoint_path_from_dir(os.path.join(cfg.WANDB.save_dir, cfg.WANDB.project, args.diffusion_checkpoint_id, "checkpoints")) + collision_checkpoint_path = get_checkpoint_path_from_dir(os.path.join(cfg.WANDB.save_dir, cfg.WANDB.project, args.collision_checkpoint_id, "checkpoints")) + + if args.eval_mode == "infer": + + tokenizer = Tokenizer(cfg.DATASET.vocab_dir) + # override ignore_rgb for visualization + cfg.DATASET.ignore_rgb = False + dataset = SemanticArrangementDataset(split="test", tokenizer=tokenizer, **cfg.DATASET) + + sampler = SamplerV2(ConditionalPoseDiffusionModel, diffusion_checkpoint_path, + PairwiseCollisionModel, collision_checkpoint_path, device) + + data_idxs = np.random.permutation(len(dataset)) + for di in data_idxs: + raw_datum = dataset.get_raw_data(di) + print(tokenizer.convert_structure_params_to_natural_language(raw_datum["sentence"])) + datum = dataset.convert_to_tensors(raw_datum, tokenizer) + batch = dataset.single_datum_to_batch(datum, args.num_samples, device, inference_mode=True) + + num_poses = datum["goal_poses"].shape[0] + struct_pose, pc_poses_in_struct = sampler.sample(batch, num_poses) + + new_obj_xyzs = move_pc_and_create_scene_simple(batch["pcs"], struct_pose, pc_poses_in_struct) + visualize_batch_pcs(new_obj_xyzs, args.num_samples, limit_B=10, trimesh=True) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="infer") + parser.add_argument("--base_config_file", help='base config yaml file', + default='../configs/base.yaml', + type=str) + parser.add_argument("--config_file", help='config yaml file', + default='../configs/conditional_pose_diffusion.yaml', + type=str) + parser.add_argument("--diffusion_checkpoint_id", + default="ConditionalPoseDiffusion", + type=str) + parser.add_argument("--collision_checkpoint_id", + default="curhl56k", + type=str) + parser.add_argument("--eval_mode", + default="infer", + type=str) + parser.add_argument("--eval_random_seed", + default=42, + type=int) + parser.add_argument("--num_samples", + default=10, + type=int) + args = parser.parse_args() + + base_cfg = OmegaConf.load(args.base_config_file) + cfg = OmegaConf.load(args.config_file) + cfg = OmegaConf.merge(base_cfg, cfg) + + main(args, cfg) + + diff --git a/scripts/train_discriminator.py b/scripts/train_discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..aa5a476c37b2c8c2c0e4192107800cd5986f6124 --- /dev/null +++ b/scripts/train_discriminator.py @@ -0,0 +1,46 @@ +import argparse +import torch +from torch.utils.data import DataLoader +from omegaconf import OmegaConf +import pytorch_lightning as pl +from pytorch_lightning.loggers import WandbLogger +from pytorch_lightning.callbacks import ModelCheckpoint + +from StructDiffusion.data.pairwise_collision import PairwiseCollisionDataset +from StructDiffusion.models.pl_models import PairwiseCollisionModel + + +def main(cfg): + + pl.seed_everything(cfg.random_seed) + + wandb_logger = WandbLogger(**cfg.WANDB) + wandb_logger.experiment.config.update(cfg) + checkpoint_callback = ModelCheckpoint() + + full_dataset = PairwiseCollisionDataset(**cfg.DATASET) + train_dataset, valid_dataset = torch.utils.data.random_split(full_dataset, [int(len(full_dataset) * 0.7), len(full_dataset) - int(len(full_dataset) * 0.7)]) + train_dataloader = DataLoader(train_dataset, shuffle=True, **cfg.DATALOADER) + valid_dataloader = DataLoader(valid_dataset, shuffle=False, **cfg.DATALOADER) + + model = PairwiseCollisionModel(cfg.MODEL, cfg.LOSS, cfg.OPTIMIZER, cfg.DATASET) + + trainer = pl.Trainer(logger=wandb_logger, callbacks=[checkpoint_callback], **cfg.TRAINER) + + trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=valid_dataloader) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="train") + parser.add_argument("--base_config_file", help='base config yaml file', + default='../configs/base.yaml', + type=str) + parser.add_argument("--config_file", help='config yaml file', + default='../configs/pairwise_collision.yaml', + type=str) + args = parser.parse_args() + base_cfg = OmegaConf.load(args.base_config_file) + cfg = OmegaConf.load(args.config_file) + cfg = OmegaConf.merge(base_cfg, cfg) + + main(cfg) \ No newline at end of file diff --git a/scripts/train_generator.py b/scripts/train_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..8131a9f2626b0f84bc91e1d76250a4e5819f8785 --- /dev/null +++ b/scripts/train_generator.py @@ -0,0 +1,49 @@ +from torch.utils.data import DataLoader +import argparse +from omegaconf import OmegaConf +import pytorch_lightning as pl +from pytorch_lightning.loggers import WandbLogger +from pytorch_lightning.callbacks import ModelCheckpoint + +from StructDiffusion.data.semantic_arrangement import SemanticArrangementDataset +from StructDiffusion.language.tokenizer import Tokenizer +from StructDiffusion.models.pl_models import ConditionalPoseDiffusionModel + + +def main(cfg): + + pl.seed_everything(cfg.random_seed) + + wandb_logger = WandbLogger(**cfg.WANDB) + wandb_logger.experiment.config.update(cfg) + checkpoint_callback = ModelCheckpoint() + + tokenizer = Tokenizer(cfg.DATASET.vocab_dir) + vocab_size = tokenizer.get_vocab_size() + + train_dataset = SemanticArrangementDataset(split="train", tokenizer=tokenizer, **cfg.DATASET) + valid_dataset = SemanticArrangementDataset(split="valid", tokenizer=tokenizer, **cfg.DATASET) + train_dataloader = DataLoader(train_dataset, shuffle=True, **cfg.DATALOADER) + valid_dataloader = DataLoader(valid_dataset, shuffle=False, **cfg.DATALOADER) + + model = ConditionalPoseDiffusionModel(vocab_size, cfg.MODEL, cfg.LOSS, cfg.NOISE_SCHEDULE, cfg.OPTIMIZER) + + trainer = pl.Trainer(logger=wandb_logger, callbacks=[checkpoint_callback], **cfg.TRAINER) + + trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=valid_dataloader) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="train") + parser.add_argument("--base_config_file", help='base config yaml file', + default='../configs/base.yaml', + type=str) + parser.add_argument("--config_file", help='config yaml file', + default='../configs/conditional_pose_diffusion.yaml', + type=str) + args = parser.parse_args() + base_cfg = OmegaConf.load(args.base_config_file) + cfg = OmegaConf.load(args.config_file) + cfg = OmegaConf.merge(base_cfg, cfg) + + main(cfg) \ No newline at end of file diff --git a/src/StructDiffusion/__init__.py b/src/StructDiffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/StructDiffusion/__pycache__/__init__.cpython-37.pyc b/src/StructDiffusion/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..73b43dd4131a1a5fe63caa5e0cc1e02ee7a25d93 Binary files /dev/null and b/src/StructDiffusion/__pycache__/__init__.cpython-37.pyc differ diff --git a/src/StructDiffusion/__pycache__/__init__.cpython-38.pyc b/src/StructDiffusion/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..962a4af7693c5c3f4471bf199241b94961d2a085 Binary files /dev/null and b/src/StructDiffusion/__pycache__/__init__.cpython-38.pyc differ diff --git a/src/StructDiffusion/data/__init__.py b/src/StructDiffusion/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/StructDiffusion/data/__pycache__/__init__.cpython-37.pyc b/src/StructDiffusion/data/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49140fb3e29de7aa109cdcf66678a0a1e2f63717 Binary files /dev/null and b/src/StructDiffusion/data/__pycache__/__init__.cpython-37.pyc differ diff --git a/src/StructDiffusion/data/__pycache__/__init__.cpython-38.pyc b/src/StructDiffusion/data/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..196f5127bc03b7b7fdba37a1dc8f835363f6266f Binary files /dev/null and b/src/StructDiffusion/data/__pycache__/__init__.cpython-38.pyc differ diff --git a/src/StructDiffusion/data/__pycache__/pairwise_collision.cpython-37.pyc b/src/StructDiffusion/data/__pycache__/pairwise_collision.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a858b5344ce7e72d040d4a4ec23d3640534dedf Binary files /dev/null and b/src/StructDiffusion/data/__pycache__/pairwise_collision.cpython-37.pyc differ diff --git a/src/StructDiffusion/data/__pycache__/semantic_arrangement.cpython-37.pyc b/src/StructDiffusion/data/__pycache__/semantic_arrangement.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4da8facf199b17a7cf2d255dd2bacf22774df8a1 Binary files /dev/null and b/src/StructDiffusion/data/__pycache__/semantic_arrangement.cpython-37.pyc differ diff --git a/src/StructDiffusion/data/__pycache__/semantic_arrangement.cpython-38.pyc b/src/StructDiffusion/data/__pycache__/semantic_arrangement.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62d26963f2419f17e4dbb98fa821f224df59a9d6 Binary files /dev/null and b/src/StructDiffusion/data/__pycache__/semantic_arrangement.cpython-38.pyc differ diff --git a/src/StructDiffusion/data/__pycache__/semantic_arrangement_demo.cpython-38.pyc b/src/StructDiffusion/data/__pycache__/semantic_arrangement_demo.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ae89fd0a7d26e8c837fd3d41b8e8f118d2c12ec Binary files /dev/null and b/src/StructDiffusion/data/__pycache__/semantic_arrangement_demo.cpython-38.pyc differ diff --git a/src/StructDiffusion/data/pairwise_collision.py b/src/StructDiffusion/data/pairwise_collision.py new file mode 100644 index 0000000000000000000000000000000000000000..2519d16fce3fe0df4a38b02e24ce174cfef09c25 --- /dev/null +++ b/src/StructDiffusion/data/pairwise_collision.py @@ -0,0 +1,361 @@ +import cv2 +import h5py +import numpy as np +import os +import trimesh +import torch +import json +from collections import defaultdict +import tqdm +import pickle +from random import shuffle + +# Local imports +from StructDiffusion.utils.rearrangement import show_pcs, get_pts, array_to_tensor +from StructDiffusion.utils.pointnet import pc_normalize + +import StructDiffusion.utils.brain2.camera as cam +import StructDiffusion.utils.brain2.image as img +import StructDiffusion.utils.transformations as tra + + +def load_pairwise_collision_data(h5_filename): + + fh = h5py.File(h5_filename, 'r') + data_dict = {} + data_dict["obj1_info"] = eval(fh["obj1_info"][()]) + data_dict["obj2_info"] = eval(fh["obj2_info"][()]) + data_dict["obj1_poses"] = fh["obj1_poses"][:] + data_dict["obj2_poses"] = fh["obj2_poses"][:] + data_dict["intersection_labels"] = fh["intersection_labels"][:] + + return data_dict + + +class PairwiseCollisionDataset(torch.utils.data.Dataset): + + def __init__(self, urdf_pc_idx_file, collision_data_dir, random_rotation=True, + num_pts=1024, normalize_pc=True, num_scene_pts=2048, data_augmentation=False, + debug=False): + + # load dictionary mapping from urdf to list of pc data, each sample is + # {"step_t": step_t, "obj": obj, "filename": filename} + with open(urdf_pc_idx_file, "rb") as fh: + self.urdf_to_pc_data = pickle.load(fh) + # filter out broken files + for urdf in self.urdf_to_pc_data: + valid_pc_data = [] + for pd in self.urdf_to_pc_data[urdf]: + filename = pd["filename"] + if "data00026058" in filename or "data00011415" in filename or "data00026061" in filename or "data00700565" in filename or "data00505290" in filename: + continue + valid_pc_data.append(pd) + if valid_pc_data: + self.urdf_to_pc_data[urdf] = valid_pc_data + + # build data index + # each sample is a tuple of (collision filename, idx for the labels and poses) + if collision_data_dir is not None: + self.data_idxs = self.build_data_idxs(collision_data_dir) + else: + print("WARNING: collision_data_dir is None") + + self.num_pts = num_pts + self.debug = debug + self.normalize_pc = normalize_pc + self.num_scene_pts = num_scene_pts + self.random_rotation = random_rotation + + # Noise + self.data_augmentation = data_augmentation + # additive noise + self.gp_rescale_factor_range = [12, 20] + self.gaussian_scale_range = [0., 0.003] + # multiplicative noise + self.gamma_shape = 1000. + self.gamma_scale = 0.001 + + def build_data_idxs(self, collision_data_dir): + print("Load collision data...") + positive_data = [] + negative_data = [] + for filename in tqdm.tqdm(os.listdir(collision_data_dir)): + if "h5" not in filename: + continue + h5_filename = os.path.join(collision_data_dir, filename) + data_dict = load_pairwise_collision_data(h5_filename) + obj1_urdf = data_dict["obj1_info"]["urdf"] + obj2_urdf = data_dict["obj2_info"]["urdf"] + if obj1_urdf not in self.urdf_to_pc_data: + print("no pc data for urdf:", obj1_urdf) + continue + if obj2_urdf not in self.urdf_to_pc_data: + print("no pc data for urdf:", obj2_urdf) + continue + for idx, l in enumerate(data_dict["intersection_labels"]): + if l: + # intersection + positive_data.append((h5_filename, idx)) + else: + negative_data.append((h5_filename, idx)) + print("Num pairwise intersections:", len(positive_data)) + print("Num pairwise no intersections:", len(negative_data)) + + if len(negative_data) != len(positive_data): + min_len = min(len(negative_data), len(positive_data)) + positive_data = [positive_data[i] for i in np.random.permutation(len(positive_data))[:min_len]] + negative_data = [negative_data[i] for i in np.random.permutation(len(negative_data))[:min_len]] + print("after balancing") + print("Num pairwise intersections:", len(positive_data)) + print("Num pairwise no intersections:", len(negative_data)) + + return positive_data + negative_data + + def create_urdf_pc_idxs(self, urdf_pc_idx_file, data_roots, index_roots): + print("Load pc data") + arrangement_steps = [] + for split in ["train"]: + for data_root, index_root in zip(data_roots, index_roots): + arrangement_indices_file = os.path.join(data_root, index_root,"{}_arrangement_indices_file_all.txt".format(split)) + if os.path.exists(arrangement_indices_file): + with open(arrangement_indices_file, "r") as fh: + arrangement_steps.extend([(os.path.join(data_root, f[0]), f[1]) for f in eval(fh.readline().strip())]) + else: + print("{} does not exist".format(arrangement_indices_file)) + + urdf_to_pc_data = defaultdict(list) + for filename, step_t in tqdm.tqdm(arrangement_steps): + h5 = h5py.File(filename, 'r') + ids = self._get_ids(h5) + # moved_objs = h5['moved_objs'][()].split(',') + all_objs = sorted([o for o in ids.keys() if "object_" in o]) + goal_specification = json.loads(str(np.array(h5["goal_specification"]))) + obj_infos = goal_specification["rearrange"]["objects"] + goal_specification["anchor"]["objects"] + goal_specification["distract"]["objects"] + for obj, obj_info in zip(all_objs, obj_infos): + urdf_to_pc_data[obj_info["urdf"]].append({"step_t": step_t, "obj": obj, "filename": filename}) + + with open(urdf_pc_idx_file, "wb") as fh: + pickle.dump(urdf_to_pc_data, fh) + + return urdf_to_pc_data + + def add_noise_to_depth(self, depth_img): + """ add depth noise """ + multiplicative_noise = np.random.gamma(self.gamma_shape, self.gamma_scale) + depth_img = multiplicative_noise * depth_img + return depth_img + + def add_noise_to_xyz(self, xyz_img, depth_img): + """ TODO: remove this code or at least celean it up""" + xyz_img = xyz_img.copy() + H, W, C = xyz_img.shape + gp_rescale_factor = np.random.randint(self.gp_rescale_factor_range[0], + self.gp_rescale_factor_range[1]) + gp_scale = np.random.uniform(self.gaussian_scale_range[0], + self.gaussian_scale_range[1]) + small_H, small_W = (np.array([H, W]) / gp_rescale_factor).astype(int) + additive_noise = np.random.normal(loc=0.0, scale=gp_scale, size=(small_H, small_W, C)) + additive_noise = cv2.resize(additive_noise, (W, H), interpolation=cv2.INTER_CUBIC) + xyz_img[depth_img > 0, :] += additive_noise[depth_img > 0, :] + return xyz_img + + def _get_images(self, h5, idx, ee=True): + if ee: + RGB, DEPTH, SEG = "ee_rgb", "ee_depth", "ee_seg" + DMIN, DMAX = "ee_depth_min", "ee_depth_max" + else: + RGB, DEPTH, SEG = "rgb", "depth", "seg" + DMIN, DMAX = "depth_min", "depth_max" + dmin = h5[DMIN][idx] + dmax = h5[DMAX][idx] + rgb1 = img.PNGToNumpy(h5[RGB][idx])[:, :, :3] / 255. # remove alpha + depth1 = h5[DEPTH][idx] / 20000. * (dmax - dmin) + dmin + seg1 = img.PNGToNumpy(h5[SEG][idx]) + + valid1 = np.logical_and(depth1 > 0.1, depth1 < 2.) + + # proj_matrix = h5['proj_matrix'][()] + camera = cam.get_camera_from_h5(h5) + if self.data_augmentation: + depth1 = self.add_noise_to_depth(depth1) + + xyz1 = cam.compute_xyz(depth1, camera) + if self.data_augmentation: + xyz1 = self.add_noise_to_xyz(xyz1, depth1) + + # Transform the point cloud + # Here it is... + # CAM_POSE = "ee_cam_pose" if ee else "cam_pose" + CAM_POSE = "ee_camera_view" if ee else "camera_view" + cam_pose = h5[CAM_POSE][idx] + if ee: + # ee_camera_view has 0s for x, y, z + cam_pos = h5["ee_cam_pose"][:][:3, 3] + cam_pose[:3, 3] = cam_pos + + # Get transformed point cloud + h, w, d = xyz1.shape + xyz1 = xyz1.reshape(h * w, -1) + xyz1 = trimesh.transform_points(xyz1, cam_pose) + xyz1 = xyz1.reshape(h, w, -1) + + scene1 = rgb1, depth1, seg1, valid1, xyz1 + + return scene1 + + def _get_ids(self, h5): + """ + get object ids + + @param h5: + @return: + """ + ids = {} + for k in h5.keys(): + if k.startswith("id_"): + ids[k[3:]] = h5[k][()] + return ids + + def get_obj_pc(self, h5, step_t, obj): + scene = self._get_images(h5, step_t, ee=True) + rgb, depth, seg, valid, xyz = scene + + # getting object point clouds + ids = self._get_ids(h5) + obj_mask = np.logical_and(seg == ids[obj], valid) + if np.sum(obj_mask) <= 0: + raise Exception + ok, obj_xyz, obj_rgb, _ = get_pts(xyz, rgb, obj_mask, num_pts=self.num_pts, to_tensor=False) + obj_pc_center = np.mean(obj_xyz, axis=0) + obj_pose = h5[obj][step_t] + + obj_pc_pose = np.eye(4) + obj_pc_pose[:3, 3] = obj_pc_center[:3] + + return obj_xyz, obj_rgb, obj_pc_pose, obj_pose + + def __len__(self): + return len(self.data_idxs) + + def __getitem__(self, idx): + collision_filename, collision_idx = self.data_idxs[idx] + collision_data_dict = load_pairwise_collision_data(collision_filename) + + obj1_urdf = collision_data_dict["obj1_info"]["urdf"] + obj2_urdf = collision_data_dict["obj2_info"]["urdf"] + + # TODO: find a better way to sample pc data? + obj1_pc_data = np.random.choice(self.urdf_to_pc_data[obj1_urdf]) + obj2_pc_data = np.random.choice(self.urdf_to_pc_data[obj2_urdf]) + + obj1_xyz, obj1_rgb, obj1_pc_pose, obj1_pose = self.get_obj_pc(h5py.File(obj1_pc_data["filename"], "r"), obj1_pc_data["step_t"], obj1_pc_data["obj"]) + obj2_xyz, obj2_rgb, obj2_pc_pose, obj2_pose = self.get_obj_pc(h5py.File(obj2_pc_data["filename"], "r"), obj2_pc_data["step_t"], obj2_pc_data["obj"]) + + obj1_c_pose = collision_data_dict["obj1_poses"][collision_idx] + obj2_c_pose = collision_data_dict["obj2_poses"][collision_idx] + label = collision_data_dict["intersection_labels"][collision_idx] + + obj1_transform = obj1_c_pose @ np.linalg.inv(obj1_pose) + obj2_transform = obj2_c_pose @ np.linalg.inv(obj2_pose) + obj1_c_xyz = trimesh.transform_points(obj1_xyz, obj1_transform) + obj2_c_xyz = trimesh.transform_points(obj2_xyz, obj2_transform) + + # if self.debug: + # show_pcs([obj1_c_xyz, obj2_c_xyz], [obj1_rgb, obj2_rgb], add_coordinate_frame=True) + + ################################### + obj_xyzs = [obj1_c_xyz, obj2_c_xyz] + shuffle(obj_xyzs) + + num_indicator = 2 + new_obj_xyzs = [] + for oi, obj_xyz in enumerate(obj_xyzs): + obj_xyz = np.concatenate([obj_xyz, np.tile(np.eye(num_indicator)[oi], (obj_xyz.shape[0], 1))], axis=1) + new_obj_xyzs.append(obj_xyz) + scene_xyz = np.concatenate(new_obj_xyzs, axis=0) + + # subsampling and normalizing pc + idx = np.random.randint(0, scene_xyz.shape[0], self.num_scene_pts) + scene_xyz = scene_xyz[idx] + if self.normalize_pc: + scene_xyz[:, 0:3] = pc_normalize(scene_xyz[:, 0:3]) + + if self.random_rotation: + scene_xyz[:, 0:3] = trimesh.transform_points(scene_xyz[:, 0:3], tra.euler_matrix(0, 0, np.random.uniform(low=0, high=2 * np.pi))) + + ################################### + scene_xyz = array_to_tensor(scene_xyz) + # convert to torch data + label = int(label) + + if self.debug: + print("intersection:", label) + show_pcs([scene_xyz[:, 0:3]], [np.tile(np.array([0, 1, 0], dtype=np.float), (scene_xyz.shape[0], 1))], add_coordinate_frame=True) + + datum = { + "scene_xyz": scene_xyz, + "label": torch.FloatTensor([label]), + } + return datum + + # @staticmethod + # def collate_fn(data): + # """ + # :param data: + # :return: + # """ + # + # batched_data_dict = {} + # for key in ["is_circle"]: + # batched_data_dict[key] = torch.cat([dict[key] for dict in data], dim=0) + # for key in ["scene_xyz"]: + # batched_data_dict[key] = torch.stack([dict[key] for dict in data], dim=0) + # + # return batched_data_dict + # + # # def create_pair_xyzs_from_obj_xyzs(self, new_obj_xyzs, debug=False): + # # + # # new_obj_xyzs = [xyz.cpu().numpy() for xyz in new_obj_xyzs] + # # + # # # compute pairwise collision + # # scene_xyzs = [] + # # obj_xyz_pair_idxs = list(itertools.combinations(range(len(new_obj_xyzs)), 2)) + # # + # # for obj_xyz_pair_idx in obj_xyz_pair_idxs: + # # obj_xyz_pair = [new_obj_xyzs[obj_xyz_pair_idx[0]], new_obj_xyzs[obj_xyz_pair_idx[1]]] + # # num_indicator = 2 + # # obj_xyz_pair_ind = [] + # # for oi, obj_xyz in enumerate(obj_xyz_pair): + # # obj_xyz = np.concatenate([obj_xyz, np.tile(np.eye(num_indicator)[oi], (obj_xyz.shape[0], 1))], axis=1) + # # obj_xyz_pair_ind.append(obj_xyz) + # # pair_scene_xyz = np.concatenate(obj_xyz_pair_ind, axis=0) + # # + # # # subsampling and normalizing pc + # # rand_idx = np.random.randint(0, pair_scene_xyz.shape[0], self.num_scene_pts) + # # pair_scene_xyz = pair_scene_xyz[rand_idx] + # # if self.normalize_pc: + # # pair_scene_xyz[:, 0:3] = pc_normalize(pair_scene_xyz[:, 0:3]) + # # + # # scene_xyzs.append(array_to_tensor(pair_scene_xyz)) + # # + # # if debug: + # # for scene_xyz in scene_xyzs: + # # show_pcs([scene_xyz[:, 0:3]], [np.tile(np.array([0, 1, 0], dtype=np.float), (scene_xyz.shape[0], 1))], + # # add_coordinate_frame=True) + # # + # # return scene_xyzs + + +if __name__ == "__main__": + dataset = PairwiseCollisionDataset(urdf_pc_idx_file="/home/weiyu/data_drive/StructDiffusion/pairwise_collision_data/urdf_pc_idx.pkl", + collision_data_dir="/home/weiyu/data_drive/StructDiffusion/pairwise_collision_data", + debug=False) + + for i in tqdm.tqdm(np.random.permutation(len(dataset))): + # print(i) + d = dataset[i] + # print(d["label"]) + + # dl = torch.utils.data.DataLoader(dataset, batch_size=32, num_workers=8) + # for b in tqdm.tqdm(dl): + # pass diff --git a/src/StructDiffusion/data/semantic_arrangement.py b/src/StructDiffusion/data/semantic_arrangement.py new file mode 100644 index 0000000000000000000000000000000000000000..7ac4ef12d7ebdb768b7af5b58d8263d171bb91bb --- /dev/null +++ b/src/StructDiffusion/data/semantic_arrangement.py @@ -0,0 +1,579 @@ +import copy +import cv2 +import h5py +import numpy as np +import os +import trimesh +import torch +from tqdm import tqdm +import json +import random + +from torch.utils.data import DataLoader + +# Local imports +from StructDiffusion.utils.rearrangement import show_pcs, get_pts, combine_and_sample_xyzs +from StructDiffusion.language.tokenizer import Tokenizer + +import StructDiffusion.utils.brain2.camera as cam +import StructDiffusion.utils.brain2.image as img +import StructDiffusion.utils.transformations as tra + + +class SemanticArrangementDataset(torch.utils.data.Dataset): + + def __init__(self, data_roots, index_roots, split, tokenizer, + max_num_target_objects=11, max_num_distractor_objects=5, + max_num_shape_parameters=7, max_num_rearrange_features=1, max_num_anchor_features=3, + num_pts=1024, + use_virtual_structure_frame=True, ignore_distractor_objects=True, ignore_rgb=True, + filter_num_moved_objects_range=None, shuffle_object_index=False, + data_augmentation=True, debug=False, **kwargs): + """ + + Note: setting filter_num_moved_objects_range=[k, k] and max_num_objects=k will create no padding for target objs + + :param data_root: + :param split: train, valid, or test + :param shuffle_object_index: whether to shuffle the positions of target objects and other objects in the sequence + :param debug: + :param max_num_shape_parameters: + :param max_num_objects: + :param max_num_rearrange_features: + :param max_num_anchor_features: + :param num_pts: + :param use_stored_arrangement_indices: + :param kwargs: + """ + + self.use_virtual_structure_frame = use_virtual_structure_frame + self.ignore_distractor_objects = ignore_distractor_objects + self.ignore_rgb = ignore_rgb and not debug + + self.num_pts = num_pts + self.debug = debug + + self.max_num_objects = max_num_target_objects + self.max_num_other_objects = max_num_distractor_objects + self.max_num_shape_parameters = max_num_shape_parameters + self.max_num_rearrange_features = max_num_rearrange_features + self.max_num_anchor_features = max_num_anchor_features + self.shuffle_object_index = shuffle_object_index + + # used to tokenize the language part + self.tokenizer = tokenizer + + # retrieve data + self.data_roots = data_roots + self.arrangement_data = [] + arrangement_steps = [] + for ddx in range(len(data_roots)): + data_root = data_roots[ddx] + index_root = index_roots[ddx] + arrangement_indices_file = os.path.join(data_root, index_root, "{}_arrangement_indices_file_all.txt".format(split)) + if os.path.exists(arrangement_indices_file): + with open(arrangement_indices_file, "r") as fh: + arrangement_steps.extend([(os.path.join(data_root, f[0]), f[1]) for f in eval(fh.readline().strip())]) + else: + print("{} does not exist".format(arrangement_indices_file)) + # only keep the goal, ignore the intermediate steps + for filename, step_t in arrangement_steps: + if step_t == 0: + if "data00026058" in filename or "data00011415" in filename or "data00026061" in filename or "data00700565" in filename: + continue + self.arrangement_data.append((filename, step_t)) + # if specified, filter data + if filter_num_moved_objects_range is not None: + self.arrangement_data = self.filter_based_on_number_of_moved_objects(filter_num_moved_objects_range) + print("{} valid sequences".format(len(self.arrangement_data))) + + # Data Aug + self.data_augmentation = data_augmentation + # additive noise + self.gp_rescale_factor_range = [12, 20] + self.gaussian_scale_range = [0., 0.003] + # multiplicative noise + self.gamma_shape = 1000. + self.gamma_scale = 0.001 + + def filter_based_on_number_of_moved_objects(self, filter_num_moved_objects_range): + assert len(list(filter_num_moved_objects_range)) == 2 + min_num, max_num = filter_num_moved_objects_range + print("Remove scenes that have less than {} or more than {} objects being moved".format(min_num, max_num)) + ok_data = [] + for filename, step_t in self.arrangement_data: + h5 = h5py.File(filename, 'r') + moved_objs = h5['moved_objs'][()].split(',') + if min_num <= len(moved_objs) <= max_num: + ok_data.append((filename, step_t)) + print("{} valid sequences left".format(len(ok_data))) + return ok_data + + def get_data_idx(self, idx): + # Create the datum to return + file_idx = np.argmax(idx < self.file_to_count) + data = h5py.File(self.data_files[file_idx], 'r') + if file_idx > 0: + # for lang2sym, idx is always 0 + idx = idx - self.file_to_count[file_idx - 1] + return data, idx, file_idx + + def add_noise_to_depth(self, depth_img): + """ add depth noise """ + multiplicative_noise = np.random.gamma(self.gamma_shape, self.gamma_scale) + depth_img = multiplicative_noise * depth_img + return depth_img + + def add_noise_to_xyz(self, xyz_img, depth_img): + """ TODO: remove this code or at least celean it up""" + xyz_img = xyz_img.copy() + H, W, C = xyz_img.shape + gp_rescale_factor = np.random.randint(self.gp_rescale_factor_range[0], + self.gp_rescale_factor_range[1]) + gp_scale = np.random.uniform(self.gaussian_scale_range[0], + self.gaussian_scale_range[1]) + small_H, small_W = (np.array([H, W]) / gp_rescale_factor).astype(int) + additive_noise = np.random.normal(loc=0.0, scale=gp_scale, size=(small_H, small_W, C)) + additive_noise = cv2.resize(additive_noise, (W, H), interpolation=cv2.INTER_CUBIC) + xyz_img[depth_img > 0, :] += additive_noise[depth_img > 0, :] + return xyz_img + + def random_index(self): + return self[np.random.randint(len(self))] + + def _get_rgb(self, h5, idx, ee=True): + RGB = "ee_rgb" if ee else "rgb" + rgb1 = img.PNGToNumpy(h5[RGB][idx])[:, :, :3] / 255. # remove alpha + return rgb1 + + def _get_depth(self, h5, idx, ee=True): + DEPTH = "ee_depth" if ee else "depth" + + def _get_images(self, h5, idx, ee=True): + if ee: + RGB, DEPTH, SEG = "ee_rgb", "ee_depth", "ee_seg" + DMIN, DMAX = "ee_depth_min", "ee_depth_max" + else: + RGB, DEPTH, SEG = "rgb", "depth", "seg" + DMIN, DMAX = "depth_min", "depth_max" + dmin = h5[DMIN][idx] + dmax = h5[DMAX][idx] + rgb1 = img.PNGToNumpy(h5[RGB][idx])[:, :, :3] / 255. # remove alpha + depth1 = h5[DEPTH][idx] / 20000. * (dmax - dmin) + dmin + seg1 = img.PNGToNumpy(h5[SEG][idx]) + + valid1 = np.logical_and(depth1 > 0.1, depth1 < 2.) + + # proj_matrix = h5['proj_matrix'][()] + camera = cam.get_camera_from_h5(h5) + if self.data_augmentation: + depth1 = self.add_noise_to_depth(depth1) + + xyz1 = cam.compute_xyz(depth1, camera) + if self.data_augmentation: + xyz1 = self.add_noise_to_xyz(xyz1, depth1) + + # Transform the point cloud + # Here it is... + # CAM_POSE = "ee_cam_pose" if ee else "cam_pose" + CAM_POSE = "ee_camera_view" if ee else "camera_view" + cam_pose = h5[CAM_POSE][idx] + if ee: + # ee_camera_view has 0s for x, y, z + cam_pos = h5["ee_cam_pose"][:][:3, 3] + cam_pose[:3, 3] = cam_pos + + # Get transformed point cloud + h, w, d = xyz1.shape + xyz1 = xyz1.reshape(h * w, -1) + xyz1 = trimesh.transform_points(xyz1, cam_pose) + xyz1 = xyz1.reshape(h, w, -1) + + scene1 = rgb1, depth1, seg1, valid1, xyz1 + + return scene1 + + def __len__(self): + return len(self.arrangement_data) + + def _get_ids(self, h5): + """ + get object ids + + @param h5: + @return: + """ + ids = {} + for k in h5.keys(): + if k.startswith("id_"): + ids[k[3:]] = h5[k][()] + return ids + + def get_positive_ratio(self): + num_pos = 0 + for d in self.arrangement_data: + filename, step_t = d + if step_t == 0: + num_pos += 1 + return (len(self.arrangement_data) - num_pos) * 1.0 / num_pos + + def get_object_position_vocab_sizes(self): + return self.tokenizer.get_object_position_vocab_sizes() + + def get_vocab_size(self): + return self.tokenizer.get_vocab_size() + + def get_data_index(self, idx): + filename = self.arrangement_data[idx] + return filename + + def get_raw_data(self, idx, inference_mode=False, shuffle_object_index=False): + """ + + :param idx: + :param inference_mode: + :param shuffle_object_index: used to test different orders of objects + :return: + """ + + filename, _ = self.arrangement_data[idx] + + h5 = h5py.File(filename, 'r') + ids = self._get_ids(h5) + all_objs = sorted([o for o in ids.keys() if "object_" in o]) + goal_specification = json.loads(str(np.array(h5["goal_specification"]))) + num_rearrange_objs = len(goal_specification["rearrange"]["objects"]) + num_other_objs = len(goal_specification["anchor"]["objects"] + goal_specification["distract"]["objects"]) + assert len(all_objs) == num_rearrange_objs + num_other_objs, "{}, {}".format(len(all_objs), num_rearrange_objs + num_other_objs) + assert num_rearrange_objs <= self.max_num_objects + assert num_other_objs <= self.max_num_other_objects + + # important: only using the last step + step_t = num_rearrange_objs + + target_objs = all_objs[:num_rearrange_objs] + other_objs = all_objs[num_rearrange_objs:] + + structure_parameters = goal_specification["shape"] + + # Important: ensure the order is correct + if structure_parameters["type"] == "circle" or structure_parameters["type"] == "line": + target_objs = target_objs[::-1] + elif structure_parameters["type"] == "tower" or structure_parameters["type"] == "dinner": + target_objs = target_objs + else: + raise KeyError("{} structure is not recognized".format(structure_parameters["type"])) + all_objs = target_objs + other_objs + + ################################### + # getting scene images and point clouds + scene = self._get_images(h5, step_t, ee=True) + rgb, depth, seg, valid, xyz = scene + if inference_mode: + initial_scene = scene + + # getting object point clouds + obj_pcs = [] + obj_pad_mask = [] + current_pc_poses = [] + other_obj_pcs = [] + other_obj_pad_mask = [] + for obj in all_objs: + obj_mask = np.logical_and(seg == ids[obj], valid) + if np.sum(obj_mask) <= 0: + raise Exception + ok, obj_xyz, obj_rgb, _ = get_pts(xyz, rgb, obj_mask, num_pts=self.num_pts) + if not ok: + raise Exception + + if obj in target_objs: + if self.ignore_rgb: + obj_pcs.append(obj_xyz) + else: + obj_pcs.append(torch.concat([obj_xyz, obj_rgb], dim=-1)) + obj_pad_mask.append(0) + pc_pose = np.eye(4) + pc_pose[:3, 3] = torch.mean(obj_xyz, dim=0).numpy() + current_pc_poses.append(pc_pose) + elif obj in other_objs: + if self.ignore_rgb: + other_obj_pcs.append(obj_xyz) + else: + other_obj_pcs.append(torch.concat([obj_xyz, obj_rgb], dim=-1)) + other_obj_pad_mask.append(0) + else: + raise Exception + + ################################### + # computes goal positions for objects + # Important: because of the noises we added to point clouds, the rearranged point clouds will not be perfect + if self.use_virtual_structure_frame: + goal_structure_pose = tra.euler_matrix(structure_parameters["rotation"][0], structure_parameters["rotation"][1], + structure_parameters["rotation"][2]) + goal_structure_pose[:3, 3] = [structure_parameters["position"][0], structure_parameters["position"][1], + structure_parameters["position"][2]] + goal_structure_pose_inv = np.linalg.inv(goal_structure_pose) + + goal_obj_poses = [] + current_obj_poses = [] + goal_pc_poses = [] + for obj, current_pc_pose in zip(target_objs, current_pc_poses): + goal_pose = h5[obj][0] + current_pose = h5[obj][step_t] + if inference_mode: + goal_obj_poses.append(goal_pose) + current_obj_poses.append(current_pose) + + goal_pc_pose = goal_pose @ np.linalg.inv(current_pose) @ current_pc_pose + if self.use_virtual_structure_frame: + goal_pc_pose = goal_structure_pose_inv @ goal_pc_pose + goal_pc_poses.append(goal_pc_pose) + + # transform current object point cloud to the goal point cloud in the world frame + if self.debug: + new_obj_pcs = [copy.deepcopy(pc.numpy()) for pc in obj_pcs] + for i, obj_pc in enumerate(new_obj_pcs): + + current_pc_pose = current_pc_poses[i] + goal_pc_pose = goal_pc_poses[i] + if self.use_virtual_structure_frame: + goal_pc_pose = goal_structure_pose @ goal_pc_pose + print("current pc pose", current_pc_pose) + print("goal pc pose", goal_pc_pose) + + goal_pc_transform = goal_pc_pose @ np.linalg.inv(current_pc_pose) + print("transform", goal_pc_transform) + new_obj_pc = copy.deepcopy(obj_pc) + new_obj_pc[:, :3] = trimesh.transform_points(obj_pc[:, :3], goal_pc_transform) + print(new_obj_pc.shape) + + # visualize rearrangement sequence (new_obj_xyzs), the current object before moving (obj_xyz), and other objects + new_obj_pcs[i] = new_obj_pc + new_obj_pcs[i][:, 3:] = np.tile(np.array([1, 0, 0], dtype=np.float), (new_obj_pc.shape[0], 1)) + new_obj_rgb_current = np.tile(np.array([0, 1, 0], dtype=np.float), (new_obj_pc.shape[0], 1)) + show_pcs([pc[:, :3] for pc in new_obj_pcs] + [pc[:, :3] for pc in other_obj_pcs] + [obj_pc[:, :3]], + [pc[:, 3:] for pc in new_obj_pcs] + [pc[:, 3:] for pc in other_obj_pcs] + [new_obj_rgb_current], + add_coordinate_frame=True) + show_pcs([pc[:, :3] for pc in new_obj_pcs], [pc[:, 3:] for pc in new_obj_pcs], add_coordinate_frame=True) + + # pad data + for i in range(self.max_num_objects - len(target_objs)): + obj_pcs.append(torch.zeros_like(obj_pcs[0], dtype=torch.float32)) + obj_pad_mask.append(1) + for i in range(self.max_num_other_objects - len(other_objs)): + other_obj_pcs.append(torch.zeros_like(obj_pcs[0], dtype=torch.float32)) + other_obj_pad_mask.append(1) + + ################################### + # preparing sentence + sentence = [] + sentence_pad_mask = [] + + # structure parameters + # 5 parameters + structure_parameters = goal_specification["shape"] + if structure_parameters["type"] == "circle" or structure_parameters["type"] == "line": + sentence.append((structure_parameters["type"], "shape")) + sentence.append((structure_parameters["rotation"][2], "rotation")) + sentence.append((structure_parameters["position"][0], "position_x")) + sentence.append((structure_parameters["position"][1], "position_y")) + if structure_parameters["type"] == "circle": + sentence.append((structure_parameters["radius"], "radius")) + elif structure_parameters["type"] == "line": + sentence.append((structure_parameters["length"] / 2.0, "radius")) + for _ in range(5): + sentence_pad_mask.append(0) + else: + sentence.append((structure_parameters["type"], "shape")) + sentence.append((structure_parameters["rotation"][2], "rotation")) + sentence.append((structure_parameters["position"][0], "position_x")) + sentence.append((structure_parameters["position"][1], "position_y")) + for _ in range(4): + sentence_pad_mask.append(0) + sentence.append(("PAD", None)) + sentence_pad_mask.append(1) + + ################################### + # paddings + for i in range(self.max_num_objects - len(target_objs)): + goal_pc_poses.append(np.eye(4)) + + ################################### + if self.debug: + print("---") + print("all objects:", all_objs) + print("target objects:", target_objs) + print("other objects:", other_objs) + print("goal specification:", goal_specification) + print("sentence:", sentence) + show_pcs([pc[:, :3] for pc in obj_pcs + other_obj_pcs], [pc[:, 3:] for pc in obj_pcs + other_obj_pcs], add_coordinate_frame=True) + + assert len(obj_pcs) == len(goal_pc_poses) + ################################### + + # shuffle the position of objects + if shuffle_object_index: + shuffle_target_object_indices = list(range(len(target_objs))) + random.shuffle(shuffle_target_object_indices) + shuffle_object_indices = shuffle_target_object_indices + list(range(len(target_objs), self.max_num_objects)) + obj_pcs = [obj_pcs[i] for i in shuffle_object_indices] + goal_pc_poses = [goal_pc_poses[i] for i in shuffle_object_indices] + if inference_mode: + goal_obj_poses = [goal_obj_poses[i] for i in shuffle_object_indices] + current_obj_poses = [current_obj_poses[i] for i in shuffle_object_indices] + target_objs = [target_objs[i] for i in shuffle_target_object_indices] + current_pc_poses = [current_pc_poses[i] for i in shuffle_object_indices] + + ################################### + if self.use_virtual_structure_frame: + if self.ignore_distractor_objects: + # language, structure virtual frame, target objects + pcs = obj_pcs + type_index = [0] * self.max_num_shape_parameters + [2] + [3] * self.max_num_objects + position_index = list(range(self.max_num_shape_parameters)) + [0] + list(range(self.max_num_objects)) + pad_mask = sentence_pad_mask + [0] + obj_pad_mask + else: + # language, distractor objects, structure virtual frame, target objects + pcs = other_obj_pcs + obj_pcs + type_index = [0] * self.max_num_shape_parameters + [1] * self.max_num_other_objects + [2] + [3] * self.max_num_objects + position_index = list(range(self.max_num_shape_parameters)) + list(range(self.max_num_other_objects)) + [0] + list(range(self.max_num_objects)) + pad_mask = sentence_pad_mask + other_obj_pad_mask + [0] + obj_pad_mask + goal_poses = [goal_structure_pose] + goal_pc_poses + else: + if self.ignore_distractor_objects: + # language, target objects + pcs = obj_pcs + type_index = [0] * self.max_num_shape_parameters + [3] * self.max_num_objects + position_index = list(range(self.max_num_shape_parameters)) + list(range(self.max_num_objects)) + pad_mask = sentence_pad_mask + obj_pad_mask + else: + # language, distractor objects, target objects + pcs = other_obj_pcs + obj_pcs + type_index = [0] * self.max_num_shape_parameters + [1] * self.max_num_other_objects + [3] * self.max_num_objects + position_index = list(range(self.max_num_shape_parameters)) + list(range(self.max_num_other_objects)) + list(range(self.max_num_objects)) + pad_mask = sentence_pad_mask + other_obj_pad_mask + obj_pad_mask + goal_poses = goal_pc_poses + + datum = { + "pcs": pcs, + "sentence": sentence, + "goal_poses": goal_poses, + "type_index": type_index, + "position_index": position_index, + "pad_mask": pad_mask, + "t": step_t, + "filename": filename + } + + if inference_mode: + datum["rgb"] = rgb + datum["goal_obj_poses"] = goal_obj_poses + datum["current_obj_poses"] = current_obj_poses + datum["target_objs"] = target_objs + datum["initial_scene"] = initial_scene + datum["ids"] = ids + datum["goal_specification"] = goal_specification + datum["current_pc_poses"] = current_pc_poses + + return datum + + @staticmethod + def convert_to_tensors(datum, tokenizer): + tensors = { + "pcs": torch.stack(datum["pcs"], dim=0), + "sentence": torch.LongTensor(np.array([tokenizer.tokenize(*i) for i in datum["sentence"]])), + "goal_poses": torch.FloatTensor(np.array(datum["goal_poses"])), + "type_index": torch.LongTensor(np.array(datum["type_index"])), + "position_index": torch.LongTensor(np.array(datum["position_index"])), + "pad_mask": torch.LongTensor(np.array(datum["pad_mask"])), + "t": datum["t"], + "filename": datum["filename"] + } + return tensors + + def __getitem__(self, idx): + + datum = self.convert_to_tensors(self.get_raw_data(idx, shuffle_object_index=self.shuffle_object_index), + self.tokenizer) + + return datum + + def single_datum_to_batch(self, x, num_samples, device, inference_mode=True): + tensor_x = {} + + tensor_x["pcs"] = x["pcs"].to(device)[None, :, :, :].repeat(num_samples, 1, 1, 1) + tensor_x["sentence"] = x["sentence"].to(device)[None, :].repeat(num_samples, 1) + if not inference_mode: + tensor_x["goal_poses"] = x["goal_poses"].to(device)[None, :, :, :].repeat(num_samples, 1, 1, 1) + + tensor_x["type_index"] = x["type_index"].to(device)[None, :].repeat(num_samples, 1) + tensor_x["position_index"] = x["position_index"].to(device)[None, :].repeat(num_samples, 1) + tensor_x["pad_mask"] = x["pad_mask"].to(device)[None, :].repeat(num_samples, 1) + + return tensor_x + + +def compute_min_max(dataloader): + + # tensor([-0.3557, -0.3847, 0.0000, -1.0000, -1.0000, -0.4759, -1.0000, -1.0000, + # -0.9079, -0.8668, -0.9105, -0.4186]) + # tensor([0.3915, 0.3494, 0.3267, 1.0000, 1.0000, 0.8961, 1.0000, 1.0000, 0.8194, + # 0.4787, 0.6421, 1.0000]) + # tensor([0.0918, -0.3758, 0.0000, -1.0000, -1.0000, 0.0000, -1.0000, -1.0000, + # -0.0000, 0.0000, 0.0000, 1.0000]) + # tensor([0.9199, 0.3710, 0.0000, 1.0000, 1.0000, 0.0000, 1.0000, 1.0000, -0.0000, + # 0.0000, 0.0000, 1.0000]) + + min_value = torch.ones(16) * 10000 + max_value = torch.ones(16) * -10000 + for d in tqdm(dataloader): + goal_poses = d["goal_poses"] + goal_poses = goal_poses.reshape(-1, 16) + current_max, _ = torch.max(goal_poses, dim=0) + current_min, _ = torch.min(goal_poses, dim=0) + max_value[max_value < current_max] = current_max[max_value < current_max] + max_value[max_value > current_min] = current_min[max_value > current_min] + print(f"{min_value} - {max_value}") + + +if __name__ == "__main__": + + tokenizer = Tokenizer("/home/weiyu/data_drive/data_new_objects/type_vocabs_coarse.json") + + data_roots = [] + index_roots = [] + for shape, index in [("circle", "index_10k"), ("line", "index_10k"), ("stacking", "index_10k"), ("dinner", "index_10k")]: + data_roots.append("/home/weiyu/data_drive/data_new_objects/examples_{}_new_objects/result".format(shape)) + index_roots.append(index) + + dataset = SemanticArrangementDataset(data_roots=data_roots, + index_roots=index_roots, + split="valid", tokenizer=tokenizer, + max_num_target_objects=7, + max_num_distractor_objects=5, + max_num_shape_parameters=5, + max_num_rearrange_features=0, + max_num_anchor_features=0, + num_pts=1024, + use_virtual_structure_frame=True, + ignore_distractor_objects=True, + ignore_rgb=True, + filter_num_moved_objects_range=None, # [5, 5] + data_augmentation=False, + shuffle_object_index=False, + debug=False) + + # print(len(dataset)) + # for d in dataset: + # print("\n\n" + "="*100) + + dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=8) + for i, d in enumerate(tqdm(dataloader)): + pass + # for k in d: + # if isinstance(d[k], torch.Tensor): + # print("--size", k, d[k].shape) + # for k in d: + # print(k, d[k]) + # + # input("next?") \ No newline at end of file diff --git a/src/StructDiffusion/data/semantic_arrangement_demo.py b/src/StructDiffusion/data/semantic_arrangement_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..653ccfd758aefc73071dfb36ba78ea46774ac7b5 --- /dev/null +++ b/src/StructDiffusion/data/semantic_arrangement_demo.py @@ -0,0 +1,563 @@ +import copy +import cv2 +import h5py +import numpy as np +import os +import trimesh +import torch +from tqdm import tqdm +import json +import random + +from torch.utils.data import DataLoader + +# Local imports +from StructDiffusion.utils.rearrangement import show_pcs, get_pts, combine_and_sample_xyzs +from StructDiffusion.language.tokenizer import Tokenizer + +import StructDiffusion.utils.brain2.camera as cam +import StructDiffusion.utils.brain2.image as img +import StructDiffusion.utils.transformations as tra + + +class SemanticArrangementDataset(torch.utils.data.Dataset): + + def __init__(self, data_root, tokenizer, + max_num_target_objects=11, max_num_distractor_objects=5, + max_num_shape_parameters=7, max_num_rearrange_features=1, max_num_anchor_features=3, + num_pts=1024, + use_virtual_structure_frame=True, ignore_distractor_objects=True, ignore_rgb=True, + filter_num_moved_objects_range=None, shuffle_object_index=False, + data_augmentation=True, debug=False, **kwargs): + """ + + Note: setting filter_num_moved_objects_range=[k, k] and max_num_objects=k will create no padding for target objs + + :param data_root: + :param split: train, valid, or test + :param shuffle_object_index: whether to shuffle the positions of target objects and other objects in the sequence + :param debug: + :param max_num_shape_parameters: + :param max_num_objects: + :param max_num_rearrange_features: + :param max_num_anchor_features: + :param num_pts: + :param use_stored_arrangement_indices: + :param kwargs: + """ + + self.use_virtual_structure_frame = use_virtual_structure_frame + self.ignore_distractor_objects = ignore_distractor_objects + self.ignore_rgb = ignore_rgb and not debug + + self.num_pts = num_pts + self.debug = debug + + self.max_num_objects = max_num_target_objects + self.max_num_other_objects = max_num_distractor_objects + self.max_num_shape_parameters = max_num_shape_parameters + self.max_num_rearrange_features = max_num_rearrange_features + self.max_num_anchor_features = max_num_anchor_features + self.shuffle_object_index = shuffle_object_index + + # used to tokenize the language part + self.tokenizer = tokenizer + + # retrieve data + self.data_root = data_root + self.arrangement_data = [] + for filename in os.listdir(data_root): + if ".h5" in filename: + self.arrangement_data.append((os.path.join(data_root, filename), 0)) + print("{} valid sequences".format(len(self.arrangement_data))) + + # Data Aug + self.data_augmentation = data_augmentation + # additive noise + self.gp_rescale_factor_range = [12, 20] + self.gaussian_scale_range = [0., 0.003] + # multiplicative noise + self.gamma_shape = 1000. + self.gamma_scale = 0.001 + + def filter_based_on_number_of_moved_objects(self, filter_num_moved_objects_range): + assert len(list(filter_num_moved_objects_range)) == 2 + min_num, max_num = filter_num_moved_objects_range + print("Remove scenes that have less than {} or more than {} objects being moved".format(min_num, max_num)) + ok_data = [] + for filename, step_t in self.arrangement_data: + h5 = h5py.File(filename, 'r') + moved_objs = h5['moved_objs'][()].split(',') + if min_num <= len(moved_objs) <= max_num: + ok_data.append((filename, step_t)) + print("{} valid sequences left".format(len(ok_data))) + return ok_data + + def get_data_idx(self, idx): + # Create the datum to return + file_idx = np.argmax(idx < self.file_to_count) + data = h5py.File(self.data_files[file_idx], 'r') + if file_idx > 0: + # for lang2sym, idx is always 0 + idx = idx - self.file_to_count[file_idx - 1] + return data, idx, file_idx + + def add_noise_to_depth(self, depth_img): + """ add depth noise """ + multiplicative_noise = np.random.gamma(self.gamma_shape, self.gamma_scale) + depth_img = multiplicative_noise * depth_img + return depth_img + + def add_noise_to_xyz(self, xyz_img, depth_img): + """ TODO: remove this code or at least celean it up""" + xyz_img = xyz_img.copy() + H, W, C = xyz_img.shape + gp_rescale_factor = np.random.randint(self.gp_rescale_factor_range[0], + self.gp_rescale_factor_range[1]) + gp_scale = np.random.uniform(self.gaussian_scale_range[0], + self.gaussian_scale_range[1]) + small_H, small_W = (np.array([H, W]) / gp_rescale_factor).astype(int) + additive_noise = np.random.normal(loc=0.0, scale=gp_scale, size=(small_H, small_W, C)) + additive_noise = cv2.resize(additive_noise, (W, H), interpolation=cv2.INTER_CUBIC) + xyz_img[depth_img > 0, :] += additive_noise[depth_img > 0, :] + return xyz_img + + def random_index(self): + return self[np.random.randint(len(self))] + + def _get_rgb(self, h5, idx, ee=True): + RGB = "ee_rgb" if ee else "rgb" + rgb1 = img.PNGToNumpy(h5[RGB][idx])[:, :, :3] / 255. # remove alpha + return rgb1 + + def _get_depth(self, h5, idx, ee=True): + DEPTH = "ee_depth" if ee else "depth" + + def _get_images(self, h5, idx, ee=True): + if ee: + RGB, DEPTH, SEG = "ee_rgb", "ee_depth", "ee_seg" + DMIN, DMAX = "ee_depth_min", "ee_depth_max" + else: + RGB, DEPTH, SEG = "rgb", "depth", "seg" + DMIN, DMAX = "depth_min", "depth_max" + dmin = h5[DMIN][idx] + dmax = h5[DMAX][idx] + rgb1 = img.PNGToNumpy(h5[RGB][idx])[:, :, :3] / 255. # remove alpha + depth1 = h5[DEPTH][idx] / 20000. * (dmax - dmin) + dmin + seg1 = img.PNGToNumpy(h5[SEG][idx]) + + valid1 = np.logical_and(depth1 > 0.1, depth1 < 2.) + + # proj_matrix = h5['proj_matrix'][()] + camera = cam.get_camera_from_h5(h5) + if self.data_augmentation: + depth1 = self.add_noise_to_depth(depth1) + + xyz1 = cam.compute_xyz(depth1, camera) + if self.data_augmentation: + xyz1 = self.add_noise_to_xyz(xyz1, depth1) + + # Transform the point cloud + # Here it is... + # CAM_POSE = "ee_cam_pose" if ee else "cam_pose" + CAM_POSE = "ee_camera_view" if ee else "camera_view" + cam_pose = h5[CAM_POSE][idx] + if ee: + # ee_camera_view has 0s for x, y, z + cam_pos = h5["ee_cam_pose"][:][:3, 3] + cam_pose[:3, 3] = cam_pos + + # Get transformed point cloud + h, w, d = xyz1.shape + xyz1 = xyz1.reshape(h * w, -1) + xyz1 = trimesh.transform_points(xyz1, cam_pose) + xyz1 = xyz1.reshape(h, w, -1) + + scene1 = rgb1, depth1, seg1, valid1, xyz1 + + return scene1 + + def __len__(self): + return len(self.arrangement_data) + + def _get_ids(self, h5): + """ + get object ids + + @param h5: + @return: + """ + ids = {} + for k in h5.keys(): + if k.startswith("id_"): + ids[k[3:]] = h5[k][()] + return ids + + def get_positive_ratio(self): + num_pos = 0 + for d in self.arrangement_data: + filename, step_t = d + if step_t == 0: + num_pos += 1 + return (len(self.arrangement_data) - num_pos) * 1.0 / num_pos + + def get_object_position_vocab_sizes(self): + return self.tokenizer.get_object_position_vocab_sizes() + + def get_vocab_size(self): + return self.tokenizer.get_vocab_size() + + def get_data_index(self, idx): + filename = self.arrangement_data[idx] + return filename + + def get_raw_data(self, idx, inference_mode=False, shuffle_object_index=False): + """ + + :param idx: + :param inference_mode: + :param shuffle_object_index: used to test different orders of objects + :return: + """ + + filename, _ = self.arrangement_data[idx] + + h5 = h5py.File(filename, 'r') + ids = self._get_ids(h5) + all_objs = sorted([o for o in ids.keys() if "object_" in o]) + goal_specification = json.loads(str(np.array(h5["goal_specification"]))) + num_rearrange_objs = len(goal_specification["rearrange"]["objects"]) + num_other_objs = len(goal_specification["anchor"]["objects"] + goal_specification["distract"]["objects"]) + assert len(all_objs) == num_rearrange_objs + num_other_objs, "{}, {}".format(len(all_objs), num_rearrange_objs + num_other_objs) + assert num_rearrange_objs <= self.max_num_objects + assert num_other_objs <= self.max_num_other_objects + + # important: only using the last step + step_t = num_rearrange_objs + + target_objs = all_objs[:num_rearrange_objs] + other_objs = all_objs[num_rearrange_objs:] + + structure_parameters = goal_specification["shape"] + + # Important: ensure the order is correct + if structure_parameters["type"] == "circle" or structure_parameters["type"] == "line": + target_objs = target_objs[::-1] + elif structure_parameters["type"] == "tower" or structure_parameters["type"] == "dinner": + target_objs = target_objs + else: + raise KeyError("{} structure is not recognized".format(structure_parameters["type"])) + all_objs = target_objs + other_objs + + ################################### + # getting scene images and point clouds + scene = self._get_images(h5, step_t, ee=True) + rgb, depth, seg, valid, xyz = scene + if inference_mode: + initial_scene = scene + + # getting object point clouds + obj_pcs = [] + obj_pad_mask = [] + current_pc_poses = [] + other_obj_pcs = [] + other_obj_pad_mask = [] + for obj in all_objs: + obj_mask = np.logical_and(seg == ids[obj], valid) + if np.sum(obj_mask) <= 0: + raise Exception + ok, obj_xyz, obj_rgb, _ = get_pts(xyz, rgb, obj_mask, num_pts=self.num_pts) + if not ok: + raise Exception + + if obj in target_objs: + if self.ignore_rgb: + obj_pcs.append(obj_xyz) + else: + obj_pcs.append(torch.concat([obj_xyz, obj_rgb], dim=-1)) + obj_pad_mask.append(0) + pc_pose = np.eye(4) + pc_pose[:3, 3] = torch.mean(obj_xyz, dim=0).numpy() + current_pc_poses.append(pc_pose) + elif obj in other_objs: + if self.ignore_rgb: + other_obj_pcs.append(obj_xyz) + else: + other_obj_pcs.append(torch.concat([obj_xyz, obj_rgb], dim=-1)) + other_obj_pad_mask.append(0) + else: + raise Exception + + ################################### + # computes goal positions for objects + # Important: because of the noises we added to point clouds, the rearranged point clouds will not be perfect + if self.use_virtual_structure_frame: + goal_structure_pose = tra.euler_matrix(structure_parameters["rotation"][0], structure_parameters["rotation"][1], + structure_parameters["rotation"][2]) + goal_structure_pose[:3, 3] = [structure_parameters["position"][0], structure_parameters["position"][1], + structure_parameters["position"][2]] + goal_structure_pose_inv = np.linalg.inv(goal_structure_pose) + + goal_obj_poses = [] + current_obj_poses = [] + goal_pc_poses = [] + for obj, current_pc_pose in zip(target_objs, current_pc_poses): + goal_pose = h5[obj][0] + current_pose = h5[obj][step_t] + if inference_mode: + goal_obj_poses.append(goal_pose) + current_obj_poses.append(current_pose) + + goal_pc_pose = goal_pose @ np.linalg.inv(current_pose) @ current_pc_pose + if self.use_virtual_structure_frame: + goal_pc_pose = goal_structure_pose_inv @ goal_pc_pose + goal_pc_poses.append(goal_pc_pose) + + # transform current object point cloud to the goal point cloud in the world frame + if self.debug: + new_obj_pcs = [copy.deepcopy(pc.numpy()) for pc in obj_pcs] + for i, obj_pc in enumerate(new_obj_pcs): + + current_pc_pose = current_pc_poses[i] + goal_pc_pose = goal_pc_poses[i] + if self.use_virtual_structure_frame: + goal_pc_pose = goal_structure_pose @ goal_pc_pose + print("current pc pose", current_pc_pose) + print("goal pc pose", goal_pc_pose) + + goal_pc_transform = goal_pc_pose @ np.linalg.inv(current_pc_pose) + print("transform", goal_pc_transform) + new_obj_pc = copy.deepcopy(obj_pc) + new_obj_pc[:, :3] = trimesh.transform_points(obj_pc[:, :3], goal_pc_transform) + print(new_obj_pc.shape) + + # visualize rearrangement sequence (new_obj_xyzs), the current object before moving (obj_xyz), and other objects + new_obj_pcs[i] = new_obj_pc + new_obj_pcs[i][:, 3:] = np.tile(np.array([1, 0, 0], dtype=np.float), (new_obj_pc.shape[0], 1)) + new_obj_rgb_current = np.tile(np.array([0, 1, 0], dtype=np.float), (new_obj_pc.shape[0], 1)) + show_pcs([pc[:, :3] for pc in new_obj_pcs] + [pc[:, :3] for pc in other_obj_pcs] + [obj_pc[:, :3]], + [pc[:, 3:] for pc in new_obj_pcs] + [pc[:, 3:] for pc in other_obj_pcs] + [new_obj_rgb_current], + add_coordinate_frame=True) + show_pcs([pc[:, :3] for pc in new_obj_pcs], [pc[:, 3:] for pc in new_obj_pcs], add_coordinate_frame=True) + + # pad data + for i in range(self.max_num_objects - len(target_objs)): + obj_pcs.append(torch.zeros_like(obj_pcs[0], dtype=torch.float32)) + obj_pad_mask.append(1) + for i in range(self.max_num_other_objects - len(other_objs)): + other_obj_pcs.append(torch.zeros_like(obj_pcs[0], dtype=torch.float32)) + other_obj_pad_mask.append(1) + + ################################### + # preparing sentence + sentence = [] + sentence_pad_mask = [] + + # structure parameters + # 5 parameters + structure_parameters = goal_specification["shape"] + if structure_parameters["type"] == "circle" or structure_parameters["type"] == "line": + sentence.append((structure_parameters["type"], "shape")) + sentence.append((structure_parameters["rotation"][2], "rotation")) + sentence.append((structure_parameters["position"][0], "position_x")) + sentence.append((structure_parameters["position"][1], "position_y")) + if structure_parameters["type"] == "circle": + sentence.append((structure_parameters["radius"], "radius")) + elif structure_parameters["type"] == "line": + sentence.append((structure_parameters["length"] / 2.0, "radius")) + for _ in range(5): + sentence_pad_mask.append(0) + else: + sentence.append((structure_parameters["type"], "shape")) + sentence.append((structure_parameters["rotation"][2], "rotation")) + sentence.append((structure_parameters["position"][0], "position_x")) + sentence.append((structure_parameters["position"][1], "position_y")) + for _ in range(4): + sentence_pad_mask.append(0) + sentence.append(("PAD", None)) + sentence_pad_mask.append(1) + + ################################### + # paddings + for i in range(self.max_num_objects - len(target_objs)): + goal_pc_poses.append(np.eye(4)) + + ################################### + if self.debug: + print("---") + print("all objects:", all_objs) + print("target objects:", target_objs) + print("other objects:", other_objs) + print("goal specification:", goal_specification) + print("sentence:", sentence) + show_pcs([pc[:, :3] for pc in obj_pcs + other_obj_pcs], [pc[:, 3:] for pc in obj_pcs + other_obj_pcs], add_coordinate_frame=True) + + assert len(obj_pcs) == len(goal_pc_poses) + ################################### + + # shuffle the position of objects + if shuffle_object_index: + shuffle_target_object_indices = list(range(len(target_objs))) + random.shuffle(shuffle_target_object_indices) + shuffle_object_indices = shuffle_target_object_indices + list(range(len(target_objs), self.max_num_objects)) + obj_pcs = [obj_pcs[i] for i in shuffle_object_indices] + goal_pc_poses = [goal_pc_poses[i] for i in shuffle_object_indices] + if inference_mode: + goal_obj_poses = [goal_obj_poses[i] for i in shuffle_object_indices] + current_obj_poses = [current_obj_poses[i] for i in shuffle_object_indices] + target_objs = [target_objs[i] for i in shuffle_target_object_indices] + current_pc_poses = [current_pc_poses[i] for i in shuffle_object_indices] + + ################################### + if self.use_virtual_structure_frame: + if self.ignore_distractor_objects: + # language, structure virtual frame, target objects + pcs = obj_pcs + type_index = [0] * self.max_num_shape_parameters + [2] + [3] * self.max_num_objects + position_index = list(range(self.max_num_shape_parameters)) + [0] + list(range(self.max_num_objects)) + pad_mask = sentence_pad_mask + [0] + obj_pad_mask + else: + # language, distractor objects, structure virtual frame, target objects + pcs = other_obj_pcs + obj_pcs + type_index = [0] * self.max_num_shape_parameters + [1] * self.max_num_other_objects + [2] + [3] * self.max_num_objects + position_index = list(range(self.max_num_shape_parameters)) + list(range(self.max_num_other_objects)) + [0] + list(range(self.max_num_objects)) + pad_mask = sentence_pad_mask + other_obj_pad_mask + [0] + obj_pad_mask + goal_poses = [goal_structure_pose] + goal_pc_poses + else: + if self.ignore_distractor_objects: + # language, target objects + pcs = obj_pcs + type_index = [0] * self.max_num_shape_parameters + [3] * self.max_num_objects + position_index = list(range(self.max_num_shape_parameters)) + list(range(self.max_num_objects)) + pad_mask = sentence_pad_mask + obj_pad_mask + else: + # language, distractor objects, target objects + pcs = other_obj_pcs + obj_pcs + type_index = [0] * self.max_num_shape_parameters + [1] * self.max_num_other_objects + [3] * self.max_num_objects + position_index = list(range(self.max_num_shape_parameters)) + list(range(self.max_num_other_objects)) + list(range(self.max_num_objects)) + pad_mask = sentence_pad_mask + other_obj_pad_mask + obj_pad_mask + goal_poses = goal_pc_poses + + datum = { + "pcs": pcs, + "sentence": sentence, + "goal_poses": goal_poses, + "type_index": type_index, + "position_index": position_index, + "pad_mask": pad_mask, + "t": step_t, + "filename": filename + } + + if inference_mode: + datum["rgb"] = rgb + datum["goal_obj_poses"] = goal_obj_poses + datum["current_obj_poses"] = current_obj_poses + datum["target_objs"] = target_objs + datum["initial_scene"] = initial_scene + datum["ids"] = ids + datum["goal_specification"] = goal_specification + datum["current_pc_poses"] = current_pc_poses + + return datum + + @staticmethod + def convert_to_tensors(datum, tokenizer): + tensors = { + "pcs": torch.stack(datum["pcs"], dim=0), + "sentence": torch.LongTensor(np.array([tokenizer.tokenize(*i) for i in datum["sentence"]])), + "goal_poses": torch.FloatTensor(np.array(datum["goal_poses"])), + "type_index": torch.LongTensor(np.array(datum["type_index"])), + "position_index": torch.LongTensor(np.array(datum["position_index"])), + "pad_mask": torch.LongTensor(np.array(datum["pad_mask"])), + "t": datum["t"], + "filename": datum["filename"] + } + return tensors + + def __getitem__(self, idx): + + datum = self.convert_to_tensors(self.get_raw_data(idx, shuffle_object_index=self.shuffle_object_index), + self.tokenizer) + + return datum + + def single_datum_to_batch(self, x, num_samples, device, inference_mode=True): + tensor_x = {} + + tensor_x["pcs"] = x["pcs"].to(device)[None, :, :, :].repeat(num_samples, 1, 1, 1) + tensor_x["sentence"] = x["sentence"].to(device)[None, :].repeat(num_samples, 1) + if not inference_mode: + tensor_x["goal_poses"] = x["goal_poses"].to(device)[None, :, :, :].repeat(num_samples, 1, 1, 1) + + tensor_x["type_index"] = x["type_index"].to(device)[None, :].repeat(num_samples, 1) + tensor_x["position_index"] = x["position_index"].to(device)[None, :].repeat(num_samples, 1) + tensor_x["pad_mask"] = x["pad_mask"].to(device)[None, :].repeat(num_samples, 1) + + return tensor_x + + +def compute_min_max(dataloader): + + # tensor([-0.3557, -0.3847, 0.0000, -1.0000, -1.0000, -0.4759, -1.0000, -1.0000, + # -0.9079, -0.8668, -0.9105, -0.4186]) + # tensor([0.3915, 0.3494, 0.3267, 1.0000, 1.0000, 0.8961, 1.0000, 1.0000, 0.8194, + # 0.4787, 0.6421, 1.0000]) + # tensor([0.0918, -0.3758, 0.0000, -1.0000, -1.0000, 0.0000, -1.0000, -1.0000, + # -0.0000, 0.0000, 0.0000, 1.0000]) + # tensor([0.9199, 0.3710, 0.0000, 1.0000, 1.0000, 0.0000, 1.0000, 1.0000, -0.0000, + # 0.0000, 0.0000, 1.0000]) + + min_value = torch.ones(16) * 10000 + max_value = torch.ones(16) * -10000 + for d in tqdm(dataloader): + goal_poses = d["goal_poses"] + goal_poses = goal_poses.reshape(-1, 16) + current_max, _ = torch.max(goal_poses, dim=0) + current_min, _ = torch.min(goal_poses, dim=0) + max_value[max_value < current_max] = current_max[max_value < current_max] + max_value[max_value > current_min] = current_min[max_value > current_min] + print(f"{min_value} - {max_value}") + + +if __name__ == "__main__": + + tokenizer = Tokenizer("/home/weiyu/data_drive/data_new_objects/type_vocabs_coarse.json") + + data_roots = [] + index_roots = [] + for shape, index in [("circle", "index_10k"), ("line", "index_10k"), ("stacking", "index_10k"), ("dinner", "index_10k")]: + data_roots.append("/home/weiyu/data_drive/data_new_objects/examples_{}_new_objects/result".format(shape)) + index_roots.append(index) + + dataset = SemanticArrangementDataset(data_roots=data_roots, + index_roots=index_roots, + split="valid", tokenizer=tokenizer, + max_num_target_objects=7, + max_num_distractor_objects=5, + max_num_shape_parameters=5, + max_num_rearrange_features=0, + max_num_anchor_features=0, + num_pts=1024, + use_virtual_structure_frame=True, + ignore_distractor_objects=True, + ignore_rgb=True, + filter_num_moved_objects_range=None, # [5, 5] + data_augmentation=False, + shuffle_object_index=False, + debug=False) + + # print(len(dataset)) + # for d in dataset: + # print("\n\n" + "="*100) + + dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=8) + for i, d in enumerate(tqdm(dataloader)): + pass + # for k in d: + # if isinstance(d[k], torch.Tensor): + # print("--size", k, d[k].shape) + # for k in d: + # print(k, d[k]) + # + # input("next?") \ No newline at end of file diff --git a/src/StructDiffusion/diffusion/__init__.py b/src/StructDiffusion/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/StructDiffusion/diffusion/__pycache__/__init__.cpython-37.pyc b/src/StructDiffusion/diffusion/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15e1d8365cad74bebf2f131d9e8646c529b9974d Binary files /dev/null and b/src/StructDiffusion/diffusion/__pycache__/__init__.cpython-37.pyc differ diff --git a/src/StructDiffusion/diffusion/__pycache__/__init__.cpython-38.pyc b/src/StructDiffusion/diffusion/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27f93c52a01731df348737534c20838286322db6 Binary files /dev/null and b/src/StructDiffusion/diffusion/__pycache__/__init__.cpython-38.pyc differ diff --git a/src/StructDiffusion/diffusion/__pycache__/noise_schedule.cpython-37.pyc b/src/StructDiffusion/diffusion/__pycache__/noise_schedule.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c50945974f35e2d0b727238215e8adc0400b8644 Binary files /dev/null and b/src/StructDiffusion/diffusion/__pycache__/noise_schedule.cpython-37.pyc differ diff --git a/src/StructDiffusion/diffusion/__pycache__/noise_schedule.cpython-38.pyc b/src/StructDiffusion/diffusion/__pycache__/noise_schedule.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..494884c7815431f08db9cd82825ee5fabb0b4a36 Binary files /dev/null and b/src/StructDiffusion/diffusion/__pycache__/noise_schedule.cpython-38.pyc differ diff --git a/src/StructDiffusion/diffusion/__pycache__/pose_conversion.cpython-37.pyc b/src/StructDiffusion/diffusion/__pycache__/pose_conversion.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4549ee4cedac8ab2704f4e3ed6481343fdf8594f Binary files /dev/null and b/src/StructDiffusion/diffusion/__pycache__/pose_conversion.cpython-37.pyc differ diff --git a/src/StructDiffusion/diffusion/__pycache__/pose_conversion.cpython-38.pyc b/src/StructDiffusion/diffusion/__pycache__/pose_conversion.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7945175edead17f2bfc5c6e27ed83105e7506532 Binary files /dev/null and b/src/StructDiffusion/diffusion/__pycache__/pose_conversion.cpython-38.pyc differ diff --git a/src/StructDiffusion/diffusion/__pycache__/sampler.cpython-37.pyc b/src/StructDiffusion/diffusion/__pycache__/sampler.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3caa501f32838b8807f2bb72a9e7ed280fff72dc Binary files /dev/null and b/src/StructDiffusion/diffusion/__pycache__/sampler.cpython-37.pyc differ diff --git a/src/StructDiffusion/diffusion/__pycache__/sampler.cpython-38.pyc b/src/StructDiffusion/diffusion/__pycache__/sampler.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef5a255067f922aa21dc539da6665bdf20f19c04 Binary files /dev/null and b/src/StructDiffusion/diffusion/__pycache__/sampler.cpython-38.pyc differ diff --git a/src/StructDiffusion/diffusion/noise_schedule.py b/src/StructDiffusion/diffusion/noise_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..0c2b7986fdc0c88d09acd16117d3d059cf67f96b --- /dev/null +++ b/src/StructDiffusion/diffusion/noise_schedule.py @@ -0,0 +1,81 @@ +import math +import torch +import torch.nn.functional as F + + +def cosine_beta_schedule(timesteps, s=0.008): + """ + cosine schedule as proposed in https://arxiv.org/abs/2102.09672 + """ + steps = timesteps + 1 + x = torch.linspace(0, timesteps, steps) + alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + return torch.clip(betas, 0.0001, 0.9999) + + +def linear_beta_schedule(timesteps): + beta_start = 0.0001 + beta_end = 0.02 + return torch.linspace(beta_start, beta_end, timesteps) + + +def quadratic_beta_schedule(timesteps): + beta_start = 0.0001 + beta_end = 0.02 + return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2 + + +def sigmoid_beta_schedule(timesteps): + beta_start = 0.0001 + beta_end = 0.02 + betas = torch.linspace(-6, 6, timesteps) + return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start + + +class NoiseSchedule: + + def __init__(self, timesteps=200): + + self.timesteps = timesteps + + # define beta schedule + self.betas = linear_beta_schedule(timesteps=timesteps) + # self.betas = cosine_beta_schedule(timesteps=timesteps) + + # define alphas + self.alphas = 1. - self.betas + # alphas_cumprod: alpha bar + self.alphas_cumprod = torch.cumprod(self.alphas, axis=0) + self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0) + self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = self.betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod) + + +def extract(a, t, x_shape): + batch_size = t.shape[0] + out = a.gather(-1, t.cpu()) + return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device) + + +# forward diffusion (using the nice property) +def q_sample(x_start, t, noise_schedule, noise=None): + if noise is None: + noise = torch.randn_like(x_start) + + sqrt_alphas_cumprod_t = extract(noise_schedule.sqrt_alphas_cumprod, t, x_start.shape) + # print("sqrt_alphas_cumprod_t", sqrt_alphas_cumprod_t) + sqrt_one_minus_alphas_cumprod_t = extract( + noise_schedule.sqrt_one_minus_alphas_cumprod, t, x_start.shape + ) + # print("sqrt_one_minus_alphas_cumprod_t", sqrt_one_minus_alphas_cumprod_t) + # print("noise", noise) + + return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise \ No newline at end of file diff --git a/src/StructDiffusion/diffusion/pose_conversion.py b/src/StructDiffusion/diffusion/pose_conversion.py new file mode 100644 index 0000000000000000000000000000000000000000..abf43bae2bc6af9fed8338c255d0bdcc06447465 --- /dev/null +++ b/src/StructDiffusion/diffusion/pose_conversion.py @@ -0,0 +1,103 @@ +import os +import torch +import pytorch3d.transforms as tra3d + +from StructDiffusion.utils.rotation_continuity import compute_rotation_matrix_from_ortho6d + + +def get_diffusion_variables_from_9D_actions(struct_xyztheta_inputs, obj_xyztheta_inputs): + + # important: we need to get the first two columns, not first two rows + # array([[ 3, 4, 5], + # [ 6, 7, 8], + # [ 9, 10, 11]]) + xyz_6d_idxs = [0, 1, 2, 3, 6, 9, 4, 7, 10] + + # print(batch_data["obj_xyztheta_inputs"].shape) + # print(batch_data["struct_xyztheta_inputs"].shape) + + # only get the first and second columns of rotation + obj_xyztheta_inputs = obj_xyztheta_inputs[:, :, xyz_6d_idxs] # B, N, 9 + struct_xyztheta_inputs = struct_xyztheta_inputs[:, :, xyz_6d_idxs] # B, 1, 9 + + x = torch.cat([struct_xyztheta_inputs, obj_xyztheta_inputs], dim=1) # B, 1 + N, 9 + + # print(x.shape) + + return x + + +def get_diffusion_variables_from_H(poses): + """ + [[0,1,2,3], + [4,5,6,7], + [8,9,10,11], + [12,13,14,15] + :param obj_xyztheta_inputs: B, N, 4, 4 + :return: + """ + + xyz_6d_idxs = [3, 7, 11, 0, 4, 8, 1, 5, 9] + + B, N, _, _ = poses.shape + x = poses.reshape(B, N, 16)[:, :, xyz_6d_idxs] # B, N, 9 + return x + + +def get_struct_objs_poses(x): + + on_gpu = x.is_cuda + if not on_gpu: + x = x.cuda() + + # assert x.is_cuda, "compute_rotation_matrix_from_ortho6d requires input to be on gpu" + device = x.device + + # important: the noisy x can go out of bounds + x = torch.clamp(x, min=-1, max=1) + + # x: B, 1 + N, 9 + B = x.shape[0] + N = x.shape[1] - 1 + + # compute_rotation_matrix_from_ortho6d takes in [B, 6], outputs [B, 3, 3] + x_6d = x[:, :, 3:].reshape(-1, 6) + x_rot = compute_rotation_matrix_from_ortho6d(x_6d).reshape(B, N+1, 3, 3) # B, 1 + N, 3, 3 + + x_trans = x[:, :, :3] # B, 1 + N, 3 + + x_full = torch.eye(4).repeat(B, 1 + N, 1, 1).to(device) + x_full[:, :, :3, :3] = x_rot + x_full[:, :, :3, 3] = x_trans + + struct_pose = x_full[:, 0].unsqueeze(1) # B, 1, 4, 4 + pc_poses_in_struct = x_full[:, 1:] # B, N, 4, 4 + + if not on_gpu: + struct_pose = struct_pose.cpu() + pc_poses_in_struct = pc_poses_in_struct.cpu() + + return struct_pose, pc_poses_in_struct + + +def compute_current_and_goal_pc_poses(obj_xyzs, struct_pose, pc_poses_in_struct): + + device = obj_xyzs.device + + # obj_xyzs: B, N, P, 3 + # struct_pose: B, 1, 4, 4 + # pc_poses_in_struct: B, N, 4, 4 + B, N, _, _ = pc_poses_in_struct.shape + _, _, P, _ = obj_xyzs.shape + + current_pc_poses = torch.eye(4).repeat(B, N, 1, 1).to(device) # B, N, 4, 4 + # print(torch.mean(obj_xyzs, dim=2).shape) + current_pc_poses[:, :, :3, 3] = torch.mean(obj_xyzs, dim=2) # B, N, 4, 4 + + struct_pose = struct_pose.repeat(1, N, 1, 1) # B, N, 4, 4 + struct_pose = struct_pose.reshape(B * N, 4, 4) # B x 1, 4, 4 + pc_poses_in_struct = pc_poses_in_struct.reshape(B * N, 4, 4) # B x N, 4, 4 + + goal_pc_poses = struct_pose @ pc_poses_in_struct # B x N, 4, 4 + goal_pc_poses = goal_pc_poses.reshape(B, N, 4, 4) # B, N, 4, 4 + return current_pc_poses, goal_pc_poses \ No newline at end of file diff --git a/src/StructDiffusion/diffusion/sampler.py b/src/StructDiffusion/diffusion/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..06ee7a6258771dd01e35c075661b5a7f0e686035 --- /dev/null +++ b/src/StructDiffusion/diffusion/sampler.py @@ -0,0 +1,296 @@ +import torch +from tqdm import tqdm +import pytorch3d.transforms as tra3d + +from StructDiffusion.diffusion.noise_schedule import extract +from StructDiffusion.diffusion.pose_conversion import get_struct_objs_poses +from StructDiffusion.utils.batch_inference import move_pc_and_create_scene_simple, visualize_batch_pcs, move_pc_and_create_scene_new + +class Sampler: + + def __init__(self, model_class, checkpoint_path, device, debug=False): + + self.debug = debug + self.device = device + + self.model = model_class.load_from_checkpoint(checkpoint_path) + self.backbone = self.model.model + self.backbone.to(device) + self.backbone.eval() + + def sample(self, batch, num_poses): + + noise_schedule = self.model.noise_schedule + + B = batch["pcs"].shape[0] + + x_noisy = torch.randn((B, num_poses, 9), device=self.device) + + xs = [] + for t_index in tqdm(reversed(range(0, noise_schedule.timesteps)), + desc='sampling loop time step', total=noise_schedule.timesteps): + + t = torch.full((B,), t_index, device=self.device, dtype=torch.long) + + # noise schedule + betas_t = extract(noise_schedule.betas, t, x_noisy.shape) + sqrt_one_minus_alphas_cumprod_t = extract(noise_schedule.sqrt_one_minus_alphas_cumprod, t, x_noisy.shape) + sqrt_recip_alphas_t = extract(noise_schedule.sqrt_recip_alphas, t, x_noisy.shape) + + # predict noise + pcs = batch["pcs"] + sentence = batch["sentence"] + type_index = batch["type_index"] + position_index = batch["position_index"] + pad_mask = batch["pad_mask"] + # calling the backbone instead of the pytorch-lightning model + with torch.no_grad(): + predicted_noise = self.backbone.forward(t, pcs, sentence, x_noisy, type_index, position_index, pad_mask) + + # compute noisy x at t + model_mean = sqrt_recip_alphas_t * (x_noisy - betas_t * predicted_noise / sqrt_one_minus_alphas_cumprod_t) + if t_index == 0: + x_noisy = model_mean + else: + posterior_variance_t = extract(noise_schedule.posterior_variance, t, x_noisy.shape) + noise = torch.randn_like(x_noisy) + x_noisy = model_mean + torch.sqrt(posterior_variance_t) * noise + + xs.append(x_noisy) + + xs = list(reversed(xs)) + return xs + +class SamplerV2: + + def __init__(self, diffusion_model_class, diffusion_checkpoint_path, + collision_model_class, collision_checkpoint_path, + device, debug=False): + + self.debug = debug + self.device = device + + self.diffusion_model = diffusion_model_class.load_from_checkpoint(diffusion_checkpoint_path) + self.diffusion_backbone = self.diffusion_model.model + self.diffusion_backbone.to(device) + self.diffusion_backbone.eval() + + self.collision_model = collision_model_class.load_from_checkpoint(collision_checkpoint_path) + self.collision_backbone = self.collision_model.model + self.collision_backbone.to(device) + self.collision_backbone.eval() + + def sample(self, batch, num_poses): + + noise_schedule = self.diffusion_model.noise_schedule + + B = batch["pcs"].shape[0] + + x_noisy = torch.randn((B, num_poses, 9), device=self.device) + + xs = [] + for t_index in tqdm(reversed(range(0, noise_schedule.timesteps)), + desc='sampling loop time step', total=noise_schedule.timesteps): + + t = torch.full((B,), t_index, device=self.device, dtype=torch.long) + + # noise schedule + betas_t = extract(noise_schedule.betas, t, x_noisy.shape) + sqrt_one_minus_alphas_cumprod_t = extract(noise_schedule.sqrt_one_minus_alphas_cumprod, t, x_noisy.shape) + sqrt_recip_alphas_t = extract(noise_schedule.sqrt_recip_alphas, t, x_noisy.shape) + + # predict noise + pcs = batch["pcs"] + sentence = batch["sentence"] + type_index = batch["type_index"] + position_index = batch["position_index"] + pad_mask = batch["pad_mask"] + # calling the backbone instead of the pytorch-lightning model + with torch.no_grad(): + predicted_noise = self.diffusion_backbone.forward(t, pcs, sentence, x_noisy, type_index, position_index, pad_mask) + + # compute noisy x at t + model_mean = sqrt_recip_alphas_t * (x_noisy - betas_t * predicted_noise / sqrt_one_minus_alphas_cumprod_t) + if t_index == 0: + x_noisy = model_mean + else: + posterior_variance_t = extract(noise_schedule.posterior_variance, t, x_noisy.shape) + noise = torch.randn_like(x_noisy) + x_noisy = model_mean + torch.sqrt(posterior_variance_t) * noise + + xs.append(x_noisy) + + xs = list(reversed(xs)) + + visualize = True + + struct_pose, pc_poses_in_struct = get_struct_objs_poses(xs[0]) + # struct_pose: B, 1, 4, 4 + # pc_poses_in_struct: B, N, 4, 4 + + S = B + num_elite = 10 + #################################################### + # only keep one copy + + # N, P, 3 + obj_xyzs = batch["pcs"][0][:, :, :3] + print("obj_xyzs shape", obj_xyzs.shape) + + # 1, N + # object_pad_mask: padding location has 1 + num_target_objs = num_poses + if self.diffusion_backbone.use_virtual_structure_frame: + num_target_objs -= 1 + object_pad_mask = batch["pad_mask"][0][-num_target_objs:].unsqueeze(0) + target_object_inds = 1 - object_pad_mask + print("target_object_inds shape", target_object_inds.shape) + print("target_object_inds", target_object_inds) + + N, P, _ = obj_xyzs.shape + print("S, N, P: {}, {}, {}".format(S, N, P)) + + #################################################### + # S, N, ... + + struct_pose = struct_pose.repeat(1, N, 1, 1) # S, N, 4, 4 + struct_pose = struct_pose.reshape(S * N, 4, 4) # S x N, 4, 4 + + new_obj_xyzs = obj_xyzs.repeat(S, 1, 1, 1) # S, N, P, 3 + current_pc_pose = torch.eye(4).repeat(S, N, 1, 1).to(self.device) # S, N, 4, 4 + current_pc_pose[:, :, :3, 3] = torch.mean(new_obj_xyzs, dim=2) # S, N, 4, 4 + current_pc_pose = current_pc_pose.reshape(S * N, 4, 4) # S x N, 4, 4 + + # optimize xyzrpy + obj_params = torch.zeros((S, N, 6)).to(self.device) + obj_params[:, :, :3] = pc_poses_in_struct[:, :, :3, 3] + obj_params[:, :, 3:] = tra3d.matrix_to_euler_angles(pc_poses_in_struct[:, :, :3, :3], "XYZ") # S, N, 6 + # + # new_obj_xyzs_before_cem, goal_pc_pose_before_cem = move_pc(obj_xyzs, obj_params, struct_pose, current_pc_pose, device) + # + # if visualize: + # print("visualizing rearrangements predicted by the generator") + # visualize_batch_pcs(new_obj_xyzs_before_cem, S, N, P, limit_B=5) + + #################################################### + # rank + + # evaluate in batches + scores = torch.zeros(S).to(self.device) + no_intersection_scores = torch.zeros(S).to(self.device) # the higher the better + num_batches = int(S / B) + if S % B != 0: + num_batches += 1 + for b in range(num_batches): + if b + 1 == num_batches: + cur_batch_idxs_start = b * B + cur_batch_idxs_end = S + else: + cur_batch_idxs_start = b * B + cur_batch_idxs_end = (b + 1) * B + cur_batch_size = cur_batch_idxs_end - cur_batch_idxs_start + + # print("current batch idxs start", cur_batch_idxs_start) + # print("current batch idxs end", cur_batch_idxs_end) + # print("size of the current batch", cur_batch_size) + + batch_obj_params = obj_params[cur_batch_idxs_start: cur_batch_idxs_end] + batch_struct_pose = struct_pose[cur_batch_idxs_start * N: cur_batch_idxs_end * N] + batch_current_pc_pose = current_pc_pose[cur_batch_idxs_start * N:cur_batch_idxs_end * N] + + new_obj_xyzs, _, subsampled_scene_xyz, _, obj_pair_xyzs = \ + move_pc_and_create_scene_new(obj_xyzs, batch_obj_params, batch_struct_pose, batch_current_pc_pose, + target_object_inds, self.device, + return_scene_pts=False, + return_scene_pts_and_pc_idxs=False, + num_scene_pts=False, + normalize_pc=False, + return_pair_pc=True, + num_pair_pc_pts=self.collision_model.data_cfg.num_scene_pts, + normalize_pair_pc=self.collision_model.data_cfg.normalize_pc) + + ####################################### + # predict whether there are pairwise collisions + # if collision_score_weight > 0: + with torch.no_grad(): + _, num_comb, num_pair_pc_pts, _ = obj_pair_xyzs.shape + # obj_pair_xyzs = obj_pair_xyzs.reshape(cur_batch_size * num_comb, num_pair_pc_pts, -1) + collision_logits = self.collision_backbone.forward(obj_pair_xyzs.reshape(cur_batch_size * num_comb, num_pair_pc_pts, -1)) + collision_scores = self.collision_backbone.convert_logits(collision_logits).reshape(cur_batch_size, num_comb) # cur_batch_size, num_comb + + # debug + # for bi, this_obj_pair_xyzs in enumerate(obj_pair_xyzs): + # print("batch id", bi) + # for pi, obj_pair_xyz in enumerate(this_obj_pair_xyzs): + # print("pair", pi) + # # obj_pair_xyzs: 2 * P, 5 + # print("collision score", collision_scores[bi, pi]) + # trimesh.PointCloud(obj_pair_xyz[:, :3].cpu()).show() + + # 1 - mean() since the collision model predicts 1 if there is a collision + no_intersection_scores[cur_batch_idxs_start:cur_batch_idxs_end] = 1 - torch.mean(collision_scores, dim=1) + if visualize: + print("no intersection scores", no_intersection_scores) + # ####################################### + # if discriminator_score_weight > 0: + # # # debug: + # # print(subsampled_scene_xyz.shape) + # # print(subsampled_scene_xyz[0]) + # # trimesh.PointCloud(subsampled_scene_xyz[0, :, :3].cpu().numpy()).show() + # # + # with torch.no_grad(): + # + # # Important: since this discriminator only uses local structure param, takes sentence from the first and last position + # # local_sentence = sentence[:, [0, 4]] + # # local_sentence_pad_mask = sentence_pad_mask[:, [0, 4]] + # # sentence_disc, sentence_pad_mask_disc, position_index_dic = discriminator_inference.dataset.tensorfy_sentence(raw_sentence_discriminator, raw_sentence_pad_mask_discriminator, raw_position_index_discriminator) + # + # sentence_disc = torch.LongTensor( + # [discriminator_tokenizer.tokenize(*i) for i in raw_sentence_discriminator]) + # sentence_pad_mask_disc = torch.LongTensor(raw_sentence_pad_mask_discriminator) + # position_index_dic = torch.LongTensor(raw_position_index_discriminator) + # + # preds = discriminator_model.forward(subsampled_scene_xyz, + # sentence_disc.unsqueeze(0).repeat(cur_batch_size, 1).to(device), + # sentence_pad_mask_disc.unsqueeze(0).repeat(cur_batch_size, + # 1).to(device), + # position_index_dic.unsqueeze(0).repeat(cur_batch_size, 1).to( + # device)) + # # preds = discriminator_model.forward(subsampled_scene_xyz) + # preds = discriminator_model.convert_logits(preds) + # preds = preds["is_circle"] # cur_batch_size, + # scores[cur_batch_idxs_start:cur_batch_idxs_end] = preds + # if visualize: + # print("discriminator scores", scores) + + # scores = scores * discriminator_score_weight + no_intersection_scores * collision_score_weight + scores = no_intersection_scores + sort_idx = torch.argsort(scores).flip(dims=[0])[:num_elite] + elite_obj_params = obj_params[sort_idx] # num_elite, N, 6 + elite_struct_poses = struct_pose.reshape(S, N, 4, 4)[sort_idx] # num_elite, N, 4, 4 + elite_struct_poses = elite_struct_poses.reshape(num_elite * N, 4, 4) # num_elite x N, 4, 4 + elite_scores = scores[sort_idx] + print("elite scores:", elite_scores) + + #################################################### + # # visualize best samples + # num_scene_pts = 4096 # if discriminator_num_scene_pts is None else discriminator_num_scene_pts + # batch_current_pc_pose = current_pc_pose[0: num_elite * N] + # best_new_obj_xyzs, best_goal_pc_pose, best_subsampled_scene_xyz, _, _ = \ + # move_pc_and_create_scene_new(obj_xyzs, elite_obj_params, elite_struct_poses, batch_current_pc_pose, + # target_object_inds, self.device, + # return_scene_pts=True, num_scene_pts=num_scene_pts, normalize_pc=True) + # if visualize: + # print("visualizing elite rearrangements ranked by collision model/discriminator") + # visualize_batch_pcs(best_new_obj_xyzs, num_elite, limit_B=num_elite) + + # num_elite, N, 6 + elite_obj_params = elite_obj_params.reshape(num_elite * N, -1) + pc_poses_in_struct = torch.eye(4).repeat(num_elite * N, 1, 1).to(self.device) + pc_poses_in_struct[:, :3, :3] = tra3d.euler_angles_to_matrix(elite_obj_params[:, 3:], "XYZ") + pc_poses_in_struct[:, :3, 3] = elite_obj_params[:, :3] + pc_poses_in_struct = pc_poses_in_struct.reshape(num_elite, N, 4, 4) # num_elite, N, 4, 4 + + struct_pose = elite_struct_poses.reshape(num_elite, N, 4, 4)[:, 0,].unsqueeze(1) # num_elite, 1, 4, 4 + + return struct_pose, pc_poses_in_struct \ No newline at end of file diff --git a/src/StructDiffusion/language/__init__.py b/src/StructDiffusion/language/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/StructDiffusion/language/__pycache__/__init__.cpython-37.pyc b/src/StructDiffusion/language/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4aed88d9ff03518b98ae5ed16652d45bf0e5e28c Binary files /dev/null and b/src/StructDiffusion/language/__pycache__/__init__.cpython-37.pyc differ diff --git a/src/StructDiffusion/language/__pycache__/__init__.cpython-38.pyc b/src/StructDiffusion/language/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83ba0f5e6c27b15d0de3f53b686d3267c861ddfa Binary files /dev/null and b/src/StructDiffusion/language/__pycache__/__init__.cpython-38.pyc differ diff --git a/src/StructDiffusion/language/__pycache__/tokenizer.cpython-37.pyc b/src/StructDiffusion/language/__pycache__/tokenizer.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09854272d48643a552100b005d109099ba6e11bb Binary files /dev/null and b/src/StructDiffusion/language/__pycache__/tokenizer.cpython-37.pyc differ diff --git a/src/StructDiffusion/language/__pycache__/tokenizer.cpython-38.pyc b/src/StructDiffusion/language/__pycache__/tokenizer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3090fb1a1f4014f8ede1eec9758f552adc21a8cd Binary files /dev/null and b/src/StructDiffusion/language/__pycache__/tokenizer.cpython-38.pyc differ diff --git a/src/StructDiffusion/language/tokenizer.py b/src/StructDiffusion/language/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..9ef678dbfcf20ef2649d3491169b2f9f5517495c --- /dev/null +++ b/src/StructDiffusion/language/tokenizer.py @@ -0,0 +1,541 @@ +import json +import numpy as np +import re + +# def add_pad_to_vocab(vocab): +# new_vocab = {"PAD": 0} +# for k in vocab: +# new_vocab[k] = vocab[k] + 1 +# return new_vocab +# +# +# def combine_vocabs(vocabs, vocab_types): +# new_vocab = {} +# for type, vocab in zip(vocab_types, vocabs): +# for k in vocab: +# new_vocab["{}:{}".format(type, k)] = len(new_vocab) +# return new_vocab +# +# +# def add_token_to_vocab(vocab): +# new_vocab = {"MASK": 0} +# for k in vocab: +# new_vocab[k] = vocab[k] + 1 +# return new_vocab +# +# +# def tokenize_circle_specification(circle_specification): +# tokenized = {} +# # min 0, max 0.5, increment 0.05, 10 discrete values +# tokenized["radius"] = int(circle_specification["radius"] / 0.05) +# +# # min 0, max 1, increment 0.10, 10 discrete values +# tokenized["position_x"] = int(circle_specification["position"][0] / 0.10) +# +# # min -0.5, max 0.5, increment 0.10, 10 discrete values +# tokenized["position_y"] = int(circle_specification["position"][1] / 0.10) +# +# # min -3.14, max 3.14, increment 3.14 / 18, 36 discrete values +# tokenized["rotation"] = int((circle_specification["rotation"][2] + 3.14) / (3.14 / 18)) +# +# uniform_angle_vocab = {"False": 0, "True": 1} +# tokenized["uniform_angle"] = uniform_angle_vocab[circle_specification["uniform_angle"]] +# +# face_center_vocab = {"False": 0, "True": 1} +# tokenized["face_center"] = face_center_vocab[circle_specification["face_center"]] +# +# angle_ratio_vocab = {0.5: 0, 1.0: 1} +# tokenized["angle_ratio"] = angle_ratio_vocab[circle_specification["angle_ratio"]] +# +# # heights min 0.0, max 0.5 +# # volumn min 0.0, max 0.012 +# +# return tokenized +# +# +# def build_vocab(old_vocab_file, new_vocab_file): +# with open(old_vocab_file, "r") as fh: +# vocab_json = json.load(fh) +# +# vocabs = {} +# vocabs["class"] = vocab_json["class_to_idx"] +# vocabs["size"] = vocab_json["size_to_idx"] +# vocabs["color"] = vocab_json["color_to_idx"] +# vocabs["material"] = vocab_json["material_to_idx"] +# vocabs["comparator"] = {"less": 1, "greater": 2, "equal": 3} +# +# vocabs["radius"] = (0.0, 0.5, 10) +# vocabs["position_x"] = (0.0, 1.0, 10) +# vocabs["position_y"] = (-0.5, 0.5, 10) +# vocabs["rotation"] = (-3.14, 3.14, 36) +# vocabs["height"] = (0.0, 0.5, 10) +# vocabs["volumn"] = (0.0, 0.012, 10) +# +# vocabs["uniform_angle"] = {"False": 0, "True": 1} +# vocabs["face_center"] = {"False": 0, "True": 1} +# vocabs["angle_ratio"] = {0.5: 0, 1.0: 1} +# +# with open(new_vocab_file, "w") as fh: +# json.dump(vocabs, fh) + + +class Tokenizer: + """ + We want to build a tokenizer that tokenize words, features, and numbers. + + This tokenizer should also allow us to sample random values. + + For discrete values, we store mapping from the value to an id + For continuous values, we store min, max, and number of bins after discretization + + """ + + def __init__(self, vocab_file): + + self.vocab_file = vocab_file + with open(self.vocab_file, "r") as fh: + self.type_vocabs = json.load(fh) + + self.vocab = {"PAD": 0, "CLS": 1} + self.discrete_types = set() + self.continuous_types = set() + self.build_one_vocab() + + self.object_position_vocabs = {} + self.build_object_position_vocabs() + + def build_one_vocab(self): + print("\nBuild one vacab for everything...") + + for typ, vocab in self.type_vocabs.items(): + if typ == "comparator": + continue + + if typ in ["obj_x", "obj_y", "obj_z", "obj_rr", "obj_rp", "obj_ry", + "struct_x", "struct_y", "struct_z", "struct_rr", "struct_rp", "struct_ry"]: + continue + + if type(vocab) == dict: + self.vocab["{}:{}".format(typ, "MASK")] = len(self.vocab) + + for v in vocab: + assert ":" not in v + self.vocab["{}:{}".format(typ, v)] = len(self.vocab) + self.discrete_types.add(typ) + + elif type(vocab) == tuple or type(vocab) == list: + self.vocab["{}:{}".format(typ, "MASK")] = len(self.vocab) + + for c in self.type_vocabs["comparator"]: + self.vocab["{}:{}".format(typ, c)] = len(self.vocab) + + min_value, max_value, num_bins = vocab + for i in range(num_bins): + self.vocab["{}:{}".format(typ, i)] = len(self.vocab) + self.continuous_types.add(typ) + else: + raise TypeError("The dtype of the vocab cannot be handled: {}".format(vocab)) + + print("The vocab has {} tokens: {}".format(len(self.vocab), self.vocab)) + + def build_object_position_vocabs(self): + print("\nBuild vocabs for object position") + for typ in ["obj_x", "obj_y", "obj_z", "obj_rr", "obj_rp", "obj_ry", + "struct_x", "struct_y", "struct_z", "struct_rr", "struct_rp", "struct_ry"]: + self.object_position_vocabs[typ] = {"PAD": 0, "MASK": 1} + + if typ not in self.type_vocabs: + continue + min_value, max_value, num_bins = self.type_vocabs[typ] + for i in range(num_bins): + self.object_position_vocabs[typ]["{}".format(i)] = len(self.object_position_vocabs[typ]) + print("The {} vocab has {} tokens: {}".format(typ, len(self.object_position_vocabs[typ]), self.object_position_vocabs[typ])) + + def get_object_position_vocab_sizes(self): + return len(self.object_position_vocabs["position_x"]), len(self.object_position_vocabs["position_y"]), len(self.object_position_vocabs["rotation"]) + + def get_vocab_size(self): + return len(self.vocab) + + def tokenize_object_position(self, value, typ): + assert typ in ["obj_x", "obj_y", "obj_z", "obj_rr", "obj_rp", "obj_ry", + "struct_x", "struct_y", "struct_z", "struct_rr", "struct_rp", "struct_ry"] + if value == "MASK" or value == "PAD": + return self.object_position_vocabs[typ][value] + elif value == "IGNORE": + # Important: used to avoid computing loss. -100 is the default ignore_index for NLLLoss + return -100 + else: + min_value, max_value, num_bins = self.type_vocabs[typ] + assert min_value <= value <= max_value, value + dv = min(int((value - min_value) / ((max_value - min_value) / num_bins)), num_bins - 1) + return self.object_position_vocabs[typ]["{}".format(dv)] + + def tokenize(self, value, typ=None): + if value in ["PAD", "CLS"]: + idx = self.vocab[value] + else: + if typ is None: + raise KeyError("Type cannot be None") + + if typ[-2:] == "_c" or typ[-2:] == "_d": + typ = typ[:-2] + + if typ in self.discrete_types: + idx = self.vocab["{}:{}".format(typ, value)] + elif typ in self.continuous_types: + if value == "MASK" or value in self.type_vocabs["comparator"]: + idx = self.vocab["{}:{}".format(typ, "MASK")] + else: + min_value, max_value, num_bins = self.type_vocabs[typ] + assert min_value <= value <= max_value, "type {} value {} exceeds {} and {}".format(typ, value, min_value, max_value) + dv = min(int((value - min_value) / ((max_value - min_value) / num_bins)), num_bins - 1) + # print(value, dv, "{}:{}".format(typ, dv)) + idx = self.vocab["{}:{}".format(typ, dv)] + else: + raise KeyError("Do not recognize the type {} of the given token: {}".format(typ, value)) + return idx + + def get_valid_random_value(self, typ): + """ + Get a random value for the given typ + :param typ: + :return: + """ + if typ[-2:] == "_c" or typ[-2:] == "_d": + typ = typ[-2:] + + candidate_values = [] + for v in self.vocab: + if v in ["PAD", "CLS"]: + continue + ft, fv = v.split(":") + if typ == ft and fv != "MASK" and fv not in self.type_vocabs["comparator"]: + candidate_values.append(v) + assert len(candidate_values) != 0 + typed_v = np.random.choice(candidate_values) + value = typed_v.split(":")[1] + + if typ in self.discrete_types: + return value + elif typ in self.continuous_types: + min_value, max_value, num_bins = self.type_vocabs[typ] + return min_value + ((max_value - min_value) / num_bins) * int(value) + else: + raise KeyError("Do not recognize the type {} of the given token".format(typ)) + + def get_all_values_of_type(self, typ): + """ + Get all values for the given typ + :param typ: + :return: + """ + if typ[-2:] == "_c" or typ[-2:] == "_d": + typ = typ[-2:] + + candidate_values = [] + for v in self.vocab: + if v in ["PAD", "CLS"]: + continue + ft, fv = v.split(":") + if typ == ft and fv != "MASK" and fv not in self.type_vocabs["comparator"]: + candidate_values.append(v) + assert len(candidate_values) != 0 + values = [typed_v.split(":")[1] for typed_v in candidate_values] + + if typ in self.discrete_types: + return values + else: + raise KeyError("Do not recognize the type {} of the given token".format(typ)) + + def convert_to_natural_sentence(self, template_sentence): + + # select objects that are [red, metal] + # select objects that are [larger, taller] than the [], [], [] object + # select objects that have the same [color, material] of the [], [], [] object + + natural_sentence_templates = ["select objects that are {}.", + "select objects that have {} {} {} the {}.", + "select objects that have the same {} as the {}."] + + v, t = template_sentence[0] + if t[-2:] == "_c" or t[-2:] == "_d": + t = t[:-2] + + if v != "MASK" and t in self.discrete_types: + natural_sentence_template = natural_sentence_templates[0] + if t == "class": + natural_sentence = natural_sentence_template.format(re.findall(r'[A-Z](?:[a-z]+|[A-Z]*(?=[A-Z]|$))', v)[0].lower()) + else: + natural_sentence = natural_sentence_template.format(v) + else: + anchor_obj_properties = [] + class_reference = None + for token in template_sentence[1:]: + if token[0] != "PAD": + if token[1] == "class": + class_reference = token[0] + else: + anchor_obj_properties.append(token[0]) + # order the properties + anchor_obj_des = ", ".join(anchor_obj_properties) + if class_reference is None: + anchor_obj_des += " object" + else: + anchor_obj_des += " {}".format(re.findall(r'[A-Z](?:[a-z]+|[A-Z]*(?=[A-Z]|$))', class_reference)[0].lower()) + + if v == "MASK": + natural_sentence_template = natural_sentence_templates[2] + anchor_type = t + natural_sentence = natural_sentence_template.format(anchor_type, anchor_obj_des) + elif t in self.continuous_types: + natural_sentence_template = natural_sentence_templates[1] + if v == "equal": + jun = "as" + else: + jun = "than" + natural_sentence = natural_sentence_template.format(v, t, jun, anchor_obj_des) + else: + raise NotImplementedError + + return natural_sentence + + def prepare_grounding_reference(self): + goal = {"rearrange": {"features": []}, + "anchor": {"features": []}} + discrete_type = ["class", "material", "color"] + continuous_type = ["volumn", "height"] + + print("#"*50) + print("Preparing referring expression") + + refer_type = verify_input("direct (1) or relational reference (2)? ", [1, 2], int) + if refer_type == 1: + + # 1. no anchor + t = verify_input("desired type: ", discrete_type, None) + v = verify_input("desired value: ", self.get_all_values_of_type(t), None) + + goal["rearrange"]["features"].append({"comparator": None, "type": t, "value": v}) + + elif refer_type == 2: + + value_type = verify_input("discrete (1) or continuous relational reference (2)? ", [1, 2], int) + if value_type == 1: + t = verify_input("desired type: ", discrete_type, None) + # 2. discrete + goal["rearrange"]["features"].append({"comparator": None, "type": t, "value": None}) + elif value_type == 2: + comp = verify_input("desired comparator: ", list(self.type_vocabs["comparator"].keys()), None) + t = verify_input("desired type: ", continuous_type, None) + # 3. continuous + goal["rearrange"]["features"].append({"comparator": comp, "type": t, "value": None}) + + num_f = verify_input("desired number of features for the anchor object: ", [1, 2, 3], int) + for i in range(num_f): + t = verify_input("desired type: ", discrete_type, None) + v = verify_input("desired value: ", self.get_all_values_of_type(t), None) + goal["anchor"]["features"].append({"comparator": None, "type": t, "value": v}) + + return goal + + def convert_structure_params_to_natural_language(self, sentence): + + # ('circle', 'shape'), (-1.3430555575431449, 'rotation'), (0.3272675147405848, 'position_x'), (-0.03104362197706456, 'position_y'), (0.04674859577847633, 'radius') + + shape = None + x = None + y = None + rot = None + size = None + + for param in sentence: + if param[0] == "PAD": + continue + + v, t = param + if t == "shape": + shape = v + elif t == "position_x": + dv = self.discretize(v, t) + if dv == 0: + x = "bottom" + elif dv == 1: + x = "middle" + elif dv == 2: + x = "top" + else: + raise KeyError("key {} not found in {}".format(v, self.type_vocabs[t])) + elif t == "position_y": + dv = self.discretize(v, t) + if dv == 0: + y = "right" + elif dv == 1: + y = "center" + elif dv == 2: + y = "left" + else: + raise KeyError("key {} not found in {}".format(v, self.type_vocabs[t])) + elif t == "radius": + dv = self.discretize(v, t) + if dv == 0: + size = "small" + elif dv == 1: + size = "medium" + elif dv == 2: + size = "large" + else: + raise KeyError("key {} not found in {}".format(v, self.type_vocabs[t])) + elif t == "rotation": + dv = self.discretize(v, t) + if dv == 0: + rot = "north" + elif dv == 1: + rot = "east" + elif dv == 2: + rot = "south" + elif dv == 3: + rot = "west" + else: + raise KeyError("key {} not found in {}".format(v, self.type_vocabs[t])) + + natural_sentence = "" # "{} {} in the {} {} of the table facing {}".format(size, shape, x, y, rot) + + if size: + natural_sentence += "{}".format(size) + if shape: + natural_sentence += " {}".format(shape) + if x: + natural_sentence += " in the {}".format(x) + if y: + natural_sentence += " {} of the table".format(y) + if rot: + natural_sentence += " facing {}".format(rot) + + natural_sentence = natural_sentence.strip() + + return natural_sentence + + def convert_structure_params_to_type_value_tuple(self, sentence): + + # ('circle', 'shape'), (-1.3430555575431449, 'rotation'), (0.3272675147405848, 'position_x'), (-0.03104362197706456, 'position_y'), (0.04674859577847633, 'radius') + + shape = None + x = None + y = None + rot = None + size = None + + for param in sentence: + if param[0] == "PAD": + continue + + v, t = param + if t == "shape": + shape = v + elif t == "position_x": + dv = self.discretize(v, t) + if dv == 0: + x = "bottom" + elif dv == 1: + x = "middle" + elif dv == 2: + x = "top" + else: + raise KeyError("key {} not found in {}".format(v, self.type_vocabs[t])) + elif t == "position_y": + dv = self.discretize(v, t) + if dv == 0: + y = "right" + elif dv == 1: + y = "center" + elif dv == 2: + y = "left" + else: + raise KeyError("key {} not found in {}".format(v, self.type_vocabs[t])) + elif t == "radius": + dv = self.discretize(v, t) + if dv == 0: + size = "small" + elif dv == 1: + size = "medium" + elif dv == 2: + size = "large" + else: + raise KeyError("key {} not found in {}".format(v, self.type_vocabs[t])) + elif t == "rotation": + dv = self.discretize(v, t) + if dv == 0: + rot = "north" + elif dv == 1: + rot = "east" + elif dv == 2: + rot = "south" + elif dv == 3: + rot = "west" + else: + raise KeyError("key {} not found in {}".format(v, self.type_vocabs[t])) + + # rotation, shape, size, x, y + type_value_tuple_init = [("rotation", rot), ("shape", shape), ("size", size), ("x", x), ("y", y)] + type_value_tuple = [] + for type_value in type_value_tuple_init: + if type_value[1] is not None: + type_value_tuple.append(type_value) + + type_value_tuple = tuple(sorted(type_value_tuple)) + return type_value_tuple + + def discretize(self, v, t): + min_value, max_value, num_bins = self.type_vocabs[t] + assert min_value <= v <= max_value, "type {} value {} exceeds {} and {}".format(t, v, min_value, max_value) + dv = min(int((v - min_value) / ((max_value - min_value) / num_bins)), num_bins - 1) + return dv + + +class ContinuousTokenizer: + """ + This tokenizer is for testing not discretizing structure parameters + """ + + def __init__(self): + + print("WARNING: Current continous tokenizer does not support multiple shapes") + + self.continuous_types = ["rotation", "position_x", "position_y", "radius"] + self.discrete_types = ["shape"] + + def tokenize(self, value, typ=None): + if value == "PAD": + idx = 0.0 + else: + if typ is None: + raise KeyError("Type cannot be None") + elif typ in self.discrete_types: + idx = 1.0 + elif typ in self.continuous_types: + idx = value + else: + raise KeyError("Do not recognize the type {} of the given token: {}".format(typ, value)) + return idx + + +if __name__ == "__main__": + tokenizer = Tokenizer("/home/weiyu/data_drive/data_new_objects/type_vocabs_coarse.json") + # print(tokenizer.get_all_values_of_type("class")) + # print(tokenizer.get_all_values_of_type("color")) + # print(tokenizer.get_all_values_of_type("material")) + # + # for type in tokenizer.type_vocabs: + # print(type, tokenizer.type_vocabs[type]) + + tokenizer.prepare_grounding_reference() + + # for i in range(100): + # types = list(tokenizer.continuous_types) + list(tokenizer.discrete_types) + # for t in types: + # v = tokenizer.get_valid_random_value(t) + # print(v) + # print(tokenizer.tokenize(v, t)) + + # build_vocab("/home/weiyu/data_drive/examples_v4/leonardo/vocab.json", "/home/weiyu/data_drive/examples_v4/leonardo/type_vocabs.json") \ No newline at end of file diff --git a/src/StructDiffusion/models/__init__.py b/src/StructDiffusion/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/StructDiffusion/models/__pycache__/__init__.cpython-37.pyc b/src/StructDiffusion/models/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fcd68754e149fcfa74b798e7cede414dc3d0ab00 Binary files /dev/null and b/src/StructDiffusion/models/__pycache__/__init__.cpython-37.pyc differ diff --git a/src/StructDiffusion/models/__pycache__/__init__.cpython-38.pyc b/src/StructDiffusion/models/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4f6076a0eef068852f173506d0a538d52340a1c Binary files /dev/null and b/src/StructDiffusion/models/__pycache__/__init__.cpython-38.pyc differ diff --git a/src/StructDiffusion/models/__pycache__/encoders.cpython-37.pyc b/src/StructDiffusion/models/__pycache__/encoders.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec884bc615594979b02e0e163c1e5546ea83b429 Binary files /dev/null and b/src/StructDiffusion/models/__pycache__/encoders.cpython-37.pyc differ diff --git a/src/StructDiffusion/models/__pycache__/encoders.cpython-38.pyc b/src/StructDiffusion/models/__pycache__/encoders.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59681495d38034504267993477bebcf776d852ab Binary files /dev/null and b/src/StructDiffusion/models/__pycache__/encoders.cpython-38.pyc differ diff --git a/src/StructDiffusion/models/__pycache__/models.cpython-37.pyc b/src/StructDiffusion/models/__pycache__/models.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54ff2c4a74799cc7866b831b395d4927544da19b Binary files /dev/null and b/src/StructDiffusion/models/__pycache__/models.cpython-37.pyc differ diff --git a/src/StructDiffusion/models/__pycache__/models.cpython-38.pyc b/src/StructDiffusion/models/__pycache__/models.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf835600dc9bcbe557378578187fc3b8a050759f Binary files /dev/null and b/src/StructDiffusion/models/__pycache__/models.cpython-38.pyc differ diff --git a/src/StructDiffusion/models/__pycache__/pl_models.cpython-37.pyc b/src/StructDiffusion/models/__pycache__/pl_models.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4aac35dc9aa36ed0b2a9dbeaf0e016e5c7ecd779 Binary files /dev/null and b/src/StructDiffusion/models/__pycache__/pl_models.cpython-37.pyc differ diff --git a/src/StructDiffusion/models/__pycache__/pl_models.cpython-38.pyc b/src/StructDiffusion/models/__pycache__/pl_models.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9dcbc7fb1943c708a3ae4476d1b9d15f78fdb50c Binary files /dev/null and b/src/StructDiffusion/models/__pycache__/pl_models.cpython-38.pyc differ diff --git a/src/StructDiffusion/models/__pycache__/point_transformer.cpython-37.pyc b/src/StructDiffusion/models/__pycache__/point_transformer.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08eed8a045e62263826d431993619fe5231e7703 Binary files /dev/null and b/src/StructDiffusion/models/__pycache__/point_transformer.cpython-37.pyc differ diff --git a/src/StructDiffusion/models/__pycache__/point_transformer.cpython-38.pyc b/src/StructDiffusion/models/__pycache__/point_transformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4c2e685a196894089b6bdc038444f2e34f39507 Binary files /dev/null and b/src/StructDiffusion/models/__pycache__/point_transformer.cpython-38.pyc differ diff --git a/src/StructDiffusion/models/__pycache__/point_transformer_large.cpython-37.pyc b/src/StructDiffusion/models/__pycache__/point_transformer_large.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5de3117d4b31d989f6aa0ed3e9775e072ccf1619 Binary files /dev/null and b/src/StructDiffusion/models/__pycache__/point_transformer_large.cpython-37.pyc differ diff --git a/src/StructDiffusion/models/__pycache__/point_transformer_large.cpython-38.pyc b/src/StructDiffusion/models/__pycache__/point_transformer_large.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4422f79019f73c9c9b64a3c0f30675ee5cfea6c2 Binary files /dev/null and b/src/StructDiffusion/models/__pycache__/point_transformer_large.cpython-38.pyc differ diff --git a/src/StructDiffusion/models/encoders.py b/src/StructDiffusion/models/encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..6c46cfc0dde57e8d1ef916494ffbf4177047b878 --- /dev/null +++ b/src/StructDiffusion/models/encoders.py @@ -0,0 +1,97 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class SinusoidalPositionEmbeddings(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, time): + device = time.device + half_dim = self.dim // 2 + embeddings = math.log(10000) / (half_dim - 1) + embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) + embeddings = time[:, None] * embeddings[None, :] + embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) + return embeddings + + +class DropoutSampler(torch.nn.Module): + def __init__(self, num_features, num_outputs, dropout_rate = 0.5): + super(DropoutSampler, self).__init__() + self.linear = nn.Linear(num_features, num_features) + self.linear2 = nn.Linear(num_features, num_features) + self.predict = nn.Linear(num_features, num_outputs) + self.num_features = num_features + self.num_outputs = num_outputs + self.dropout_rate = dropout_rate + + def forward(self, x): + x = F.relu(self.linear(x)) + if self.dropout_rate > 0: + x = F.dropout(x, self.dropout_rate) + x = F.relu(self.linear2(x)) + return self.predict(x) + + +class EncoderMLP(torch.nn.Module): + def __init__(self, in_dim, out_dim, pt_dim=3, uses_pt=True): + super(EncoderMLP, self).__init__() + self.uses_pt = uses_pt + self.output = out_dim + d5 = int(in_dim) + d6 = int(2 * self.output) + d7 = self.output + self.encode_position = nn.Sequential( + nn.Linear(pt_dim, in_dim), + nn.LayerNorm(in_dim), + nn.ReLU(), + nn.Linear(in_dim, in_dim), + nn.LayerNorm(in_dim), + nn.ReLU(), + ) + d5 = 2 * in_dim if self.uses_pt else in_dim + self.fc_block = nn.Sequential( + nn.Linear(int(d5), d6), + nn.LayerNorm(int(d6)), + nn.ReLU(), + nn.Linear(int(d6), d6), + nn.LayerNorm(int(d6)), + nn.ReLU(), + nn.Linear(d6, d7)) + + def forward(self, x, pt=None): + if self.uses_pt: + if pt is None: raise RuntimeError('did not provide pt') + y = self.encode_position(pt) + x = torch.cat([x, y], dim=-1) + return self.fc_block(x) + + +class MeanEncoder(torch.nn.Module): + def __init__(self, input_channels=3, use_xyz=True, output=512, scale=0.04, factor=1): + super(MeanEncoder, self).__init__() + self.uses_rgb = False + self.dim = 3 + + def forward(self, xyz, f=None): + + # Fix shape + if f is not None: + if len(f.shape) < 3: + f = f.transpose(0,1).contiguous() + f = f[None] + else: + f = f.transpose(1,2).contiguous() + if len(xyz.shape) == 3: + center = torch.mean(xyz, dim=1) + elif len(xyz.shape) == 2: + center = torch.mean(xyz, dim=0) + else: + raise RuntimeError('not sure what to do with points of shape ' + str(xyz.shape)) + assert(xyz.shape[-1]) == 3 + assert(center.shape[-1]) == 3 + return center, center \ No newline at end of file diff --git a/src/StructDiffusion/models/models.py b/src/StructDiffusion/models/models.py new file mode 100644 index 0000000000000000000000000000000000000000..1c8db381307eb8dc64e55308e021b73510ed52ad --- /dev/null +++ b/src/StructDiffusion/models/models.py @@ -0,0 +1,184 @@ +import torch +import torch.nn as nn +import math +import torch.nn.functional as F + +from StructDiffusion.models.encoders import EncoderMLP, DropoutSampler, SinusoidalPositionEmbeddings +from torch.nn import TransformerEncoder, TransformerEncoderLayer +from StructDiffusion.models.point_transformer import PointTransformerEncoderSmall +from StructDiffusion.models.point_transformer_large import PointTransformerCls + + +class TransformerDiffusionModel(torch.nn.Module): + + def __init__(self, vocab_size, + # transformer params + encoder_input_dim=256, + num_attention_heads=8, encoder_hidden_dim=16, encoder_dropout=0.1, encoder_activation="relu", + encoder_num_layers=8, + # output head params + structure_dropout=0.0, object_dropout=0.0, + # pc encoder params + ignore_rgb=False, pc_emb_dim=256, posed_pc_emb_dim=80, + pose_emb_dim=80, + max_seq_size=7, max_token_type_size=4, + seq_pos_emb_dim=8, seq_type_emb_dim=8, + word_emb_dim=160, + time_emb_dim=80, + use_virtual_structure_frame=True, + ): + super(TransformerDiffusionModel, self).__init__() + + assert posed_pc_emb_dim + pose_emb_dim == word_emb_dim + assert encoder_input_dim == word_emb_dim + time_emb_dim + seq_pos_emb_dim + seq_type_emb_dim + + # 3D translation + 6D rotation + action_dim = 3 + 6 + + # default: + # 256 = 80 (point cloud) + 80 (position) + 80 (time) + 8 (position idx) + 8 (token idx) + # 256 = 160 (word embedding) + 80 (time) + 8 (position idx) + 8 (token idx) + + # PC + self.ignore_rgb = ignore_rgb + if ignore_rgb: + self.pc_encoder = PointTransformerEncoderSmall(output_dim=pc_emb_dim, input_dim=3, mean_center=True) + else: + self.pc_encoder = PointTransformerEncoderSmall(output_dim=pc_emb_dim, input_dim=6, mean_center=True) + self.posed_pc_encoder = EncoderMLP(pc_emb_dim, posed_pc_emb_dim, uses_pt=True) + + # for virtual structure frame + self.use_virtual_structure_frame = use_virtual_structure_frame + if use_virtual_structure_frame: + self.virtual_frame_embed = nn.Parameter(torch.randn(1, 1, posed_pc_emb_dim)) # B, 1, posed_pc_emb_dim + + # for language + self.word_embeddings = torch.nn.Embedding(vocab_size, word_emb_dim, padding_idx=0) + + # for diffusion + self.pose_encoder = nn.Sequential(nn.Linear(action_dim, pose_emb_dim)) + self.time_embeddings = nn.Sequential( + SinusoidalPositionEmbeddings(time_emb_dim), + nn.Linear(time_emb_dim, time_emb_dim), + nn.GELU(), + nn.Linear(time_emb_dim, time_emb_dim), + ) + + # for transformer + self.position_embeddings = torch.nn.Embedding(max_seq_size, seq_pos_emb_dim) + self.type_embeddings = torch.nn.Embedding(max_token_type_size, seq_type_emb_dim) + + encoder_layers = TransformerEncoderLayer(encoder_input_dim, num_attention_heads, + encoder_hidden_dim, encoder_dropout, encoder_activation) + self.encoder = TransformerEncoder(encoder_layers, encoder_num_layers) + + self.struct_head = DropoutSampler(encoder_input_dim, action_dim, dropout_rate=structure_dropout) + self.obj_head = DropoutSampler(encoder_input_dim, action_dim, dropout_rate=object_dropout) + + def encode_posed_pc(self, pcs, batch_size, num_objects): + if self.ignore_rgb: + center_xyz, x = self.pc_encoder(pcs[:, :, :3], None) + else: + center_xyz, x = self.pc_encoder(pcs[:, :, :3], pcs[:, :, 3:]) + posed_pc_embed = self.posed_pc_encoder(x, center_xyz) + posed_pc_embed = posed_pc_embed.reshape(batch_size, num_objects, -1) + return posed_pc_embed + + def forward(self, t, pcs, sentence, poses, type_index, position_index, pad_mask): + + batch_size, num_objects, num_pts, _ = pcs.shape + _, num_poses, _ = poses.shape + _, sentence_len = sentence.shape + _, total_len = type_index.shape + + pcs = pcs.reshape(batch_size * num_objects, num_pts, -1) + posed_pc_embed = self.encode_posed_pc(pcs, batch_size, num_objects) + + pose_embed = self.pose_encoder(poses) + + if self.use_virtual_structure_frame: + virtual_frame_embed = self.virtual_frame_embed.repeat(batch_size, 1, 1) + posed_pc_embed = torch.cat([virtual_frame_embed, posed_pc_embed], dim=1) + tgt_obj_embed = torch.cat([pose_embed, posed_pc_embed], dim=-1) + + ######################### + sentence_embed = self.word_embeddings(sentence) + + ######################### + + # transformer time dim: sentence, struct, obj + # transformer feat dim: obj pc + pose / word, time, token type, position + + time_embed = self.time_embeddings(t) # B, dim + time_embed = time_embed.unsqueeze(1).repeat(1, total_len, 1) # B, L, dim + + position_embed = self.position_embeddings(position_index) + type_embed = self.type_embeddings(type_index) + + tgt_sequence_encode = torch.cat([sentence_embed, tgt_obj_embed], dim=1) + tgt_sequence_encode = torch.cat([tgt_sequence_encode, time_embed, position_embed, type_embed], dim=-1) + + tgt_pad_mask = pad_mask + + ######################### + # sequence_encode: [batch size, sequence_length, encoder input dimension] + # input to transformer needs to have dimenion [sequence_length, batch size, encoder input dimension] + tgt_sequence_encode = tgt_sequence_encode.transpose(1, 0) + # convert to bool + tgt_pad_mask = (tgt_pad_mask == 1) + # encode: [sequence_length, batch_size, embedding_size] + encode = self.encoder(tgt_sequence_encode, src_key_padding_mask=tgt_pad_mask) + encode = encode.transpose(1, 0) + ######################### + + target_encodes = encode[:, -num_poses:, :] + if self.use_virtual_structure_frame: + obj_encodes = target_encodes[:, 1:, :] + pred_obj_poses = self.obj_head(obj_encodes) # B, N, 3 + 6 + struct_encode = encode[:, 0, :].unsqueeze(1) + # use a different sampler for struct prediction since it should have larger variance than object predictions + pred_struct_pose = self.struct_head(struct_encode) # B, 1, 3 + 6 + pred_poses = torch.cat([pred_struct_pose, pred_obj_poses], dim=1) + else: + pred_poses = self.obj_head(target_encodes) # B, N, 3 + 6 + + assert pred_poses.shape == poses.shape + + return pred_poses + + +class FocalLoss(nn.Module): + def __init__(self, gamma=2, alpha=.25): + super(FocalLoss, self).__init__() + # self.alpha = torch.tensor([alpha, 1-alpha]) + self.gamma = gamma + + def forward(self, inputs, targets): + BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none') + pt = torch.exp(-BCE_loss) + # targets = targets.type(torch.long) + # at = self.alpha.gather(0, targets.data.view(-1)) + # F_loss = at*(1-pt)**self.gamma * BCE_loss + F_loss = (1 - pt)**self.gamma * BCE_loss + return F_loss.mean() + + +class PCTDiscriminator(torch.nn.Module): + + def __init__(self, max_num_objects, include_env_pc=False, pct_random_sampling=False): + + super(PCTDiscriminator, self).__init__() + + # input_dim: xyz + one hot for each object + if include_env_pc: + self.classifier = PointTransformerCls(input_dim=max_num_objects + 1 + 3, output_dim=1, use_random_sampling=pct_random_sampling) + else: + self.classifier = PointTransformerCls(input_dim=max_num_objects + 3, output_dim=1, use_random_sampling=pct_random_sampling) + + def forward(self, scene_xyz): + label = self.classifier(scene_xyz) + return label + + def convert_logits(self, logits): + return torch.sigmoid(logits) + diff --git a/src/StructDiffusion/models/pl_models.py b/src/StructDiffusion/models/pl_models.py new file mode 100644 index 0000000000000000000000000000000000000000..237ef12741758dbabe84bae424fd33d5eb50f5c2 --- /dev/null +++ b/src/StructDiffusion/models/pl_models.py @@ -0,0 +1,136 @@ +import torch +import torch.nn.functional as F +import pytorch_lightning as pl +from StructDiffusion.models.models import TransformerDiffusionModel, PCTDiscriminator, FocalLoss + +from StructDiffusion.diffusion.noise_schedule import NoiseSchedule, q_sample +from StructDiffusion.diffusion.pose_conversion import get_diffusion_variables_from_H, get_diffusion_variables_from_9D_actions + + +class ConditionalPoseDiffusionModel(pl.LightningModule): + + def __init__(self, vocab_size, model_cfg, loss_cfg, noise_scheduler_cfg, optimizer_cfg): + super().__init__() + self.save_hyperparameters() + + self.model = TransformerDiffusionModel(vocab_size, **model_cfg) + + self.noise_schedule = NoiseSchedule(**noise_scheduler_cfg) + + self.loss_type = loss_cfg.type + + self.optimizer_cfg = optimizer_cfg + self.configure_optimizers() + + self.batch_size = None + + def forward(self, batch): + + # input + pcs = batch["pcs"] + B = pcs.shape[0] + self.batch_size = B + sentence = batch["sentence"] + goal_poses = batch["goal_poses"] + type_index = batch["type_index"] + position_index = batch["position_index"] + pad_mask = batch["pad_mask"] + + t = torch.randint(0, self.noise_schedule.timesteps, (B,), dtype=torch.long).to(self.device) + + # -------------- + x_start = get_diffusion_variables_from_H(goal_poses) + noise = torch.randn_like(x_start, device=self.device) + x_noisy = q_sample(x_start=x_start, t=t, noise_schedule=self.noise_schedule, noise=noise) + + predicted_noise = self.model.forward(t, pcs, sentence, x_noisy, type_index, position_index, pad_mask) + + # important: skip computing loss for masked positions + num_poses = goal_poses.shape[1] # B, N, 4, 4 + pose_pad_mask = pad_mask[:, -num_poses:] + keep_mask = (pose_pad_mask == 0) + noise = noise[keep_mask] # dim: number of positions that need loss calculation + predicted_noise = predicted_noise[keep_mask] + + return noise, predicted_noise + + def compute_loss(self, noise, predicted_noise, prefix="train/"): + if self.loss_type == 'l1': + loss = F.l1_loss(noise, predicted_noise) + elif self.loss_type == 'l2': + loss = F.mse_loss(noise, predicted_noise) + elif self.loss_type == "huber": + loss = F.smooth_l1_loss(noise, predicted_noise) + else: + raise NotImplementedError() + + self.log(prefix + "loss", loss, prog_bar=True, batch_size=self.batch_size) + return loss + + def training_step(self, batch, batch_idx): + noise, pred_noise = self.forward(batch) + loss = self.compute_loss(noise, pred_noise, prefix="train/") + return loss + + def validation_step(self, batch, batch_idx): + noise, pred_noise = self.forward(batch) + loss = self.compute_loss(noise, pred_noise, prefix="val/") + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.optimizer_cfg.lr, weight_decay=self.optimizer_cfg.weight_decay) # 1e-5 + return optimizer + + +class PairwiseCollisionModel(pl.LightningModule): + + def __init__(self, model_cfg, loss_cfg, optimizer_cfg, data_cfg): + super().__init__() + self.save_hyperparameters() + + self.model = PCTDiscriminator(**model_cfg) + + self.loss_cfg = loss_cfg + self.loss = None + self.configure_loss() + + self.optimizer_cfg = optimizer_cfg + self.configure_optimizers() + + # this is stored, because some of the data parameters affect the model behavior + self.data_cfg = data_cfg + + def forward(self, batch): + label = batch["label"] + predicted_label = self.model.forward(batch["scene_xyz"]) + return label, predicted_label + + def compute_loss(self, label, predicted_label, prefix="train/"): + if self.loss_cfg.type == "MSE": + predicted_label = torch.sigmoid(predicted_label) + loss = self.loss(predicted_label, label) + self.log(prefix + "loss", loss, prog_bar=True) + return loss + + def training_step(self, batch, batch_idx): + label, predicted_label = self.forward(batch) + loss = self.compute_loss(label, predicted_label, prefix="train/") + return loss + + def validation_step(self, batch, batch_idx): + label, predicted_label = self.forward(batch) + loss = self.compute_loss(label, predicted_label, prefix="val/") + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.optimizer_cfg.lr, weight_decay=self.optimizer_cfg.weight_decay) # 1e-5 + return optimizer + + def configure_loss(self): + if self.loss_cfg.type == "Focal": + print("use focal loss with gamma {}".format(self.loss_cfg.focal_gamma)) + self.loss = FocalLoss(gamma=self.loss_cfg.focal_gamma) + elif self.loss_cfg.type == "MSE": + print("use regression L2 loss") + self.loss = torch.nn.MSELoss() + elif self.loss_cfg.type == "BCE": + print("use standard BCE logit loss") + self.loss = torch.nn.BCEWithLogitsLoss(reduction="mean") \ No newline at end of file diff --git a/src/StructDiffusion/models/point_transformer.py b/src/StructDiffusion/models/point_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..08d8f3ac790b8811dc7352842e05b3caa5fd7756 --- /dev/null +++ b/src/StructDiffusion/models/point_transformer.py @@ -0,0 +1,208 @@ +import torch +import torch.nn as nn +from StructDiffusion.utils.pointnet import farthest_point_sample, index_points, square_distance + +# adapted from https://github.com/qq456cvb/Point-Transformers + + +def sample_and_group(npoint, nsample, xyz, points): + B, N, C = xyz.shape + S = npoint + + fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint] + + new_xyz = index_points(xyz, fps_idx) + new_points = index_points(points, fps_idx) + + dists = square_distance(new_xyz, xyz) # B x npoint x N + idx = dists.argsort()[:, :, :nsample] # B x npoint x K + + grouped_points = index_points(points, idx) + grouped_points_norm = grouped_points - new_points.view(B, S, 1, -1) + new_points = torch.cat([grouped_points_norm, new_points.view(B, S, 1, -1).repeat(1, 1, nsample, 1)], dim=-1) + return new_xyz, new_points + + +class Local_op(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm1d(out_channels) + self.relu = nn.ReLU() + + def forward(self, x): + b, n, s, d = x.size() # torch.Size([32, 512, 32, 6]) + x = x.permute(0, 1, 3, 2) + x = x.reshape(-1, d, s) + batch_size, _, N = x.size() + x = self.relu(self.bn1(self.conv1(x))) # B, D, N + x = torch.max(x, 2)[0] + x = x.view(batch_size, -1) + x = x.reshape(b, n, -1).permute(0, 2, 1) + return x + + +class SA_Layer(nn.Module): + def __init__(self, channels): + super().__init__() + self.q_conv = nn.Conv1d(channels, channels // 4, 1, bias=False) + self.k_conv = nn.Conv1d(channels, channels // 4, 1, bias=False) + self.q_conv.weight = self.k_conv.weight + self.v_conv = nn.Conv1d(channels, channels, 1) + self.trans_conv = nn.Conv1d(channels, channels, 1) + self.after_norm = nn.BatchNorm1d(channels) + self.act = nn.ReLU() + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x): + x_q = self.q_conv(x).permute(0, 2, 1) # b, n, c + x_k = self.k_conv(x)# b, c, n + x_v = self.v_conv(x) + energy = x_q @ x_k # b, n, n + attention = self.softmax(energy) + attention = attention / (1e-9 + attention.sum(dim=1, keepdims=True)) + x_r = x_v @ attention # b, c, n + x_r = self.act(self.after_norm(self.trans_conv(x - x_r))) + x = x + x_r + return x + + +class StackedAttention(nn.Module): + def __init__(self, channels=64): + super().__init__() + self.conv1 = nn.Conv1d(channels, channels, kernel_size=1, bias=False) + self.conv2 = nn.Conv1d(channels, channels, kernel_size=1, bias=False) + + self.bn1 = nn.BatchNorm1d(channels) + self.bn2 = nn.BatchNorm1d(channels) + + self.sa1 = SA_Layer(channels) + self.sa2 = SA_Layer(channels) + + self.relu = nn.ReLU() + + def forward(self, x): + # + # b, 3, npoint, nsample + # conv2d 3 -> 128 channels 1, 1 + # b * npoint, c, nsample + # permute reshape + batch_size, _, N = x.size() + + x = self.relu(self.bn1(self.conv1(x))) # B, D, N + x = self.relu(self.bn2(self.conv2(x))) + + x1 = self.sa1(x) + x2 = self.sa2(x1) + + x = torch.cat((x1, x2), dim=1) + + return x + + +class PointTransformerEncoderSmall(nn.Module): + + def __init__(self, output_dim=256, input_dim=6, mean_center=True): + super(PointTransformerEncoderSmall, self).__init__() + + self.mean_center = mean_center + + # map the second dim of the input from input_dim to 64 + self.conv1 = nn.Conv1d(input_dim, 64, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm1d(64) + self.gather_local_0 = Local_op(in_channels=128, out_channels=64) + self.gather_local_1 = Local_op(in_channels=128, out_channels=64) + self.pt_last = StackedAttention(channels=64) + + self.relu = nn.ReLU() + self.conv_fuse = nn.Sequential(nn.Conv1d(192, 256, kernel_size=1, bias=False), + nn.BatchNorm1d(256), + nn.LeakyReLU(negative_slope=0.2)) + + self.linear1 = nn.Linear(256, 256, bias=False) + self.bn6 = nn.BatchNorm1d(256) + self.dp1 = nn.Dropout(p=0.5) + self.linear2 = nn.Linear(256, 256) + + def forward(self, xyz, f=None): + # xyz: B, N, 3 + # f: B, N, D + center = torch.mean(xyz, dim=1) + if self.mean_center: + xyz = xyz - center.view(-1, 1, 3).repeat(1, xyz.shape[1], 1) + if f is None: + x = self.pct(xyz) + else: + x = self.pct(torch.cat([xyz, f], dim=2)) # B, output_dim + + return center, x + + def pct(self, x): + + # x: B, N, D + xyz = x[..., :3] + x = x.permute(0, 2, 1) + batch_size, _, _ = x.size() + x = self.relu(self.bn1(self.conv1(x))) # B, D, N + x = x.permute(0, 2, 1) + new_xyz, new_feature = sample_and_group(npoint=128, nsample=32, xyz=xyz, points=x) + feature_0 = self.gather_local_0(new_feature) + feature = feature_0.permute(0, 2, 1) # B, nsamples, D + new_xyz, new_feature = sample_and_group(npoint=32, nsample=16, xyz=new_xyz, points=feature) + feature_1 = self.gather_local_1(new_feature) # B, D, nsamples + + x = self.pt_last(feature_1) # B, D * 2, nsamples + x = torch.cat([x, feature_1], dim=1) # B, D * 3, nsamples + x = self.conv_fuse(x) + x = torch.max(x, 2)[0] + x = x.view(batch_size, -1) + + x = self.relu(self.bn6(self.linear1(x))) + x = self.dp1(x) + x = self.linear2(x) + + return x + + +class SampleAndGroup(nn.Module): + + def __init__(self, output_dim=64, input_dim=6, mean_center=True, npoints=(128, 32), nsamples=(32, 16)): + super(SampleAndGroup, self).__init__() + + self.mean_center = mean_center + self.npoints = npoints + self.nsamples = nsamples + + # map the second dim of the input from input_dim to 64 + self.conv1 = nn.Conv1d(input_dim, output_dim, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm1d(output_dim) + self.gather_local_0 = Local_op(in_channels=output_dim * 2, out_channels=output_dim) + self.gather_local_1 = Local_op(in_channels=output_dim * 2, out_channels=output_dim) + self.relu = nn.ReLU() + + def forward(self, xyz, f): + # xyz: B, N, 3 + # f: B, N, D + center = torch.mean(xyz, dim=1) + if self.mean_center: + xyz = xyz - center.view(-1, 1, 3).repeat(1, xyz.shape[1], 1) + x = self.sg(torch.cat([xyz, f], dim=2)) # B, nsamples, output_dim + + return center, x + + def sg(self, x): + + # x: B, N, D + xyz = x[..., :3] + x = x.permute(0, 2, 1) + batch_size, _, _ = x.size() + x = self.relu(self.bn1(self.conv1(x))) # B, D, N + x = x.permute(0, 2, 1) + new_xyz, new_feature = sample_and_group(npoint=self.npoints[0], nsample=self.nsamples[0], xyz=xyz, points=x) + feature_0 = self.gather_local_0(new_feature) + feature = feature_0.permute(0, 2, 1) # B, nsamples, D + new_xyz, new_feature = sample_and_group(npoint=self.npoints[1], nsample=self.nsamples[1], xyz=new_xyz, points=feature) + feature_1 = self.gather_local_1(new_feature) # B, D, nsamples + x = feature_1.permute(0, 2, 1) # B, nsamples, D + + return x \ No newline at end of file diff --git a/src/StructDiffusion/models/point_transformer_large.py b/src/StructDiffusion/models/point_transformer_large.py new file mode 100644 index 0000000000000000000000000000000000000000..2e2a3d9f65c922b18641064e001c6fcc428d30cf --- /dev/null +++ b/src/StructDiffusion/models/point_transformer_large.py @@ -0,0 +1,306 @@ +import torch +import torch.nn as nn +from StructDiffusion.utils.pointnet import farthest_point_sample, index_points, square_distance, random_point_sample + + +def sample_and_group(npoint, nsample, xyz, points, use_random_sampling=False): + B, N, C = xyz.shape + S = npoint + + if not use_random_sampling: + fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint] + else: + fps_idx = random_point_sample(xyz, npoint) # [B, npoint] + + new_xyz = index_points(xyz, fps_idx) + new_points = index_points(points, fps_idx) + + dists = square_distance(new_xyz, xyz) # B x npoint x N + idx = dists.argsort()[:, :, :nsample] # B x npoint x K + + grouped_points = index_points(points, idx) + grouped_points_norm = grouped_points - new_points.view(B, S, 1, -1) + new_points = torch.cat([grouped_points_norm, new_points.view(B, S, 1, -1).repeat(1, 1, nsample, 1)], dim=-1) + return new_xyz, new_points + + +class Local_op(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False) + self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm1d(out_channels) + self.bn2 = nn.BatchNorm1d(out_channels) + self.relu = nn.ReLU() + + def forward(self, x): + b, n, s, d = x.size() # torch.Size([32, 512, 32, 6]) + x = x.permute(0, 1, 3, 2) + x = x.reshape(-1, d, s) + batch_size, _, N = x.size() + x = self.relu(self.bn1(self.conv1(x))) # B, D, N + x = self.relu(self.bn2(self.conv2(x))) # B, D, N + x = torch.max(x, 2)[0] + x = x.view(batch_size, -1) + x = x.reshape(b, n, -1).permute(0, 2, 1) + return x + + +class SA_Layer(nn.Module): + def __init__(self, channels): + super().__init__() + self.q_conv = nn.Conv1d(channels, channels // 4, 1, bias=False) + self.k_conv = nn.Conv1d(channels, channels // 4, 1, bias=False) + self.q_conv.weight = self.k_conv.weight + self.v_conv = nn.Conv1d(channels, channels, 1) + self.trans_conv = nn.Conv1d(channels, channels, 1) + self.after_norm = nn.BatchNorm1d(channels) + self.act = nn.ReLU() + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x): + x_q = self.q_conv(x).permute(0, 2, 1) # b, n, c + x_k = self.k_conv(x)# b, c, n + x_v = self.v_conv(x) + energy = x_q @ x_k # b, n, n + attention = self.softmax(energy) + attention = attention / (1e-9 + attention.sum(dim=1, keepdims=True)) + x_r = x_v @ attention # b, c, n + x_r = self.act(self.after_norm(self.trans_conv(x - x_r))) + x = x + x_r + return x + + +class StackedAttention(nn.Module): + def __init__(self, channels=256): + super().__init__() + self.conv1 = nn.Conv1d(channels, channels, kernel_size=1, bias=False) + self.conv2 = nn.Conv1d(channels, channels, kernel_size=1, bias=False) + + self.bn1 = nn.BatchNorm1d(channels) + self.bn2 = nn.BatchNorm1d(channels) + + self.sa1 = SA_Layer(channels) + self.sa2 = SA_Layer(channels) + self.sa3 = SA_Layer(channels) + self.sa4 = SA_Layer(channels) + + self.relu = nn.ReLU() + + def forward(self, x): + # + # b, 3, npoint, nsample + # conv2d 3 -> 128 channels 1, 1 + # b * npoint, c, nsample + # permute reshape + batch_size, _, N = x.size() + + x = self.relu(self.bn1(self.conv1(x))) # B, D, N + x = self.relu(self.bn2(self.conv2(x))) + + x1 = self.sa1(x) + x2 = self.sa2(x1) + x3 = self.sa3(x2) + x4 = self.sa4(x3) + + x = torch.cat((x1, x2, x3, x4), dim=1) + + return x + + +class PointTransformerCls(nn.Module): + def __init__(self, input_dim, output_dim, use_random_sampling=False): + super().__init__() + + self.use_random_sampling = use_random_sampling + + self.conv1 = nn.Conv1d(input_dim, 64, kernel_size=1, bias=False) + self.conv2 = nn.Conv1d(64, 64, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm1d(64) + self.bn2 = nn.BatchNorm1d(64) + self.gather_local_0 = Local_op(in_channels=128, out_channels=128) + self.gather_local_1 = Local_op(in_channels=256, out_channels=256) + self.pt_last = StackedAttention() + + self.relu = nn.ReLU() + self.conv_fuse = nn.Sequential(nn.Conv1d(1280, 1024, kernel_size=1, bias=False), + nn.BatchNorm1d(1024), + nn.LeakyReLU(negative_slope=0.2)) + + self.linear1 = nn.Linear(1024, 512, bias=False) + self.bn6 = nn.BatchNorm1d(512) + self.dp1 = nn.Dropout(p=0.5) + self.linear2 = nn.Linear(512, 256) + self.bn7 = nn.BatchNorm1d(256) + self.dp2 = nn.Dropout(p=0.5) + self.linear3 = nn.Linear(256, output_dim) + + def forward(self, x): + xyz = x[..., :3] + x = x.permute(0, 2, 1) + batch_size, _, _ = x.size() + x = self.relu(self.bn1(self.conv1(x))) # B, D, N + x = self.relu(self.bn2(self.conv2(x))) # B, D, N + x = x.permute(0, 2, 1) + new_xyz, new_feature = sample_and_group(npoint=512, nsample=32, xyz=xyz, points=x, + use_random_sampling=self.use_random_sampling) + feature_0 = self.gather_local_0(new_feature) + feature = feature_0.permute(0, 2, 1) + new_xyz, new_feature = sample_and_group(npoint=256, nsample=32, xyz=new_xyz, points=feature, + use_random_sampling=self.use_random_sampling) + + # debug: visualize + # # new_xyz: B, N, 3 + # from rearrangement_utils import show_pcs + # import numpy as np + # + # new_xyz_copy = new_xyz.detach().cpu().numpy() + # for i in range(new_xyz_copy.shape[0]): + # print(new_xyz_copy[i].shape) + # show_pcs([new_xyz_copy[i]], [np.tile(np.array([0, 1, 0], dtype=np.float), (new_xyz_copy[i].shape[0], 1))]) + + feature_1 = self.gather_local_1(new_feature) + + x = self.pt_last(feature_1) + x = torch.cat([x, feature_1], dim=1) + x = self.conv_fuse(x) + x = torch.max(x, 2)[0] + x = x.view(batch_size, -1) + + x = self.relu(self.bn6(self.linear1(x))) + x = self.dp1(x) + x = self.relu(self.bn7(self.linear2(x))) + x = self.dp2(x) + x = self.linear3(x) + + return x + + +class PointTransformerClsLarge(nn.Module): + def __init__(self, input_dim, output_dim, use_random_sampling=False): + super().__init__() + + self.use_random_sampling = use_random_sampling + + self.conv1 = nn.Conv1d(input_dim, 64, kernel_size=1, bias=False) + self.conv2 = nn.Conv1d(64, 64, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm1d(64) + self.bn2 = nn.BatchNorm1d(64) + self.gather_local_0 = Local_op(in_channels=128, out_channels=128) + self.gather_local_1 = Local_op(in_channels=256, out_channels=256) + self.pt_last = StackedAttention() + + self.relu = nn.ReLU() + self.conv_fuse = nn.Sequential(nn.Conv1d(1280, 1024, kernel_size=1, bias=False), + nn.BatchNorm1d(1024), + nn.LeakyReLU(negative_slope=0.2)) + + self.linear1 = nn.Linear(1024, 1024, bias=False) + self.bn6 = nn.BatchNorm1d(1024) + self.dp1 = nn.Dropout(p=0.5) + self.linear2 = nn.Linear(1024, 512) + self.bn7 = nn.BatchNorm1d(512) + self.dp2 = nn.Dropout(p=0.5) + self.linear3 = nn.Linear(512, output_dim) + + def forward(self, x): + xyz = x[..., :3] + x = x.permute(0, 2, 1) + batch_size, _, _ = x.size() + x = self.relu(self.bn1(self.conv1(x))) # B, D, N + x = self.relu(self.bn2(self.conv2(x))) # B, D, N + x = x.permute(0, 2, 1) + new_xyz, new_feature = sample_and_group(npoint=512, nsample=32, xyz=xyz, points=x, + use_random_sampling=self.use_random_sampling) + feature_0 = self.gather_local_0(new_feature) + feature = feature_0.permute(0, 2, 1) + new_xyz, new_feature = sample_and_group(npoint=256, nsample=32, xyz=new_xyz, points=feature, + use_random_sampling=self.use_random_sampling) + + # debug: visualize + # # new_xyz: B, N, 3 + # from rearrangement_utils import show_pcs + # import numpy as np + # + # new_xyz_copy = new_xyz.detach().cpu().numpy() + # for i in range(new_xyz_copy.shape[0]): + # print(new_xyz_copy[i].shape) + # show_pcs([new_xyz_copy[i]], [np.tile(np.array([0, 1, 0], dtype=np.float), (new_xyz_copy[i].shape[0], 1))]) + + feature_1 = self.gather_local_1(new_feature) + + x = self.pt_last(feature_1) + x = torch.cat([x, feature_1], dim=1) + x = self.conv_fuse(x) + x = torch.max(x, 2)[0] + x = x.view(batch_size, -1) + + x = self.relu(self.bn6(self.linear1(x))) + x = self.dp1(x) + x = self.relu(self.bn7(self.linear2(x))) + x = self.dp2(x) + x = self.linear3(x) + + return x + + +class PointTransformerEncoderLarge(nn.Module): + def __init__(self, output_dim=256, input_dim=6, mean_center=True): + super(PointTransformerEncoderLarge, self).__init__() + + self.mean_center = mean_center + + self.conv1 = nn.Conv1d(input_dim, 64, kernel_size=1, bias=False) + self.conv2 = nn.Conv1d(64, 64, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm1d(64) + self.bn2 = nn.BatchNorm1d(64) + self.gather_local_0 = Local_op(in_channels=128, out_channels=128) + self.gather_local_1 = Local_op(in_channels=256, out_channels=256) + self.pt_last = StackedAttention() + + self.relu = nn.ReLU() + self.conv_fuse = nn.Sequential(nn.Conv1d(1280, 1024, kernel_size=1, bias=False), + nn.BatchNorm1d(1024), + nn.LeakyReLU(negative_slope=0.2)) + + self.linear1 = nn.Linear(1024, 512, bias=False) + self.bn6 = nn.BatchNorm1d(512) + self.dp1 = nn.Dropout(p=0.5) + self.linear2 = nn.Linear(512, 256) + + def forward(self, xyz, f): + # xyz: B, N, 3 + # f: B, N, D + center = torch.mean(xyz, dim=1) + if self.mean_center: + xyz = xyz - center.view(-1, 1, 3).repeat(1, xyz.shape[1], 1) + x = self.pct(torch.cat([xyz, f], dim=2)) # B, output_dim + + return center, x + + def pct(self, x): + + xyz = x[..., :3] + x = x.permute(0, 2, 1) + batch_size, _, _ = x.size() + x = self.relu(self.bn1(self.conv1(x))) # B, D, N + x = self.relu(self.bn2(self.conv2(x))) # B, D, N + x = x.permute(0, 2, 1) + new_xyz, new_feature = sample_and_group(npoint=512, nsample=32, xyz=xyz, points=x) + feature_0 = self.gather_local_0(new_feature) + feature = feature_0.permute(0, 2, 1) + new_xyz, new_feature = sample_and_group(npoint=256, nsample=32, xyz=new_xyz, points=feature) + feature_1 = self.gather_local_1(new_feature) + + x = self.pt_last(feature_1) + x = torch.cat([x, feature_1], dim=1) + x = self.conv_fuse(x) + x = torch.max(x, 2)[0] + x = x.view(batch_size, -1) + + x = self.relu(self.bn6(self.linear1(x))) + x = self.dp1(x) + x = self.linear2(x) + + return x + diff --git a/src/StructDiffusion/utils/__init__.py b/src/StructDiffusion/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/StructDiffusion/utils/__pycache__/__init__.cpython-37.pyc b/src/StructDiffusion/utils/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cae37618e6d32d718366dffd61e9b763f6c9ce5f Binary files /dev/null and b/src/StructDiffusion/utils/__pycache__/__init__.cpython-37.pyc differ diff --git a/src/StructDiffusion/utils/__pycache__/__init__.cpython-38.pyc b/src/StructDiffusion/utils/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4727c42b4a7ac2a745e70e0ea443820e1e2a4f94 Binary files /dev/null and b/src/StructDiffusion/utils/__pycache__/__init__.cpython-38.pyc differ diff --git a/src/StructDiffusion/utils/__pycache__/batch_inference.cpython-37.pyc b/src/StructDiffusion/utils/__pycache__/batch_inference.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df26d00a967aeeb79278f647634af346c43a0743 Binary files /dev/null and b/src/StructDiffusion/utils/__pycache__/batch_inference.cpython-37.pyc differ diff --git a/src/StructDiffusion/utils/__pycache__/batch_inference.cpython-38.pyc b/src/StructDiffusion/utils/__pycache__/batch_inference.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ecc3e99b8d0692f36305d31b7cc84a9e275d799 Binary files /dev/null and b/src/StructDiffusion/utils/__pycache__/batch_inference.cpython-38.pyc differ diff --git a/src/StructDiffusion/utils/__pycache__/files.cpython-37.pyc b/src/StructDiffusion/utils/__pycache__/files.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..159b7bdeaae322d391f3f7ba279e732a7f857a19 Binary files /dev/null and b/src/StructDiffusion/utils/__pycache__/files.cpython-37.pyc differ diff --git a/src/StructDiffusion/utils/__pycache__/files.cpython-38.pyc b/src/StructDiffusion/utils/__pycache__/files.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5a63e4fc58eefe3762552512a001ca8f9cab83e Binary files /dev/null and b/src/StructDiffusion/utils/__pycache__/files.cpython-38.pyc differ diff --git a/src/StructDiffusion/utils/__pycache__/pointnet.cpython-37.pyc b/src/StructDiffusion/utils/__pycache__/pointnet.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..951dc53d8b891232b7812c5fb5041acc0e1cf92b Binary files /dev/null and b/src/StructDiffusion/utils/__pycache__/pointnet.cpython-37.pyc differ diff --git a/src/StructDiffusion/utils/__pycache__/pointnet.cpython-38.pyc b/src/StructDiffusion/utils/__pycache__/pointnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e5e1f5b5d5ce557c861482a10811e48342fca05 Binary files /dev/null and b/src/StructDiffusion/utils/__pycache__/pointnet.cpython-38.pyc differ diff --git a/src/StructDiffusion/utils/__pycache__/rearrangement.cpython-37.pyc b/src/StructDiffusion/utils/__pycache__/rearrangement.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60c8e1e1607fc3b2ef51626f3833350afdbc73a5 Binary files /dev/null and b/src/StructDiffusion/utils/__pycache__/rearrangement.cpython-37.pyc differ diff --git a/src/StructDiffusion/utils/__pycache__/rearrangement.cpython-38.pyc b/src/StructDiffusion/utils/__pycache__/rearrangement.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4ee8032d620ef9ee48c65bb0708352c4fd71bf6 Binary files /dev/null and b/src/StructDiffusion/utils/__pycache__/rearrangement.cpython-38.pyc differ diff --git a/src/StructDiffusion/utils/__pycache__/rotation_continuity.cpython-37.pyc b/src/StructDiffusion/utils/__pycache__/rotation_continuity.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c256ed0456e4a1b552ddbb4595b4366bf90fb637 Binary files /dev/null and b/src/StructDiffusion/utils/__pycache__/rotation_continuity.cpython-37.pyc differ diff --git a/src/StructDiffusion/utils/__pycache__/rotation_continuity.cpython-38.pyc b/src/StructDiffusion/utils/__pycache__/rotation_continuity.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1754983c6f0a3a2d7c5160573123f64637422360 Binary files /dev/null and b/src/StructDiffusion/utils/__pycache__/rotation_continuity.cpython-38.pyc differ diff --git a/src/StructDiffusion/utils/__pycache__/transformations.cpython-37.pyc b/src/StructDiffusion/utils/__pycache__/transformations.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1644b1ba1c64215da2ced9538e5d98f81a15613e Binary files /dev/null and b/src/StructDiffusion/utils/__pycache__/transformations.cpython-37.pyc differ diff --git a/src/StructDiffusion/utils/__pycache__/transformations.cpython-38.pyc b/src/StructDiffusion/utils/__pycache__/transformations.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fac52a646e04870c1698bfb9fd092c58e802e39e Binary files /dev/null and b/src/StructDiffusion/utils/__pycache__/transformations.cpython-38.pyc differ diff --git a/src/StructDiffusion/utils/batch_inference.py b/src/StructDiffusion/utils/batch_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..91ae39a6b8220383fe2765ab8455c0dffb38f24d --- /dev/null +++ b/src/StructDiffusion/utils/batch_inference.py @@ -0,0 +1,490 @@ +import os +import torch +import numpy as np +import pytorch3d.transforms as tra3d + +from StructDiffusion.utils.rearrangement import show_pcs_color_order, show_pcs_with_trimesh +from StructDiffusion.utils.pointnet import random_point_sample, index_points + + +def move_pc_and_create_scene(obj_xyzs, obj_params, struct_pose, current_pc_pose, target_object_inds, + num_scene_pts, device, normalize_pc=False, + return_pair_pc=False, normalize_pair_pc=False, num_pair_pc_pts=None, + return_scene_pts=True, return_scene_pts_and_pc_idxs=False): + + # obj_xyzs: N, P, 3 + # obj_params: B, N, 6 + # struct_pose: B x N, 4, 4 + # current_pc_pose: B x N, 4, 4 + # target_object_inds: 1, N + + B, N, _ = obj_params.shape + _, P, _ = obj_xyzs.shape + + # B, N, 6 + flat_obj_params = obj_params.reshape(B * N, -1) + goal_pc_pose_in_struct = torch.eye(4).repeat(B * N, 1, 1).to(device) + goal_pc_pose_in_struct[:, :3, :3] = tra3d.euler_angles_to_matrix(flat_obj_params[:, 3:], "XYZ") + goal_pc_pose_in_struct[:, :3, 3] = flat_obj_params[:, :3] # B x N, 4, 4 + + goal_pc_pose = struct_pose @ goal_pc_pose_in_struct + goal_pc_transform = goal_pc_pose @ torch.inverse(current_pc_pose) # cur_batch_size x N, 4, 4 + + # important: pytorch3d uses row-major ordering, need to transpose each transformation matrix + transpose = tra3d.Transform3d(matrix=goal_pc_transform.transpose(1, 2)) + + # obj_xyzs: N, P, 3 + new_obj_xyzs = obj_xyzs.repeat(B, 1, 1) + new_obj_xyzs = transpose.transform_points(new_obj_xyzs) + + # put it back to B, N, P, 3 + new_obj_xyzs = new_obj_xyzs.reshape(B, N, P, -1) + # visualize_batch_pcs(new_obj_xyzs, S, N, P) + + # =================================== + # Pass to discriminator + subsampled_scene_xyz = None + if return_scene_pts: + + num_indicator = N + + # add one hot + indicator_variables = torch.eye(num_indicator).repeat(B, 1, 1, P).reshape(B, num_indicator, P, num_indicator).to(device) # B, N, P, N + # print(indicator_variables.shape) + # print(new_obj_xyzs.shape) + new_obj_xyzs = torch.cat([new_obj_xyzs, indicator_variables], dim=-1) # B, N, P, 3 + N + + # combine pcs in each scene + scene_xyzs = new_obj_xyzs.reshape(B, N * P, 3 + N) + + # ToDo: maybe convert this to a batch operation + subsampled_scene_xyz = torch.FloatTensor(B, num_scene_pts, 3 + N).to(device) + for si, scene_xyz in enumerate(scene_xyzs): + # scene_xyz: N*P, 3+N + # target_object_inds: 1, N + subsample_idx = torch.randint(0, torch.sum(target_object_inds[0]) * P, (num_scene_pts,)).to(device) + subsampled_scene_xyz[si] = scene_xyz[subsample_idx] + + # # debug: + # print("-"*50) + # if si < 10: + # trimesh.PointCloud(scene_xyz[:, :3].cpu().numpy(), colors=[255, 0, 0, 255]).show() + # trimesh.PointCloud(subsampled_scene_xyz[si, :, :3].cpu().numpy(), colors=[0, 255, 0, 255]).show() + + # subsampled_scene_xyz: B, num_scene_pts, 3+N + # new_obj_xyzs: B, N, P, 3 + # goal_pc_pose: B, N, 4, 4 + + # important: + if normalize_pc: + subsampled_scene_xyz[:, :, 0:3] = pc_normalize_batch(subsampled_scene_xyz[:, :, 0:3]) + + # # debug: + # for si in range(10): + # trimesh.PointCloud(subsampled_scene_xyz[si, :, :3].cpu().numpy(), colors=[0, 0, 255, 255]).show() + + if return_scene_pts_and_pc_idxs: + num_indicator = N + pc_idxs = torch.arange(0, num_indicator)[:, None].repeat(B, 1, P).reshape(B, num_indicator, P).to(device) # B, N, P + # new_obj_xyzs: B, N, P, 3 + 1 + + # combine pcs in each scene + scene_xyzs = new_obj_xyzs.reshape(B, N * P, 3) + pc_idxs = pc_idxs.reshape(B, N*P) + + subsampled_scene_xyz = torch.FloatTensor(B, num_scene_pts, 3).to(device) + subsampled_pc_idxs = torch.LongTensor(B, num_scene_pts).to(device) + for si, (scene_xyz, pc_idx) in enumerate(zip(scene_xyzs, pc_idxs)): + # scene_xyz: N*P, 3+1 + # target_object_inds: 1, N + subsample_idx = torch.randint(0, torch.sum(target_object_inds[0]) * P, (num_scene_pts,)).to(device) + subsampled_scene_xyz[si] = scene_xyz[subsample_idx] + subsampled_pc_idxs[si] = pc_idx[subsample_idx] + + # subsampled_scene_xyz: B, num_scene_pts, 3 + # subsampled_pc_idxs: B, num_scene_pts + # new_obj_xyzs: B, N, P, 3 + # goal_pc_pose: B, N, 4, 4 + + # important: + if normalize_pc: + subsampled_scene_xyz[:, :, 0:3] = pc_normalize_batch(subsampled_scene_xyz[:, :, 0:3]) + + # TODO: visualize each individual object + # debug + # print(subsampled_scene_xyz.shape) + # print(subsampled_pc_idxs.shape) + # print("visualize subsampled scene") + # for si in range(5): + # trimesh.PointCloud(subsampled_scene_xyz[si, :, :3].cpu().numpy(), colors=[0, 0, 255, 255]).show() + + ############################################### + # Create input for pairwise collision detector + if return_pair_pc: + + assert num_pair_pc_pts is not None + + # new_obj_xyzs: B, N, P, 3 + N + # target_object_inds: 1, N + # ignore paddings + num_objs = torch.sum(target_object_inds[0]) + obj_pair_idxs = torch.combinations(torch.arange(num_objs), r=2) # num_comb, 2 + + # use [:, :, :, :3] to get obj_xyzs without object-wise indicator + obj_pair_xyzs = new_obj_xyzs[:, :, :, :3][:, obj_pair_idxs] # B, num_comb, 2 (obj 1 and obj 2), P, 3 + num_comb = obj_pair_xyzs.shape[1] + pair_indicator_variables = torch.eye(2).repeat(B, num_comb, 1, 1, P).reshape(B, num_comb, 2, P, 2).to(device) # B, num_comb, 2, P, 2 + obj_pair_xyzs = torch.cat([obj_pair_xyzs, pair_indicator_variables], dim=-1) # B, num_comb, 2, P, 3 (pc channels) + 2 (indicator for obj 1 and obj 2) + obj_pair_xyzs = obj_pair_xyzs.reshape(B, num_comb, P * 2, 5) + + # random sample: idx = np.random.randint(0, scene_xyz.shape[0], self.num_scene_pts) + obj_pair_xyzs = obj_pair_xyzs.reshape(B * num_comb, P * 2, 5) + # random_point_sample() input dim: B, N, C + rand_idxs = random_point_sample(obj_pair_xyzs, num_pair_pc_pts) # B * num_comb, num_pair_pc_pts + obj_pair_xyzs = index_points(obj_pair_xyzs, rand_idxs) # B * num_comb, num_pair_pc_pts, 5 + + if normalize_pair_pc: + # pc_normalize_batch() input dim: pc: B, num_scene_pts, 3 + # obj_pair_xyzs = obj_pair_xyzs.reshape(B * num_comb, num_pair_pc_pts, 5) + obj_pair_xyzs[:, :, 0:3] = pc_normalize_batch(obj_pair_xyzs[:, :, 0:3]) + obj_pair_xyzs = obj_pair_xyzs.reshape(B, num_comb, num_pair_pc_pts, 5) + + # # debug + # for bi, this_obj_pair_xyzs in enumerate(obj_pair_xyzs): + # print("batch id", bi) + # for pi, obj_pair_xyz in enumerate(this_obj_pair_xyzs): + # print("pair", pi) + # # obj_pair_xyzs: 2 * P, 5 + # print(obj_pair_xyz[:, :3].shape) + # trimesh.PointCloud(obj_pair_xyz[:, :3].cpu()).show() + + # obj_pair_xyzs: B, num_comb, num_pair_pc_pts, 3 + 2 + goal_pc_pose = goal_pc_pose.reshape(B, N, 4, 4) + + # TODO: update the return logic, a mess right now + if return_scene_pts_and_pc_idxs: + return subsampled_scene_xyz, subsampled_pc_idxs, new_obj_xyzs, goal_pc_pose + + if return_pair_pc: + return subsampled_scene_xyz, new_obj_xyzs, goal_pc_pose, obj_pair_xyzs + else: + return subsampled_scene_xyz, new_obj_xyzs, goal_pc_pose + + +def move_pc_and_create_scene_new(obj_xyzs, obj_params, struct_pose, current_pc_pose, target_object_inds, device, + return_scene_pts=False, return_scene_pts_and_pc_idxs=False, num_scene_pts=None, normalize_pc=False, + return_pair_pc=False, num_pair_pc_pts=None, normalize_pair_pc=False): + + # obj_xyzs: N, P, 3 + # obj_params: B, N, 6 + # struct_pose: B x N, 4, 4 + # current_pc_pose: B x N, 4, 4 + # target_object_inds: 1, N + + B, N, _ = obj_params.shape + _, P, _ = obj_xyzs.shape + + # B, N, 6 + flat_obj_params = obj_params.reshape(B * N, -1) + goal_pc_pose_in_struct = torch.eye(4).repeat(B * N, 1, 1).to(device) + goal_pc_pose_in_struct[:, :3, :3] = tra3d.euler_angles_to_matrix(flat_obj_params[:, 3:], "XYZ") + goal_pc_pose_in_struct[:, :3, 3] = flat_obj_params[:, :3] # B x N, 4, 4 + + goal_pc_pose = struct_pose @ goal_pc_pose_in_struct + goal_pc_transform = goal_pc_pose @ torch.inverse(current_pc_pose) # cur_batch_size x N, 4, 4 + + # important: pytorch3d uses row-major ordering, need to transpose each transformation matrix + transpose = tra3d.Transform3d(matrix=goal_pc_transform.transpose(1, 2)) + + # obj_xyzs: N, P, 3 + new_obj_xyzs = obj_xyzs.repeat(B, 1, 1) + new_obj_xyzs = transpose.transform_points(new_obj_xyzs) + + # put it back to B, N, P, 3 + new_obj_xyzs = new_obj_xyzs.reshape(B, N, P, -1) + # visualize_batch_pcs(new_obj_xyzs, S, N, P) + + + # initialize the additional outputs + subsampled_scene_xyz = None + subsampled_pc_idxs = None + obj_pair_xyzs = None + + # =================================== + # Pass to discriminator + if return_scene_pts: + + num_indicator = N + + # add one hot + indicator_variables = torch.eye(num_indicator).repeat(B, 1, 1, P).reshape(B, num_indicator, P, num_indicator).to(device) # B, N, P, N + # print(indicator_variables.shape) + # print(new_obj_xyzs.shape) + new_obj_xyzs = torch.cat([new_obj_xyzs, indicator_variables], dim=-1) # B, N, P, 3 + N + + # combine pcs in each scene + scene_xyzs = new_obj_xyzs.reshape(B, N * P, 3 + N) + + # ToDo: maybe convert this to a batch operation + subsampled_scene_xyz = torch.FloatTensor(B, num_scene_pts, 3 + N).to(device) + for si, scene_xyz in enumerate(scene_xyzs): + # scene_xyz: N*P, 3+N + # target_object_inds: 1, N + subsample_idx = torch.randint(0, torch.sum(target_object_inds[0]) * P, (num_scene_pts,)).to(device) + subsampled_scene_xyz[si] = scene_xyz[subsample_idx] + + # # debug: + # print("-"*50) + # if si < 10: + # trimesh.PointCloud(scene_xyz[:, :3].cpu().numpy(), colors=[255, 0, 0, 255]).show() + # trimesh.PointCloud(subsampled_scene_xyz[si, :, :3].cpu().numpy(), colors=[0, 255, 0, 255]).show() + + # subsampled_scene_xyz: B, num_scene_pts, 3+N + # new_obj_xyzs: B, N, P, 3 + # goal_pc_pose: B, N, 4, 4 + + # important: + if normalize_pc: + subsampled_scene_xyz[:, :, 0:3] = pc_normalize_batch(subsampled_scene_xyz[:, :, 0:3]) + + # # debug: + # for si in range(10): + # trimesh.PointCloud(subsampled_scene_xyz[si, :, :3].cpu().numpy(), colors=[0, 0, 255, 255]).show() + + if return_scene_pts_and_pc_idxs: + num_indicator = N + pc_idxs = torch.arange(0, num_indicator)[:, None].repeat(B, 1, P).reshape(B, num_indicator, P).to(device) # B, N, P + # new_obj_xyzs: B, N, P, 3 + 1 + + # combine pcs in each scene + scene_xyzs = new_obj_xyzs.reshape(B, N * P, 3) + pc_idxs = pc_idxs.reshape(B, N*P) + + subsampled_scene_xyz = torch.FloatTensor(B, num_scene_pts, 3).to(device) + subsampled_pc_idxs = torch.LongTensor(B, num_scene_pts).to(device) + for si, (scene_xyz, pc_idx) in enumerate(zip(scene_xyzs, pc_idxs)): + # scene_xyz: N*P, 3+1 + # target_object_inds: 1, N + subsample_idx = torch.randint(0, torch.sum(target_object_inds[0]) * P, (num_scene_pts,)).to(device) + subsampled_scene_xyz[si] = scene_xyz[subsample_idx] + subsampled_pc_idxs[si] = pc_idx[subsample_idx] + + # subsampled_scene_xyz: B, num_scene_pts, 3 + # subsampled_pc_idxs: B, num_scene_pts + # new_obj_xyzs: B, N, P, 3 + # goal_pc_pose: B, N, 4, 4 + + # important: + if normalize_pc: + subsampled_scene_xyz[:, :, 0:3] = pc_normalize_batch(subsampled_scene_xyz[:, :, 0:3]) + + # TODO: visualize each individual object + # debug + # print(subsampled_scene_xyz.shape) + # print(subsampled_pc_idxs.shape) + # print("visualize subsampled scene") + # for si in range(5): + # trimesh.PointCloud(subsampled_scene_xyz[si, :, :3].cpu().numpy(), colors=[0, 0, 255, 255]).show() + + ############################################### + # Create input for pairwise collision detector + if return_pair_pc: + + assert num_pair_pc_pts is not None + + # new_obj_xyzs: B, N, P, 3 + N + # target_object_inds: 1, N + # ignore paddings + num_objs = torch.sum(target_object_inds[0]) + obj_pair_idxs = torch.combinations(torch.arange(num_objs), r=2) # num_comb, 2 + + # use [:, :, :, :3] to get obj_xyzs without object-wise indicator + obj_pair_xyzs = new_obj_xyzs[:, :, :, :3][:, obj_pair_idxs] # B, num_comb, 2 (obj 1 and obj 2), P, 3 + num_comb = obj_pair_xyzs.shape[1] + pair_indicator_variables = torch.eye(2).repeat(B, num_comb, 1, 1, P).reshape(B, num_comb, 2, P, 2).to(device) # B, num_comb, 2, P, 2 + obj_pair_xyzs = torch.cat([obj_pair_xyzs, pair_indicator_variables], dim=-1) # B, num_comb, 2, P, 3 (pc channels) + 2 (indicator for obj 1 and obj 2) + obj_pair_xyzs = obj_pair_xyzs.reshape(B, num_comb, P * 2, 5) + + # random sample: idx = np.random.randint(0, scene_xyz.shape[0], self.num_scene_pts) + obj_pair_xyzs = obj_pair_xyzs.reshape(B * num_comb, P * 2, 5) + # random_point_sample() input dim: B, N, C + rand_idxs = random_point_sample(obj_pair_xyzs, num_pair_pc_pts) # B * num_comb, num_pair_pc_pts + obj_pair_xyzs = index_points(obj_pair_xyzs, rand_idxs) # B * num_comb, num_pair_pc_pts, 5 + + if normalize_pair_pc: + # pc_normalize_batch() input dim: pc: B, num_scene_pts, 3 + # obj_pair_xyzs = obj_pair_xyzs.reshape(B * num_comb, num_pair_pc_pts, 5) + obj_pair_xyzs[:, :, 0:3] = pc_normalize_batch(obj_pair_xyzs[:, :, 0:3]) + obj_pair_xyzs = obj_pair_xyzs.reshape(B, num_comb, num_pair_pc_pts, 5) + + # # debug + # for bi, this_obj_pair_xyzs in enumerate(obj_pair_xyzs): + # print("batch id", bi) + # for pi, obj_pair_xyz in enumerate(this_obj_pair_xyzs): + # print("pair", pi) + # # obj_pair_xyzs: 2 * P, 5 + # print(obj_pair_xyz[:, :3].shape) + # trimesh.PointCloud(obj_pair_xyz[:, :3].cpu()).show() + + # obj_pair_xyzs: B, num_comb, num_pair_pc_pts, 3 + 2 + goal_pc_pose = goal_pc_pose.reshape(B, N, 4, 4) + + return new_obj_xyzs, goal_pc_pose, subsampled_scene_xyz, subsampled_pc_idxs, obj_pair_xyzs + + +def move_pc(obj_xyzs, obj_params, struct_pose, current_pc_pose, device): + + # obj_xyzs: N, P, 3 + # obj_params: B, N, 6 + # struct_pose: B x N, 4, 4 + # current_pc_pose: B x N, 4, 4 + # target_object_inds: 1, N + + B, N, _ = obj_params.shape + _, P, _ = obj_xyzs.shape + + # B, N, 6 + flat_obj_params = obj_params.reshape(B * N, -1) + goal_pc_pose_in_struct = torch.eye(4).repeat(B * N, 1, 1).to(device) + goal_pc_pose_in_struct[:, :3, :3] = tra3d.euler_angles_to_matrix(flat_obj_params[:, 3:], "XYZ") + goal_pc_pose_in_struct[:, :3, 3] = flat_obj_params[:, :3] # B x N, 4, 4 + + goal_pc_pose = struct_pose @ goal_pc_pose_in_struct + goal_pc_transform = goal_pc_pose @ torch.inverse(current_pc_pose) # cur_batch_size x N, 4, 4 + + # important: pytorch3d uses row-major ordering, need to transpose each transformation matrix + transpose = tra3d.Transform3d(matrix=goal_pc_transform.transpose(1, 2)) + + # obj_xyzs: N, P, 3 + new_obj_xyzs = obj_xyzs.repeat(B, 1, 1) + new_obj_xyzs = transpose.transform_points(new_obj_xyzs) + + # put it back to B, N, P, 3 + new_obj_xyzs = new_obj_xyzs.reshape(B, N, P, -1) + # visualize_batch_pcs(new_obj_xyzs, S, N, P) + + # subsampled_scene_xyz: B, num_scene_pts, 3+N + # new_obj_xyzs: B, N, P, 3 + # goal_pc_pose: B, N, 4, 4 + + goal_pc_pose = goal_pc_pose.reshape(B, N, 4, 4) + return new_obj_xyzs, goal_pc_pose + + +def move_pc_and_create_scene_simple(obj_xyzs, struct_pose, pc_poses_in_struct): + + device = obj_xyzs.device + + # obj_xyzs: B, N, P, 3 or 6 + # struct_pose: B, 1, 4, 4 + # pc_poses_in_struct: B, N, 4, 4 + + B, N, _, _ = pc_poses_in_struct.shape + _, _, P, _ = obj_xyzs.shape + + current_pc_poses = torch.eye(4).repeat(B, N, 1, 1).to(device) # B, N, 4, 4 + # print(torch.mean(obj_xyzs, dim=2).shape) + current_pc_poses[:, :, :3, 3] = torch.mean(obj_xyzs[:, :, :, :3], dim=2) # B, N, 4, 4 + current_pc_poses = current_pc_poses.reshape(B * N, 4, 4) # B x N, 4, 4 + + struct_pose = struct_pose.repeat(1, N, 1, 1) # B, N, 4, 4 + struct_pose = struct_pose.reshape(B * N, 4, 4) # B x 1, 4, 4 + pc_poses_in_struct = pc_poses_in_struct.reshape(B * N, 4, 4) # B x N, 4, 4 + + goal_pc_pose = struct_pose @ pc_poses_in_struct # B x N, 4, 4 + # print("goal pc poses") + # print(goal_pc_pose) + goal_pc_transform = goal_pc_pose @ torch.inverse(current_pc_poses) # B x N, 4, 4 + + # important: pytorch3d uses row-major ordering, need to transpose each transformation matrix + transpose = tra3d.Transform3d(matrix=goal_pc_transform.transpose(1, 2)) + + new_obj_xyzs = obj_xyzs.reshape(B * N, P, -1) # B x N, P, 3 + new_obj_xyzs[:, :, :3] = transpose.transform_points(new_obj_xyzs[:, :, :3]) + + # put it back to B, N, P, 3 + new_obj_xyzs = new_obj_xyzs.reshape(B, N, P, -1) + + return new_obj_xyzs + + +def compute_current_and_goal_pc_poses(obj_xyzs, struct_pose, pc_poses_in_struct): + + device = obj_xyzs.device + + # obj_xyzs: B, N, P, 3 + # struct_pose: B, 1, 4, 4 + # pc_poses_in_struct: B, N, 4, 4 + B, N, _, _ = pc_poses_in_struct.shape + _, _, P, _ = obj_xyzs.shape + + current_pc_poses = torch.eye(4).repeat(B, N, 1, 1).to(device) # B, N, 4, 4 + # print(torch.mean(obj_xyzs, dim=2).shape) + current_pc_poses[:, :, :3, 3] = torch.mean(obj_xyzs, dim=2) # B, N, 4, 4 + + struct_pose = struct_pose.repeat(1, N, 1, 1) # B, N, 4, 4 + struct_pose = struct_pose.reshape(B * N, 4, 4) # B x 1, 4, 4 + pc_poses_in_struct = pc_poses_in_struct.reshape(B * N, 4, 4) # B x N, 4, 4 + + goal_pc_poses = struct_pose @ pc_poses_in_struct # B x N, 4, 4 + goal_pc_poses = goal_pc_poses.reshape(B, N, 4, 4) # B, N, 4, 4 + return current_pc_poses, goal_pc_poses + + +def sample_gaussians(mus, sigmas, sample_size): + # mus: [number of individual gaussians] + # sigmas: [number of individual gaussians] + normal = torch.distributions.Normal(mus, sigmas) + samples = normal.sample((sample_size,)) + # samples: [sample_size, number of individual gaussians] + return samples + +def fit_gaussians(samples, sigma_eps=0.01): + device = samples.device + + # samples: [sample_size, number of individual gaussians] + num_gs = samples.shape[1] + mus = torch.mean(samples, dim=0).to(device) + sigmas = torch.std(samples, dim=0).to(device) + sigma_eps * torch.ones(num_gs).to(device) + # mus: [number of individual gaussians] + # sigmas: [number of individual gaussians] + return mus, sigmas + + +def visualize_batch_pcs(obj_xyzs, B, verbose=False, limit_B=None, save_dir=None, trimesh=False): + if limit_B is None: + limit_B = B + + vis_obj_xyzs = obj_xyzs[:limit_B] + + if torch.is_tensor(vis_obj_xyzs): + if vis_obj_xyzs.is_cuda: + vis_obj_xyzs = vis_obj_xyzs.detach().cpu() + vis_obj_xyzs = vis_obj_xyzs.numpy() + + for bi, vis_obj_xyz in enumerate(vis_obj_xyzs): + if verbose: + print("example {}".format(bi)) + print(vis_obj_xyz.shape) + + if trimesh: + show_pcs_with_trimesh([xyz[:, :3] for xyz in vis_obj_xyz], [xyz[:, 3:] for xyz in vis_obj_xyz]) + else: + if save_dir: + if not os.path.exists(save_dir): + os.makedirs(save_dir) + save_path = os.path.join(save_dir, "b{}.jpg".format(bi)) + show_pcs_color_order([xyz[:, :3] for xyz in vis_obj_xyz], None, visualize=False, add_coordinate_frame=False, + side_view=True, save_path=save_path) + else: + show_pcs_color_order([xyz[:, :3] for xyz in vis_obj_xyz], None, visualize=True, add_coordinate_frame=False, + side_view=True) + + +def pc_normalize_batch(pc): + # pc: B, num_scene_pts, 3 + centroid = torch.mean(pc, dim=1) # B, 3 + pc = pc - centroid[:, None, :] + m = torch.max(torch.sqrt(torch.sum(pc ** 2, dim=2)), dim=1)[0] + pc = pc / m[:, None, None] + return pc \ No newline at end of file diff --git a/src/StructDiffusion/utils/brain2/__init__.py b/src/StructDiffusion/utils/brain2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/StructDiffusion/utils/brain2/__pycache__/__init__.cpython-37.pyc b/src/StructDiffusion/utils/brain2/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a865e957b0416b8531ff96968d43a42e02d0c5c8 Binary files /dev/null and b/src/StructDiffusion/utils/brain2/__pycache__/__init__.cpython-37.pyc differ diff --git a/src/StructDiffusion/utils/brain2/__pycache__/__init__.cpython-38.pyc b/src/StructDiffusion/utils/brain2/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b61b09a1b77329c30f59eaffd4d878dc0d0f02bf Binary files /dev/null and b/src/StructDiffusion/utils/brain2/__pycache__/__init__.cpython-38.pyc differ diff --git a/src/StructDiffusion/utils/brain2/__pycache__/camera.cpython-37.pyc b/src/StructDiffusion/utils/brain2/__pycache__/camera.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85dd8dfbe93b525792a1dda7ffc56fc69ac43830 Binary files /dev/null and b/src/StructDiffusion/utils/brain2/__pycache__/camera.cpython-37.pyc differ diff --git a/src/StructDiffusion/utils/brain2/__pycache__/camera.cpython-38.pyc b/src/StructDiffusion/utils/brain2/__pycache__/camera.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9af852d208bc53275e669633ebb0c883123427de Binary files /dev/null and b/src/StructDiffusion/utils/brain2/__pycache__/camera.cpython-38.pyc differ diff --git a/src/StructDiffusion/utils/brain2/__pycache__/image.cpython-37.pyc b/src/StructDiffusion/utils/brain2/__pycache__/image.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5e36a52f7c675b572fa3e1ceeedfb06f0009fcb Binary files /dev/null and b/src/StructDiffusion/utils/brain2/__pycache__/image.cpython-37.pyc differ diff --git a/src/StructDiffusion/utils/brain2/__pycache__/image.cpython-38.pyc b/src/StructDiffusion/utils/brain2/__pycache__/image.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ef499255c77b1d8fa3cc65a32a9e4a6b9d9f11a Binary files /dev/null and b/src/StructDiffusion/utils/brain2/__pycache__/image.cpython-38.pyc differ diff --git a/src/StructDiffusion/utils/brain2/__pycache__/pose.cpython-37.pyc b/src/StructDiffusion/utils/brain2/__pycache__/pose.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b910883f1c4dd2fbc1a5475965faa9cc9ed50d3 Binary files /dev/null and b/src/StructDiffusion/utils/brain2/__pycache__/pose.cpython-37.pyc differ diff --git a/src/StructDiffusion/utils/brain2/__pycache__/pose.cpython-38.pyc b/src/StructDiffusion/utils/brain2/__pycache__/pose.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1325a12669cdb6cfbe1c0b12fed60e2b9a79536e Binary files /dev/null and b/src/StructDiffusion/utils/brain2/__pycache__/pose.cpython-38.pyc differ diff --git a/src/StructDiffusion/utils/brain2/camera.py b/src/StructDiffusion/utils/brain2/camera.py new file mode 100644 index 0000000000000000000000000000000000000000..6ac75323f0468815b9f25066d66a458f3e0dc771 --- /dev/null +++ b/src/StructDiffusion/utils/brain2/camera.py @@ -0,0 +1,175 @@ +from __future__ import print_function + +import numpy as np +import open3d +import trimesh + +import StructDiffusion.utils.transformations as tra +from StructDiffusion.utils.brain2.pose import make_pose + + +def get_camera_from_h5(h5): + """ Simple reference to help make these """ + proj_near = h5['cam_near'][()] + proj_far = h5['cam_far'][()] + proj_fov = h5['cam_fov'][()] + width = h5['cam_width'][()] + height = h5['cam_height'][()] + return GenericCameraReference(proj_near, proj_far, proj_fov, width, height) + + +class GenericCameraReference(object): + """ Class storing camera information and providing easy image capture """ + + def __init__(self, proj_near=0.01, proj_far=5., proj_fov=60., img_width=640, + img_height=480): + + self.proj_near = proj_near + self.proj_far = proj_far + self.proj_fov = proj_fov + self.img_width = img_width + self.img_height = img_height + self.x_offset = self.img_width / 2. + self.y_offset = self.img_height / 2. + + # Compute focal params + aspect_ratio = self.img_width / self.img_height + e = 1 / (np.tan(np.radians(self.proj_fov/2.))) + t = self.proj_near / e + b = -t + r = t * aspect_ratio + l = -r + # pixels per meter + alpha = self.img_width / (r-l) + self.focal_length = self.proj_near * alpha + self.fx = self.focal_length + self.fy = self.focal_length + self.pose = None + self.inv_pose = None + + def set_pose(self, trans, rot): + self.pose = make_pose(trans, rot) + self.inv_pose = tra.inverse_matrix(self.pose) + + def set_pose_matrix(self, matrix): + self.pose = matrix + self.inv_pose = tra.inverse_matrix(matrix) + + def transform_to_world_coords(self, xyz): + """ transform xyz into world coordinates """ + #cam_pose = tra.inverse_matrix(self.pose).dot(tra.euler_matrix(np.pi, 0, 0)) + #xyz = trimesh.transform_points(xyz, self.inv_pose) + #xyz = trimesh.transform_points(xyz, cam_pose) + #pose = tra.euler_matrix(np.pi, 0, 0) @ self.pose + pose = self.pose + xyz = trimesh.transform_points(xyz, pose) + return xyz + +def get_camera_presets(): + return [ + "n/a", + "azure_depth_nfov", + "realsense", + "azure_720p", + "simple256", + "simple512", + ] + + +def get_camera_preset(name): + + if name == "azure_depth_nfov": + # Setting for depth camera is pretty different from RGB + height, width, fov = 576, 640, 75 + if name == "azure_720p": + # This is actually the 720p RGB setting + # Used for our color camera most of the time + #height, width, fov = 720, 1280, 90 + height, width, fov = 720, 1280, 60 + elif name == "realsense": + height, width, fov = 480, 640, 60 + elif name == "simple256": + height, width, fov = 256, 256, 60 + elif name == "simple512": + height, width, fov = 512, 512, 60 + else: + raise RuntimeError(('camera "%s" not supported, choose from: ' + + str(get_camera_presets())) % str(name)) + return height, width, fov + + +def get_generic_camera(name): + h, w, fov = get_camera_preset(name) + return GenericCameraReference(img_height=h, img_width=w, proj_fov=fov) + + +def get_matrix_of_indices(height, width): + """ Get indices """ + return np.indices((height, width), dtype=np.float32).transpose(1,2,0) + +# -------------------------------------------------------- +# NOTE: this code taken from Arsalan and modified +def compute_xyz(depth_img, camera, visualize_xyz=False, + xmap=None, ymap=None, max_clip_depth=5): + """ Compute xyz image from depth for a camera """ + + # We need thes eparameters + height = camera.img_height + width = camera.img_width + assert depth_img.shape[0] == camera.img_height + assert depth_img.shape[1] == camera.img_width + fx = camera.fx + fy = camera.fy + cx = camera.x_offset + cy = camera.y_offset + + """ + # Create the matrix of parameters + indices = np.indices((height, width), dtype=np.float32).transpose(1,2,0) + # pixel indices start at top-left corner. for these equations, it starts at bottom-left + # indices[..., 0] = np.flipud(indices[..., 0]) + z_e = depth_img + x_e = (indices[..., 1] - x_offset) * z_e / fx + y_e = (indices[..., 0] - y_offset) * z_e / fy + xyz_img = np.stack([x_e, y_e, z_e], axis=-1) # Shape: [H x W x 3] + """ + + height = depth_img.shape[0] + width = depth_img.shape[1] + input_x = np.arange(width) + input_y = np.arange(height) + input_x, input_y = np.meshgrid(input_x, input_y) + input_x = input_x.flatten() + input_y = input_y.flatten() + input_z = depth_img.flatten() + # clip points that are farther than max distance + input_z[input_z > max_clip_depth] = 0 + output_x = (input_x * input_z - cx * input_z) / fx + output_y = (input_y * input_z - cy * input_z) / fy + raw_pc = np.stack([output_x, output_y, input_z], -1).reshape( + height, width, 3 + ) + return raw_pc + + if visualize_xyz: + unordered_pc = xyz_img.reshape(-1, 3) + pcd = open3d.geometry.PointCloud() + pcd.points = open3d.utility.Vector3dVector(unordered_pc) + pcd.transform([[1,0,0,0], [0,1,0,0], [0,0,-1,0], [0,0,0,1]]) # Transform it so it's not upside down + open3d.visualization.draw_geometries([pcd]) + + return xyz_img + +def show_pcs(xyz, rgb): + """ Display point clouds """ + if len(xyz.shape) > 2: + unordered_pc = xyz.reshape(-1, 3) + unordered_rgb = rgb.reshape(-1, 3) / 255. + assert(unordered_rgb.shape[0] == unordered_pc.shape[0]) + assert(unordered_pc.shape[1] == 3) + assert(unordered_rgb.shape[1] == 3) + pcd = open3d.geometry.PointCloud() + pcd.points = open3d.utility.Vector3dVector(unordered_pc) + pcd.colors = open3d.utility.Vector3dVector(unordered_rgb) + pcd.transform([[1,0,0,0],[0,1,0,0],[0,0,-1,0],[0,0,0,1]]) # Transform it so it's not upside down + open3d.visualization.draw_geometries([pcd]) diff --git a/src/StructDiffusion/utils/brain2/image.py b/src/StructDiffusion/utils/brain2/image.py new file mode 100644 index 0000000000000000000000000000000000000000..668a551760f2548e1dec1e318ef54073244a7383 --- /dev/null +++ b/src/StructDiffusion/utils/brain2/image.py @@ -0,0 +1,154 @@ +""" +By Chris Paxton. + +Copyright (c) 2018, Johns Hopkins University +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the Johns Hopkins University nor the + names of its contributors may be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL JOHNS HOPKINS UNIVERSITY BE LIABLE FOR ANY +DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + +import numpy as np +import io +from PIL import Image + +def GetJpeg(img): + ''' + Save a numpy array as a Jpeg, then get it out as a binary blob + ''' + im = Image.fromarray(np.uint8(img)) + output = io.BytesIO() + im.save(output, format="JPEG", quality=80) + return output.getvalue() + +def JpegToNumpy(jpeg): + stream = io.BytesIO(jpeg) + im = Image.open(stream) + return np.asarray(im, dtype=np.uint8) + +def ConvertJpegListToNumpy(data): + length = len(data) + imgs = [] + for raw in data: + imgs.append(JpegToNumpy(raw)) + arr = np.array(imgs) + return arr + +def DepthToZBuffer(img, z_near, z_far): + real_depth = z_near * z_far / (z_far - img * (z_far - z_near)) + return real_depth + +def ZBufferToRGB(img, z_near, z_far): + real_depth = z_near * z_far / (z_far - img * (z_far - z_near)) + depth_m = np.uint8(real_depth) + depth_cm = np.uint8((real_depth-depth_m)*100) + depth_tmm = np.uint8((real_depth-depth_m-0.01*depth_cm)*10000) + return np.dstack([depth_m, depth_cm, depth_tmm]) + +def RGBToDepth(img, min_dist=0., max_dist=2.,): + return (img[:,:,0]+.01*img[:,:,1]+.0001*img[:,:,2]).clip(min_dist, max_dist) + #return img[:,:,0]+.01*img[:,:,1]+.0001*img[:,:,2] + +def MaskToRGBA(img): + buf = img.astype(np.int32) + A = buf.astype(np.uint8) + buf = np.right_shift(buf, 8) + B = buf.astype(np.uint8) + buf = np.right_shift(buf, 8) + G = buf.astype(np.uint8) + buf = np.right_shift(buf, 8) + R = buf.astype(np.uint8) + + dims = [np.expand_dims(d, -1) for d in [R,G,B,A]] + return np.concatenate(dims, axis=-1) + +def RGBAToMask(img): + mask = np.zeros(img.shape[:-1], dtype=np.int32) + buf = img.astype(np.int32) + for i, dim in enumerate([3,2,1,0]): + shift = 8*i + #print(i, dim, shift, buf[0,0,dim], np.left_shift(buf[0,0,dim], shift)) + mask += np.left_shift(buf[:,:, dim], shift) + return mask + +def RGBAArrayToMasks(img): + mask = np.zeros(img.shape[:-1], dtype=np.int32) + buf = img.astype(np.int32) + for i, dim in enumerate([3,2,1,0]): + shift = 8*i + mask += np.left_shift(buf[:,:,:, dim], shift) + return mask + +def GetPNG(img): + ''' + Save a numpy array as a PNG, then get it out as a binary blob + ''' + im = Image.fromarray(np.uint8(img)) + output = io.BytesIO() + im.save(output, format="PNG")#, quality=80) + return output.getvalue() + +def PNGToNumpy(png): + stream = io.BytesIO(png) + im = Image.open(stream) + return np.array(im, dtype=np.uint8) + +def ConvertPNGListToNumpy(data): + length = len(data) + imgs = [] + for raw in data: + imgs.append(PNGToNumpy(raw)) + arr = np.array(imgs) + return arr + +def ConvertDepthPNGListToNumpy(data): + length = len(data) + imgs = [] + for raw in data: + imgs.append(RGBToDepth(PNGToNumpy(raw))) + arr = np.array(imgs) + return arr + +import cv2 +def Shrink(img, nw=64): + h,w = img.shape[:2] + ratio = float(nw) / w + nh = int(ratio * h) + img2 = cv2.resize(img, dsize=(nw, nh), + interpolation=cv2.INTER_NEAREST) + return img2 + +def ShrinkSmooth(img, nw=64): + h,w = img.shape[:2] + ratio = float(nw) / w + nh = int(ratio * h) + img2 = cv2.resize(img, dsize=(nw, nh), + interpolation=cv2.INTER_LINEAR) + return img2 + +def CropCenter(img, cropx, cropy): + y = img.shape[0] + x = img.shape[1] + startx = (x // 2) - (cropx // 2) + starty = (y // 2) - (cropy // 2) + return img[starty: starty + cropy, startx : startx + cropx] + diff --git a/src/StructDiffusion/utils/brain2/pose.py b/src/StructDiffusion/utils/brain2/pose.py new file mode 100644 index 0000000000000000000000000000000000000000..e314f58f6c1942b1f62e827130a681b829d4e5f1 --- /dev/null +++ b/src/StructDiffusion/utils/brain2/pose.py @@ -0,0 +1,19 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + + +from __future__ import print_function + +import StructDiffusion.utils.transformations as tra + + +def make_pose(trans, rot): + """Make 4x4 matrix from (trans, rot)""" + pose = tra.quaternion_matrix(rot) + pose[:3, 3] = trans + return pose diff --git a/src/StructDiffusion/utils/files.py b/src/StructDiffusion/utils/files.py new file mode 100644 index 0000000000000000000000000000000000000000..ffd77d30a73ade07f627ebe87dc0b4ce2e7230ed --- /dev/null +++ b/src/StructDiffusion/utils/files.py @@ -0,0 +1,9 @@ +import os + +def get_checkpoint_path_from_dir(checkpoint_dir): + checkpoint_path = None + for file in os.listdir(checkpoint_dir): + if "ckpt" in file: + checkpoint_path = os.path.join(checkpoint_dir, file) + assert checkpoint_path is not None + return checkpoint_path \ No newline at end of file diff --git a/src/StructDiffusion/utils/physics_eval.py b/src/StructDiffusion/utils/physics_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..aa14d5f92c6239bf7b86039d5020fa54eb8eae58 --- /dev/null +++ b/src/StructDiffusion/utils/physics_eval.py @@ -0,0 +1,295 @@ +import sys +import os +import h5py +import torch +import pytorch3d.transforms as tra3d + +from StructDiffusion.utils.rearrangement import show_pcs_color_order +from StructDiffusion.utils.pointnet import random_point_sample, index_points + + +def switch_stdout(stdout_filename=None): + if stdout_filename: + print("setting stdout to {}".format(stdout_filename)) + if os.path.exists(stdout_filename): + sys.stdout = open(stdout_filename, 'a') + else: + sys.stdout = open(stdout_filename, 'w') + else: + sys.stdout = sys.__stdout__ + + +def visualize_batch_pcs(obj_xyzs, B, N, P, verbose=True, limit_B=None): + if limit_B is None: + limit_B = B + + vis_obj_xyzs = obj_xyzs.reshape(B, N, P, -1) + vis_obj_xyzs = vis_obj_xyzs[:limit_B] + + if type(vis_obj_xyzs).__module__ == torch.__name__: + if vis_obj_xyzs.is_cuda: + vis_obj_xyzs = vis_obj_xyzs.detach().cpu() + vis_obj_xyzs = vis_obj_xyzs.numpy() + + for bi, vis_obj_xyz in enumerate(vis_obj_xyzs): + if verbose: + print("example {}".format(bi)) + print(vis_obj_xyz.shape) + show_pcs_color_order([xyz[:, :3] for xyz in vis_obj_xyz], None, visualize=True, add_coordinate_frame=True, add_table=False) + + +def convert_bool(d): + for k in d: + if type(d[k]) == list: + d[k] = [bool(i) for i in d[k]] + else: + d[k] = bool(d[k]) + return d + + +def save_dict_to_h5(dict_data, filename): + fh = h5py.File(filename, 'w') + for k in dict_data: + key_data = dict_data[k] + if key_data is None: + raise RuntimeError('data was not properly populated') + # if type(key_data) is dict: + # key_data = json.dumps(key_data, sort_keys=True) + try: + fh.create_dataset(k, data=key_data) + except TypeError as e: + print("Failure on key", k) + print(key_data) + print(e) + raise e + fh.close() + + +def move_pc_and_create_scene_new(obj_xyzs, obj_params, struct_pose, current_pc_pose, target_object_inds, device, + return_scene_pts=False, return_scene_pts_and_pc_idxs=False, num_scene_pts=None, normalize_pc=False, + return_pair_pc=False, num_pair_pc_pts=None, normalize_pair_pc=False): + + # obj_xyzs: N, P, 3 + # obj_params: B, N, 6 + # struct_pose: B x N, 4, 4 + # current_pc_pose: B x N, 4, 4 + # target_object_inds: 1, N + + B, N, _ = obj_params.shape + _, P, _ = obj_xyzs.shape + + # B, N, 6 + flat_obj_params = obj_params.reshape(B * N, -1) + goal_pc_pose_in_struct = torch.eye(4).repeat(B * N, 1, 1).to(device) + goal_pc_pose_in_struct[:, :3, :3] = tra3d.euler_angles_to_matrix(flat_obj_params[:, 3:], "XYZ") + goal_pc_pose_in_struct[:, :3, 3] = flat_obj_params[:, :3] # B x N, 4, 4 + + goal_pc_pose = struct_pose @ goal_pc_pose_in_struct + goal_pc_transform = goal_pc_pose @ torch.inverse(current_pc_pose) # cur_batch_size x N, 4, 4 + + # important: pytorch3d uses row-major ordering, need to transpose each transformation matrix + transpose = tra3d.Transform3d(matrix=goal_pc_transform.transpose(1, 2)) + + # obj_xyzs: N, P, 3 + new_obj_xyzs = obj_xyzs.repeat(B, 1, 1) + new_obj_xyzs = transpose.transform_points(new_obj_xyzs) + + # put it back to B, N, P, 3 + new_obj_xyzs = new_obj_xyzs.reshape(B, N, P, -1) + # visualize_batch_pcs(new_obj_xyzs, S, N, P) + + + # initialize the additional outputs + subsampled_scene_xyz = None + subsampled_pc_idxs = None + obj_pair_xyzs = None + + # =================================== + # Pass to discriminator + if return_scene_pts: + + num_indicator = N + + # add one hot + indicator_variables = torch.eye(num_indicator).repeat(B, 1, 1, P).reshape(B, num_indicator, P, num_indicator).to(device) # B, N, P, N + # print(indicator_variables.shape) + # print(new_obj_xyzs.shape) + new_obj_xyzs = torch.cat([new_obj_xyzs, indicator_variables], dim=-1) # B, N, P, 3 + N + + # combine pcs in each scene + scene_xyzs = new_obj_xyzs.reshape(B, N * P, 3 + N) + + # ToDo: maybe convert this to a batch operation + subsampled_scene_xyz = torch.FloatTensor(B, num_scene_pts, 3 + N).to(device) + for si, scene_xyz in enumerate(scene_xyzs): + # scene_xyz: N*P, 3+N + # target_object_inds: 1, N + subsample_idx = torch.randint(0, torch.sum(target_object_inds[0]) * P, (num_scene_pts,)).to(device) + subsampled_scene_xyz[si] = scene_xyz[subsample_idx] + + # # debug: + # print("-"*50) + # if si < 10: + # trimesh.PointCloud(scene_xyz[:, :3].cpu().numpy(), colors=[255, 0, 0, 255]).show() + # trimesh.PointCloud(subsampled_scene_xyz[si, :, :3].cpu().numpy(), colors=[0, 255, 0, 255]).show() + + # subsampled_scene_xyz: B, num_scene_pts, 3+N + # new_obj_xyzs: B, N, P, 3 + # goal_pc_pose: B, N, 4, 4 + + # important: + if normalize_pc: + subsampled_scene_xyz[:, :, 0:3] = pc_normalize_batch(subsampled_scene_xyz[:, :, 0:3]) + + # # debug: + # for si in range(10): + # trimesh.PointCloud(subsampled_scene_xyz[si, :, :3].cpu().numpy(), colors=[0, 0, 255, 255]).show() + + if return_scene_pts_and_pc_idxs: + num_indicator = N + pc_idxs = torch.arange(0, num_indicator)[:, None].repeat(B, 1, P).reshape(B, num_indicator, P).to(device) # B, N, P + # new_obj_xyzs: B, N, P, 3 + 1 + + # combine pcs in each scene + scene_xyzs = new_obj_xyzs.reshape(B, N * P, 3) + pc_idxs = pc_idxs.reshape(B, N*P) + + subsampled_scene_xyz = torch.FloatTensor(B, num_scene_pts, 3).to(device) + subsampled_pc_idxs = torch.LongTensor(B, num_scene_pts).to(device) + for si, (scene_xyz, pc_idx) in enumerate(zip(scene_xyzs, pc_idxs)): + # scene_xyz: N*P, 3+1 + # target_object_inds: 1, N + subsample_idx = torch.randint(0, torch.sum(target_object_inds[0]) * P, (num_scene_pts,)).to(device) + subsampled_scene_xyz[si] = scene_xyz[subsample_idx] + subsampled_pc_idxs[si] = pc_idx[subsample_idx] + + # subsampled_scene_xyz: B, num_scene_pts, 3 + # subsampled_pc_idxs: B, num_scene_pts + # new_obj_xyzs: B, N, P, 3 + # goal_pc_pose: B, N, 4, 4 + + # important: + if normalize_pc: + subsampled_scene_xyz[:, :, 0:3] = pc_normalize_batch(subsampled_scene_xyz[:, :, 0:3]) + + # TODO: visualize each individual object + # debug + # print(subsampled_scene_xyz.shape) + # print(subsampled_pc_idxs.shape) + # print("visualize subsampled scene") + # for si in range(5): + # trimesh.PointCloud(subsampled_scene_xyz[si, :, :3].cpu().numpy(), colors=[0, 0, 255, 255]).show() + + ############################################### + # Create input for pairwise collision detector + if return_pair_pc: + + assert num_pair_pc_pts is not None + + # new_obj_xyzs: B, N, P, 3 + N + # target_object_inds: 1, N + # ignore paddings + num_objs = torch.sum(target_object_inds[0]) + obj_pair_idxs = torch.combinations(torch.arange(num_objs), r=2) # num_comb, 2 + + # use [:, :, :, :3] to get obj_xyzs without object-wise indicator + obj_pair_xyzs = new_obj_xyzs[:, :, :, :3][:, obj_pair_idxs] # B, num_comb, 2 (obj 1 and obj 2), P, 3 + num_comb = obj_pair_xyzs.shape[1] + pair_indicator_variables = torch.eye(2).repeat(B, num_comb, 1, 1, P).reshape(B, num_comb, 2, P, 2).to(device) # B, num_comb, 2, P, 2 + obj_pair_xyzs = torch.cat([obj_pair_xyzs, pair_indicator_variables], dim=-1) # B, num_comb, 2, P, 3 (pc channels) + 2 (indicator for obj 1 and obj 2) + obj_pair_xyzs = obj_pair_xyzs.reshape(B, num_comb, P * 2, 5) + + # random sample: idx = np.random.randint(0, scene_xyz.shape[0], self.num_scene_pts) + obj_pair_xyzs = obj_pair_xyzs.reshape(B * num_comb, P * 2, 5) + # random_point_sample() input dim: B, N, C + rand_idxs = random_point_sample(obj_pair_xyzs, num_pair_pc_pts) # B * num_comb, num_pair_pc_pts + obj_pair_xyzs = index_points(obj_pair_xyzs, rand_idxs) # B * num_comb, num_pair_pc_pts, 5 + + if normalize_pair_pc: + # pc_normalize_batch() input dim: pc: B, num_scene_pts, 3 + # obj_pair_xyzs = obj_pair_xyzs.reshape(B * num_comb, num_pair_pc_pts, 5) + obj_pair_xyzs[:, :, 0:3] = pc_normalize_batch(obj_pair_xyzs[:, :, 0:3]) + obj_pair_xyzs = obj_pair_xyzs.reshape(B, num_comb, num_pair_pc_pts, 5) + + # # debug + # for bi, this_obj_pair_xyzs in enumerate(obj_pair_xyzs): + # print("batch id", bi) + # for pi, obj_pair_xyz in enumerate(this_obj_pair_xyzs): + # print("pair", pi) + # # obj_pair_xyzs: 2 * P, 5 + # print(obj_pair_xyz[:, :3].shape) + # trimesh.PointCloud(obj_pair_xyz[:, :3].cpu()).show() + + # obj_pair_xyzs: B, num_comb, num_pair_pc_pts, 3 + 2 + goal_pc_pose = goal_pc_pose.reshape(B, N, 4, 4) + + return new_obj_xyzs, goal_pc_pose, subsampled_scene_xyz, subsampled_pc_idxs, obj_pair_xyzs + + + +def move_pc(obj_xyzs, obj_params, struct_pose, current_pc_pose, device): + + # obj_xyzs: N, P, 3 + # obj_params: B, N, 6 + # struct_pose: B x N, 4, 4 + # current_pc_pose: B x N, 4, 4 + # target_object_inds: 1, N + + B, N, _ = obj_params.shape + _, P, _ = obj_xyzs.shape + + # B, N, 6 + flat_obj_params = obj_params.reshape(B * N, -1) + goal_pc_pose_in_struct = torch.eye(4).repeat(B * N, 1, 1).to(device) + goal_pc_pose_in_struct[:, :3, :3] = tra3d.euler_angles_to_matrix(flat_obj_params[:, 3:], "XYZ") + goal_pc_pose_in_struct[:, :3, 3] = flat_obj_params[:, :3] # B x N, 4, 4 + + goal_pc_pose = struct_pose @ goal_pc_pose_in_struct + goal_pc_transform = goal_pc_pose @ torch.inverse(current_pc_pose) # cur_batch_size x N, 4, 4 + + # important: pytorch3d uses row-major ordering, need to transpose each transformation matrix + transpose = tra3d.Transform3d(matrix=goal_pc_transform.transpose(1, 2)) + + # obj_xyzs: N, P, 3 + new_obj_xyzs = obj_xyzs.repeat(B, 1, 1) + new_obj_xyzs = transpose.transform_points(new_obj_xyzs) + + # put it back to B, N, P, 3 + new_obj_xyzs = new_obj_xyzs.reshape(B, N, P, -1) + # visualize_batch_pcs(new_obj_xyzs, S, N, P) + + # subsampled_scene_xyz: B, num_scene_pts, 3+N + # new_obj_xyzs: B, N, P, 3 + # goal_pc_pose: B, N, 4, 4 + + goal_pc_pose = goal_pc_pose.reshape(B, N, 4, 4) + return new_obj_xyzs, goal_pc_pose + + +def sample_gaussians(mus, sigmas, sample_size): + # mus: [number of individual gaussians] + # sigmas: [number of individual gaussians] + normal = torch.distributions.Normal(mus, sigmas) + samples = normal.sample((sample_size,)) + # samples: [sample_size, number of individual gaussians] + return samples + +def fit_gaussians(samples, sigma_eps=0.01): + device = samples.device + + # samples: [sample_size, number of individual gaussians] + num_gs = samples.shape[1] + mus = torch.mean(samples, dim=0).to(device) + sigmas = torch.std(samples, dim=0).to(device) + sigma_eps * torch.ones(num_gs).to(device) + # mus: [number of individual gaussians] + # sigmas: [number of individual gaussians] + return mus, sigmas + + +def pc_normalize_batch(pc): + # pc: B, num_scene_pts, 3 + centroid = torch.mean(pc, dim=1) # B, 3 + pc = pc - centroid[:, None, :] + m = torch.max(torch.sqrt(torch.sum(pc ** 2, dim=2)), dim=1)[0] + pc = pc / m[:, None, None] + return pc diff --git a/src/StructDiffusion/utils/pointnet.py b/src/StructDiffusion/utils/pointnet.py new file mode 100644 index 0000000000000000000000000000000000000000..8f7371b6d4250d86668a16a396572286a3f91bbc --- /dev/null +++ b/src/StructDiffusion/utils/pointnet.py @@ -0,0 +1,583 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from time import time +import numpy as np + + +# reference https://github.com/yanx27/Pointnet_Pointnet2_pytorch, modified by Yang You + + +def timeit(tag, t): + print("{}: {}s".format(tag, time() - t)) + return time() + +def pc_normalize(pc): + if type(pc).__module__ == np.__name__: + centroid = np.mean(pc, axis=0) + pc = pc - centroid + m = np.max(np.sqrt(np.sum(pc**2, axis=1))) + pc = pc / m + else: + centroid = torch.mean(pc, dim=0) + pc = pc - centroid + m = torch.max(torch.sqrt(torch.sum(pc ** 2, dim=1))) + pc = pc / m + return pc + +def square_distance(src, dst): + """ + Calculate Euclid distance between each two points. + src^T * dst = xn * xm + yn * ym + zn * zmï¼› + sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; + sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; + dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 + = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst + Input: + src: source points, [B, N, C] + dst: target points, [B, M, C] + Output: + dist: per-point square distance, [B, N, M] + """ + return torch.sum((src[:, :, None] - dst[:, None]) ** 2, dim=-1) + + +def index_points(points, idx): + """ + Input: + points: input points data, [B, N, C] + idx: sample index data, [B, S, [K]] + Return: + new_points:, indexed points data, [B, S, [K], C] + """ + raw_size = idx.size() + idx = idx.reshape(raw_size[0], -1) + res = torch.gather(points, 1, idx[..., None].expand(-1, -1, points.size(-1))) + return res.reshape(*raw_size, -1) + + +def farthest_point_sample(xyz, npoint): + """ + Input: + xyz: pointcloud data, [B, N, 3] + npoint: number of samples + Return: + centroids: sampled pointcloud index, [B, npoint] + """ + device = xyz.device + B, N, C = xyz.shape + centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) + distance = torch.ones(B, N).to(device) * 1e10 + farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) + batch_indices = torch.arange(B, dtype=torch.long).to(device) + for i in range(npoint): + centroids[:, i] = farthest + centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) + dist = torch.sum((xyz - centroid) ** 2, -1) + distance = torch.min(distance, dist) + farthest = torch.max(distance, -1)[1] + return centroids + + +def random_point_sample(xyz, npoint): + """ + Input: + xyz: pointcloud data, [B, N, 3] + npoint: number of samples + Return: + idxs: sampled pointcloud index, [B, npoint] + """ + device = xyz.device + B, N, C = xyz.shape + idxs = torch.randint(0, N, (B, npoint), dtype=torch.long).to(device) + return idxs + + +def query_ball_point(radius, nsample, xyz, new_xyz): + """ + Input: + radius: local region radius + nsample: max sample number in local region + xyz: all points, [B, N, 3] + new_xyz: query points, [B, S, 3] + Return: + group_idx: grouped points index, [B, S, nsample] + """ + device = xyz.device + B, N, C = xyz.shape + _, S, _ = new_xyz.shape + group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) + sqrdists = square_distance(new_xyz, xyz) + group_idx[sqrdists > radius ** 2] = N + group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] + group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) + mask = group_idx == N + group_idx[mask] = group_first[mask] + return group_idx + + +def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False, knn=False): + """ + Input: + npoint: + radius: + nsample: + xyz: input points position data, [B, N, 3] + points: input points data, [B, N, D] + Return: + new_xyz: sampled points position data, [B, npoint, nsample, 3] + new_points: sampled points data, [B, npoint, nsample, 3+D] + """ + B, N, C = xyz.shape + S = npoint + fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint] + torch.cuda.empty_cache() + new_xyz = index_points(xyz, fps_idx) + torch.cuda.empty_cache() + if knn: + dists = square_distance(new_xyz, xyz) # B x npoint x N + idx = dists.argsort()[:, :, :nsample] # B x npoint x K + else: + idx = query_ball_point(radius, nsample, xyz, new_xyz) + torch.cuda.empty_cache() + grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C] + torch.cuda.empty_cache() + grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) + torch.cuda.empty_cache() + + if points is not None: + grouped_points = index_points(points, idx) + new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D] + else: + new_points = grouped_xyz_norm + if returnfps: + return new_xyz, new_points, grouped_xyz, fps_idx + else: + return new_xyz, new_points + + +def sample_and_group_all(xyz, points): + """ + Input: + xyz: input points position data, [B, N, 3] + points: input points data, [B, N, D] + Return: + new_xyz: sampled points position data, [B, 1, 3] + new_points: sampled points data, [B, 1, N, 3+D] + """ + device = xyz.device + B, N, C = xyz.shape + new_xyz = torch.zeros(B, 1, C).to(device) + grouped_xyz = xyz.view(B, 1, N, C) + if points is not None: + new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1) + else: + new_points = grouped_xyz + return new_xyz, new_points + + +class PointNetSetAbstraction(nn.Module): + def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all, knn=False): + super(PointNetSetAbstraction, self).__init__() + self.npoint = npoint + self.radius = radius + self.nsample = nsample + self.knn = knn + self.mlp_convs = nn.ModuleList() + self.mlp_bns = nn.ModuleList() + last_channel = in_channel + for out_channel in mlp: + self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1)) + self.mlp_bns.append(nn.BatchNorm2d(out_channel)) + last_channel = out_channel + self.group_all = group_all + + def forward(self, xyz, points): + """ + Input: + xyz: input points position data, [B, N, C] + points: input points data, [B, N, C] + Return: + new_xyz: sampled points position data, [B, S, C] + new_points_concat: sample points feature data, [B, S, D'] + """ + if self.group_all: + new_xyz, new_points = sample_and_group_all(xyz, points) + else: + new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points, knn=self.knn) + # new_xyz: sampled points position data, [B, npoint, C] + # new_points: sampled points data, [B, npoint, nsample, C+D] + new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint] + for i, conv in enumerate(self.mlp_convs): + bn = self.mlp_bns[i] + new_points = F.relu(bn(conv(new_points))) + + new_points = torch.max(new_points, 2)[0].transpose(1, 2) + return new_xyz, new_points + + +class PointNetSetAbstractionMsg(nn.Module): + def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list, knn=False): + super(PointNetSetAbstractionMsg, self).__init__() + self.npoint = npoint + self.radius_list = radius_list + self.nsample_list = nsample_list + self.knn = knn + self.conv_blocks = nn.ModuleList() + self.bn_blocks = nn.ModuleList() + for i in range(len(mlp_list)): + convs = nn.ModuleList() + bns = nn.ModuleList() + last_channel = in_channel + 3 + for out_channel in mlp_list[i]: + convs.append(nn.Conv2d(last_channel, out_channel, 1)) + bns.append(nn.BatchNorm2d(out_channel)) + last_channel = out_channel + self.conv_blocks.append(convs) + self.bn_blocks.append(bns) + + def forward(self, xyz, points, seed_idx=None): + """ + Input: + xyz: input points position data, [B, C, N] + points: input points data, [B, D, N] + Return: + new_xyz: sampled points position data, [B, C, S] + new_points_concat: sample points feature data, [B, D', S] + """ + + B, N, C = xyz.shape + S = self.npoint + new_xyz = index_points(xyz, farthest_point_sample(xyz, S) if seed_idx is None else seed_idx) + new_points_list = [] + for i, radius in enumerate(self.radius_list): + K = self.nsample_list[i] + if self.knn: + dists = square_distance(new_xyz, xyz) # B x npoint x N + group_idx = dists.argsort()[:, :, :K] # B x npoint x K + else: + group_idx = query_ball_point(radius, K, xyz, new_xyz) + grouped_xyz = index_points(xyz, group_idx) + grouped_xyz -= new_xyz.view(B, S, 1, C) + if points is not None: + grouped_points = index_points(points, group_idx) + grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1) + else: + grouped_points = grouped_xyz + + grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S] + for j in range(len(self.conv_blocks[i])): + conv = self.conv_blocks[i][j] + bn = self.bn_blocks[i][j] + grouped_points = F.relu(bn(conv(grouped_points))) + new_points = torch.max(grouped_points, 2)[0] # [B, D', S] + new_points_list.append(new_points) + + new_points_concat = torch.cat(new_points_list, dim=1).transpose(1, 2) + return new_xyz, new_points_concat + + +# NoteL this function swaps N and C +class PointNetFeaturePropagation(nn.Module): + def __init__(self, in_channel, mlp): + super(PointNetFeaturePropagation, self).__init__() + self.mlp_convs = nn.ModuleList() + self.mlp_bns = nn.ModuleList() + last_channel = in_channel + for out_channel in mlp: + self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1)) + self.mlp_bns.append(nn.BatchNorm1d(out_channel)) + last_channel = out_channel + + def forward(self, xyz1, xyz2, points1, points2): + """ + Input: + xyz1: input points position data, [B, C, N] + xyz2: sampled input points position data, [B, C, S] + points1: input points data, [B, D, N] + points2: input points data, [B, D, S] + Return: + new_points: upsampled points data, [B, D', N] + """ + xyz1 = xyz1.permute(0, 2, 1) + xyz2 = xyz2.permute(0, 2, 1) + + points2 = points2.permute(0, 2, 1) + B, N, C = xyz1.shape + _, S, _ = xyz2.shape + + if S == 1: + interpolated_points = points2.repeat(1, N, 1) + else: + dists = square_distance(xyz1, xyz2) + dists, idx = dists.sort(dim=-1) + dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3] + + dist_recip = 1.0 / (dists + 1e-8) + norm = torch.sum(dist_recip, dim=2, keepdim=True) + weight = dist_recip / norm + interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2) + + if points1 is not None: + points1 = points1.permute(0, 2, 1) + new_points = torch.cat([points1, interpolated_points], dim=-1) + else: + new_points = interpolated_points + + new_points = new_points.permute(0, 2, 1) + for i, conv in enumerate(self.mlp_convs): + bn = self.mlp_bns[i] + new_points = F.relu(bn(conv(new_points))) + return new_points + + +# reference https://github.com/qq456cvb/Point-Transformers + +def normalize_data(batch_data): + """ Normalize the batch data, use coordinates of the block centered at origin, + Input: + BxNxC array + Output: + BxNxC array + """ + B, N, C = batch_data.shape + normal_data = np.zeros((B, N, C)) + for b in range(B): + pc = batch_data[b] + centroid = np.mean(pc, axis=0) + pc = pc - centroid + m = np.max(np.sqrt(np.sum(pc ** 2, axis=1))) + pc = pc / m + normal_data[b] = pc + return normal_data + + +def shuffle_data(data, labels): + """ Shuffle data and labels. + Input: + data: B,N,... numpy array + label: B,... numpy array + Return: + shuffled data, label and shuffle indices + """ + idx = np.arange(len(labels)) + np.random.shuffle(idx) + return data[idx, ...], labels[idx], idx + +def shuffle_points(batch_data): + """ Shuffle orders of points in each point cloud -- changes FPS behavior. + Use the same shuffling idx for the entire batch. + Input: + BxNxC array + Output: + BxNxC array + """ + idx = np.arange(batch_data.shape[1]) + np.random.shuffle(idx) + return batch_data[:,idx,:] + +def rotate_point_cloud(batch_data): + """ Randomly rotate the point clouds to augument the dataset + rotation is per shape based along up direction + Input: + BxNx3 array, original batch of point clouds + Return: + BxNx3 array, rotated batch of point clouds + """ + rotated_data = np.zeros(batch_data.shape, dtype=np.float32) + for k in range(batch_data.shape[0]): + rotation_angle = np.random.uniform() * 2 * np.pi + cosval = np.cos(rotation_angle) + sinval = np.sin(rotation_angle) + rotation_matrix = np.array([[cosval, 0, sinval], + [0, 1, 0], + [-sinval, 0, cosval]]) + shape_pc = batch_data[k, ...] + rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) + return rotated_data + +def rotate_point_cloud_z(batch_data): + """ Randomly rotate the point clouds to augument the dataset + rotation is per shape based along up direction + Input: + BxNx3 array, original batch of point clouds + Return: + BxNx3 array, rotated batch of point clouds + """ + rotated_data = np.zeros(batch_data.shape, dtype=np.float32) + for k in range(batch_data.shape[0]): + rotation_angle = np.random.uniform() * 2 * np.pi + cosval = np.cos(rotation_angle) + sinval = np.sin(rotation_angle) + rotation_matrix = np.array([[cosval, sinval, 0], + [-sinval, cosval, 0], + [0, 0, 1]]) + shape_pc = batch_data[k, ...] + rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) + return rotated_data + +def rotate_point_cloud_with_normal(batch_xyz_normal): + ''' Randomly rotate XYZ, normal point cloud. + Input: + batch_xyz_normal: B,N,6, first three channels are XYZ, last 3 all normal + Output: + B,N,6, rotated XYZ, normal point cloud + ''' + for k in range(batch_xyz_normal.shape[0]): + rotation_angle = np.random.uniform() * 2 * np.pi + cosval = np.cos(rotation_angle) + sinval = np.sin(rotation_angle) + rotation_matrix = np.array([[cosval, 0, sinval], + [0, 1, 0], + [-sinval, 0, cosval]]) + shape_pc = batch_xyz_normal[k,:,0:3] + shape_normal = batch_xyz_normal[k,:,3:6] + batch_xyz_normal[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) + batch_xyz_normal[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), rotation_matrix) + return batch_xyz_normal + +def rotate_perturbation_point_cloud_with_normal(batch_data, angle_sigma=0.06, angle_clip=0.18): + """ Randomly perturb the point clouds by small rotations + Input: + BxNx6 array, original batch of point clouds and point normals + Return: + BxNx3 array, rotated batch of point clouds + """ + rotated_data = np.zeros(batch_data.shape, dtype=np.float32) + for k in range(batch_data.shape[0]): + angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip) + Rx = np.array([[1,0,0], + [0,np.cos(angles[0]),-np.sin(angles[0])], + [0,np.sin(angles[0]),np.cos(angles[0])]]) + Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])], + [0,1,0], + [-np.sin(angles[1]),0,np.cos(angles[1])]]) + Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0], + [np.sin(angles[2]),np.cos(angles[2]),0], + [0,0,1]]) + R = np.dot(Rz, np.dot(Ry,Rx)) + shape_pc = batch_data[k,:,0:3] + shape_normal = batch_data[k,:,3:6] + rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), R) + rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), R) + return rotated_data + + +def rotate_point_cloud_by_angle(batch_data, rotation_angle): + """ Rotate the point cloud along up direction with certain angle. + Input: + BxNx3 array, original batch of point clouds + Return: + BxNx3 array, rotated batch of point clouds + """ + rotated_data = np.zeros(batch_data.shape, dtype=np.float32) + for k in range(batch_data.shape[0]): + #rotation_angle = np.random.uniform() * 2 * np.pi + cosval = np.cos(rotation_angle) + sinval = np.sin(rotation_angle) + rotation_matrix = np.array([[cosval, 0, sinval], + [0, 1, 0], + [-sinval, 0, cosval]]) + shape_pc = batch_data[k,:,0:3] + rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) + return rotated_data + +def rotate_point_cloud_by_angle_with_normal(batch_data, rotation_angle): + """ Rotate the point cloud along up direction with certain angle. + Input: + BxNx6 array, original batch of point clouds with normal + scalar, angle of rotation + Return: + BxNx6 array, rotated batch of point clouds iwth normal + """ + rotated_data = np.zeros(batch_data.shape, dtype=np.float32) + for k in range(batch_data.shape[0]): + #rotation_angle = np.random.uniform() * 2 * np.pi + cosval = np.cos(rotation_angle) + sinval = np.sin(rotation_angle) + rotation_matrix = np.array([[cosval, 0, sinval], + [0, 1, 0], + [-sinval, 0, cosval]]) + shape_pc = batch_data[k,:,0:3] + shape_normal = batch_data[k,:,3:6] + rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) + rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1,3)), rotation_matrix) + return rotated_data + + + +def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18): + """ Randomly perturb the point clouds by small rotations + Input: + BxNx3 array, original batch of point clouds + Return: + BxNx3 array, rotated batch of point clouds + """ + rotated_data = np.zeros(batch_data.shape, dtype=np.float32) + for k in range(batch_data.shape[0]): + angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip) + Rx = np.array([[1,0,0], + [0,np.cos(angles[0]),-np.sin(angles[0])], + [0,np.sin(angles[0]),np.cos(angles[0])]]) + Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])], + [0,1,0], + [-np.sin(angles[1]),0,np.cos(angles[1])]]) + Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0], + [np.sin(angles[2]),np.cos(angles[2]),0], + [0,0,1]]) + R = np.dot(Rz, np.dot(Ry,Rx)) + shape_pc = batch_data[k, ...] + rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R) + return rotated_data + + +def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05): + """ Randomly jitter points. jittering is per point. + Input: + BxNx3 array, original batch of point clouds + Return: + BxNx3 array, jittered batch of point clouds + """ + B, N, C = batch_data.shape + assert(clip > 0) + jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip) + jittered_data += batch_data + return jittered_data + +def shift_point_cloud(batch_data, shift_range=0.1): + """ Randomly shift point cloud. Shift is per point cloud. + Input: + BxNx3 array, original batch of point clouds + Return: + BxNx3 array, shifted batch of point clouds + """ + B, N, C = batch_data.shape + shifts = np.random.uniform(-shift_range, shift_range, (B,3)) + for batch_index in range(B): + batch_data[batch_index,:,:] += shifts[batch_index,:] + return batch_data + + +def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25): + """ Randomly scale the point cloud. Scale is per point cloud. + Input: + BxNx3 array, original batch of point clouds + Return: + BxNx3 array, scaled batch of point clouds + """ + B, N, C = batch_data.shape + scales = np.random.uniform(scale_low, scale_high, B) + for batch_index in range(B): + batch_data[batch_index,:,:] *= scales[batch_index] + return batch_data + +def random_point_dropout(batch_pc, max_dropout_ratio=0.875): + ''' batch_pc: BxNx3 ''' + for b in range(batch_pc.shape[0]): + dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875 + drop_idx = np.where(np.random.random((batch_pc.shape[1]))<=dropout_ratio)[0] + if len(drop_idx)>0: + batch_pc[b,drop_idx,:] = batch_pc[b,0,:] # set to the first point + return batch_pc + + diff --git a/src/StructDiffusion/utils/random.py b/src/StructDiffusion/utils/random.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/StructDiffusion/utils/rearrangement.py b/src/StructDiffusion/utils/rearrangement.py new file mode 100644 index 0000000000000000000000000000000000000000..d3957666076f0ce471745a9e8bf37f45a0f60382 --- /dev/null +++ b/src/StructDiffusion/utils/rearrangement.py @@ -0,0 +1,1201 @@ +import copy +import os +import torch +import trimesh +import numpy as np +import open3d +from PIL import Image, ImageDraw, ImageFont +from sklearn.metrics import classification_report +from collections import defaultdict +import matplotlib.pyplot as plt +import itertools +import matplotlib +import h5py +import json + +import StructDiffusion.utils.transformations as tra +from StructDiffusion.utils.rotation_continuity import compute_geodesic_distance_from_two_matrices + +# from pointnet_utils import farthest_point_sample, index_points + + +def flatten1d(img): + return img.reshape(-1) + + +def flatten3d(img): + hw = img.shape[0] * img.shape[1] + return img.reshape(hw, -1) + + +def array_to_tensor(array): + """ Assume arrays are in numpy (channels-last) format and put them into the right one """ + if array.ndim == 4: # NHWC + tensor = torch.from_numpy(array).permute(0,3,1,2).float() + elif array.ndim == 3: # HWC + tensor = torch.from_numpy(array).permute(2,0,1).float() + else: # everything else - just keep it as-is + tensor = torch.from_numpy(array).float() + return tensor + + +def get_pts(xyz_in, rgb_in, mask, bg_mask=None, num_pts=1024, center=None, + radius=0.5, filename=None, to_tensor=True): + + # Get the XYZ and RGB + mask = flatten1d(mask) + assert(np.sum(mask) > 0) + xyz = flatten3d(xyz_in)[mask > 0] + if rgb_in is not None: + rgb = flatten3d(rgb_in)[mask > 0] + + if xyz.shape[0] == 0: + raise RuntimeError('this should not happen') + ok = False + xyz = flatten3d(xyz_in) + if rgb_in is not None: + rgb = flatten3d(rgb_in) + else: + ok = True + + # prune to this region + if center is not None: + # numpy matrix + # use the full xyz point cloud to determine what is close enough + # now that we have the closest background point we can place the object on it + # Just center on the point + center = center.numpy() + center = center[None].repeat(xyz.shape[0], axis=0) + dists = np.linalg.norm(xyz - center, axis=-1) + idx = dists < radius + xyz = xyz[idx] + if rgb_in is not None: + rgb = rgb[idx] + center = center[0] + else: + center = None + + # Compute number of points we are using + if num_pts is not None: + if xyz.shape[0] < 1: + print("!!!! bad shape:", xyz.shape, filename, "!!!!") + return (None, None, None, None) + idx = np.random.randint(0, xyz.shape[0], num_pts) + xyz = xyz[idx] + if rgb_in is not None: + rgb = rgb[idx] + + # Shuffle the points + if rgb_in is not None: + rgb = array_to_tensor(rgb) if to_tensor else rgb + else: + rgb = None + xyz = array_to_tensor(xyz) if to_tensor else xyz + return (ok, xyz, rgb, center) + + +def align(y_true, y_pred): + """ Add or remove 2*pi to predicted angle to minimize difference from GT""" + y_pred = y_pred.copy() + y_pred[y_true - y_pred > np.pi] += np.pi * 2 + y_pred[y_true - y_pred < -np.pi] -= np.pi * 2 + return y_pred + + +def random_move_obj_xyz(obj_xyz, + min_translation, max_translation, + min_rotation, max_rotation, mode, + visualize=False, return_perturbed_obj_xyzs=True): + + assert mode in ["planar", "6d", "3d_planar"] + + if mode == "planar": + random_translation = np.random.uniform(low=min_translation, high=max_translation, size=2) * np.random.choice( + [-1, 1], size=2) + random_rotation = np.random.uniform(low=min_rotation, high=max_rotation) * np.random.choice([-1, 1]) + random_rotation = tra.euler_matrix(0, 0, random_rotation) + elif mode == "6d": + random_rotation = np.random.uniform(low=min_rotation, high=max_rotation, size=3) * np.random.choice([-1, 1], size=3) + random_rotation = tra.euler_matrix(*random_rotation) + random_translation = np.random.uniform(low=min_translation, high=max_translation, size=3) * np.random.choice([-1, 1], size=3) + elif mode == "3d_planar": + random_translation = np.random.uniform(low=min_translation, high=max_translation, size=3) * np.random.choice( + [-1, 1], size=3) + random_rotation = np.random.uniform(low=min_rotation, high=max_rotation) * np.random.choice([-1, 1]) + random_rotation = tra.euler_matrix(0, 0, random_rotation) + + if return_perturbed_obj_xyzs: + raise Exception("return_perturbed_obj_xyzs=True is no longer supported") + # xyz_mean = np.mean(obj_xyz, axis=0) + # new_obj_xyz = obj_xyz - xyz_mean + # new_obj_xyz = trimesh.transform_points(new_obj_xyz, random_rotation, translate=False) + # new_obj_xyz = new_obj_xyz + xyz_mean + random_translation + else: + new_obj_xyz = obj_xyz + + # test moving the perturbed obj pc back + # new_xyz_mean = np.mean(new_obj_xyz, axis=0) + # old_obj_xyz = new_obj_xyz - new_xyz_mean + # old_obj_xyz = trimesh.transform_points(old_obj_xyz, np.linalg.inv(random_rotation), translate=False) + # old_obj_xyz = old_obj_xyz + new_xyz_mean - random_translation + + # even though we are putting perturbation rotation and translation in the same matrix, they should be applied + # independently. More specifically, rotate the object pc in place and then translate it. + perturbation_matrix = random_rotation + perturbation_matrix[:3, 3] = random_translation + + if visualize: + show_pcs([new_obj_xyz, obj_xyz], + [np.tile(np.array([1, 0, 0], dtype=np.float), (obj_xyz.shape[0], 1)), + np.tile(np.array([0, 1, 0], dtype=np.float), (obj_xyz.shape[0], 1))], add_coordinate_frame=True) + + return new_obj_xyz, perturbation_matrix + + +def random_move_obj_xyzs(obj_xyzs, + min_translation, max_translation, + min_rotation, max_rotation, mode, move_obj_idxs=None, visualize=False, return_moved_obj_idxs=False, + return_perturbation=False, return_perturbed_obj_xyzs=True): + """ + + :param obj_xyzs: + :param min_translation: + :param max_translation: + :param min_rotation: + :param max_rotation: + :param mode: + :param move_obj_idxs: + :param visualize: + :param return_moved_obj_idxs: + :param return_perturbation: + :param return_perturbed_obj_xyzs: + :return: + """ + + new_obj_xyzs = [] + new_obj_rgbs = [] + old_obj_rgbs = [] + perturbation_matrices = [] + + if move_obj_idxs is None: + move_obj_idxs = list(range(len(obj_xyzs))) + + # this many objects will not be randomly moved + stationary_obj_idxs = np.random.choice(move_obj_idxs, np.random.randint(0, len(move_obj_idxs)), replace=False).tolist() + + moved_obj_idxs = [] + for obj_idx, obj_xyz in enumerate(obj_xyzs): + + if obj_idx in stationary_obj_idxs: + new_obj_xyzs.append(obj_xyz) + perturbation_matrices.append(np.eye(4)) + if visualize: + new_obj_rgbs.append(np.tile(np.array([1, 0, 0], dtype=np.float), (obj_xyz.shape[0], 1))) + old_obj_rgbs.append(np.tile(np.array([0, 0, 1], dtype=np.float), (obj_xyz.shape[0], 1))) + else: + new_obj_xyz, perturbation_matrix = random_move_obj_xyz(obj_xyz, + min_translation=min_translation, max_translation=max_translation, + min_rotation=min_rotation, max_rotation=max_rotation, mode=mode, + return_perturbed_obj_xyzs=return_perturbed_obj_xyzs) + new_obj_xyzs.append(new_obj_xyz) + moved_obj_idxs.append(obj_idx) + perturbation_matrices.append(perturbation_matrix) + if visualize: + new_obj_rgbs.append(np.tile(np.array([1, 0, 0], dtype=np.float), (obj_xyz.shape[0], 1))) + old_obj_rgbs.append(np.tile(np.array([0, 1, 0], dtype=np.float), (obj_xyz.shape[0], 1))) + if visualize: + show_pcs(new_obj_xyzs + obj_xyzs, + new_obj_rgbs + old_obj_rgbs, add_coordinate_frame=True) + + if return_moved_obj_idxs: + if return_perturbation: + return new_obj_xyzs, moved_obj_idxs, perturbation_matrices + else: + return new_obj_xyzs, moved_obj_idxs + else: + if return_perturbation: + return new_obj_xyzs, perturbation_matrices + else: + return new_obj_xyzs + + +def check_pairwise_collision(pcs, visualize=False): + + voxel_extents = [0.005] * 3 + + collision_managers = [] + collision_objects = [] + + for pc in pcs: + + # farthest point sample + pc = pc.unsqueeze(0) + fps_idx = farthest_point_sample(pc, 100) # [B, npoint] + pc = index_points(pc, fps_idx).squeeze(0) + + pc = np.asanyarray(pc) + # ignore empty pc + if np.all(pc == 0): + continue + + n_points = pc.shape[0] + collision_object = [] + collision_manager = trimesh.collision.CollisionManager() + + # Construct collision objects + for i in range(n_points): + extents = voxel_extents + transform = np.eye(4) + transform[:3, 3] = pc[i, :3] + voxel = trimesh.primitives.Box(extents=extents, transform=transform) + collision_object.append((voxel, extents, transform)) + + # Add to collision manager + for i, (voxel, _, _) in enumerate(collision_object): + collision_manager.add_object("voxel_{}".format(i), voxel) + + collision_managers.append(collision_manager) + collision_objects.append(collision_object) + + in_collision = False + for i, cm_i in enumerate(collision_managers): + for j, cm_j in enumerate(collision_managers): + if i == j: + continue + if cm_i.in_collision_other(cm_j): + in_collision = True + + if visualize: + visualize_collision_objects(collision_objects[i] + collision_objects[j]) + + break + + if in_collision: + break + + return in_collision + + +def check_collision_with(this_pc, other_pcs, visualize=False): + + voxel_extents = [0.005] * 3 + + this_collision_manager = None + this_collision_object = None + other_collision_managers = [] + other_collision_objects = [] + + for oi, pc in enumerate([this_pc] + other_pcs): + + # farthest point sample + pc = pc.unsqueeze(0) + fps_idx = farthest_point_sample(pc, 100) # [B, npoint] + pc = index_points(pc, fps_idx).squeeze(0) + + pc = np.asanyarray(pc) + # ignore empty pc + if np.all(pc == 0): + continue + + n_points = pc.shape[0] + collision_object = [] + collision_manager = trimesh.collision.CollisionManager() + + # Construct collision objects + for i in range(n_points): + extents = voxel_extents + transform = np.eye(4) + transform[:3, 3] = pc[i, :3] + voxel = trimesh.primitives.Box(extents=extents, transform=transform) + collision_object.append((voxel, extents, transform)) + + # Add to collision manager + for i, (voxel, _, _) in enumerate(collision_object): + collision_manager.add_object("voxel_{}".format(i), voxel) + + if oi == 0: + this_collision_manager = collision_manager + this_collision_object = collision_object + else: + other_collision_managers.append(collision_manager) + other_collision_objects.append(collision_object) + + collisions = [] + for i, cm_i in enumerate(other_collision_managers): + if this_collision_manager.in_collision_other(cm_i): + collisions.append(i) + + if visualize: + visualize_collision_objects(this_collision_object + other_collision_objects[i]) + + return collisions + + +def visualize_collision_objects(collision_objects): + + # Convert from trimesh to open3d + meshes_o3d = [] + for elem in collision_objects: + (voxel, extents, transform) = elem + voxel_o3d = open3d.geometry.TriangleMesh.create_box(width=extents[0], height=extents[1], + depth=extents[2]) + voxel_o3d.compute_vertex_normals() + voxel_o3d.paint_uniform_color([0.8, 0.2, 0]) + voxel_o3d.transform(transform) + meshes_o3d.append(voxel_o3d) + meshes = meshes_o3d + + vis = open3d.visualization.Visualizer() + vis.create_window() + + for mesh in meshes: + vis.add_geometry(mesh) + + vis.run() + vis.destroy_window() + + +# def test_collision(pc): +# n_points = pc.shape[0] +# voxel_extents = [0.005] * 3 +# collision_objects = [] +# collision_manager = trimesh.collision.CollisionManager() +# +# # Construct collision objects +# for i in range(n_points): +# extents = voxel_extents +# transform = np.eye(4) +# transform[:3, 3] = pc[i, :3] +# voxel = trimesh.primitives.Box(extents=extents, transform=transform) +# collision_objects.append((voxel, extents, transform)) +# +# # Add to collision manager +# for i, (voxel, _, _) in enumerate(collision_objects): +# collision_manager.add_object("voxel_{}".format(i), voxel) +# +# for i, (voxel, _, _) in enumerate(collision_objects): +# c, names = collision_manager.in_collision_single(voxel, return_names=True) +# if c: +# print(i, names) +# +# # Convert from trimesh to open3d +# meshes_o3d = [] +# for elem in collision_objects: +# (voxel, extents, transform) = elem +# voxel_o3d = open3d.geometry.TriangleMesh.create_box(width=extents[0], height=extents[1], +# depth=extents[2]) +# voxel_o3d.compute_vertex_normals() +# voxel_o3d.paint_uniform_color([0.8, 0.2, 0]) +# voxel_o3d.transform(transform) +# meshes_o3d.append(voxel_o3d) +# meshes = meshes_o3d +# +# vis = open3d.visualization.Visualizer() +# vis.create_window() +# +# for mesh in meshes: +# vis.add_geometry(mesh) +# +# vis.run() +# vis.destroy_window() +# +# +# def test_collision2(pc): +# pcd = open3d.geometry.PointCloud() +# pcd.points = open3d.utility.Vector3dVector(pc) +# pcd.estimate_normals() +# open3d.visualization.draw_geometries([pcd]) +# +# # poisson_mesh = open3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=8, width=0, scale=1.1, linear_fit=False)[0] +# # bbox = pcd.get_axis_aligned_bounding_box() +# # p_mesh_crop = poisson_mesh.crop(bbox) +# # open3d.visualization.draw_geometries([p_mesh_crop, pcd]) +# +# distances = pcd.compute_nearest_neighbor_distance() +# avg_dist = np.mean(distances) +# radius = 3 * avg_dist +# bpa_mesh = open3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(pcd, open3d.utility.DoubleVector( +# [radius, radius * 2])) +# dec_mesh = bpa_mesh.simplify_quadric_decimation(100000) +# dec_mesh.remove_degenerate_triangles() +# dec_mesh.remove_duplicated_triangles() +# dec_mesh.remove_duplicated_vertices() +# dec_mesh.remove_non_manifold_edges() +# open3d.visualization.draw_geometries([dec_mesh, pcd]) +# open3d.visualization.draw_geometries([dec_mesh]) + + +def make_gifs(imgs, save_path, texts=None, numpy_img=True, duration=10): + gif_filename = os.path.join(save_path) + pil_imgs = [] + for i, img in enumerate(imgs): + if numpy_img: + img = Image.fromarray(img) + if texts: + text = texts[i] + draw = ImageDraw.Draw(img) + font = ImageFont.truetype("FreeMono.ttf", 40) + draw.text((0, 0), text, (120, 120, 120), font=font) + pil_imgs.append(img) + + pil_imgs[0].save(gif_filename, save_all=True, + append_images=pil_imgs[1:], optimize=True, + duration=duration*len(pil_imgs), loop=0) + + +def save_img(img, save_path, text=None, numpy_img=True): + if numpy_img: + img = Image.fromarray(img) + if text: + draw = ImageDraw.Draw(img) + font = ImageFont.truetype("FreeMono.ttf", 40) + draw.text((0, 0), text, (120, 120, 120), font=font) + img.save(save_path) + + +def move_one_object_pc(obj_xyz, obj_rgb, struct_params, object_params, euler_angles=False): + struct_params = np.asanyarray(struct_params) + object_params = np.asanyarray(object_params) + + R_struct = np.eye(4) + if not euler_angles: + R_struct[:3, :3] = struct_params[3:].reshape(3, 3) + else: + R_struct[:3, :3] = tra.euler_matrix(*struct_params[3:])[:3, :3] + R_obj = np.eye(4) + if not euler_angles: + R_obj[:3, :3] = object_params[3:].reshape(3, 3) + else: + R_obj[:3, :3] = tra.euler_matrix(*object_params[3:])[:3, :3] + + T_struct = R_struct + T_struct[:3, 3] = [struct_params[0], struct_params[1], struct_params[2]] + + # translate to structure frame + t = np.eye(4) + obj_center = torch.mean(obj_xyz, dim=0) + t[:3, 3] = [object_params[0] - obj_center[0], object_params[1] - obj_center[1], object_params[2] - obj_center[2]] + new_obj_xyz = trimesh.transform_points(obj_xyz, t) + + # rotate in place + R = R_obj + obj_center = np.mean(new_obj_xyz, axis=0) + centered_obj_xyz = new_obj_xyz - obj_center + new_centered_obj_xyz = trimesh.transform_points(centered_obj_xyz, R, translate=True) + new_obj_xyz = new_centered_obj_xyz + obj_center + + # transform to the global frame from the structure frame + new_obj_xyz = trimesh.transform_points(new_obj_xyz, T_struct) + + # convert back to torch + new_obj_xyz = torch.tensor(new_obj_xyz, dtype=obj_xyz.dtype) + + return new_obj_xyz, obj_rgb + + +def move_one_object_pc_no_struct(obj_xyz, obj_rgb, object_params, euler_angles=False): + object_params = np.asanyarray(object_params) + + R_obj = np.eye(4) + if not euler_angles: + R_obj[:3, :3] = object_params[3:].reshape(3, 3) + else: + R_obj[:3, :3] = tra.euler_matrix(*object_params[3:])[:3, :3] + + t = np.eye(4) + obj_center = torch.mean(obj_xyz, dim=0) + t[:3, 3] = [object_params[0] - obj_center[0], object_params[1] - obj_center[1], object_params[2] - obj_center[2]] + new_obj_xyz = trimesh.transform_points(obj_xyz, t) + + # rotate in place + R = R_obj + obj_center = np.mean(new_obj_xyz, axis=0) + centered_obj_xyz = new_obj_xyz - obj_center + new_centered_obj_xyz = trimesh.transform_points(centered_obj_xyz, R, translate=True) + new_obj_xyz = new_centered_obj_xyz + obj_center + + # convert back to torch + new_obj_xyz = torch.tensor(new_obj_xyz, dtype=obj_xyz.dtype) + + return new_obj_xyz, obj_rgb + + +def modify_language(sentence, radius=None, position_x=None, position_y=None, rotation=None, shape=None): + # "radius": [0.0, 0.5, 3], "position_x": [-0.1, 1.0, 3], "position_y": [-0.5, 0.5, 3], "rotation": [-3.15, 3.15, 4] + + sentence = copy.deepcopy(sentence) + for pi, pair in enumerate(sentence): + if radius is not None and len(pair) == 2 and pair[1] == "radius": + sentence[pi] = (radius, 'radius') + if position_y is not None and len(pair) == 2 and pair[1] == "position_y": + sentence[pi] = (position_y, 'position_y') + if position_x is not None and len(pair) == 2 and pair[1] == "position_x": + sentence[pi] = (position_x, 'position_x') + if rotation is not None and len(pair) == 2 and pair[1] == "rotation": + sentence[pi] = (rotation, 'rotation') + if shape is not None and len(pair) == 2 and pair[1] == "shape": + sentence[pi] = (shape, 'shape') + + return sentence + + +def sample_gaussians(mus, sigmas, sample_size): + # mus: [number of individual gaussians] + # sigmas: [number of individual gaussians] + normal = torch.distributions.Normal(mus, sigmas) + samples = normal.sample((sample_size,)) + # samples: [sample_size, number of individual gaussians] + return samples + + +def fit_gaussians(samples, sigma_eps=0.01): + # samples: [sample_size, number of individual gaussians] + num_gs = samples.shape[1] + mus = torch.mean(samples, dim=0) + sigmas = torch.std(samples, dim=0) + sigma_eps * torch.ones(num_gs) + # mus: [number of individual gaussians] + # sigmas: [number of individual gaussians] + return mus, sigmas + + +def show_pcs_with_trimesh(obj_xyzs, obj_rgbs, return_scene=False): + vis_pcs = [trimesh.PointCloud(obj_xyz, colors=np.concatenate([obj_rgb * 255, np.ones([obj_rgb.shape[0], 1]) * 255], axis=-1)) for + obj_xyz, obj_rgb in zip(obj_xyzs, obj_rgbs)] + scene = trimesh.Scene() + # add the coordinate frame first + geom = trimesh.creation.axis(0.01) + # scene.add_geometry(geom) + table = trimesh.creation.box(extents=[1.0, 1.0, 0.02]) + table.apply_translation([0.5, 0, -0.01]) + table.visual.vertex_colors = [150, 111, 87, 125] + scene.add_geometry(table) + # bounds = trimesh.creation.box(extents=[4.0, 4.0, 4.0]) + bounds = trimesh.creation.icosphere(subdivisions=3, radius=3.1) + bounds.apply_translation([0, 0, 0]) + bounds.visual.vertex_colors = [30, 30, 30, 30] + # scene.add_geometry(bounds) + scene.add_geometry(vis_pcs) + RT_4x4 = np.array([[-0.39560353822208355, -0.9183993826406329, 0.006357240869497738, 0.2651463080169481], + [-0.797630370081598, 0.3401340617616391, -0.4980909683511864, 0.2225696480721997], + [0.45528412367406523, -0.2021172778236285, -0.8671014777611122, 0.9449050652025951], + [0.0, 0.0, 0.0, 1.0]]) + RT_4x4 = np.linalg.inv(RT_4x4) + RT_4x4 = RT_4x4 @ np.diag([1, -1, -1, 1]) + scene.camera_transform = RT_4x4 + if return_scene: + return scene + else: + scene.show() + + +def show_pcs_with_predictions(xyz, rgb, gts, predictions, add_coordinate_frame=False, return_buffer=False, add_table=True, side_view=True): + """ Display point clouds """ + + assert len(gts) == len(predictions) == len(xyz) == len(rgb) + + unordered_pc = np.concatenate(xyz, axis=0) + unordered_rgb = np.concatenate(rgb, axis=0) + pcd = open3d.geometry.PointCloud() + pcd.points = open3d.utility.Vector3dVector(unordered_pc) + pcd.colors = open3d.utility.Vector3dVector(unordered_rgb) + + vis = open3d.visualization.Visualizer() + vis.create_window() + vis.add_geometry(pcd) + + if add_table: + table_color = [0.7, 0.7, 0.7] + origin = [0, -0.5, -0.05] + table = open3d.geometry.TriangleMesh.create_box(width=1.0, height=1.0, depth=0.02) + table.paint_uniform_color(table_color) + table.translate(origin) + vis.add_geometry(table) + + if add_coordinate_frame: + mesh_frame = open3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=[0, 0, 0]) + vis.add_geometry(mesh_frame) + + for i in range(len(xyz)): + pred_color = [0.0, 1.0, 0] if predictions[i] else [1.0, 0.0, 0] + gt_color = [0.0, 1.0, 0] if gts[i] else [1.0, 0.0, 0] + origin = torch.mean(xyz[i], dim=0) + origin[2] += 0.02 + pred_vis = open3d.geometry.TriangleMesh.create_torus(torus_radius=0.02, tube_radius=0.01) + pred_vis.paint_uniform_color(pred_color) + pred_vis.translate(origin) + gt_vis = open3d.geometry.TriangleMesh.create_sphere(radius=0.01) + gt_vis.paint_uniform_color(gt_color) + gt_vis.translate(origin) + vis.add_geometry(pred_vis) + vis.add_geometry(gt_vis) + + if side_view: + open3d_set_side_view(vis) + + if return_buffer: + vis.poll_events() + vis.update_renderer() + buffer = vis.capture_screen_float_buffer(False) + vis.destroy_window() + return buffer + else: + vis.run() + vis.destroy_window() + + +def show_pcs_with_only_predictions(xyz, rgb, gts, predictions, add_coordinate_frame=False, return_buffer=False, add_table=True, side_view=True): + """ Display point clouds """ + + assert len(gts) == len(predictions) == len(xyz) == len(rgb) + + unordered_pc = np.concatenate(xyz, axis=0) + unordered_rgb = np.concatenate(rgb, axis=0) + pcd = open3d.geometry.PointCloud() + pcd.points = open3d.utility.Vector3dVector(unordered_pc) + pcd.colors = open3d.utility.Vector3dVector(unordered_rgb) + + vis = open3d.visualization.Visualizer() + vis.create_window() + vis.add_geometry(pcd) + + if add_table: + table_color = [0.7, 0.7, 0.7] + origin = [0, -0.5, -0.05] + table = open3d.geometry.TriangleMesh.create_box(width=1.0, height=1.0, depth=0.02) + table.paint_uniform_color(table_color) + table.translate(origin) + vis.add_geometry(table) + + if add_coordinate_frame: + mesh_frame = open3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=[0, 0, 0]) + vis.add_geometry(mesh_frame) + + for i in range(len(xyz)): + pred_color = [0.0, 1.0, 0] if predictions[i] else [1.0, 0.0, 0] + pcd = open3d.geometry.PointCloud() + pcd.points = open3d.utility.Vector3dVector(xyz[i]) + pcd.colors = open3d.utility.Vector3dVector(np.tile(np.array(pred_color, dtype=np.float), (xyz[i].shape[0], 1))) + # pcd = pcd.uniform_down_sample(10) + # vis.add_geometry(pcd) + + obb = pcd.get_axis_aligned_bounding_box() + obb.color = pred_color + vis.add_geometry(obb) + + + # origin = torch.mean(xyz[i], dim=0) + # origin[2] += 0.02 + # pred_vis = open3d.geometry.TriangleMesh.create_torus(torus_radius=0.02, tube_radius=0.01) + # pred_vis.paint_uniform_color(pred_color) + # pred_vis.translate(origin) + # gt_vis = open3d.geometry.TriangleMesh.create_sphere(radius=0.01) + # gt_vis.paint_uniform_color(gt_color) + # gt_vis.translate(origin) + # vis.add_geometry(pred_vis) + # vis.add_geometry(gt_vis) + + if side_view: + open3d_set_side_view(vis) + + if return_buffer: + vis.poll_events() + vis.update_renderer() + buffer = vis.capture_screen_float_buffer(False) + vis.destroy_window() + return buffer + else: + vis.run() + vis.destroy_window() + + +def test_new_vis(xyz, rgb): + pass +# unordered_pc = np.concatenate(xyz, axis=0) +# unordered_rgb = np.concatenate(rgb, axis=0) +# pcd = open3d.geometry.PointCloud() +# pcd.points = open3d.utility.Vector3dVector(unordered_pc) +# pcd.colors = open3d.utility.Vector3dVector(unordered_rgb) +# +# # Some platforms do not require OpenGL implementations to support wide lines, +# # so the renderer requires a custom shader to implement this: "unlitLine". +# # The line_width field is only used by this shader; all other shaders ignore +# # it. +# # mat = o3d.visualization.rendering.Material() +# # mat.shader = "unlitLine" +# # mat.line_width = 10 # note that this is scaled with respect to pixels, +# # # so will give different results depending on the +# # # scaling values of your system +# # mat.transmission = 0.5 +# open3d.visualization.draw({ +# "name": "pcd", +# "geometry": pcd, +# # "material": mat +# }) +# +# for i in range(len(xyz)): +# pred_color = [0.0, 1.0, 0] if predictions[i] else [1.0, 0.0, 0] +# pcd = open3d.geometry.PointCloud() +# pcd.points = open3d.utility.Vector3dVector(xyz[i]) +# pcd.colors = open3d.utility.Vector3dVector(np.tile(np.array(pred_color, dtype=np.float), (xyz[i].shape[0], 1))) +# # pcd = pcd.uniform_down_sample(10) +# # vis.add_geometry(pcd) +# +# obb = pcd.get_axis_aligned_bounding_box() +# obb.color = pred_color +# vis.add_geometry(obb) + + +def show_pcs(xyz, rgb, add_coordinate_frame=False, side_view=False, add_table=True): + """ Display point clouds """ + + unordered_pc = np.concatenate(xyz, axis=0) + unordered_rgb = np.concatenate(rgb, axis=0) + pcd = open3d.geometry.PointCloud() + pcd.points = open3d.utility.Vector3dVector(unordered_pc) + pcd.colors = open3d.utility.Vector3dVector(unordered_rgb) + + if add_table: + table_color = [0.78, 0.64, 0.44] + origin = [0, -0.5, -0.02] + table = open3d.geometry.TriangleMesh.create_box(width=1.0, height=1.0, depth=0.001) + table.paint_uniform_color(table_color) + table.translate(origin) + + if not add_coordinate_frame: + vis = open3d.visualization.Visualizer() + vis.create_window() + vis.add_geometry(pcd) + if add_table: + vis.add_geometry(table) + if side_view: + open3d_set_side_view(vis) + vis.run() + vis.destroy_window() + else: + mesh_frame = open3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=[0, 0, 0]) + # open3d.visualization.draw_geometries([pcd, mesh_frame]) + vis = open3d.visualization.Visualizer() + vis.create_window() + vis.add_geometry(pcd) + vis.add_geometry(mesh_frame) + if add_table: + vis.add_geometry(table) + if side_view: + open3d_set_side_view(vis) + vis.run() + vis.destroy_window() + + +def show_pcs_color_order(xyzs, rgbs, add_coordinate_frame=False, side_view=False, add_table=True, save_path=None, texts=None, visualize=False): + + rgb_colors = get_rgb_colors() + + order_rgbs = [] + for i, xyz in enumerate(xyzs): + order_rgbs.append(np.tile(np.array(rgb_colors[i][1], dtype=np.float), (xyz.shape[0], 1))) + + if visualize: + show_pcs(xyzs, order_rgbs, add_coordinate_frame=add_coordinate_frame, side_view=side_view, add_table=add_table) + if save_path: + if not texts: + save_pcs(xyzs, order_rgbs, save_path=save_path, add_coordinate_frame=add_coordinate_frame, side_view=side_view, add_table=add_table) + if texts: + buffer = save_pcs(xyzs, order_rgbs, add_coordinate_frame=add_coordinate_frame, + side_view=side_view, add_table=add_table, return_buffer=True) + img = np.uint8(np.asarray(buffer) * 255) + img = Image.fromarray(img) + draw = ImageDraw.Draw(img) + font = ImageFont.truetype("FreeMono.ttf", 20) + for it, text in enumerate(texts): + draw.text((0, it*20), text, (120, 120, 120), font=font) + img.save(save_path) + + +def get_rgb_colors(): + rgb_colors = [] + # each color is a tuple of (name, (r,g,b)) + for name, hex in matplotlib.colors.cnames.items(): + rgb_colors.append((name, matplotlib.colors.to_rgb(hex))) + + rgb_colors = sorted(rgb_colors, key=lambda x: x[0]) + + priority_colors = [('red', (1.0, 0.0, 0.0)), ('green', (0.0, 1.0, 0.0)), ('blue', (0.0, 0.0, 1.0)), ('orange', (1.0, 0.6470588235294118, 0.0)), ('purple', (0.5019607843137255, 0.0, 0.5019607843137255)), ('magenta', (1.0, 0.0, 1.0)),] + rgb_colors = priority_colors + rgb_colors + + return rgb_colors + + +def open3d_set_side_view(vis): + ctr = vis.get_view_control() + # ctr.set_front([-0.61959040621518757, 0.46765094085676973, 0.63040489055992976]) + # ctr.set_lookat([0.28810001969337462, 0.10746435821056366, 0.23499999999999999]) + # ctr.set_up([0.64188154672853504, -0.16037991603449936, 0.74984422549096852]) + # ctr.set_zoom(0.7) + # ctr.rotate(10.0, 0.0) + + # ctr.set_front([ -0.51720189814974493, 0.55636089622063711, 0.65035740151617438 ]) + # ctr.set_lookat([ 0.23103321183824999, 0.26154772406860449, 0.15131956132592411 ]) + # ctr.set_up([ 0.47073865286968591, -0.44969907810742304, 0.75906248744340343 ]) + # ctr.set_zoom(3) + + # ctr.set_front([-0.86019269757539152, 0.40355968763418076, 0.31178213796587784]) + # ctr.set_lookat([0.28810001969337462, 0.10746435821056366, 0.23499999999999999]) + # ctr.set_up([0.30587875107201218, -0.080905438599338214, 0.94862663869811026]) + # ctr.set_zoom(0.69999999999999996) + + # ctr.set_front([0.40466417238365116, 0.019007526352692254, 0.91426780624224468]) + # ctr.set_lookat([0.61287602731590907, 0.010181152776318789, -0.073166629933366326]) + # ctr.set_up([-0.91444954965885639, 0.0025306059632757057, 0.40469200283941076]) + # ctr.set_zoom(0.84000000000000008) + + ctr.set_front([-0.45528412367406523, 0.20211727782362851, 0.86710147776111224]) + ctr.set_lookat([0.48308104105920047, 0.078726411326627957, -0.27298814087096795]) + ctr.set_up([0.79763037008159798, -0.34013406176163907, 0.49809096835118638]) + ctr.set_zoom(0.80000000000000004) + + init_param = ctr.convert_to_pinhole_camera_parameters() + print("camera extrinsic", init_param.extrinsic.tolist()) + + +def save_pcs(xyz, rgb, save_path=None, return_buffer=False, add_coordinate_frame=False, side_view=False, add_table=True): + + assert save_path or return_buffer, "provide path to save or set return_buffer to true" + + unordered_pc = np.concatenate(xyz, axis=0) + unordered_rgb = np.concatenate(rgb, axis=0) + pcd = open3d.geometry.PointCloud() + pcd.points = open3d.utility.Vector3dVector(unordered_pc) + pcd.colors = open3d.utility.Vector3dVector(unordered_rgb) + + vis = open3d.visualization.Visualizer() + vis.create_window() + + vis.add_geometry(pcd) + vis.update_geometry(pcd) + + if add_table: + table_color = [0.7, 0.7, 0.7] + origin = [0, -0.5, -0.03] + table = open3d.geometry.TriangleMesh.create_box(width=1.0, height=1.0, depth=0.02) + table.paint_uniform_color(table_color) + table.translate(origin) + vis.add_geometry(table) + + if add_coordinate_frame: + mesh_frame = open3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=[0, 0, 0]) + vis.add_geometry(mesh_frame) + vis.update_geometry(mesh_frame) + + if side_view: + open3d_set_side_view(vis) + + vis.poll_events() + vis.update_renderer() + if save_path: + vis.capture_screen_image(save_path) + elif return_buffer: + buffer = vis.capture_screen_float_buffer(False) + + vis.destroy_window() + + if return_buffer: + return buffer + else: + return None + + +def get_initial_scene_idxs(dataset): + """ + This function finds initial scenes from the dataset + :param dataset: + :return: + """ + + initial_scene2idx_t = {} + for idx in range(len(dataset)): + filename, t = dataset.get_data_index(idx) + if filename not in initial_scene2idx_t: + initial_scene2idx_t[filename] = (idx, t) + else: + if t > initial_scene2idx_t[filename][1]: + initial_scene2idx_t[filename] = (idx, t) + initial_scene_idxs = [initial_scene2idx_t[f][0] for f in initial_scene2idx_t] + return initial_scene_idxs + + +def get_initial_scene_idxs_raw_data(data): + """ + This function finds initial scenes from the dataset + :param dataset: + :return: + """ + + initial_scene2idx_t = {} + for idx in range(len(data)): + filename, t = data[idx] + if filename not in initial_scene2idx_t: + initial_scene2idx_t[filename] = (idx, t) + else: + if t > initial_scene2idx_t[filename][1]: + initial_scene2idx_t[filename] = (idx, t) + initial_scene_idxs = [initial_scene2idx_t[f][0] for f in initial_scene2idx_t] + return initial_scene_idxs + + +def evaluate_target_object_predictions(all_gts, all_predictions, all_sentences, initial_scene_idxs, tokenizer): + """ + This function evaluates target object predictions + + :param all_gts: a list of predictions for scenes. Each element is a list of booleans for objects in the scene + :param all_predictions: + :param all_sentences: a list of descriptions for scenes + :param initial_scene_idxs: + :param tokenizer: + :return: + """ + + # overall accuracy + print("\noverall accuracy") + report = classification_report(list(itertools.chain(*all_gts)), list(itertools.chain(*all_predictions)), + output_dict=True) + print(report) + + # scene average + print("\naccuracy per scene") + acc_per_scene = [] + for gts, preds in zip(all_gts, all_predictions): + acc_per_scene.append(sum(np.array(gts) == np.array(preds)) * 1.0 / len(gts)) + print(np.mean(acc_per_scene)) + plt.hist(acc_per_scene, 10, range=(0, 1), facecolor='g', alpha=0.75) + plt.xlabel('Accuracy') + plt.ylabel('# Scene') + plt.title('Predicting objects to be rearranged') + plt.xticks(np.linspace(0, 1, 11), np.linspace(0, 1, 11).round(1)) + plt.grid(True) + plt.show() + + # initial scene accuracy + print("\noverall accuracy for initial scenes") + tested_initial_scene_idxs = [i for i in initial_scene_idxs if i < len(all_gts)] + initial_gts = [all_gts[i] for i in tested_initial_scene_idxs] + initial_predictions = [all_predictions[i] for i in tested_initial_scene_idxs] + report = classification_report(list(itertools.chain(*initial_gts)), list(itertools.chain(*initial_predictions)), + output_dict=True) + print(report) + + # break down by the number of objects + print("\naccuracy for # objects in scene") + num_objects_in_scenes = np.array([len(gts) for gts in all_gts]) + unique_num_objects = np.unique(num_objects_in_scenes) + acc_per_scene = np.array(acc_per_scene) + assert len(acc_per_scene) == len(num_objects_in_scenes) + for num_objects in unique_num_objects: + this_scene_idxs = [i for i in range(len(all_gts)) if len(all_gts[i]) == num_objects] + this_num_obj_gts = [all_gts[i] for i in this_scene_idxs] + this_num_obj_predictions = [all_predictions[i] for i in this_scene_idxs] + report = classification_report(list(itertools.chain(*this_num_obj_gts)), list(itertools.chain(*this_num_obj_predictions)), + output_dict=True) + print("{} objects".format(num_objects)) + print(report) + + # reference + print("\noverall accuracy break down") + direct_gts_by_type = defaultdict(list) + direct_preds_by_type = defaultdict(list) + d_anchor_gts_by_type = defaultdict(list) + d_anchor_preds_by_type = defaultdict(list) + c_anchor_gts_by_type = defaultdict(list) + c_anchor_preds_by_type = defaultdict(list) + + for i, s in enumerate(all_sentences): + v, t = s[0] + if t[-2:] == "_c" or t[-2:] == "_d": + t = t[:-2] + if v != "MASK" and t in tokenizer.discrete_types: + # direct reference + direct_gts_by_type[t].extend(all_gts[i]) + direct_preds_by_type[t].extend(all_predictions[i]) + else: + if v == "MASK": + # discrete anchor + d_anchor_gts_by_type[t].extend(all_gts[i]) + d_anchor_preds_by_type[t].extend(all_predictions[i]) + else: + c_anchor_gts_by_type[t].extend(all_gts[i]) + c_anchor_preds_by_type[t].extend(all_predictions[i]) + + print("direct") + for t in direct_gts_by_type: + report = classification_report(direct_gts_by_type[t], direct_preds_by_type[t], output_dict=True) + print(t, report) + + print("discrete anchor") + for t in d_anchor_gts_by_type: + report = classification_report(d_anchor_gts_by_type[t], d_anchor_preds_by_type[t], output_dict=True) + print(t, report) + + print("continuous anchor") + for t in c_anchor_gts_by_type: + report = classification_report(c_anchor_gts_by_type[t], c_anchor_preds_by_type[t], output_dict=True) + print(t, report) + + # break down by object class + + +def combine_and_sample_xyzs(xyzs, rgbs, center=None, radius=0.5, num_pts=1024): + xyz = torch.cat(xyzs, dim=0) + rgb = torch.cat(rgbs, dim=0) + + if center is not None: + center = center.repeat(xyz.shape[0], 1) + dists = torch.linalg.norm(xyz - center, dim=-1) + idx = dists < radius + xyz = xyz[idx] + rgb = rgb[idx] + + idx = np.random.randint(0, xyz.shape[0], num_pts) + xyz = xyz[idx] + rgb = rgb[idx] + + return xyz, rgb + + +def evaluate_prior_prediction(gts, predictions, keys, debug=False): + """ + :param gts: expect a list of tensors + :param predictions: expect a list of tensor + :return: + """ + + total_mses = 0 + obj_dists = [] + struct_dists = [] + for key in keys: + # predictions[key][0]: [batch_size * number_of_objects, dim] + predictions_for_key = torch.cat(predictions[key], dim=0) + # gts[key][0]: [batch_size * number_of_objects, dim] + gts_for_key = torch.cat(gts[key], dim=0) + + assert gts_for_key.shape == predictions_for_key.shape + + target_indices = gts_for_key != -100 + gts_for_key = gts_for_key[target_indices] + predictions_for_key = predictions_for_key[target_indices] + num_objects = len(predictions_for_key) + + distances = predictions_for_key - gts_for_key + + me = torch.mean(torch.abs(distances)) + mse = torch.mean(distances ** 2) + med = torch.median(torch.abs(distances)) + + if "obj_x" in key or "obj_y" in key or "obj_z" in key: + obj_dists.append(distances) + if "struct_x" in key or "struct_y" in key or "struct_z" in key: + struct_dists.append(distances) + + if debug: + print("Groundtruths:") + print(gts_for_key[:100]) + print("Predictions") + print(predictions_for_key[:100]) + + print("{} ME for {} objects: {}".format(key, num_objects, me)) + print("{} MSE for {} objects: {}".format(key, num_objects, mse)) + print("{} MEDIAN for {} objects: {}".format(key, num_objects, med)) + total_mses += mse + + if "theta" in key: + predictions_for_key = predictions_for_key.reshape(-1, 3, 3) + gts_for_key = gts_for_key.reshape(-1, 3, 3) + geodesic_distance = compute_geodesic_distance_from_two_matrices(predictions_for_key, gts_for_key) + geodesic_distance = torch.rad2deg(geodesic_distance) + mgd = torch.mean(geodesic_distance) + stdgd = torch.std(geodesic_distance) + megd = torch.median(geodesic_distance) + print("{} Mean and std Geodesic Distance for {} objects: {} +- {}".format(key, num_objects, mgd, stdgd)) + print("{} Median Geodesic Distance for {} objects: {}".format(key, num_objects, megd)) + + if obj_dists: + euclidean_dists = torch.sqrt(obj_dists[0]**2 + obj_dists[1]**2 + obj_dists[2]**2) + me = torch.mean(euclidean_dists) + stde = torch.std(euclidean_dists) + med = torch.median(euclidean_dists) + print("Mean and std euclidean dist for {} objects: {} +- {}".format(len(euclidean_dists), me, stde)) + print("Median euclidean dist for {} objects: {}".format(len(euclidean_dists), med)) + if struct_dists: + euclidean_dists = torch.sqrt(struct_dists[0] ** 2 + struct_dists[1] ** 2 + struct_dists[2] ** 2) + me = torch.mean(euclidean_dists) + stde = torch.std(euclidean_dists) + med = torch.median(euclidean_dists) + print("Mean euclidean dist for {} structures: {} +- {}".format(len(euclidean_dists), me, stde)) + print("Median euclidean dist for {} structures: {}".format(len(euclidean_dists), med)) + + return -total_mses + + +def generate_square_subsequent_mask(sz): + mask = (torch.triu(torch.ones((sz, sz))) == 1).transpose(0, 1) + mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) + return mask + + +def visualize_occ(points, occupancies, in_num_pts=1000, out_num_pts=1000, visualize=False, threshold=0.5): + + rix = np.random.permutation(points.shape[0]) + vis_points = points[rix] + vis_occupancies = occupancies[rix] + in_pc = vis_points[vis_occupancies.squeeze() > threshold, :][:in_num_pts] + out_pc = vis_points[vis_occupancies.squeeze() < threshold, :][:out_num_pts] + + if len(in_pc) == 0: + print("no in points") + if len(out_pc) == 0: + print("no out points") + + in_pc = trimesh.PointCloud(in_pc) + out_pc = trimesh.PointCloud(out_pc) + in_pc.colors = np.tile((255, 0, 0, 255), (in_pc.vertices.shape[0], 1)) + out_pc.colors = np.tile((255, 255, 0, 120), (out_pc.vertices.shape[0], 1)) + + if visualize: + scene = trimesh.Scene([in_pc, out_pc]) + scene.show() + + return in_pc, out_pc + + +def save_dict_to_h5(dict_data, filename): + fh = h5py.File(filename, 'w') + for k in dict_data: + key_data = dict_data[k] + if key_data is None: + raise RuntimeError('data was not properly populated') + # if type(key_data) is dict: + # key_data = json.dumps(key_data, sort_keys=True) + try: + fh.create_dataset(k, data=key_data) + except TypeError as e: + print("Failure on key", k) + print(key_data) + print(e) + raise e + fh.close() + + +def load_h5_key(h5, key): + if key in h5: + return h5[key][()] + elif "json_" + key in h5: + return json.loads(h5["json_" + key][()]) + else: + return None + + +def load_dict_from_h5(filename): + h5 = h5py.File(filename, "r") + data_dict = {} + for k in h5: + data_dict[k] = h5[k][()] + return data_dict + diff --git a/src/StructDiffusion/utils/rotation_continuity.py b/src/StructDiffusion/utils/rotation_continuity.py new file mode 100755 index 0000000000000000000000000000000000000000..7e0e92ea44c590d95e2dc39b9f3641963ad4eab7 --- /dev/null +++ b/src/StructDiffusion/utils/rotation_continuity.py @@ -0,0 +1,395 @@ +import torch +import torch.nn as nn +from torch.autograd import Variable +import numpy as np + +# Code adapted from the rotation continuity repo (https://github.com/papagina/RotationContinuity) + +#T_poses num*3 +#r_matrix batch*3*3 +def compute_pose_from_rotation_matrix(T_pose, r_matrix): + batch=r_matrix.shape[0] + joint_num = T_pose.shape[0] + r_matrices = r_matrix.view(batch,1, 3,3).expand(batch,joint_num, 3,3).contiguous().view(batch*joint_num,3,3) + src_poses = T_pose.view(1,joint_num,3,1).expand(batch,joint_num,3,1).contiguous().view(batch*joint_num,3,1) + + out_poses = torch.matmul(r_matrices, src_poses) #(batch*joint_num)*3*1 + + return out_poses.view(batch, joint_num, 3) + +# batch*n +def normalize_vector( v, return_mag =False): + batch=v.shape[0] + v_mag = torch.sqrt(v.pow(2).sum(1))# batch + v_mag = torch.max(v_mag, torch.autograd.Variable(torch.FloatTensor([1e-8]).cuda())) + v_mag = v_mag.view(batch,1).expand(batch,v.shape[1]) + v = v/v_mag + if(return_mag==True): + return v, v_mag[:,0] + else: + return v + +# u, v batch*n +def cross_product( u, v): + batch = u.shape[0] + #print (u.shape) + #print (v.shape) + i = u[:,1]*v[:,2] - u[:,2]*v[:,1] + j = u[:,2]*v[:,0] - u[:,0]*v[:,2] + k = u[:,0]*v[:,1] - u[:,1]*v[:,0] + + out = torch.cat((i.view(batch,1), j.view(batch,1), k.view(batch,1)),1)#batch*3 + + return out + + +#poses batch*6 +#poses +def compute_rotation_matrix_from_ortho6d(ortho6d): + x_raw = ortho6d[:,0:3]#batch*3 + y_raw = ortho6d[:,3:6]#batch*3 + + x = normalize_vector(x_raw) #batch*3 + z = cross_product(x,y_raw) #batch*3 + z = normalize_vector(z)#batch*3 + y = cross_product(z,x)#batch*3 + + x = x.view(-1,3,1) + y = y.view(-1,3,1) + z = z.view(-1,3,1) + matrix = torch.cat((x,y,z), 2) #batch*3*3 + return matrix + + +#in batch*6 +#out batch*5 +def stereographic_project(a): + dim = a.shape[1] + a = normalize_vector(a) + out = a[:,0:dim-1]/(1-a[:,dim-1]) + return out + + + +#in a batch*5, axis int +def stereographic_unproject(a, axis=None): + """ + Inverse of stereographic projection: increases dimension by one. + """ + batch=a.shape[0] + if axis is None: + axis = a.shape[1] + s2 = torch.pow(a,2).sum(1) #batch + ans = torch.autograd.Variable(torch.zeros(batch, a.shape[1]+1).cuda()) #batch*6 + unproj = 2*a/(s2+1).view(batch,1).repeat(1,a.shape[1]) #batch*5 + if(axis>0): + ans[:,:axis] = unproj[:,:axis] #batch*(axis-0) + ans[:,axis] = (s2-1)/(s2+1) #batch + ans[:,axis+1:] = unproj[:,axis:] #batch*(5-axis) # Note that this is a no-op if the default option (last axis) is used + return ans + + +#a batch*5 +#out batch*3*3 +def compute_rotation_matrix_from_ortho5d(a): + batch = a.shape[0] + proj_scale_np = np.array([np.sqrt(2)+1, np.sqrt(2)+1, np.sqrt(2)]) #3 + proj_scale = torch.autograd.Variable(torch.FloatTensor(proj_scale_np).cuda()).view(1,3).repeat(batch,1) #batch,3 + + u = stereographic_unproject(a[:, 2:5] * proj_scale, axis=0)#batch*4 + norm = torch.sqrt(torch.pow(u[:,1:],2).sum(1)) #batch + u = u/ norm.view(batch,1).repeat(1,u.shape[1]) #batch*4 + b = torch.cat((a[:,0:2], u),1)#batch*6 + matrix = compute_rotation_matrix_from_ortho6d(b) + return matrix + + +#quaternion batch*4 +def compute_rotation_matrix_from_quaternion( quaternion): + batch=quaternion.shape[0] + + + quat = normalize_vector(quaternion).contiguous() + + qw = quat[...,0].contiguous().view(batch, 1) + qx = quat[...,1].contiguous().view(batch, 1) + qy = quat[...,2].contiguous().view(batch, 1) + qz = quat[...,3].contiguous().view(batch, 1) + + # Unit quaternion rotation matrices computatation + xx = qx*qx + yy = qy*qy + zz = qz*qz + xy = qx*qy + xz = qx*qz + yz = qy*qz + xw = qx*qw + yw = qy*qw + zw = qz*qw + + row0 = torch.cat((1-2*yy-2*zz, 2*xy - 2*zw, 2*xz + 2*yw), 1) #batch*3 + row1 = torch.cat((2*xy+ 2*zw, 1-2*xx-2*zz, 2*yz-2*xw ), 1) #batch*3 + row2 = torch.cat((2*xz-2*yw, 2*yz+2*xw, 1-2*xx-2*yy), 1) #batch*3 + + matrix = torch.cat((row0.view(batch, 1, 3), row1.view(batch,1,3), row2.view(batch,1,3)),1) #batch*3*3 + + return matrix + +#axisAngle batch*4 angle, x,y,z +def compute_rotation_matrix_from_axisAngle( axisAngle): + batch = axisAngle.shape[0] + + theta = torch.tanh(axisAngle[:,0])*np.pi #[-180, 180] + sin = torch.sin(theta*0.5) + axis = normalize_vector(axisAngle[:,1:4]) #batch*3 + qw = torch.cos(theta*0.5) + qx = axis[:,0]*sin + qy = axis[:,1]*sin + qz = axis[:,2]*sin + + # Unit quaternion rotation matrices computatation + xx = (qx*qx).view(batch,1) + yy = (qy*qy).view(batch,1) + zz = (qz*qz).view(batch,1) + xy = (qx*qy).view(batch,1) + xz = (qx*qz).view(batch,1) + yz = (qy*qz).view(batch,1) + xw = (qx*qw).view(batch,1) + yw = (qy*qw).view(batch,1) + zw = (qz*qw).view(batch,1) + + row0 = torch.cat((1-2*yy-2*zz, 2*xy - 2*zw, 2*xz + 2*yw), 1) #batch*3 + row1 = torch.cat((2*xy+ 2*zw, 1-2*xx-2*zz, 2*yz-2*xw ), 1) #batch*3 + row2 = torch.cat((2*xz-2*yw, 2*yz+2*xw, 1-2*xx-2*yy), 1) #batch*3 + + matrix = torch.cat((row0.view(batch, 1, 3), row1.view(batch,1,3), row2.view(batch,1,3)),1) #batch*3*3 + + return matrix + +#axisAngle batch*3 (x,y,z)*theta +def compute_rotation_matrix_from_Rodriguez( rod): + batch = rod.shape[0] + + axis, theta = normalize_vector(rod, return_mag=True) + + sin = torch.sin(theta) + + + qw = torch.cos(theta) + qx = axis[:,0]*sin + qy = axis[:,1]*sin + qz = axis[:,2]*sin + + # Unit quaternion rotation matrices computatation + xx = (qx*qx).view(batch,1) + yy = (qy*qy).view(batch,1) + zz = (qz*qz).view(batch,1) + xy = (qx*qy).view(batch,1) + xz = (qx*qz).view(batch,1) + yz = (qy*qz).view(batch,1) + xw = (qx*qw).view(batch,1) + yw = (qy*qw).view(batch,1) + zw = (qz*qw).view(batch,1) + + row0 = torch.cat((1-2*yy-2*zz, 2*xy - 2*zw, 2*xz + 2*yw), 1) #batch*3 + row1 = torch.cat((2*xy+ 2*zw, 1-2*xx-2*zz, 2*yz-2*xw ), 1) #batch*3 + row2 = torch.cat((2*xz-2*yw, 2*yz+2*xw, 1-2*xx-2*yy), 1) #batch*3 + + matrix = torch.cat((row0.view(batch, 1, 3), row1.view(batch,1,3), row2.view(batch,1,3)),1) #batch*3*3 + + return matrix + +#axisAngle batch*3 a,b,c +def compute_rotation_matrix_from_hopf( hopf): + batch = hopf.shape[0] + + theta = (torch.tanh(hopf[:,0])+1.0)*np.pi/2.0 #[0, pi] + phi = (torch.tanh(hopf[:,1])+1.0)*np.pi #[0,2pi) + tao = (torch.tanh(hopf[:,2])+1.0)*np.pi #[0,2pi) + + qw = torch.cos(theta/2)*torch.cos(tao/2) + qx = torch.cos(theta/2)*torch.sin(tao/2) + qy = torch.sin(theta/2)*torch.cos(phi+tao/2) + qz = torch.sin(theta/2)*torch.sin(phi+tao/2) + + # Unit quaternion rotation matrices computatation + xx = (qx*qx).view(batch,1) + yy = (qy*qy).view(batch,1) + zz = (qz*qz).view(batch,1) + xy = (qx*qy).view(batch,1) + xz = (qx*qz).view(batch,1) + yz = (qy*qz).view(batch,1) + xw = (qx*qw).view(batch,1) + yw = (qy*qw).view(batch,1) + zw = (qz*qw).view(batch,1) + + row0 = torch.cat((1-2*yy-2*zz, 2*xy - 2*zw, 2*xz + 2*yw), 1) #batch*3 + row1 = torch.cat((2*xy+ 2*zw, 1-2*xx-2*zz, 2*yz-2*xw ), 1) #batch*3 + row2 = torch.cat((2*xz-2*yw, 2*yz+2*xw, 1-2*xx-2*yy), 1) #batch*3 + + matrix = torch.cat((row0.view(batch, 1, 3), row1.view(batch,1,3), row2.view(batch,1,3)),1) #batch*3*3 + + return matrix + + +#euler batch*4 +#output cuda batch*3*3 matrices in the rotation order of XZ'Y'' (intrinsic) or YZX (extrinsic) +def compute_rotation_matrix_from_euler(euler): + batch=euler.shape[0] + + c1=torch.cos(euler[:,0]).view(batch,1)#batch*1 + s1=torch.sin(euler[:,0]).view(batch,1)#batch*1 + c2=torch.cos(euler[:,2]).view(batch,1)#batch*1 + s2=torch.sin(euler[:,2]).view(batch,1)#batch*1 + c3=torch.cos(euler[:,1]).view(batch,1)#batch*1 + s3=torch.sin(euler[:,1]).view(batch,1)#batch*1 + + row1=torch.cat((c2*c3, -s2, c2*s3 ), 1).view(-1,1,3) #batch*1*3 + row2=torch.cat((c1*s2*c3+s1*s3, c1*c2, c1*s2*s3-s1*c3), 1).view(-1,1,3) #batch*1*3 + row3=torch.cat((s1*s2*c3-c1*s3, s1*c2, s1*s2*s3+c1*c3), 1).view(-1,1,3) #batch*1*3 + + matrix = torch.cat((row1, row2, row3), 1) #batch*3*3 + + + return matrix + + +#euler_sin_cos batch*6 +#output cuda batch*3*3 matrices in the rotation order of XZ'Y'' (intrinsic) or YZX (extrinsic) +def compute_rotation_matrix_from_euler_sin_cos(euler_sin_cos): + batch=euler_sin_cos.shape[0] + + s1 = euler_sin_cos[:,0].view(batch,1) + c1 = euler_sin_cos[:,1].view(batch,1) + s2 = euler_sin_cos[:,2].view(batch,1) + c2 = euler_sin_cos[:,3].view(batch,1) + s3 = euler_sin_cos[:,4].view(batch,1) + c3 = euler_sin_cos[:,5].view(batch,1) + + + row1=torch.cat((c2*c3, -s2, c2*s3 ), 1).view(-1,1,3) #batch*1*3 + row2=torch.cat((c1*s2*c3+s1*s3, c1*c2, c1*s2*s3-s1*c3), 1).view(-1,1,3) #batch*1*3 + row3=torch.cat((s1*s2*c3-c1*s3, s1*c2, s1*s2*s3+c1*c3), 1).view(-1,1,3) #batch*1*3 + + matrix = torch.cat((row1, row2, row3), 1) #batch*3*3 + + + return matrix + + +#matrices batch*3*3 +#both matrix are orthogonal rotation matrices +#out theta between 0 to 180 degree batch +def compute_geodesic_distance_from_two_matrices(m1, m2): + batch=m1.shape[0] + m = torch.bmm(m1, m2.transpose(1,2)) #batch*3*3 + + cos = ( m[:,0,0] + m[:,1,1] + m[:,2,2] - 1 )/2 + cos = torch.min(cos, torch.autograd.Variable(torch.ones(batch).cuda()) ) + cos = torch.max(cos, torch.autograd.Variable(torch.ones(batch).cuda())*-1 ) + + + theta = torch.acos(cos) + + #theta = torch.min(theta, 2*np.pi - theta) + + + return theta + + +#matrices batch*3*3 +#both matrix are orthogonal rotation matrices +#out theta between 0 to 180 degree batch +def compute_angle_from_r_matrices(m): + + batch=m.shape[0] + + cos = ( m[:,0,0] + m[:,1,1] + m[:,2,2] - 1 )/2 + cos = torch.min(cos, torch.autograd.Variable(torch.ones(batch).cuda()) ) + cos = torch.max(cos, torch.autograd.Variable(torch.ones(batch).cuda())*-1 ) + + theta = torch.acos(cos) + + return theta + +def get_sampled_rotation_matrices_by_quat(batch): + #quat = torch.autograd.Variable(torch.rand(batch,4).cuda()) + quat = torch.autograd.Variable(torch.randn(batch, 4).cuda()) + matrix = compute_rotation_matrix_from_quaternion(quat) + return matrix + +def get_sampled_rotation_matrices_by_hpof(batch): + + theta = torch.autograd.Variable(torch.FloatTensor(np.random.uniform(0,1, batch)*np.pi).cuda()) #[0, pi] + phi = torch.autograd.Variable(torch.FloatTensor(np.random.uniform(0,2,batch)*np.pi).cuda()) #[0,2pi) + tao = torch.autograd.Variable(torch.FloatTensor(np.random.uniform(0,2,batch)*np.pi).cuda()) #[0,2pi) + + + qw = torch.cos(theta/2)*torch.cos(tao/2) + qx = torch.cos(theta/2)*torch.sin(tao/2) + qy = torch.sin(theta/2)*torch.cos(phi+tao/2) + qz = torch.sin(theta/2)*torch.sin(phi+tao/2) + + # Unit quaternion rotation matrices computatation + xx = (qx*qx).view(batch,1) + yy = (qy*qy).view(batch,1) + zz = (qz*qz).view(batch,1) + xy = (qx*qy).view(batch,1) + xz = (qx*qz).view(batch,1) + yz = (qy*qz).view(batch,1) + xw = (qx*qw).view(batch,1) + yw = (qy*qw).view(batch,1) + zw = (qz*qw).view(batch,1) + + row0 = torch.cat((1-2*yy-2*zz, 2*xy - 2*zw, 2*xz + 2*yw), 1) #batch*3 + row1 = torch.cat((2*xy+ 2*zw, 1-2*xx-2*zz, 2*yz-2*xw ), 1) #batch*3 + row2 = torch.cat((2*xz-2*yw, 2*yz+2*xw, 1-2*xx-2*yy), 1) #batch*3 + + matrix = torch.cat((row0.view(batch, 1, 3), row1.view(batch,1,3), row2.view(batch,1,3)),1) #batch*3*3 + + return matrix + +#axisAngle batch*4 angle, x,y,z +def get_sampled_rotation_matrices_by_axisAngle( batch, return_quaternion=False): + + theta = torch.autograd.Variable(torch.FloatTensor(np.random.uniform(-1,1, batch)*np.pi).cuda()) #[0, pi] #[-180, 180] + sin = torch.sin(theta) + axis = torch.autograd.Variable(torch.randn(batch, 3).cuda()) + axis = normalize_vector(axis) #batch*3 + qw = torch.cos(theta) + qx = axis[:,0]*sin + qy = axis[:,1]*sin + qz = axis[:,2]*sin + + quaternion = torch.cat((qw.view(batch,1), qx.view(batch,1), qy.view(batch,1), qz.view(batch,1)), 1 ) + + # Unit quaternion rotation matrices computatation + xx = (qx*qx).view(batch,1) + yy = (qy*qy).view(batch,1) + zz = (qz*qz).view(batch,1) + xy = (qx*qy).view(batch,1) + xz = (qx*qz).view(batch,1) + yz = (qy*qz).view(batch,1) + xw = (qx*qw).view(batch,1) + yw = (qy*qw).view(batch,1) + zw = (qz*qw).view(batch,1) + + row0 = torch.cat((1-2*yy-2*zz, 2*xy - 2*zw, 2*xz + 2*yw), 1) #batch*3 + row1 = torch.cat((2*xy+ 2*zw, 1-2*xx-2*zz, 2*yz-2*xw ), 1) #batch*3 + row2 = torch.cat((2*xz-2*yw, 2*yz+2*xw, 1-2*xx-2*yy), 1) #batch*3 + + matrix = torch.cat((row0.view(batch, 1, 3), row1.view(batch,1,3), row2.view(batch,1,3)),1) #batch*3*3 + + if(return_quaternion==True): + return matrix, quaternion + else: + return matrix + + + + + + + + + diff --git a/src/StructDiffusion/utils/torch_data.py b/src/StructDiffusion/utils/torch_data.py new file mode 100644 index 0000000000000000000000000000000000000000..eae8f23fd7eec4f6e2a49f52f25fad0f8bfc0bbf --- /dev/null +++ b/src/StructDiffusion/utils/torch_data.py @@ -0,0 +1,185 @@ +r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers to +collate samples fetched from dataset into Tensor(s). + +These **needs** to be in global scope since Py2 doesn't support serializing +static methods. + +`default_collate` and `default_convert` are exposed to users via 'dataloader.py'. +""" + +import torch +import re +import collections +from torch._six import string_classes + +np_str_obj_array_pattern = re.compile(r'[SaUO]') + + +def default_convert(data): + r""" + Function that converts each NumPy array element into a :class:`torch.Tensor`. If the input is a `Sequence`, + `Collection`, or `Mapping`, it tries to convert each element inside to a :class:`torch.Tensor`. + If the input is not an NumPy array, it is left unchanged. + This is used as the default function for collation when both `batch_sampler` and + `batch_size` are NOT defined in :class:`~torch.utils.data.DataLoader`. + + The general input type to output type mapping is similar to that + of :func:`~torch.utils.data.default_collate`. See the description there for more details. + + Args: + data: a single data point to be converted + + Examples: + >>> # Example with `int` + >>> default_convert(0) + 0 + >>> # Example with NumPy array + >>> # xdoctest: +SKIP + >>> default_convert(np.array([0, 1])) + tensor([0, 1]) + >>> # Example with NamedTuple + >>> Point = namedtuple('Point', ['x', 'y']) + >>> default_convert(Point(0, 0)) + Point(x=0, y=0) + >>> default_convert(Point(np.array(0), np.array(0))) + Point(x=tensor(0), y=tensor(0)) + >>> # Example with List + >>> default_convert([np.array([0, 1]), np.array([2, 3])]) + [tensor([0, 1]), tensor([2, 3])] + """ + elem_type = type(data) + if isinstance(data, torch.Tensor): + return data + elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ + and elem_type.__name__ != 'string_': + # array of string classes and object + if elem_type.__name__ == 'ndarray' \ + and np_str_obj_array_pattern.search(data.dtype.str) is not None: + return data + return torch.as_tensor(data) + elif isinstance(data, collections.abc.Mapping): + try: + return elem_type({key: default_convert(data[key]) for key in data}) + except TypeError: + # The mapping type may not support `__init__(iterable)`. + return {key: default_convert(data[key]) for key in data} + elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple + return elem_type(*(default_convert(d) for d in data)) + elif isinstance(data, tuple): + return [default_convert(d) for d in data] # Backwards compatibility. + elif isinstance(data, collections.abc.Sequence) and not isinstance(data, string_classes): + try: + return elem_type([default_convert(d) for d in data]) + except TypeError: + # The sequence type may not support `__init__(iterable)` (e.g., `range`). + return [default_convert(d) for d in data] + else: + return data + + +default_collate_err_msg_format = ( + "default_collate: batch must contain tensors, numpy arrays, numbers, " + "dicts or lists; found {}") + + +def default_collate(batch): + r""" + Function that takes in a batch of data and puts the elements within the batch + into a tensor with an additional outer dimension - batch size. The exact output type can be + a :class:`torch.Tensor`, a `Sequence` of :class:`torch.Tensor`, a + Collection of :class:`torch.Tensor`, or left unchanged, depending on the input type. + This is used as the default function for collation when + `batch_size` or `batch_sampler` is defined in :class:`~torch.utils.data.DataLoader`. + + Here is the general input type (based on the type of the element within the batch) to output type mapping: + + * :class:`torch.Tensor` -> :class:`torch.Tensor` (with an added outer dimension batch size) + * NumPy Arrays -> :class:`torch.Tensor` + * `float` -> :class:`torch.Tensor` + * `int` -> :class:`torch.Tensor` + * `str` -> `str` (unchanged) + * `bytes` -> `bytes` (unchanged) + * `Mapping[K, V_i]` -> `Mapping[K, default_collate([V_1, V_2, ...])]` + * `NamedTuple[V1_i, V2_i, ...]` -> `NamedTuple[default_collate([V1_1, V1_2, ...]), + default_collate([V2_1, V2_2, ...]), ...]` + * `Sequence[V1_i, V2_i, ...]` -> `Sequence[default_collate([V1_1, V1_2, ...]), + default_collate([V2_1, V2_2, ...]), ...]` + + Args: + batch: a single batch to be collated + + Examples: + >>> # Example with a batch of `int`s: + >>> default_collate([0, 1, 2, 3]) + tensor([0, 1, 2, 3]) + >>> # Example with a batch of `str`s: + >>> default_collate(['a', 'b', 'c']) + ['a', 'b', 'c'] + >>> # Example with `Map` inside the batch: + >>> default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}]) + {'A': tensor([ 0, 100]), 'B': tensor([ 1, 100])} + >>> # Example with `NamedTuple` inside the batch: + >>> # xdoctest: +SKIP + >>> Point = namedtuple('Point', ['x', 'y']) + >>> default_collate([Point(0, 0), Point(1, 1)]) + Point(x=tensor([0, 1]), y=tensor([0, 1])) + >>> # Example with `Tuple` inside the batch: + >>> default_collate([(0, 1), (2, 3)]) + [tensor([0, 2]), tensor([1, 3])] + >>> # Example with `List` inside the batch: + >>> default_collate([[0, 1], [2, 3]]) + [tensor([0, 2]), tensor([1, 3])] + """ + elem = batch[0] + elem_type = type(elem) + if isinstance(elem, torch.Tensor): + out = None + if torch.utils.data.get_worker_info() is not None: + # If we're in a background process, concatenate directly into a + # shared memory tensor to avoid an extra copy + numel = sum(x.numel() for x in batch) + storage = elem.storage()._new_shared(numel, device=elem.device) + out = elem.new(storage).resize_(len(batch), *list(elem.size())) + return torch.stack(batch, 0, out=out) + elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ + and elem_type.__name__ != 'string_': + if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': + # array of string classes and object + if np_str_obj_array_pattern.search(elem.dtype.str) is not None: + raise TypeError(default_collate_err_msg_format.format(elem.dtype)) + + return default_collate([torch.as_tensor(b) for b in batch]) + elif elem.shape == (): # scalars + return torch.as_tensor(batch) + elif isinstance(elem, float): + return torch.tensor(batch, dtype=torch.float64) + elif isinstance(elem, int): + return torch.tensor(batch) + elif isinstance(elem, string_classes): + return batch + elif isinstance(elem, collections.abc.Mapping): + try: + return elem_type({key: default_collate([d[key] for d in batch]) for key in elem}) + except TypeError: + # The mapping type may not support `__init__(iterable)`. + return {key: default_collate([d[key] for d in batch]) for key in elem} + elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple + return elem_type(*(default_collate(samples) for samples in zip(*batch))) + elif isinstance(elem, collections.abc.Sequence): + # check to make sure that the elements in batch have consistent size + it = iter(batch) + elem_size = len(next(it)) + if not all(len(elem) == elem_size for elem in it): + raise RuntimeError('each element in list of batch should be of equal size') + transposed = list(zip(*batch)) # It may be accessed twice, so we use a list. + + if isinstance(elem, tuple): + return [default_collate(samples) for samples in transposed] # Backwards compatibility. + else: + try: + return elem_type([default_collate(samples) for samples in transposed]) + except TypeError: + # The sequence type may not support `__init__(iterable)` (e.g., `range`). + return [default_collate(samples) for samples in transposed] + + raise TypeError(default_collate_err_msg_format.format(elem_type)) \ No newline at end of file diff --git a/src/StructDiffusion/utils/transformations.py b/src/StructDiffusion/utils/transformations.py new file mode 100644 index 0000000000000000000000000000000000000000..6d6f19e5b909a4aa4a024b3356aaabf960071792 --- /dev/null +++ b/src/StructDiffusion/utils/transformations.py @@ -0,0 +1,1705 @@ +# -*- coding: utf-8 -*- +# transformations.py + +# Copyright (c) 2006, Christoph Gohlke +# Copyright (c) 2006-2009, The Regents of the University of California +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of the copyright holders nor the names of any +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. + +"""Homogeneous Transformation Matrices and Quaternions. + +A library for calculating 4x4 matrices for translating, rotating, reflecting, +scaling, shearing, projecting, orthogonalizing, and superimposing arrays of +3D homogeneous coordinates as well as for converting between rotation matrices, +Euler angles, and quaternions. Also includes an Arcball control object and +functions to decompose transformation matrices. + +:Authors: + `Christoph Gohlke `__, + Laboratory for Fluorescence Dynamics, University of California, Irvine + +:Version: 20090418 + +Requirements +------------ + +* `Python 2.6 `__ +* `Numpy 1.3 `__ +* `transformations.c 20090418 `__ + (optional implementation of some functions in C) + +Notes +----- + +Matrices (M) can be inverted using numpy.linalg.inv(M), concatenated using +numpy.dot(M0, M1), or used to transform homogeneous coordinates (v) using +numpy.dot(M, v) for shape (4, \*) "point of arrays", respectively +numpy.dot(v, M.T) for shape (\*, 4) "array of points". + +Calculations are carried out with numpy.float64 precision. + +This Python implementation is not optimized for speed. + +Vector, point, quaternion, and matrix function arguments are expected to be +"array like", i.e. tuple, list, or numpy arrays. + +Return types are numpy arrays unless specified otherwise. + +Angles are in radians unless specified otherwise. + +Quaternions ix+jy+kz+w are represented as [x, y, z, w]. + +Use the transpose of transformation matrices for OpenGL glMultMatrixd(). + +A triple of Euler angles can be applied/interpreted in 24 ways, which can +be specified using a 4 character string or encoded 4-tuple: + + *Axes 4-string*: e.g. 'sxyz' or 'ryxy' + + - first character : rotations are applied to 's'tatic or 'r'otating frame + - remaining characters : successive rotation axis 'x', 'y', or 'z' + + *Axes 4-tuple*: e.g. (0, 0, 0, 0) or (1, 1, 1, 1) + + - inner axis: code of axis ('x':0, 'y':1, 'z':2) of rightmost matrix. + - parity : even (0) if inner axis 'x' is followed by 'y', 'y' is followed + by 'z', or 'z' is followed by 'x'. Otherwise odd (1). + - repetition : first and last axis are same (1) or different (0). + - frame : rotations are applied to static (0) or rotating (1) frame. + +References +---------- + +(1) Matrices and transformations. Ronald Goldman. + In "Graphics Gems I", pp 472-475. Morgan Kaufmann, 1990. +(2) More matrices and transformations: shear and pseudo-perspective. + Ronald Goldman. In "Graphics Gems II", pp 320-323. Morgan Kaufmann, 1991. +(3) Decomposing a matrix into simple transformations. Spencer Thomas. + In "Graphics Gems II", pp 320-323. Morgan Kaufmann, 1991. +(4) Recovering the data from the transformation matrix. Ronald Goldman. + In "Graphics Gems II", pp 324-331. Morgan Kaufmann, 1991. +(5) Euler angle conversion. Ken Shoemake. + In "Graphics Gems IV", pp 222-229. Morgan Kaufmann, 1994. +(6) Arcball rotation control. Ken Shoemake. + In "Graphics Gems IV", pp 175-192. Morgan Kaufmann, 1994. +(7) Representing attitude: Euler angles, unit quaternions, and rotation + vectors. James Diebel. 2006. +(8) A discussion of the solution for the best rotation to relate two sets + of vectors. W Kabsch. Acta Cryst. 1978. A34, 827-828. +(9) Closed-form solution of absolute orientation using unit quaternions. + BKP Horn. J Opt Soc Am A. 1987. 4(4), 629-642. +(10) Quaternions. Ken Shoemake. + http://www.sfu.ca/~jwa3/cmpt461/files/quatut.pdf +(11) From quaternion to matrix and back. JMP van Waveren. 2005. + http://www.intel.com/cd/ids/developer/asmo-na/eng/293748.htm +(12) Uniform random rotations. Ken Shoemake. + In "Graphics Gems III", pp 124-132. Morgan Kaufmann, 1992. + + +Examples +-------- + +>>> alpha, beta, gamma = 0.123, -1.234, 2.345 +>>> origin, xaxis, yaxis, zaxis = (0, 0, 0), (1, 0, 0), (0, 1, 0), (0, 0, 1) +>>> I = identity_matrix() +>>> Rx = rotation_matrix(alpha, xaxis) +>>> Ry = rotation_matrix(beta, yaxis) +>>> Rz = rotation_matrix(gamma, zaxis) +>>> R = concatenate_matrices(Rx, Ry, Rz) +>>> euler = euler_from_matrix(R, 'rxyz') +>>> numpy.allclose([alpha, beta, gamma], euler) +True +>>> Re = euler_matrix(alpha, beta, gamma, 'rxyz') +>>> is_same_transform(R, Re) +True +>>> al, be, ga = euler_from_matrix(Re, 'rxyz') +>>> is_same_transform(Re, euler_matrix(al, be, ga, 'rxyz')) +True +>>> qx = quaternion_about_axis(alpha, xaxis) +>>> qy = quaternion_about_axis(beta, yaxis) +>>> qz = quaternion_about_axis(gamma, zaxis) +>>> q = quaternion_multiply(qx, qy) +>>> q = quaternion_multiply(q, qz) +>>> Rq = quaternion_matrix(q) +>>> is_same_transform(R, Rq) +True +>>> S = scale_matrix(1.23, origin) +>>> T = translation_matrix((1, 2, 3)) +>>> Z = shear_matrix(beta, xaxis, origin, zaxis) +>>> R = random_rotation_matrix(numpy.random.rand(3)) +>>> M = concatenate_matrices(T, R, Z, S) +>>> scale, shear, angles, trans, persp = decompose_matrix(M) +>>> numpy.allclose(scale, 1.23) +True +>>> numpy.allclose(trans, (1, 2, 3)) +True +>>> numpy.allclose(shear, (0, math.tan(beta), 0)) +True +>>> is_same_transform(R, euler_matrix(axes='sxyz', *angles)) +True +>>> M1 = compose_matrix(scale, shear, angles, trans, persp) +>>> is_same_transform(M, M1) +True + +""" + +from __future__ import division + +import warnings +import math + +import numpy + +# Documentation in HTML format can be generated with Epydoc +__docformat__ = "restructuredtext en" + + +def identity_matrix(): + """Return 4x4 identity/unit matrix. + + >>> I = identity_matrix() + >>> numpy.allclose(I, numpy.dot(I, I)) + True + >>> numpy.sum(I), numpy.trace(I) + (4.0, 4.0) + >>> numpy.allclose(I, numpy.identity(4, dtype=numpy.float64)) + True + + """ + return numpy.identity(4, dtype=numpy.float64) + + +def translation_matrix(direction): + """Return matrix to translate by direction vector. + + >>> v = numpy.random.random(3) - 0.5 + >>> numpy.allclose(v, translation_matrix(v)[:3, 3]) + True + + """ + M = numpy.identity(4) + M[:3, 3] = direction[:3] + return M + + +def translation_from_matrix(matrix): + """Return translation vector from translation matrix. + + >>> v0 = numpy.random.random(3) - 0.5 + >>> v1 = translation_from_matrix(translation_matrix(v0)) + >>> numpy.allclose(v0, v1) + True + + """ + return numpy.array(matrix, copy=False)[:3, 3].copy() + + +def reflection_matrix(point, normal): + """Return matrix to mirror at plane defined by point and normal vector. + + >>> v0 = numpy.random.random(4) - 0.5 + >>> v0[3] = 1.0 + >>> v1 = numpy.random.random(3) - 0.5 + >>> R = reflection_matrix(v0, v1) + >>> numpy.allclose(2., numpy.trace(R)) + True + >>> numpy.allclose(v0, numpy.dot(R, v0)) + True + >>> v2 = v0.copy() + >>> v2[:3] += v1 + >>> v3 = v0.copy() + >>> v2[:3] -= v1 + >>> numpy.allclose(v2, numpy.dot(R, v3)) + True + + """ + normal = unit_vector(normal[:3]) + M = numpy.identity(4) + M[:3, :3] -= 2.0 * numpy.outer(normal, normal) + M[:3, 3] = (2.0 * numpy.dot(point[:3], normal)) * normal + return M + + +def reflection_from_matrix(matrix): + """Return mirror plane point and normal vector from reflection matrix. + + >>> v0 = numpy.random.random(3) - 0.5 + >>> v1 = numpy.random.random(3) - 0.5 + >>> M0 = reflection_matrix(v0, v1) + >>> point, normal = reflection_from_matrix(M0) + >>> M1 = reflection_matrix(point, normal) + >>> is_same_transform(M0, M1) + True + + """ + M = numpy.array(matrix, dtype=numpy.float64, copy=False) + # normal: unit eigenvector corresponding to eigenvalue -1 + l, V = numpy.linalg.eig(M[:3, :3]) + i = numpy.where(abs(numpy.real(l) + 1.0) < 1e-8)[0] + if not len(i): + raise ValueError("no unit eigenvector corresponding to eigenvalue -1") + normal = numpy.real(V[:, i[0]]).squeeze() + # point: any unit eigenvector corresponding to eigenvalue 1 + l, V = numpy.linalg.eig(M) + i = numpy.where(abs(numpy.real(l) - 1.0) < 1e-8)[0] + if not len(i): + raise ValueError("no unit eigenvector corresponding to eigenvalue 1") + point = numpy.real(V[:, i[-1]]).squeeze() + point /= point[3] + return point, normal + + +def rotation_matrix(angle, direction, point=None): + """Return matrix to rotate about axis defined by point and direction. + + >>> angle = (random.random() - 0.5) * (2*math.pi) + >>> direc = numpy.random.random(3) - 0.5 + >>> point = numpy.random.random(3) - 0.5 + >>> R0 = rotation_matrix(angle, direc, point) + >>> R1 = rotation_matrix(angle-2*math.pi, direc, point) + >>> is_same_transform(R0, R1) + True + >>> R0 = rotation_matrix(angle, direc, point) + >>> R1 = rotation_matrix(-angle, -direc, point) + >>> is_same_transform(R0, R1) + True + >>> I = numpy.identity(4, numpy.float64) + >>> numpy.allclose(I, rotation_matrix(math.pi*2, direc)) + True + >>> numpy.allclose(2., numpy.trace(rotation_matrix(math.pi/2, + ... direc, point))) + True + + """ + sina = math.sin(angle) + cosa = math.cos(angle) + direction = unit_vector(direction[:3]) + # rotation matrix around unit vector + R = numpy.array(((cosa, 0.0, 0.0), + (0.0, cosa, 0.0), + (0.0, 0.0, cosa)), dtype=numpy.float64) + R += numpy.outer(direction, direction) * (1.0 - cosa) + direction *= sina + R += numpy.array((( 0.0, -direction[2], direction[1]), + ( direction[2], 0.0, -direction[0]), + (-direction[1], direction[0], 0.0)), + dtype=numpy.float64) + M = numpy.identity(4) + M[:3, :3] = R + if point is not None: + # rotation not around origin + point = numpy.array(point[:3], dtype=numpy.float64, copy=False) + M[:3, 3] = point - numpy.dot(R, point) + return M + + +def rotation_from_matrix(matrix): + """Return rotation angle and axis from rotation matrix. + + >>> angle = (random.random() - 0.5) * (2*math.pi) + >>> direc = numpy.random.random(3) - 0.5 + >>> point = numpy.random.random(3) - 0.5 + >>> R0 = rotation_matrix(angle, direc, point) + >>> angle, direc, point = rotation_from_matrix(R0) + >>> R1 = rotation_matrix(angle, direc, point) + >>> is_same_transform(R0, R1) + True + + """ + R = numpy.array(matrix, dtype=numpy.float64, copy=False) + R33 = R[:3, :3] + # direction: unit eigenvector of R33 corresponding to eigenvalue of 1 + l, W = numpy.linalg.eig(R33.T) + i = numpy.where(abs(numpy.real(l) - 1.0) < 1e-8)[0] + if not len(i): + raise ValueError("no unit eigenvector corresponding to eigenvalue 1") + direction = numpy.real(W[:, i[-1]]).squeeze() + # point: unit eigenvector of R33 corresponding to eigenvalue of 1 + l, Q = numpy.linalg.eig(R) + i = numpy.where(abs(numpy.real(l) - 1.0) < 1e-8)[0] + if not len(i): + raise ValueError("no unit eigenvector corresponding to eigenvalue 1") + point = numpy.real(Q[:, i[-1]]).squeeze() + point /= point[3] + # rotation angle depending on direction + cosa = (numpy.trace(R33) - 1.0) / 2.0 + if abs(direction[2]) > 1e-8: + sina = (R[1, 0] + (cosa-1.0)*direction[0]*direction[1]) / direction[2] + elif abs(direction[1]) > 1e-8: + sina = (R[0, 2] + (cosa-1.0)*direction[0]*direction[2]) / direction[1] + else: + sina = (R[2, 1] + (cosa-1.0)*direction[1]*direction[2]) / direction[0] + angle = math.atan2(sina, cosa) + return angle, direction, point + + +def scale_matrix(factor, origin=None, direction=None): + """Return matrix to scale by factor around origin in direction. + + Use factor -1 for point symmetry. + + >>> v = (numpy.random.rand(4, 5) - 0.5) * 20.0 + >>> v[3] = 1.0 + >>> S = scale_matrix(-1.234) + >>> numpy.allclose(numpy.dot(S, v)[:3], -1.234*v[:3]) + True + >>> factor = random.random() * 10 - 5 + >>> origin = numpy.random.random(3) - 0.5 + >>> direct = numpy.random.random(3) - 0.5 + >>> S = scale_matrix(factor, origin) + >>> S = scale_matrix(factor, origin, direct) + + """ + if direction is None: + # uniform scaling + M = numpy.array(((factor, 0.0, 0.0, 0.0), + (0.0, factor, 0.0, 0.0), + (0.0, 0.0, factor, 0.0), + (0.0, 0.0, 0.0, 1.0)), dtype=numpy.float64) + if origin is not None: + M[:3, 3] = origin[:3] + M[:3, 3] *= 1.0 - factor + else: + # nonuniform scaling + direction = unit_vector(direction[:3]) + factor = 1.0 - factor + M = numpy.identity(4) + M[:3, :3] -= factor * numpy.outer(direction, direction) + if origin is not None: + M[:3, 3] = (factor * numpy.dot(origin[:3], direction)) * direction + return M + + +def scale_from_matrix(matrix): + """Return scaling factor, origin and direction from scaling matrix. + + >>> factor = random.random() * 10 - 5 + >>> origin = numpy.random.random(3) - 0.5 + >>> direct = numpy.random.random(3) - 0.5 + >>> S0 = scale_matrix(factor, origin) + >>> factor, origin, direction = scale_from_matrix(S0) + >>> S1 = scale_matrix(factor, origin, direction) + >>> is_same_transform(S0, S1) + True + >>> S0 = scale_matrix(factor, origin, direct) + >>> factor, origin, direction = scale_from_matrix(S0) + >>> S1 = scale_matrix(factor, origin, direction) + >>> is_same_transform(S0, S1) + True + + """ + M = numpy.array(matrix, dtype=numpy.float64, copy=False) + M33 = M[:3, :3] + factor = numpy.trace(M33) - 2.0 + try: + # direction: unit eigenvector corresponding to eigenvalue factor + l, V = numpy.linalg.eig(M33) + i = numpy.where(abs(numpy.real(l) - factor) < 1e-8)[0][0] + direction = numpy.real(V[:, i]).squeeze() + direction /= vector_norm(direction) + except IndexError: + # uniform scaling + factor = (factor + 2.0) / 3.0 + direction = None + # origin: any eigenvector corresponding to eigenvalue 1 + l, V = numpy.linalg.eig(M) + i = numpy.where(abs(numpy.real(l) - 1.0) < 1e-8)[0] + if not len(i): + raise ValueError("no eigenvector corresponding to eigenvalue 1") + origin = numpy.real(V[:, i[-1]]).squeeze() + origin /= origin[3] + return factor, origin, direction + + +def projection_matrix(point, normal, direction=None, + perspective=None, pseudo=False): + """Return matrix to project onto plane defined by point and normal. + + Using either perspective point, projection direction, or none of both. + + If pseudo is True, perspective projections will preserve relative depth + such that Perspective = dot(Orthogonal, PseudoPerspective). + + >>> P = projection_matrix((0, 0, 0), (1, 0, 0)) + >>> numpy.allclose(P[1:, 1:], numpy.identity(4)[1:, 1:]) + True + >>> point = numpy.random.random(3) - 0.5 + >>> normal = numpy.random.random(3) - 0.5 + >>> direct = numpy.random.random(3) - 0.5 + >>> persp = numpy.random.random(3) - 0.5 + >>> P0 = projection_matrix(point, normal) + >>> P1 = projection_matrix(point, normal, direction=direct) + >>> P2 = projection_matrix(point, normal, perspective=persp) + >>> P3 = projection_matrix(point, normal, perspective=persp, pseudo=True) + >>> is_same_transform(P2, numpy.dot(P0, P3)) + True + >>> P = projection_matrix((3, 0, 0), (1, 1, 0), (1, 0, 0)) + >>> v0 = (numpy.random.rand(4, 5) - 0.5) * 20.0 + >>> v0[3] = 1.0 + >>> v1 = numpy.dot(P, v0) + >>> numpy.allclose(v1[1], v0[1]) + True + >>> numpy.allclose(v1[0], 3.0-v1[1]) + True + + """ + M = numpy.identity(4) + point = numpy.array(point[:3], dtype=numpy.float64, copy=False) + normal = unit_vector(normal[:3]) + if perspective is not None: + # perspective projection + perspective = numpy.array(perspective[:3], dtype=numpy.float64, + copy=False) + M[0, 0] = M[1, 1] = M[2, 2] = numpy.dot(perspective-point, normal) + M[:3, :3] -= numpy.outer(perspective, normal) + if pseudo: + # preserve relative depth + M[:3, :3] -= numpy.outer(normal, normal) + M[:3, 3] = numpy.dot(point, normal) * (perspective+normal) + else: + M[:3, 3] = numpy.dot(point, normal) * perspective + M[3, :3] = -normal + M[3, 3] = numpy.dot(perspective, normal) + elif direction is not None: + # parallel projection + direction = numpy.array(direction[:3], dtype=numpy.float64, copy=False) + scale = numpy.dot(direction, normal) + M[:3, :3] -= numpy.outer(direction, normal) / scale + M[:3, 3] = direction * (numpy.dot(point, normal) / scale) + else: + # orthogonal projection + M[:3, :3] -= numpy.outer(normal, normal) + M[:3, 3] = numpy.dot(point, normal) * normal + return M + + +def projection_from_matrix(matrix, pseudo=False): + """Return projection plane and perspective point from projection matrix. + + Return values are same as arguments for projection_matrix function: + point, normal, direction, perspective, and pseudo. + + >>> point = numpy.random.random(3) - 0.5 + >>> normal = numpy.random.random(3) - 0.5 + >>> direct = numpy.random.random(3) - 0.5 + >>> persp = numpy.random.random(3) - 0.5 + >>> P0 = projection_matrix(point, normal) + >>> result = projection_from_matrix(P0) + >>> P1 = projection_matrix(*result) + >>> is_same_transform(P0, P1) + True + >>> P0 = projection_matrix(point, normal, direct) + >>> result = projection_from_matrix(P0) + >>> P1 = projection_matrix(*result) + >>> is_same_transform(P0, P1) + True + >>> P0 = projection_matrix(point, normal, perspective=persp, pseudo=False) + >>> result = projection_from_matrix(P0, pseudo=False) + >>> P1 = projection_matrix(*result) + >>> is_same_transform(P0, P1) + True + >>> P0 = projection_matrix(point, normal, perspective=persp, pseudo=True) + >>> result = projection_from_matrix(P0, pseudo=True) + >>> P1 = projection_matrix(*result) + >>> is_same_transform(P0, P1) + True + + """ + M = numpy.array(matrix, dtype=numpy.float64, copy=False) + M33 = M[:3, :3] + l, V = numpy.linalg.eig(M) + i = numpy.where(abs(numpy.real(l) - 1.0) < 1e-8)[0] + if not pseudo and len(i): + # point: any eigenvector corresponding to eigenvalue 1 + point = numpy.real(V[:, i[-1]]).squeeze() + point /= point[3] + # direction: unit eigenvector corresponding to eigenvalue 0 + l, V = numpy.linalg.eig(M33) + i = numpy.where(abs(numpy.real(l)) < 1e-8)[0] + if not len(i): + raise ValueError("no eigenvector corresponding to eigenvalue 0") + direction = numpy.real(V[:, i[0]]).squeeze() + direction /= vector_norm(direction) + # normal: unit eigenvector of M33.T corresponding to eigenvalue 0 + l, V = numpy.linalg.eig(M33.T) + i = numpy.where(abs(numpy.real(l)) < 1e-8)[0] + if len(i): + # parallel projection + normal = numpy.real(V[:, i[0]]).squeeze() + normal /= vector_norm(normal) + return point, normal, direction, None, False + else: + # orthogonal projection, where normal equals direction vector + return point, direction, None, None, False + else: + # perspective projection + i = numpy.where(abs(numpy.real(l)) > 1e-8)[0] + if not len(i): + raise ValueError( + "no eigenvector not corresponding to eigenvalue 0") + point = numpy.real(V[:, i[-1]]).squeeze() + point /= point[3] + normal = - M[3, :3] + perspective = M[:3, 3] / numpy.dot(point[:3], normal) + if pseudo: + perspective -= normal + return point, normal, None, perspective, pseudo + + +def clip_matrix(left, right, bottom, top, near, far, perspective=False): + """Return matrix to obtain normalized device coordinates from frustrum. + + The frustrum bounds are axis-aligned along x (left, right), + y (bottom, top) and z (near, far). + + Normalized device coordinates are in range [-1, 1] if coordinates are + inside the frustrum. + + If perspective is True the frustrum is a truncated pyramid with the + perspective point at origin and direction along z axis, otherwise an + orthographic canonical view volume (a box). + + Homogeneous coordinates transformed by the perspective clip matrix + need to be dehomogenized (devided by w coordinate). + + >>> frustrum = numpy.random.rand(6) + >>> frustrum[1] += frustrum[0] + >>> frustrum[3] += frustrum[2] + >>> frustrum[5] += frustrum[4] + >>> M = clip_matrix(*frustrum, perspective=False) + >>> numpy.dot(M, [frustrum[0], frustrum[2], frustrum[4], 1.0]) + array([-1., -1., -1., 1.]) + >>> numpy.dot(M, [frustrum[1], frustrum[3], frustrum[5], 1.0]) + array([ 1., 1., 1., 1.]) + >>> M = clip_matrix(*frustrum, perspective=True) + >>> v = numpy.dot(M, [frustrum[0], frustrum[2], frustrum[4], 1.0]) + >>> v / v[3] + array([-1., -1., -1., 1.]) + >>> v = numpy.dot(M, [frustrum[1], frustrum[3], frustrum[4], 1.0]) + >>> v / v[3] + array([ 1., 1., -1., 1.]) + + """ + if left >= right or bottom >= top or near >= far: + raise ValueError("invalid frustrum") + if perspective: + if near <= _EPS: + raise ValueError("invalid frustrum: near <= 0") + t = 2.0 * near + M = ((-t/(right-left), 0.0, (right+left)/(right-left), 0.0), + (0.0, -t/(top-bottom), (top+bottom)/(top-bottom), 0.0), + (0.0, 0.0, -(far+near)/(far-near), t*far/(far-near)), + (0.0, 0.0, -1.0, 0.0)) + else: + M = ((2.0/(right-left), 0.0, 0.0, (right+left)/(left-right)), + (0.0, 2.0/(top-bottom), 0.0, (top+bottom)/(bottom-top)), + (0.0, 0.0, 2.0/(far-near), (far+near)/(near-far)), + (0.0, 0.0, 0.0, 1.0)) + return numpy.array(M, dtype=numpy.float64) + + +def shear_matrix(angle, direction, point, normal): + """Return matrix to shear by angle along direction vector on shear plane. + + The shear plane is defined by a point and normal vector. The direction + vector must be orthogonal to the plane's normal vector. + + A point P is transformed by the shear matrix into P" such that + the vector P-P" is parallel to the direction vector and its extent is + given by the angle of P-P'-P", where P' is the orthogonal projection + of P onto the shear plane. + + >>> angle = (random.random() - 0.5) * 4*math.pi + >>> direct = numpy.random.random(3) - 0.5 + >>> point = numpy.random.random(3) - 0.5 + >>> normal = numpy.cross(direct, numpy.random.random(3)) + >>> S = shear_matrix(angle, direct, point, normal) + >>> numpy.allclose(1.0, numpy.linalg.det(S)) + True + + """ + normal = unit_vector(normal[:3]) + direction = unit_vector(direction[:3]) + if abs(numpy.dot(normal, direction)) > 1e-6: + raise ValueError("direction and normal vectors are not orthogonal") + angle = math.tan(angle) + M = numpy.identity(4) + M[:3, :3] += angle * numpy.outer(direction, normal) + M[:3, 3] = -angle * numpy.dot(point[:3], normal) * direction + return M + + +def shear_from_matrix(matrix): + """Return shear angle, direction and plane from shear matrix. + + >>> angle = (random.random() - 0.5) * 4*math.pi + >>> direct = numpy.random.random(3) - 0.5 + >>> point = numpy.random.random(3) - 0.5 + >>> normal = numpy.cross(direct, numpy.random.random(3)) + >>> S0 = shear_matrix(angle, direct, point, normal) + >>> angle, direct, point, normal = shear_from_matrix(S0) + >>> S1 = shear_matrix(angle, direct, point, normal) + >>> is_same_transform(S0, S1) + True + + """ + M = numpy.array(matrix, dtype=numpy.float64, copy=False) + M33 = M[:3, :3] + # normal: cross independent eigenvectors corresponding to the eigenvalue 1 + l, V = numpy.linalg.eig(M33) + i = numpy.where(abs(numpy.real(l) - 1.0) < 1e-4)[0] + if len(i) < 2: + raise ValueError("No two linear independent eigenvectors found %s" % l) + V = numpy.real(V[:, i]).squeeze().T + lenorm = -1.0 + for i0, i1 in ((0, 1), (0, 2), (1, 2)): + n = numpy.cross(V[i0], V[i1]) + l = vector_norm(n) + if l > lenorm: + lenorm = l + normal = n + normal /= lenorm + # direction and angle + direction = numpy.dot(M33 - numpy.identity(3), normal) + angle = vector_norm(direction) + direction /= angle + angle = math.atan(angle) + # point: eigenvector corresponding to eigenvalue 1 + l, V = numpy.linalg.eig(M) + i = numpy.where(abs(numpy.real(l) - 1.0) < 1e-8)[0] + if not len(i): + raise ValueError("no eigenvector corresponding to eigenvalue 1") + point = numpy.real(V[:, i[-1]]).squeeze() + point /= point[3] + return angle, direction, point, normal + + +def decompose_matrix(matrix): + """Return sequence of transformations from transformation matrix. + + matrix : array_like + Non-degenerative homogeneous transformation matrix + + Return tuple of: + scale : vector of 3 scaling factors + shear : list of shear factors for x-y, x-z, y-z axes + angles : list of Euler angles about static x, y, z axes + translate : translation vector along x, y, z axes + perspective : perspective partition of matrix + + Raise ValueError if matrix is of wrong type or degenerative. + + >>> T0 = translation_matrix((1, 2, 3)) + >>> scale, shear, angles, trans, persp = decompose_matrix(T0) + >>> T1 = translation_matrix(trans) + >>> numpy.allclose(T0, T1) + True + >>> S = scale_matrix(0.123) + >>> scale, shear, angles, trans, persp = decompose_matrix(S) + >>> scale[0] + 0.123 + >>> R0 = euler_matrix(1, 2, 3) + >>> scale, shear, angles, trans, persp = decompose_matrix(R0) + >>> R1 = euler_matrix(*angles) + >>> numpy.allclose(R0, R1) + True + + """ + M = numpy.array(matrix, dtype=numpy.float64, copy=True).T + if abs(M[3, 3]) < _EPS: + raise ValueError("M[3, 3] is zero") + M /= M[3, 3] + P = M.copy() + P[:, 3] = 0, 0, 0, 1 + if not numpy.linalg.det(P): + raise ValueError("Matrix is singular") + + scale = numpy.zeros((3, ), dtype=numpy.float64) + shear = [0, 0, 0] + angles = [0, 0, 0] + + if any(abs(M[:3, 3]) > _EPS): + perspective = numpy.dot(M[:, 3], numpy.linalg.inv(P.T)) + M[:, 3] = 0, 0, 0, 1 + else: + perspective = numpy.array((0, 0, 0, 1), dtype=numpy.float64) + + translate = M[3, :3].copy() + M[3, :3] = 0 + + row = M[:3, :3].copy() + scale[0] = vector_norm(row[0]) + row[0] /= scale[0] + shear[0] = numpy.dot(row[0], row[1]) + row[1] -= row[0] * shear[0] + scale[1] = vector_norm(row[1]) + row[1] /= scale[1] + shear[0] /= scale[1] + shear[1] = numpy.dot(row[0], row[2]) + row[2] -= row[0] * shear[1] + shear[2] = numpy.dot(row[1], row[2]) + row[2] -= row[1] * shear[2] + scale[2] = vector_norm(row[2]) + row[2] /= scale[2] + shear[1:] /= scale[2] + + if numpy.dot(row[0], numpy.cross(row[1], row[2])) < 0: + scale *= -1 + row *= -1 + + angles[1] = math.asin(-row[0, 2]) + if math.cos(angles[1]): + angles[0] = math.atan2(row[1, 2], row[2, 2]) + angles[2] = math.atan2(row[0, 1], row[0, 0]) + else: + #angles[0] = math.atan2(row[1, 0], row[1, 1]) + angles[0] = math.atan2(-row[2, 1], row[1, 1]) + angles[2] = 0.0 + + return scale, shear, angles, translate, perspective + + +def compose_matrix(scale=None, shear=None, angles=None, translate=None, + perspective=None): + """Return transformation matrix from sequence of transformations. + + This is the inverse of the decompose_matrix function. + + Sequence of transformations: + scale : vector of 3 scaling factors + shear : list of shear factors for x-y, x-z, y-z axes + angles : list of Euler angles about static x, y, z axes + translate : translation vector along x, y, z axes + perspective : perspective partition of matrix + + >>> scale = numpy.random.random(3) - 0.5 + >>> shear = numpy.random.random(3) - 0.5 + >>> angles = (numpy.random.random(3) - 0.5) * (2*math.pi) + >>> trans = numpy.random.random(3) - 0.5 + >>> persp = numpy.random.random(4) - 0.5 + >>> M0 = compose_matrix(scale, shear, angles, trans, persp) + >>> result = decompose_matrix(M0) + >>> M1 = compose_matrix(*result) + >>> is_same_transform(M0, M1) + True + + """ + M = numpy.identity(4) + if perspective is not None: + P = numpy.identity(4) + P[3, :] = perspective[:4] + M = numpy.dot(M, P) + if translate is not None: + T = numpy.identity(4) + T[:3, 3] = translate[:3] + M = numpy.dot(M, T) + if angles is not None: + R = euler_matrix(angles[0], angles[1], angles[2], 'sxyz') + M = numpy.dot(M, R) + if shear is not None: + Z = numpy.identity(4) + Z[1, 2] = shear[2] + Z[0, 2] = shear[1] + Z[0, 1] = shear[0] + M = numpy.dot(M, Z) + if scale is not None: + S = numpy.identity(4) + S[0, 0] = scale[0] + S[1, 1] = scale[1] + S[2, 2] = scale[2] + M = numpy.dot(M, S) + M /= M[3, 3] + return M + + +def orthogonalization_matrix(lengths, angles): + """Return orthogonalization matrix for crystallographic cell coordinates. + + Angles are expected in degrees. + + The de-orthogonalization matrix is the inverse. + + >>> O = orthogonalization_matrix((10., 10., 10.), (90., 90., 90.)) + >>> numpy.allclose(O[:3, :3], numpy.identity(3, float) * 10) + True + >>> O = orthogonalization_matrix([9.8, 12.0, 15.5], [87.2, 80.7, 69.7]) + >>> numpy.allclose(numpy.sum(O), 43.063229) + True + + """ + a, b, c = lengths + angles = numpy.radians(angles) + sina, sinb, _ = numpy.sin(angles) + cosa, cosb, cosg = numpy.cos(angles) + co = (cosa * cosb - cosg) / (sina * sinb) + return numpy.array(( + ( a*sinb*math.sqrt(1.0-co*co), 0.0, 0.0, 0.0), + (-a*sinb*co, b*sina, 0.0, 0.0), + ( a*cosb, b*cosa, c, 0.0), + ( 0.0, 0.0, 0.0, 1.0)), + dtype=numpy.float64) + + +def superimposition_matrix(v0, v1, scaling=False, usesvd=True): + """Return matrix to transform given vector set into second vector set. + + v0 and v1 are shape (3, \*) or (4, \*) arrays of at least 3 vectors. + + If usesvd is True, the weighted sum of squared deviations (RMSD) is + minimized according to the algorithm by W. Kabsch [8]. Otherwise the + quaternion based algorithm by B. Horn [9] is used (slower when using + this Python implementation). + + The returned matrix performs rotation, translation and uniform scaling + (if specified). + + >>> v0 = numpy.random.rand(3, 10) + >>> M = superimposition_matrix(v0, v0) + >>> numpy.allclose(M, numpy.identity(4)) + True + >>> R = random_rotation_matrix(numpy.random.random(3)) + >>> v0 = ((1,0,0), (0,1,0), (0,0,1), (1,1,1)) + >>> v1 = numpy.dot(R, v0) + >>> M = superimposition_matrix(v0, v1) + >>> numpy.allclose(v1, numpy.dot(M, v0)) + True + >>> v0 = (numpy.random.rand(4, 100) - 0.5) * 20.0 + >>> v0[3] = 1.0 + >>> v1 = numpy.dot(R, v0) + >>> M = superimposition_matrix(v0, v1) + >>> numpy.allclose(v1, numpy.dot(M, v0)) + True + >>> S = scale_matrix(random.random()) + >>> T = translation_matrix(numpy.random.random(3)-0.5) + >>> M = concatenate_matrices(T, R, S) + >>> v1 = numpy.dot(M, v0) + >>> v0[:3] += numpy.random.normal(0.0, 1e-9, 300).reshape(3, -1) + >>> M = superimposition_matrix(v0, v1, scaling=True) + >>> numpy.allclose(v1, numpy.dot(M, v0)) + True + >>> M = superimposition_matrix(v0, v1, scaling=True, usesvd=False) + >>> numpy.allclose(v1, numpy.dot(M, v0)) + True + >>> v = numpy.empty((4, 100, 3), dtype=numpy.float64) + >>> v[:, :, 0] = v0 + >>> M = superimposition_matrix(v0, v1, scaling=True, usesvd=False) + >>> numpy.allclose(v1, numpy.dot(M, v[:, :, 0])) + True + + """ + v0 = numpy.array(v0, dtype=numpy.float64, copy=False)[:3] + v1 = numpy.array(v1, dtype=numpy.float64, copy=False)[:3] + + if v0.shape != v1.shape or v0.shape[1] < 3: + raise ValueError("Vector sets are of wrong shape or type.") + + # move centroids to origin + t0 = numpy.mean(v0, axis=1) + t1 = numpy.mean(v1, axis=1) + v0 = v0 - t0.reshape(3, 1) + v1 = v1 - t1.reshape(3, 1) + + if usesvd: + # Singular Value Decomposition of covariance matrix + u, s, vh = numpy.linalg.svd(numpy.dot(v1, v0.T)) + # rotation matrix from SVD orthonormal bases + R = numpy.dot(u, vh) + if numpy.linalg.det(R) < 0.0: + # R does not constitute right handed system + R -= numpy.outer(u[:, 2], vh[2, :]*2.0) + s[-1] *= -1.0 + # homogeneous transformation matrix + M = numpy.identity(4) + M[:3, :3] = R + else: + # compute symmetric matrix N + xx, yy, zz = numpy.sum(v0 * v1, axis=1) + xy, yz, zx = numpy.sum(v0 * numpy.roll(v1, -1, axis=0), axis=1) + xz, yx, zy = numpy.sum(v0 * numpy.roll(v1, -2, axis=0), axis=1) + N = ((xx+yy+zz, yz-zy, zx-xz, xy-yx), + (yz-zy, xx-yy-zz, xy+yx, zx+xz), + (zx-xz, xy+yx, -xx+yy-zz, yz+zy), + (xy-yx, zx+xz, yz+zy, -xx-yy+zz)) + # quaternion: eigenvector corresponding to most positive eigenvalue + l, V = numpy.linalg.eig(N) + q = V[:, numpy.argmax(l)] + q /= vector_norm(q) # unit quaternion + q = numpy.roll(q, -1) # move w component to end + # homogeneous transformation matrix + M = quaternion_matrix(q) + + # scale: ratio of rms deviations from centroid + if scaling: + v0 *= v0 + v1 *= v1 + M[:3, :3] *= math.sqrt(numpy.sum(v1) / numpy.sum(v0)) + + # translation + M[:3, 3] = t1 + T = numpy.identity(4) + T[:3, 3] = -t0 + M = numpy.dot(M, T) + return M + + +def euler_matrix(ai, aj, ak, axes='sxyz'): + """Return homogeneous rotation matrix from Euler angles and axis sequence. + + ai, aj, ak : Euler's roll, pitch and yaw angles + axes : One of 24 axis sequences as string or encoded tuple + + >>> R = euler_matrix(1, 2, 3, 'syxz') + >>> numpy.allclose(numpy.sum(R[0]), -1.34786452) + True + >>> R = euler_matrix(1, 2, 3, (0, 1, 0, 1)) + >>> numpy.allclose(numpy.sum(R[0]), -0.383436184) + True + >>> ai, aj, ak = (4.0*math.pi) * (numpy.random.random(3) - 0.5) + >>> for axes in _AXES2TUPLE.keys(): + ... R = euler_matrix(ai, aj, ak, axes) + >>> for axes in _TUPLE2AXES.keys(): + ... R = euler_matrix(ai, aj, ak, axes) + + """ + try: + firstaxis, parity, repetition, frame = _AXES2TUPLE[axes] + except (AttributeError, KeyError): + _ = _TUPLE2AXES[axes] + firstaxis, parity, repetition, frame = axes + + i = firstaxis + j = _NEXT_AXIS[i+parity] + k = _NEXT_AXIS[i-parity+1] + + if frame: + ai, ak = ak, ai + if parity: + ai, aj, ak = -ai, -aj, -ak + + si, sj, sk = math.sin(ai), math.sin(aj), math.sin(ak) + ci, cj, ck = math.cos(ai), math.cos(aj), math.cos(ak) + cc, cs = ci*ck, ci*sk + sc, ss = si*ck, si*sk + + M = numpy.identity(4) + if repetition: + M[i, i] = cj + M[i, j] = sj*si + M[i, k] = sj*ci + M[j, i] = sj*sk + M[j, j] = -cj*ss+cc + M[j, k] = -cj*cs-sc + M[k, i] = -sj*ck + M[k, j] = cj*sc+cs + M[k, k] = cj*cc-ss + else: + M[i, i] = cj*ck + M[i, j] = sj*sc-cs + M[i, k] = sj*cc+ss + M[j, i] = cj*sk + M[j, j] = sj*ss+cc + M[j, k] = sj*cs-sc + M[k, i] = -sj + M[k, j] = cj*si + M[k, k] = cj*ci + return M + + +def euler_from_matrix(matrix, axes='sxyz'): + """Return Euler angles from rotation matrix for specified axis sequence. + + axes : One of 24 axis sequences as string or encoded tuple + + Note that many Euler angle triplets can describe one matrix. + + >>> R0 = euler_matrix(1, 2, 3, 'syxz') + >>> al, be, ga = euler_from_matrix(R0, 'syxz') + >>> R1 = euler_matrix(al, be, ga, 'syxz') + >>> numpy.allclose(R0, R1) + True + >>> angles = (4.0*math.pi) * (numpy.random.random(3) - 0.5) + >>> for axes in _AXES2TUPLE.keys(): + ... R0 = euler_matrix(axes=axes, *angles) + ... R1 = euler_matrix(axes=axes, *euler_from_matrix(R0, axes)) + ... if not numpy.allclose(R0, R1): print axes, "failed" + + """ + try: + firstaxis, parity, repetition, frame = _AXES2TUPLE[axes.lower()] + except (AttributeError, KeyError): + _ = _TUPLE2AXES[axes] + firstaxis, parity, repetition, frame = axes + + i = firstaxis + j = _NEXT_AXIS[i+parity] + k = _NEXT_AXIS[i-parity+1] + + M = numpy.array(matrix, dtype=numpy.float64, copy=False)[:3, :3] + if repetition: + sy = math.sqrt(M[i, j]*M[i, j] + M[i, k]*M[i, k]) + if sy > _EPS: + ax = math.atan2( M[i, j], M[i, k]) + ay = math.atan2( sy, M[i, i]) + az = math.atan2( M[j, i], -M[k, i]) + else: + ax = math.atan2(-M[j, k], M[j, j]) + ay = math.atan2( sy, M[i, i]) + az = 0.0 + else: + cy = math.sqrt(M[i, i]*M[i, i] + M[j, i]*M[j, i]) + if cy > _EPS: + ax = math.atan2( M[k, j], M[k, k]) + ay = math.atan2(-M[k, i], cy) + az = math.atan2( M[j, i], M[i, i]) + else: + ax = math.atan2(-M[j, k], M[j, j]) + ay = math.atan2(-M[k, i], cy) + az = 0.0 + + if parity: + ax, ay, az = -ax, -ay, -az + if frame: + ax, az = az, ax + return ax, ay, az + + +def euler_from_quaternion(quaternion, axes='sxyz'): + """Return Euler angles from quaternion for specified axis sequence. + + >>> angles = euler_from_quaternion([0.06146124, 0, 0, 0.99810947]) + >>> numpy.allclose(angles, [0.123, 0, 0]) + True + + """ + return euler_from_matrix(quaternion_matrix(quaternion), axes) + + +def quaternion_from_euler(ai, aj, ak, axes='sxyz'): + """Return quaternion from Euler angles and axis sequence. + + ai, aj, ak : Euler's roll, pitch and yaw angles + axes : One of 24 axis sequences as string or encoded tuple + + >>> q = quaternion_from_euler(1, 2, 3, 'ryxz') + >>> numpy.allclose(q, [0.310622, -0.718287, 0.444435, 0.435953]) + True + + """ + try: + firstaxis, parity, repetition, frame = _AXES2TUPLE[axes.lower()] + except (AttributeError, KeyError): + _ = _TUPLE2AXES[axes] + firstaxis, parity, repetition, frame = axes + + i = firstaxis + j = _NEXT_AXIS[i+parity] + k = _NEXT_AXIS[i-parity+1] + + if frame: + ai, ak = ak, ai + if parity: + aj = -aj + + ai /= 2.0 + aj /= 2.0 + ak /= 2.0 + ci = math.cos(ai) + si = math.sin(ai) + cj = math.cos(aj) + sj = math.sin(aj) + ck = math.cos(ak) + sk = math.sin(ak) + cc = ci*ck + cs = ci*sk + sc = si*ck + ss = si*sk + + quaternion = numpy.empty((4, ), dtype=numpy.float64) + if repetition: + quaternion[i] = cj*(cs + sc) + quaternion[j] = sj*(cc + ss) + quaternion[k] = sj*(cs - sc) + quaternion[3] = cj*(cc - ss) + else: + quaternion[i] = cj*sc - sj*cs + quaternion[j] = cj*ss + sj*cc + quaternion[k] = cj*cs - sj*sc + quaternion[3] = cj*cc + sj*ss + if parity: + quaternion[j] *= -1 + + return quaternion + + +def quaternion_about_axis(angle, axis): + """Return quaternion for rotation about axis. + + >>> q = quaternion_about_axis(0.123, (1, 0, 0)) + >>> numpy.allclose(q, [0.06146124, 0, 0, 0.99810947]) + True + + """ + quaternion = numpy.zeros((4, ), dtype=numpy.float64) + quaternion[:3] = axis[:3] + qlen = vector_norm(quaternion) + if qlen > _EPS: + quaternion *= math.sin(angle/2.0) / qlen + quaternion[3] = math.cos(angle/2.0) + return quaternion + + +def quaternion_matrix(quaternion): + """Return homogeneous rotation matrix from quaternion. + + >>> R = quaternion_matrix([0.06146124, 0, 0, 0.99810947]) + >>> numpy.allclose(R, rotation_matrix(0.123, (1, 0, 0))) + True + + """ + q = numpy.array(quaternion[:4], dtype=numpy.float64, copy=True) + nq = numpy.dot(q, q) + if nq < _EPS: + return numpy.identity(4) + q *= math.sqrt(2.0 / nq) + q = numpy.outer(q, q) + return numpy.array(( + (1.0-q[1, 1]-q[2, 2], q[0, 1]-q[2, 3], q[0, 2]+q[1, 3], 0.0), + ( q[0, 1]+q[2, 3], 1.0-q[0, 0]-q[2, 2], q[1, 2]-q[0, 3], 0.0), + ( q[0, 2]-q[1, 3], q[1, 2]+q[0, 3], 1.0-q[0, 0]-q[1, 1], 0.0), + ( 0.0, 0.0, 0.0, 1.0) + ), dtype=numpy.float64) + + +def quaternion_from_matrix(matrix): + """Return quaternion from rotation matrix. + + >>> R = rotation_matrix(0.123, (1, 2, 3)) + >>> q = quaternion_from_matrix(R) + >>> numpy.allclose(q, [0.0164262, 0.0328524, 0.0492786, 0.9981095]) + True + + """ + q = numpy.empty((4, ), dtype=numpy.float64) + M = numpy.array(matrix, dtype=numpy.float64, copy=False)[:4, :4] + t = numpy.trace(M) + if t > M[3, 3]: + q[3] = t + q[2] = M[1, 0] - M[0, 1] + q[1] = M[0, 2] - M[2, 0] + q[0] = M[2, 1] - M[1, 2] + else: + i, j, k = 0, 1, 2 + if M[1, 1] > M[0, 0]: + i, j, k = 1, 2, 0 + if M[2, 2] > M[i, i]: + i, j, k = 2, 0, 1 + t = M[i, i] - (M[j, j] + M[k, k]) + M[3, 3] + q[i] = t + q[j] = M[i, j] + M[j, i] + q[k] = M[k, i] + M[i, k] + q[3] = M[k, j] - M[j, k] + q *= 0.5 / math.sqrt(t * M[3, 3]) + return q + + +def quaternion_multiply(quaternion1, quaternion0): + """Return multiplication of two quaternions. + + >>> q = quaternion_multiply([1, -2, 3, 4], [-5, 6, 7, 8]) + >>> numpy.allclose(q, [-44, -14, 48, 28]) + True + + """ + x0, y0, z0, w0 = quaternion0 + x1, y1, z1, w1 = quaternion1 + return numpy.array(( + x1*w0 + y1*z0 - z1*y0 + w1*x0, + -x1*z0 + y1*w0 + z1*x0 + w1*y0, + x1*y0 - y1*x0 + z1*w0 + w1*z0, + -x1*x0 - y1*y0 - z1*z0 + w1*w0), dtype=numpy.float64) + + +def quaternion_conjugate(quaternion): + """Return conjugate of quaternion. + + >>> q0 = random_quaternion() + >>> q1 = quaternion_conjugate(q0) + >>> q1[3] == q0[3] and all(q1[:3] == -q0[:3]) + True + + """ + return numpy.array((-quaternion[0], -quaternion[1], + -quaternion[2], quaternion[3]), dtype=numpy.float64) + + +def quaternion_inverse(quaternion): + """Return inverse of quaternion. + + >>> q0 = random_quaternion() + >>> q1 = quaternion_inverse(q0) + >>> numpy.allclose(quaternion_multiply(q0, q1), [0, 0, 0, 1]) + True + + """ + return quaternion_conjugate(quaternion) / numpy.dot(quaternion, quaternion) + + +def quaternion_slerp(quat0, quat1, fraction, spin=0, shortestpath=True): + """Return spherical linear interpolation between two quaternions. + + >>> q0 = random_quaternion() + >>> q1 = random_quaternion() + >>> q = quaternion_slerp(q0, q1, 0.0) + >>> numpy.allclose(q, q0) + True + >>> q = quaternion_slerp(q0, q1, 1.0, 1) + >>> numpy.allclose(q, q1) + True + >>> q = quaternion_slerp(q0, q1, 0.5) + >>> angle = math.acos(numpy.dot(q0, q)) + >>> numpy.allclose(2.0, math.acos(numpy.dot(q0, q1)) / angle) or \ + numpy.allclose(2.0, math.acos(-numpy.dot(q0, q1)) / angle) + True + + """ + q0 = unit_vector(quat0[:4]) + q1 = unit_vector(quat1[:4]) + if fraction == 0.0: + return q0 + elif fraction == 1.0: + return q1 + d = numpy.dot(q0, q1) + if abs(abs(d) - 1.0) < _EPS: + return q0 + if shortestpath and d < 0.0: + # invert rotation + d = -d + q1 *= -1.0 + angle = math.acos(d) + spin * math.pi + if abs(angle) < _EPS: + return q0 + isin = 1.0 / math.sin(angle) + q0 *= math.sin((1.0 - fraction) * angle) * isin + q1 *= math.sin(fraction * angle) * isin + q0 += q1 + return q0 + + +def random_quaternion(rand=None): + """Return uniform random unit quaternion. + + rand: array like or None + Three independent random variables that are uniformly distributed + between 0 and 1. + + >>> q = random_quaternion() + >>> numpy.allclose(1.0, vector_norm(q)) + True + >>> q = random_quaternion(numpy.random.random(3)) + >>> q.shape + (4,) + + """ + if rand is None: + rand = numpy.random.rand(3) + else: + assert len(rand) == 3 + r1 = numpy.sqrt(1.0 - rand[0]) + r2 = numpy.sqrt(rand[0]) + pi2 = math.pi * 2.0 + t1 = pi2 * rand[1] + t2 = pi2 * rand[2] + return numpy.array((numpy.sin(t1)*r1, + numpy.cos(t1)*r1, + numpy.sin(t2)*r2, + numpy.cos(t2)*r2), dtype=numpy.float64) + + +def random_rotation_matrix(rand=None): + """Return uniform random rotation matrix. + + rnd: array like + Three independent random variables that are uniformly distributed + between 0 and 1 for each returned quaternion. + + >>> R = random_rotation_matrix() + >>> numpy.allclose(numpy.dot(R.T, R), numpy.identity(4)) + True + + """ + return quaternion_matrix(random_quaternion(rand)) + + +class Arcball(object): + """Virtual Trackball Control. + + >>> ball = Arcball() + >>> ball = Arcball(initial=numpy.identity(4)) + >>> ball.place([320, 320], 320) + >>> ball.down([500, 250]) + >>> ball.drag([475, 275]) + >>> R = ball.matrix() + >>> numpy.allclose(numpy.sum(R), 3.90583455) + True + >>> ball = Arcball(initial=[0, 0, 0, 1]) + >>> ball.place([320, 320], 320) + >>> ball.setaxes([1,1,0], [-1, 1, 0]) + >>> ball.setconstrain(True) + >>> ball.down([400, 200]) + >>> ball.drag([200, 400]) + >>> R = ball.matrix() + >>> numpy.allclose(numpy.sum(R), 0.2055924) + True + >>> ball.next() + + """ + + def __init__(self, initial=None): + """Initialize virtual trackball control. + + initial : quaternion or rotation matrix + + """ + self._axis = None + self._axes = None + self._radius = 1.0 + self._center = [0.0, 0.0] + self._vdown = numpy.array([0, 0, 1], dtype=numpy.float64) + self._constrain = False + + if initial is None: + self._qdown = numpy.array([0, 0, 0, 1], dtype=numpy.float64) + else: + initial = numpy.array(initial, dtype=numpy.float64) + if initial.shape == (4, 4): + self._qdown = quaternion_from_matrix(initial) + elif initial.shape == (4, ): + initial /= vector_norm(initial) + self._qdown = initial + else: + raise ValueError("initial not a quaternion or matrix.") + + self._qnow = self._qpre = self._qdown + + def place(self, center, radius): + """Place Arcball, e.g. when window size changes. + + center : sequence[2] + Window coordinates of trackball center. + radius : float + Radius of trackball in window coordinates. + + """ + self._radius = float(radius) + self._center[0] = center[0] + self._center[1] = center[1] + + def setaxes(self, *axes): + """Set axes to constrain rotations.""" + if axes is None: + self._axes = None + else: + self._axes = [unit_vector(axis) for axis in axes] + + def setconstrain(self, constrain): + """Set state of constrain to axis mode.""" + self._constrain = constrain == True + + def getconstrain(self): + """Return state of constrain to axis mode.""" + return self._constrain + + def down(self, point): + """Set initial cursor window coordinates and pick constrain-axis.""" + self._vdown = arcball_map_to_sphere(point, self._center, self._radius) + self._qdown = self._qpre = self._qnow + + if self._constrain and self._axes is not None: + self._axis = arcball_nearest_axis(self._vdown, self._axes) + self._vdown = arcball_constrain_to_axis(self._vdown, self._axis) + else: + self._axis = None + + def drag(self, point): + """Update current cursor window coordinates.""" + vnow = arcball_map_to_sphere(point, self._center, self._radius) + + if self._axis is not None: + vnow = arcball_constrain_to_axis(vnow, self._axis) + + self._qpre = self._qnow + + t = numpy.cross(self._vdown, vnow) + if numpy.dot(t, t) < _EPS: + self._qnow = self._qdown + else: + q = [t[0], t[1], t[2], numpy.dot(self._vdown, vnow)] + self._qnow = quaternion_multiply(q, self._qdown) + + def next(self, acceleration=0.0): + """Continue rotation in direction of last drag.""" + q = quaternion_slerp(self._qpre, self._qnow, 2.0+acceleration, False) + self._qpre, self._qnow = self._qnow, q + + def matrix(self): + """Return homogeneous rotation matrix.""" + return quaternion_matrix(self._qnow) + + +def arcball_map_to_sphere(point, center, radius): + """Return unit sphere coordinates from window coordinates.""" + v = numpy.array(((point[0] - center[0]) / radius, + (center[1] - point[1]) / radius, + 0.0), dtype=numpy.float64) + n = v[0]*v[0] + v[1]*v[1] + if n > 1.0: + v /= math.sqrt(n) # position outside of sphere + else: + v[2] = math.sqrt(1.0 - n) + return v + + +def arcball_constrain_to_axis(point, axis): + """Return sphere point perpendicular to axis.""" + v = numpy.array(point, dtype=numpy.float64, copy=True) + a = numpy.array(axis, dtype=numpy.float64, copy=True) + v -= a * numpy.dot(a, v) # on plane + n = vector_norm(v) + if n > _EPS: + if v[2] < 0.0: + v *= -1.0 + v /= n + return v + if a[2] == 1.0: + return numpy.array([1, 0, 0], dtype=numpy.float64) + return unit_vector([-a[1], a[0], 0]) + + +def arcball_nearest_axis(point, axes): + """Return axis, which arc is nearest to point.""" + point = numpy.array(point, dtype=numpy.float64, copy=False) + nearest = None + mx = -1.0 + for axis in axes: + t = numpy.dot(arcball_constrain_to_axis(point, axis), point) + if t > mx: + nearest = axis + mx = t + return nearest + + +# epsilon for testing whether a number is close to zero +_EPS = numpy.finfo(float).eps * 4.0 + +# axis sequences for Euler angles +_NEXT_AXIS = [1, 2, 0, 1] + +# map axes strings to/from tuples of inner axis, parity, repetition, frame +_AXES2TUPLE = { + 'sxyz': (0, 0, 0, 0), 'sxyx': (0, 0, 1, 0), 'sxzy': (0, 1, 0, 0), + 'sxzx': (0, 1, 1, 0), 'syzx': (1, 0, 0, 0), 'syzy': (1, 0, 1, 0), + 'syxz': (1, 1, 0, 0), 'syxy': (1, 1, 1, 0), 'szxy': (2, 0, 0, 0), + 'szxz': (2, 0, 1, 0), 'szyx': (2, 1, 0, 0), 'szyz': (2, 1, 1, 0), + 'rzyx': (0, 0, 0, 1), 'rxyx': (0, 0, 1, 1), 'ryzx': (0, 1, 0, 1), + 'rxzx': (0, 1, 1, 1), 'rxzy': (1, 0, 0, 1), 'ryzy': (1, 0, 1, 1), + 'rzxy': (1, 1, 0, 1), 'ryxy': (1, 1, 1, 1), 'ryxz': (2, 0, 0, 1), + 'rzxz': (2, 0, 1, 1), 'rxyz': (2, 1, 0, 1), 'rzyz': (2, 1, 1, 1)} + +_TUPLE2AXES = dict((v, k) for k, v in _AXES2TUPLE.items()) + +# helper functions + +def vector_norm(data, axis=None, out=None): + """Return length, i.e. eucledian norm, of ndarray along axis. + + >>> v = numpy.random.random(3) + >>> n = vector_norm(v) + >>> numpy.allclose(n, numpy.linalg.norm(v)) + True + >>> v = numpy.random.rand(6, 5, 3) + >>> n = vector_norm(v, axis=-1) + >>> numpy.allclose(n, numpy.sqrt(numpy.sum(v*v, axis=2))) + True + >>> n = vector_norm(v, axis=1) + >>> numpy.allclose(n, numpy.sqrt(numpy.sum(v*v, axis=1))) + True + >>> v = numpy.random.rand(5, 4, 3) + >>> n = numpy.empty((5, 3), dtype=numpy.float64) + >>> vector_norm(v, axis=1, out=n) + >>> numpy.allclose(n, numpy.sqrt(numpy.sum(v*v, axis=1))) + True + >>> vector_norm([]) + 0.0 + >>> vector_norm([1.0]) + 1.0 + + """ + data = numpy.array(data, dtype=numpy.float64, copy=True) + if out is None: + if data.ndim == 1: + return math.sqrt(numpy.dot(data, data)) + data *= data + out = numpy.atleast_1d(numpy.sum(data, axis=axis)) + numpy.sqrt(out, out) + return out + else: + data *= data + numpy.sum(data, axis=axis, out=out) + numpy.sqrt(out, out) + + +def unit_vector(data, axis=None, out=None): + """Return ndarray normalized by length, i.e. eucledian norm, along axis. + + >>> v0 = numpy.random.random(3) + >>> v1 = unit_vector(v0) + >>> numpy.allclose(v1, v0 / numpy.linalg.norm(v0)) + True + >>> v0 = numpy.random.rand(5, 4, 3) + >>> v1 = unit_vector(v0, axis=-1) + >>> v2 = v0 / numpy.expand_dims(numpy.sqrt(numpy.sum(v0*v0, axis=2)), 2) + >>> numpy.allclose(v1, v2) + True + >>> v1 = unit_vector(v0, axis=1) + >>> v2 = v0 / numpy.expand_dims(numpy.sqrt(numpy.sum(v0*v0, axis=1)), 1) + >>> numpy.allclose(v1, v2) + True + >>> v1 = numpy.empty((5, 4, 3), dtype=numpy.float64) + >>> unit_vector(v0, axis=1, out=v1) + >>> numpy.allclose(v1, v2) + True + >>> list(unit_vector([])) + [] + >>> list(unit_vector([1.0])) + [1.0] + + """ + if out is None: + data = numpy.array(data, dtype=numpy.float64, copy=True) + if data.ndim == 1: + data /= math.sqrt(numpy.dot(data, data)) + return data + else: + if out is not data: + out[:] = numpy.array(data, copy=False) + data = out + length = numpy.atleast_1d(numpy.sum(data*data, axis)) + numpy.sqrt(length, length) + if axis is not None: + length = numpy.expand_dims(length, axis) + data /= length + if out is None: + return data + + +def random_vector(size): + """Return array of random doubles in the half-open interval [0.0, 1.0). + + >>> v = random_vector(10000) + >>> numpy.all(v >= 0.0) and numpy.all(v < 1.0) + True + >>> v0 = random_vector(10) + >>> v1 = random_vector(10) + >>> numpy.any(v0 == v1) + False + + """ + return numpy.random.random(size) + + +def inverse_matrix(matrix): + """Return inverse of square transformation matrix. + + >>> M0 = random_rotation_matrix() + >>> M1 = inverse_matrix(M0.T) + >>> numpy.allclose(M1, numpy.linalg.inv(M0.T)) + True + >>> for size in range(1, 7): + ... M0 = numpy.random.rand(size, size) + ... M1 = inverse_matrix(M0) + ... if not numpy.allclose(M1, numpy.linalg.inv(M0)): print size + + """ + return numpy.linalg.inv(matrix) + + +def concatenate_matrices(*matrices): + """Return concatenation of series of transformation matrices. + + >>> M = numpy.random.rand(16).reshape((4, 4)) - 0.5 + >>> numpy.allclose(M, concatenate_matrices(M)) + True + >>> numpy.allclose(numpy.dot(M, M.T), concatenate_matrices(M, M.T)) + True + + """ + M = numpy.identity(4) + for i in matrices: + M = numpy.dot(M, i) + return M + + +def is_same_transform(matrix0, matrix1): + """Return True if two matrices perform same transformation. + + >>> is_same_transform(numpy.identity(4), numpy.identity(4)) + True + >>> is_same_transform(numpy.identity(4), random_rotation_matrix()) + False + + """ + matrix0 = numpy.array(matrix0, dtype=numpy.float64, copy=True) + matrix0 /= matrix0[3, 3] + matrix1 = numpy.array(matrix1, dtype=numpy.float64, copy=True) + matrix1 /= matrix1[3, 3] + return numpy.allclose(matrix0, matrix1) + + +def _import_module(module_name, warn=True, prefix='_py_', ignore='_'): + """Try import all public attributes from module into global namespace. + + Existing attributes with name clashes are renamed with prefix. + Attributes starting with underscore are ignored by default. + + Return True on successful import. + + """ + try: + module = __import__(module_name) + except ImportError: + if warn: + warnings.warn("Failed to import module " + module_name) + else: + for attr in dir(module): + if ignore and attr.startswith(ignore): + continue + if prefix: + if attr in globals(): + globals()[prefix + attr] = globals()[attr] + elif warn: + warnings.warn("No Python implementation of " + attr) + globals()[attr] = getattr(module, attr) + return True diff --git a/tmp_data/scene.glb b/tmp_data/scene.glb new file mode 100644 index 0000000000000000000000000000000000000000..80c11dfc2e335f39ea2e3520eaac2ffa9fa84332 Binary files /dev/null and b/tmp_data/scene.glb differ diff --git a/wandb_logs/StructDiffusion/ConditionalPoseDiffusion/checkpoints/epoch=191-step=96000.ckpt b/wandb_logs/StructDiffusion/ConditionalPoseDiffusion/checkpoints/epoch=191-step=96000.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..0776091759fcafc1a2c92801aad2b7656a9486c3 --- /dev/null +++ b/wandb_logs/StructDiffusion/ConditionalPoseDiffusion/checkpoints/epoch=191-step=96000.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b05ed7f605c30c240051c22db501ede6ea2fea79b3b66a23bce28fb01defa463 +size 59444321