xiexh20 commited on
Commit
2fd6166
1 Parent(s): 12a785b

add hdm demo v1

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +15 -12
  2. app.py +177 -0
  3. configs/__init__.py +0 -0
  4. configs/structured.py +416 -0
  5. dataset/__init__.py +301 -0
  6. dataset/base_data.py +110 -0
  7. dataset/behave_paths.py +228 -0
  8. dataset/demo_dataset.py +198 -0
  9. dataset/img_utils.py +149 -0
  10. demo.py +280 -0
  11. diffusion_utils.py +313 -0
  12. examples/017450/k1.color.jpg +0 -0
  13. examples/017450/k1.obj_rend_mask.png +0 -0
  14. examples/017450/k1.person_mask.png +0 -0
  15. model/__init__.py +28 -0
  16. model/feature_model.py +160 -0
  17. model/model.py +303 -0
  18. model/model_coloring.py +84 -0
  19. model/model_diff_data.py +238 -0
  20. model/model_hoattn.py +457 -0
  21. model/model_utils.py +58 -0
  22. model/point_cloud_model.py +67 -0
  23. model/point_cloud_transformer_model.py +80 -0
  24. model/projection_model.py +273 -0
  25. model/pvcnn/__init__.py +0 -0
  26. model/pvcnn/modules/__init__.py +8 -0
  27. model/pvcnn/modules/ball_query.py +69 -0
  28. model/pvcnn/modules/frustum.py +138 -0
  29. model/pvcnn/modules/functional/__init__.py +7 -0
  30. model/pvcnn/modules/functional/backend.py +33 -0
  31. model/pvcnn/modules/functional/ball_query.py +19 -0
  32. model/pvcnn/modules/functional/devoxelization.py +42 -0
  33. model/pvcnn/modules/functional/grouping.py +32 -0
  34. model/pvcnn/modules/functional/interpolatation.py +38 -0
  35. model/pvcnn/modules/functional/loss.py +17 -0
  36. model/pvcnn/modules/functional/sampling.py +84 -0
  37. model/pvcnn/modules/functional/src/ball_query/ball_query.cpp +30 -0
  38. model/pvcnn/modules/functional/src/ball_query/ball_query.cu +59 -0
  39. model/pvcnn/modules/functional/src/ball_query/ball_query.cuh +8 -0
  40. model/pvcnn/modules/functional/src/ball_query/ball_query.hpp +10 -0
  41. model/pvcnn/modules/functional/src/bindings.cpp +37 -0
  42. model/pvcnn/modules/functional/src/cuda_utils.cuh +39 -0
  43. model/pvcnn/modules/functional/src/grouping/grouping.cpp +44 -0
  44. model/pvcnn/modules/functional/src/grouping/grouping.cu +85 -0
  45. model/pvcnn/modules/functional/src/grouping/grouping.cuh +9 -0
  46. model/pvcnn/modules/functional/src/grouping/grouping.hpp +10 -0
  47. model/pvcnn/modules/functional/src/interpolate/neighbor_interpolate.cpp +65 -0
  48. model/pvcnn/modules/functional/src/interpolate/neighbor_interpolate.cu +181 -0
  49. model/pvcnn/modules/functional/src/interpolate/neighbor_interpolate.cuh +16 -0
  50. model/pvcnn/modules/functional/src/interpolate/neighbor_interpolate.hpp +16 -0
README.md CHANGED
@@ -1,13 +1,16 @@
1
- ---
2
- title: HDM Interaction Recon
3
- emoji: 🌍
4
- colorFrom: yellow
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 4.20.1
8
- app_file: app.py
9
- pinned: false
10
- license: cc-by-nc-4.0
11
- ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HDM
2
+ Official implementation for Hierarachical Diffusion Model in CVPR24 Template free reconstruction of human object interaction
 
 
 
 
 
 
 
 
 
3
 
4
+ [Project Page](https://virtualhumans.mpi-inf.mpg.de/procigen-hdm/)|[Code](https://github.com/xiexh20/HDM)|[Dataset](https://edmond.mpg.de/dataset.xhtml?persistentId=doi:10.17617/3.2VUEUS )|[Paper](https://virtualhumans.mpi-inf.mpg.de/procigen-hdm/paper-lowreso.pdf)
5
+
6
+
7
+ ## Citation
8
+ ```
9
+ @inproceedings{xie2023template_free,
10
+ title = {Template Free Reconstruction of Human-object Interaction with Procedural Interaction Generation},
11
+ author = {Xie, Xianghui and Bhatnagar, Bharat Lal and Lenssen, Jan Eric and Pons-Moll, Gerard},
12
+ booktitle = {IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
13
+ month = {June},
14
+ year = {2024},
15
+ }
16
+ ```
app.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Demo built with gradio
3
+ """
4
+ import pickle as pkl
5
+ import sys, os
6
+ import os.path as osp
7
+ from typing import Iterable, Optional
8
+ from functools import partial
9
+
10
+ import trimesh
11
+ from torch.utils.data import DataLoader
12
+ import cv2
13
+ from accelerate import Accelerator
14
+ from tqdm import tqdm
15
+ from glob import glob
16
+
17
+ sys.path.append(os.getcwd())
18
+ import hydra
19
+ import torch
20
+ import numpy as np
21
+ import imageio
22
+ import gradio as gr
23
+ import plotly.graph_objs as go
24
+ import training_utils
25
+
26
+ from configs.structured import ProjectConfig
27
+ from demo import DemoRunner
28
+ from dataset.demo_dataset import DemoDataset
29
+
30
+
31
+ md_description="""
32
+ # HDM Interaction Reconstruction Demo
33
+ ### Official Implementation of the paper \"Template Free Reconstruction of Human Object Interaction\", CVPR'24.
34
+ [Project Page](https://virtualhumans.mpi-inf.mpg.de/procigen-hdm/)|[Code](https://github.com/xiexh20/HDM)|[Dataset](https://edmond.mpg.de/dataset.xhtml?persistentId=doi:10.17617/3.2VUEUS )|[Paper](https://virtualhumans.mpi-inf.mpg.de/procigen-hdm/paper-lowreso.pdf)
35
+
36
+
37
+ Upload your own human object interaction image and get full 3D reconstruction!
38
+
39
+ ## Citation
40
+ ```
41
+ @inproceedings{xie2023template_free,
42
+ title = {Template Free Reconstruction of Human-object Interaction with Procedural Interaction Generation},
43
+ author = {Xie, Xianghui and Bhatnagar, Bharat Lal and Lenssen, Jan Eric and Pons-Moll, Gerard},
44
+ booktitle = {IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
45
+ month = {June},
46
+ year = {2024},
47
+ }
48
+ ```
49
+ """
50
+
51
+ def plot_points(colors, coords):
52
+ """
53
+ use plotly to visualize 3D point with colors
54
+ """
55
+ trace = go.Scatter3d(x=coords[:, 0], y=coords[:, 1], z=coords[:, 2], mode='markers',
56
+ marker=dict(
57
+ size=2,
58
+ color=colors
59
+ ))
60
+ layout = go.Layout(
61
+ scene=dict(
62
+ xaxis=dict(
63
+ title="",
64
+ showgrid=False,
65
+ zeroline=False,
66
+ showline=False,
67
+ ticks='',
68
+ showticklabels=False
69
+ ),
70
+ yaxis=dict(
71
+ title="",
72
+ showgrid=False,
73
+ zeroline=False,
74
+ showline=False,
75
+ ticks='',
76
+ showticklabels=False
77
+ ),
78
+ zaxis=dict(
79
+ title="",
80
+ showgrid=False,
81
+ zeroline=False,
82
+ showline=False,
83
+ ticks='',
84
+ showticklabels=False
85
+ ),
86
+ ),
87
+ margin=dict(l=0, r=0, b=0, t=0),
88
+ showlegend=False
89
+ )
90
+ fig = go.Figure(data=[trace], layout=layout)
91
+ return fig
92
+
93
+
94
+ def inference(runner: DemoRunner, cfg: ProjectConfig, rgb, mask_hum, mask_obj, std_coverage, input_seed):
95
+ """
96
+ given user input, run inference
97
+ :param runner:
98
+ :param cfg:
99
+ :param rgb: (h, w, 3), np array
100
+ :param mask_hum: (h, w, 3), np array
101
+ :param mask_obj: (h, w, 3), np array
102
+ :param std_coverage: float value, used to estimate camera translation
103
+ :param input_seed: random seed
104
+ :return: path to the 3D reconstruction, and an interactive 3D figure for visualizing the point cloud
105
+ """
106
+ # Set random seed
107
+ training_utils.set_seed(int(input_seed))
108
+
109
+ data = DemoDataset([], (cfg.dataset.image_size, cfg.dataset.image_size),
110
+ std_coverage)
111
+ batch = data.image2batch(rgb, mask_hum, mask_obj)
112
+
113
+ out_stage1, out_stage2 = runner.forward_batch(batch, cfg)
114
+ points = out_stage2.points_packed().cpu().numpy()
115
+ colors = out_stage2.features_packed().cpu().numpy()
116
+ fig = plot_points(colors, points)
117
+ # save tmp point cloud
118
+ outdir = './results'
119
+ os.makedirs(outdir, exist_ok=True)
120
+ trimesh.PointCloud(points, colors).export(outdir + f"/pred_std{std_coverage}_seed{input_seed}_stage2.ply")
121
+ trimesh.PointCloud(out_stage1.points_packed().cpu().numpy(),
122
+ out_stage1.features_packed().cpu().numpy()).export(outdir + f"/pred_std{std_coverage}_seed{input_seed}_stage1.ply")
123
+ return fig, outdir + f"/pred_std{std_coverage}_seed{input_seed}_stage2.ply"
124
+
125
+
126
+ @hydra.main(config_path='configs', config_name='configs', version_base='1.1')
127
+ def main(cfg: ProjectConfig):
128
+ # Setup model
129
+ runner = DemoRunner(cfg)
130
+
131
+ # Setup interface
132
+ demo = gr.Blocks(title="HDM Interaction Reconstruction Demo")
133
+ with demo:
134
+ gr.Markdown(md_description)
135
+ gr.HTML("""<h1 style="text-align:center; color:#10768c">HDM Demo</h1>""")
136
+ gr.HTML("""<h3 style="text-align:center; color:#10768c">Instruction: Upload RGB, human, object masks and then click reconstruct.</h1>""")
137
+
138
+ # Input data
139
+ with gr.Row():
140
+ input_rgb = gr.Image(label='Input RGB', type='numpy')
141
+ input_mask_hum = gr.Image(label='Human mask', type='numpy')
142
+ with gr.Row():
143
+ input_mask_obj = gr.Image(label='Object mask', type='numpy')
144
+ with gr.Column():
145
+ # TODO: add hint for this value here
146
+ input_std = gr.Number(label='Gaussian std coverage', value=3.5)
147
+ input_seed = gr.Number(label='Random seed', value=42)
148
+ # Output visualization
149
+ with gr.Row():
150
+ pc_plot = gr.Plot(label="Reconstructed point cloud")
151
+ out_pc_download = gr.File(label="3D reconstruction for download") # this allows downloading
152
+
153
+ gr.HTML("""<br/>""")
154
+ # Control
155
+ with gr.Row():
156
+ button_recon = gr.Button("Start Reconstruction", interactive=True, variant='secondary')
157
+ button_recon.click(fn=partial(inference, runner, cfg),
158
+ inputs=[input_rgb, input_mask_hum, input_mask_obj, input_std, input_seed],
159
+ outputs=[pc_plot, out_pc_download])
160
+ gr.HTML("""<br/>""")
161
+ # Example input
162
+ example_dir = cfg.run.code_dir_abs+"/examples"
163
+ rgb, ps, obj = 'k1.color.jpg', 'k1.person_mask.png', 'k1.obj_rend_mask.png'
164
+ example_images = gr.Examples([
165
+ [f"{example_dir}/017450/{rgb}", f"{example_dir}/017450/{ps}", f"{example_dir}/017450/{obj}", 3.0, 42],
166
+ [f"{example_dir}/002446/{rgb}", f"{example_dir}/002446/{ps}", f"{example_dir}/002446/{obj}", 3.0, 42],
167
+ [f"{example_dir}/053431/{rgb}", f"{example_dir}/053431/{ps}", f"{example_dir}/053431/{obj}", 3.8, 42],
168
+ [f"{example_dir}/158107/{rgb}", f"{example_dir}/158107/{ps}", f"{example_dir}/158107/{obj}", 3.8, 42],
169
+
170
+ ], inputs=[input_rgb, input_mask_hum, input_mask_obj, input_std, input_seed],)
171
+
172
+ # demo.launch(share=True)
173
+ # Enabling queue for runtime>60s, see: https://github.com/tloen/alpaca-lora/issues/60#issuecomment-1510006062
174
+ demo.queue(concurrency_count=3).launch(share=True)
175
+
176
+ if __name__ == '__main__':
177
+ main()
configs/__init__.py ADDED
File without changes
configs/structured.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass, field
3
+ from typing import Any, Dict, List, Optional, Iterable
4
+ import os.path as osp
5
+
6
+ from hydra.core.config_store import ConfigStore
7
+ from hydra.conf import RunDir
8
+
9
+
10
+ @dataclass
11
+ class CustomHydraRunDir(RunDir):
12
+ dir: str = './outputs/${run.name}/single'
13
+
14
+
15
+ @dataclass
16
+ class RunConfig:
17
+ name: str = 'debug'
18
+ job: str = 'train'
19
+ mixed_precision: str = 'fp16' # 'no'
20
+ cpu: bool = False
21
+ seed: int = 42
22
+ val_before_training: bool = True
23
+ vis_before_training: bool = True
24
+ limit_train_batches: Optional[int] = None
25
+ limit_val_batches: Optional[int] = None
26
+ max_steps: int = 100_000
27
+ checkpoint_freq: int = 1_000
28
+ val_freq: int = 5_000
29
+ vis_freq: int = 5_000
30
+ # vis_freq: int = 10_000
31
+ log_step_freq: int = 20
32
+ print_step_freq: int = 100
33
+
34
+ # config to run demo
35
+ stage1_name: str = 'stage1' # experiment name to the stage 1 model
36
+ stage2_name: str = 'stage2' # experiment name to the stage 2 model
37
+ image_path: str = '' # the path to the images for running demo, can be a single file or a glob pattern
38
+
39
+ # abs path to working dir
40
+ code_dir_abs: str = osp.dirname(osp.dirname(osp.abspath(__file__)))
41
+
42
+ # Inference configs
43
+ num_inference_steps: int = 1000
44
+ diffusion_scheduler: Optional[str] = 'ddpm'
45
+ num_samples: int = 1
46
+ # num_sample_batches: Optional[int] = None
47
+ num_sample_batches: Optional[int] = 2000 # XH: change to 2
48
+ sample_from_ema: bool = False
49
+ sample_save_evolutions: bool = False # temporarily set by default
50
+ save_name: str = 'sample' # XH: additional save name
51
+ redo: bool = False
52
+
53
+ # for parallel sampling in slurm
54
+ batch_start: int = 0
55
+ batch_end: Optional[int] = None
56
+
57
+ # Training configs
58
+ freeze_feature_model: bool = True
59
+
60
+ # Coloring training configs
61
+ coloring_training_noise_std: float = 0.0
62
+ coloring_sample_dir: Optional[str] = None
63
+
64
+ sample_mode: str = 'sample' # whether from noise or from some intermediate steps
65
+ sample_noise_step: int = 500 # add noise to GT up to some steps, and then denoise
66
+ sample_save_gt: bool = True
67
+
68
+
69
+ @dataclass
70
+ class LoggingConfig:
71
+ wandb: bool = True
72
+ wandb_project: str = 'pc2'
73
+
74
+
75
+
76
+ @dataclass
77
+ class PointCloudProjectionModelConfig:
78
+ # Feature extraction arguments
79
+ image_size: int = '${dataset.image_size}'
80
+ image_feature_model: str = 'vit_base_patch16_224_mae' # or 'vit_small_patch16_224_msn' or 'identity'
81
+ use_local_colors: bool = True
82
+ use_local_features: bool = True
83
+ use_global_features: bool = False
84
+ use_mask: bool = True
85
+ use_distance_transform: bool = True
86
+
87
+ # Point cloud data arguments. Note these are here because the processing happens
88
+ # inside the model, rather than inside the dataset.
89
+ scale_factor: float = "${dataset.scale_factor}"
90
+ colors_mean: float = 0.5
91
+ colors_std: float = 0.5
92
+ color_channels: int = 3
93
+ predict_shape: bool = True
94
+ predict_color: bool = False
95
+
96
+ # added by XH
97
+ load_sample_init: bool = False # load init samples from file
98
+ sample_init_scale: float = 1.0 # scale the initial pc samples
99
+ test_init_with_gtpc: bool = False # test time init samples with GT samples
100
+ consistent_center: bool = True # use consistent center prediction by CCD-3DR
101
+ voxel_resolution_multiplier: float = 1 # increase network voxel resolution
102
+
103
+ # predict binary segmentation
104
+ predict_binary: bool = False # True for stage 1 model, False for others
105
+ lw_binary: float = 3.0 # to have roughly the same magnitude of the binary segmentation loss
106
+ # for separate model
107
+ binary_training_noise_std: float = 0.1 # from github doc for predicting color
108
+ self_conditioning: bool = False
109
+
110
+ @dataclass
111
+ class PVCNNAEModelConfig(PointCloudProjectionModelConfig):
112
+ "my own model config, must inherit parent class"
113
+ model_name: str = 'pvcnn-ae'
114
+ latent_dim: int = 1024
115
+ num_dec_blocks: int = 6
116
+ block_dims: List[int] = field(default_factory=lambda: [512, 256])
117
+ num_points: int = 1500
118
+ bottleneck_dim: int = -1 # the input dim to the last MLP layer
119
+
120
+ @dataclass
121
+ class PointCloudDiffusionModelConfig(PointCloudProjectionModelConfig):
122
+ model_name: str = 'pc2-diff-ho' # default as behave
123
+
124
+ # Diffusion arguments
125
+ beta_start: float = 1e-5 # 0.00085
126
+ beta_end: float = 8e-3 # 0.012
127
+ beta_schedule: str = 'linear' # 'custom'
128
+ dm_pred_type: str = 'epsilon' # diffusion model prediction type, sample (x0) or noise
129
+
130
+ # Point cloud model arguments
131
+ point_cloud_model: str = 'pvcnn'
132
+ point_cloud_model_embed_dim: int = 64
133
+
134
+ dataset_type: str = '${dataset.type}'
135
+
136
+ @dataclass
137
+ class CrossAttnHOModelConfig(PointCloudDiffusionModelConfig):
138
+ model_name: str = 'diff-ho-attn'
139
+
140
+ attn_type: str = 'coord3d+posenc-learnable'
141
+ attn_weight: float = 1.0
142
+ point_visible_test: str = 'combine' # To compute point visibility: use all points or only human/object points
143
+
144
+
145
+ @dataclass
146
+ class DirectTransModelConfig(PointCloudProjectionModelConfig):
147
+ model_name: str = 'direct-transl-ho'
148
+
149
+ pooling: str = "avg"
150
+ act: str = 'gelu'
151
+ out_act: str = 'relu'
152
+ # feat_dims_transl: Iterable[Any] = (384, 256, 128, 6) # cannot use List[int] https://github.com/facebookresearch/hydra/issues/1752#issuecomment-893174197
153
+ # feat_dims_scale: Iterable[Any] = (384, 128, 64, 2)
154
+ feat_dims_transl: List[int] = field(default_factory=lambda: [384, 256, 128, 6])
155
+ feat_dims_scale: List[int] = field(default_factory=lambda: [384, 128, 64, 2])
156
+ lw_transl: float = 10000.0
157
+ lw_scale: float = 10000.0
158
+
159
+
160
+ @dataclass
161
+ class PointCloudColoringModelConfig(PointCloudProjectionModelConfig):
162
+ # Projection arguments
163
+ predict_shape: bool = False
164
+ predict_color: bool = True
165
+
166
+ # Point cloud model arguments
167
+ point_cloud_model: str = 'pvcnn'
168
+ point_cloud_model_layers: int = 1
169
+ point_cloud_model_embed_dim: int = 64
170
+
171
+
172
+ @dataclass
173
+ class DatasetConfig:
174
+ type: str
175
+
176
+
177
+ @dataclass
178
+ class PointCloudDatasetConfig(DatasetConfig):
179
+ eval_split: str = 'val'
180
+ max_points: int = 16_384
181
+ image_size: int = 224
182
+ scale_factor: float = 1.0
183
+ restrict_model_ids: Optional[List] = None # for only running on a subset of data points
184
+
185
+
186
+ @dataclass
187
+ class CO3DConfig(PointCloudDatasetConfig):
188
+ type: str = 'co3dv2'
189
+ # root: str = os.getenv('CO3DV2_DATASET_ROOT')
190
+ root: str = "/BS/xxie-2/work/co3d/hydrant"
191
+ category: str = 'hydrant'
192
+ subset_name: str = 'fewview_dev'
193
+ mask_images: bool = '${model.use_mask}'
194
+
195
+
196
+ @dataclass
197
+ class ShapeNetR2N2Config(PointCloudDatasetConfig):
198
+ # added by XH
199
+ fix_sample: bool = True
200
+ category: str = 'chair'
201
+
202
+ type: str = 'shapenet_r2n2'
203
+ root: str = "/BS/chiban2/work/data_shapenet/ShapeNetCore.v1"
204
+ r2n2_dir: str = "/BS/databases20/3d-r2n2"
205
+ shapenet_dir: str = "/BS/chiban2/work/data_shapenet/ShapeNetCore.v1"
206
+ preprocessed_r2n2_dir: str = "${dataset.root}/r2n2_preprocessed_renders"
207
+ splits_file: str = "${dataset.root}/r2n2_standard_splits_from_ShapeNet_taxonomy.json"
208
+ # splits_file: str = "${dataset.root}/pix2mesh_splits_val05.json" # <-- incorrect
209
+ scale_factor: float = 7.0
210
+ point_cloud_filename: str = 'pointcloud_r2n2.npz' # should use 'pointcloud_mesh.npz'
211
+
212
+
213
+
214
+ @dataclass
215
+ class BehaveDatasetConfig(PointCloudDatasetConfig):
216
+ # added by XH
217
+ type: str = 'behave'
218
+
219
+ fix_sample: bool = True
220
+ behave_dir: str = "/BS/xxie-5/static00/behave_release/sequences/"
221
+ split_file: str = "" # specify you dataset split file here
222
+ scale_factor: float = 7.0 # use the same as shapenet
223
+ sample_ratio_hum: float = 0.5
224
+ image_size: int = 224
225
+
226
+ normalize_type: str = 'comb'
227
+ smpl_type: str = 'gt' # use which SMPL mesh to obtain normalization parameters
228
+ test_transl_type: str = 'norm'
229
+
230
+ load_corr_points: bool = False # load autoencoder points for object and SMPL
231
+ uniform_obj_sample: bool = False
232
+
233
+ # configs for direct translation prediction
234
+ bkg_type: str = 'none'
235
+ bbox_params: str = 'none'
236
+ ho_segm_pred_path: Optional[str] = None
237
+ use_gt_transl: bool = False
238
+
239
+ cam_noise_std: float = 0. # add noise to the camera pose
240
+ sep_same_crop: bool = False # use same input image crop to separate models
241
+ aug_blur: float = 0. # blur augmentation
242
+
243
+ std_coverage: float=3.5 # a heuristic value to estimate translation
244
+
245
+ v2v_path: str = '' # object v2v corr path
246
+
247
+ @dataclass
248
+ class ShapeDatasetConfig(BehaveDatasetConfig):
249
+ "the dataset to train AE for aligned shapes"
250
+ type: str = 'shape'
251
+ fix_sample: bool = False
252
+ split_file: str = "/BS/xxie-2/work/pc2-diff/experiments/splits/shapes-chair.pkl"
253
+
254
+
255
+ # TODO
256
+ @dataclass
257
+ class ShapeNetNMRConfig(PointCloudDatasetConfig):
258
+ type: str = 'shapenet_nmr'
259
+ shapenet_nmr_dir: str = "/work/lukemk/machine-learning-datasets/3d-reconstruction/ShapeNet_NMR/NMR_Dataset"
260
+ synset_names: str = 'chair' # comma-separated or 'all'
261
+ augmentation: str = 'all'
262
+ scale_factor: float = 7.0
263
+
264
+
265
+ @dataclass
266
+ class AugmentationConfig:
267
+ # need to specify the variable type in order to define it properly
268
+ max_radius: int = 0 # generate a random square to mask object, this is the radius for the square in pixel size, zero means no occlusion
269
+
270
+
271
+ @dataclass
272
+ class DataloaderConfig:
273
+ # batch_size: int = 8 # 2 for debug
274
+ batch_size: int = 16
275
+ num_workers: int = 14 # 0 for debug # suggested by accelerator for gpu20
276
+
277
+
278
+ @dataclass
279
+ class LossConfig:
280
+ diffusion_weight: float = 1.0
281
+ rgb_weight: float = 1.0
282
+ consistency_weight: float = 1.0
283
+
284
+
285
+ @dataclass
286
+ class CheckpointConfig:
287
+ resume: Optional[str] = "test"
288
+ resume_training: bool = True
289
+ resume_training_optimizer: bool = True
290
+ resume_training_scheduler: bool = True
291
+ resume_training_state: bool = True
292
+
293
+
294
+ @dataclass
295
+ class ExponentialMovingAverageConfig:
296
+ use_ema: bool = False
297
+ # # From Diffusers EMA (should probably switch)
298
+ # ema_inv_gamma: float = 1.0
299
+ # ema_power: float = 0.75
300
+ # ema_max_decay: float = 0.9999
301
+ decay: float = 0.999
302
+ update_every: int = 20
303
+
304
+
305
+ @dataclass
306
+ class OptimizerConfig:
307
+ type: str
308
+ name: str
309
+ lr: float = 3e-4
310
+ weight_decay: float = 0.0
311
+ scale_learning_rate_with_batch_size: bool = False
312
+ gradient_accumulation_steps: int = 1
313
+ clip_grad_norm: Optional[float] = 50.0 # 5.0
314
+ kwargs: Dict = field(default_factory=lambda: dict())
315
+
316
+
317
+ @dataclass
318
+ class AdadeltaOptimizerConfig(OptimizerConfig):
319
+ type: str = 'torch'
320
+ name: str = 'Adadelta'
321
+ kwargs: Dict = field(default_factory=lambda: dict(
322
+ weight_decay=1e-6,
323
+ ))
324
+
325
+
326
+ @dataclass
327
+ class AdamOptimizerConfig(OptimizerConfig):
328
+ type: str = 'torch'
329
+ name: str = 'AdamW'
330
+ weight_decay: float = 1e-6
331
+ kwargs: Dict = field(default_factory=lambda: dict(betas=(0.95, 0.999)))
332
+
333
+
334
+ @dataclass
335
+ class SchedulerConfig:
336
+ type: str
337
+ kwargs: Dict = field(default_factory=lambda: dict())
338
+
339
+
340
+ @dataclass
341
+ class LinearSchedulerConfig(SchedulerConfig):
342
+ type: str = 'transformers'
343
+ kwargs: Dict = field(default_factory=lambda: dict(
344
+ name='linear',
345
+ num_warmup_steps=0,
346
+ num_training_steps="${run.max_steps}",
347
+ ))
348
+
349
+
350
+ @dataclass
351
+ class CosineSchedulerConfig(SchedulerConfig):
352
+ type: str = 'transformers'
353
+ kwargs: Dict = field(default_factory=lambda: dict(
354
+ name='cosine',
355
+ num_warmup_steps=2000, # 0
356
+ num_training_steps="${run.max_steps}",
357
+ ))
358
+
359
+
360
+ @dataclass
361
+ class ProjectConfig:
362
+ run: RunConfig
363
+ logging: LoggingConfig
364
+ dataset: PointCloudDatasetConfig
365
+ augmentations: AugmentationConfig
366
+ dataloader: DataloaderConfig
367
+ loss: LossConfig
368
+ model: PointCloudProjectionModelConfig
369
+ ema: ExponentialMovingAverageConfig
370
+ checkpoint: CheckpointConfig
371
+ optimizer: OptimizerConfig
372
+ scheduler: SchedulerConfig
373
+
374
+ defaults: List[Any] = field(default_factory=lambda: [
375
+ 'custom_hydra_run_dir',
376
+ {'run': 'default'},
377
+ {'logging': 'default'},
378
+ {'model': 'ho-attn'},
379
+ # {'dataset': 'co3d'},
380
+ {'dataset': 'behave'},
381
+ {'augmentations': 'default'},
382
+ {'dataloader': 'default'},
383
+ {'ema': 'default'},
384
+ {'loss': 'default'},
385
+ {'checkpoint': 'default'},
386
+ {'optimizer': 'adam'}, # default adamw
387
+ {'scheduler': 'linear'},
388
+ # {'scheduler': 'cosine'},
389
+ ])
390
+
391
+
392
+ cs = ConfigStore.instance()
393
+ cs.store(name='custom_hydra_run_dir', node=CustomHydraRunDir, package="hydra.run")
394
+ cs.store(group='run', name='default', node=RunConfig)
395
+ cs.store(group='logging', name='default', node=LoggingConfig)
396
+ cs.store(group='model', name='diffrec', node=PointCloudDiffusionModelConfig)
397
+ cs.store(group='model', name='coloring_model', node=PointCloudColoringModelConfig)
398
+ cs.store(group='model', name='direct-transl', node=DirectTransModelConfig)
399
+ cs.store(group='model', name='ho-attn', node=CrossAttnHOModelConfig)
400
+ cs.store(group='model', name='pvcnn-ae', node=PVCNNAEModelConfig)
401
+ cs.store(group='dataset', name='co3d', node=CO3DConfig)
402
+ # TODO
403
+ cs.store(group='dataset', name='shapenet_r2n2', node=ShapeNetR2N2Config)
404
+ cs.store(group='dataset', name='behave', node=BehaveDatasetConfig)
405
+ cs.store(group='dataset', name='shape', node=ShapeDatasetConfig)
406
+ # cs.store(group='dataset', name='shapenet_nmr', node=ShapeNetNMRConfig)
407
+ cs.store(group='augmentations', name='default', node=AugmentationConfig)
408
+ cs.store(group='dataloader', name='default', node=DataloaderConfig)
409
+ cs.store(group='loss', name='default', node=LossConfig)
410
+ cs.store(group='ema', name='default', node=ExponentialMovingAverageConfig)
411
+ cs.store(group='checkpoint', name='default', node=CheckpointConfig)
412
+ cs.store(group='optimizer', name='adadelta', node=AdadeltaOptimizerConfig)
413
+ cs.store(group='optimizer', name='adam', node=AdamOptimizerConfig)
414
+ cs.store(group='scheduler', name='linear', node=LinearSchedulerConfig)
415
+ cs.store(group='scheduler', name='cosine', node=CosineSchedulerConfig)
416
+ cs.store(name='configs', node=ProjectConfig)
dataset/__init__.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import pytorch3d
6
+ import torch
7
+ from torch.utils.data import SequentialSampler
8
+ from omegaconf import DictConfig
9
+ from pytorch3d.implicitron.dataset.data_loader_map_provider import \
10
+ SequenceDataLoaderMapProvider
11
+ from pytorch3d.implicitron.dataset.dataset_base import FrameData
12
+ from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
13
+ from pytorch3d.implicitron.dataset.json_index_dataset_map_provider_v2 import (
14
+ JsonIndexDatasetMapProviderV2, registry)
15
+ from pytorch3d.implicitron.tools.config import expand_args_fields
16
+ from pytorch3d.renderer.cameras import CamerasBase
17
+ from torch.utils.data import DataLoader
18
+
19
+ from configs.structured import CO3DConfig, DataloaderConfig, ProjectConfig, Optional
20
+ from .exclude_sequence import EXCLUDE_SEQUENCE, LOW_QUALITY_SEQUENCE
21
+ from .utils import DatasetMap
22
+ from .r2n2_my import R2N2Sample, collate_batched_meshes
23
+
24
+
25
+ def get_dataset(cfg: ProjectConfig):
26
+
27
+ if cfg.dataset.type == 'co3dv2':
28
+ dataset_cfg: CO3DConfig = cfg.dataset
29
+ dataloader_cfg: DataloaderConfig = cfg.dataloader
30
+
31
+ # Exclude bad and low-quality sequences, XH: why this is needed?
32
+ exclude_sequence = []
33
+ exclude_sequence.extend(EXCLUDE_SEQUENCE.get(dataset_cfg.category, []))
34
+ exclude_sequence.extend(LOW_QUALITY_SEQUENCE.get(dataset_cfg.category, []))
35
+
36
+ # Whether to load pointclouds
37
+ kwargs = dict(
38
+ remove_empty_masks=True,
39
+ n_frames_per_sequence=1,
40
+ load_point_clouds=True,
41
+ max_points=dataset_cfg.max_points,
42
+ image_height=dataset_cfg.image_size,
43
+ image_width=dataset_cfg.image_size,
44
+ mask_images=dataset_cfg.mask_images,
45
+ exclude_sequence=exclude_sequence,
46
+ pick_sequence=() if dataset_cfg.restrict_model_ids is None else dataset_cfg.restrict_model_ids,
47
+ )
48
+
49
+ # Get dataset mapper
50
+ dataset_map_provider_type = registry.get(JsonIndexDatasetMapProviderV2, "JsonIndexDatasetMapProviderV2")
51
+ expand_args_fields(dataset_map_provider_type)
52
+ dataset_map_provider = dataset_map_provider_type(
53
+ category=dataset_cfg.category,
54
+ subset_name=dataset_cfg.subset_name,
55
+ dataset_root=dataset_cfg.root,
56
+ test_on_train=False,
57
+ only_test_set=False,
58
+ load_eval_batches=True,
59
+ dataset_JsonIndexDataset_args=DictConfig(kwargs),
60
+ )
61
+
62
+ # Get datasets
63
+ datasets = dataset_map_provider.get_dataset_map() # how to select specific frames??
64
+
65
+ # PATCH BUG WITH POINT CLOUD LOCATIONS!
66
+ for dataset in (datasets["train"], datasets["val"]):
67
+ # print(dataset.seq_annots.items())
68
+ for key, ann in dataset.seq_annots.items():
69
+ correct_point_cloud_path = Path(dataset.dataset_root) / Path(*Path(ann.point_cloud.path).parts[-3:])
70
+ assert correct_point_cloud_path.is_file(), correct_point_cloud_path
71
+ ann.point_cloud.path = str(correct_point_cloud_path)
72
+
73
+ # Get dataloader mapper
74
+ data_loader_map_provider_type = registry.get(SequenceDataLoaderMapProvider, "SequenceDataLoaderMapProvider")
75
+ expand_args_fields(data_loader_map_provider_type)
76
+ data_loader_map_provider = data_loader_map_provider_type(
77
+ batch_size=dataloader_cfg.batch_size,
78
+ num_workers=dataloader_cfg.num_workers,
79
+ )
80
+
81
+ # QUICK HACK: Patch the train dataset because it is not used but it throws an error
82
+ if (len(datasets['train']) == 0 and len(datasets[dataset_cfg.eval_split]) > 0 and
83
+ dataset_cfg.restrict_model_ids is not None and cfg.run.job == 'sample'):
84
+ datasets = DatasetMap(train=datasets[dataset_cfg.eval_split], val=datasets[dataset_cfg.eval_split],
85
+ test=datasets[dataset_cfg.eval_split])
86
+ # XH: why all eval split?
87
+ print('Note: You used restrict_model_ids and there were no ids in the train set.')
88
+
89
+ # Get dataloaders
90
+ dataloaders = data_loader_map_provider.get_data_loader_map(datasets)
91
+ dataloader_train = dataloaders['train']
92
+ dataloader_val = dataloader_vis = dataloaders[dataset_cfg.eval_split]
93
+
94
+ # Replace validation dataloader sampler with SequentialSampler
95
+ # seems to be randomly sampled? with a fixed random seed? but one cannot control which image is being sampled??
96
+ dataloader_val.batch_sampler.sampler = SequentialSampler(dataloader_val.batch_sampler.sampler.data_source)
97
+
98
+ # Modify for accelerate
99
+ dataloader_train.batch_sampler.drop_last = True
100
+ dataloader_val.batch_sampler.drop_last = False
101
+ elif cfg.dataset.type == 'shapenet_r2n2':
102
+ # from ..configs.structured import ShapeNetR2N2Config
103
+ dataset_cfg: ShapeNetR2N2Config = cfg.dataset
104
+ # for k in dataset_cfg:
105
+ # print(k)
106
+ datasets = [R2N2Sample(dataset_cfg.max_points, dataset_cfg.fix_sample,
107
+ dataset_cfg.image_size, cfg.augmentations,
108
+ s, dataset_cfg.shapenet_dir,
109
+ dataset_cfg.r2n2_dir, dataset_cfg.splits_file,
110
+ load_textures=False, return_all_views=True) for s in ['train', 'val', 'test']]
111
+ dataloader_train = DataLoader(datasets[0], batch_size=cfg.dataloader.batch_size,
112
+ collate_fn=collate_batched_meshes,
113
+ num_workers=cfg.dataloader.num_workers, shuffle=True)
114
+ dataloader_val = DataLoader(datasets[1], batch_size=cfg.dataloader.batch_size,
115
+ collate_fn=collate_batched_meshes,
116
+ num_workers=cfg.dataloader.num_workers, shuffle=False)
117
+ dataloader_vis = DataLoader(datasets[2], batch_size=cfg.dataloader.batch_size,
118
+ collate_fn=collate_batched_meshes,
119
+ num_workers=cfg.dataloader.num_workers, shuffle=False)
120
+
121
+ elif cfg.dataset.type in ['behave', 'behave-objonly', 'behave-humonly', 'behave-dtransl',
122
+ 'behave-objonly-segm', 'behave-humonly-segm', 'behave-attn',
123
+ 'behave-test', 'behave-attn-test', 'behave-hum-pe', 'behave-hum-noscale',
124
+ 'behave-hum-surf', 'behave-objv2v']:
125
+ from .behave_dataset import BehaveDataset, NTUDataset, BehaveObjOnly, BehaveHumanOnly, BehaveHumanOnlyPosEnc
126
+ from .behave_dataset import BehaveHumanOnlySegmInput, BehaveObjOnlySegmInput, BehaveTestOnly, BehaveHumNoscale
127
+ from .behave_dataset import BehaveHumanOnlySurfSample
128
+ from .dtransl_dataset import DirectTranslDataset
129
+ from .behave_paths import DataPaths
130
+ from configs.structured import BehaveDatasetConfig
131
+ from .behave_crossattn import BehaveCrossAttnDataset, BehaveCrossAttnTest
132
+ from .behave_dataset import BehaveObjOnlyV2V
133
+
134
+ dataset_cfg: BehaveDatasetConfig = cfg.dataset
135
+ # print(dataset_cfg.behave_dir)
136
+ train_paths, val_paths = DataPaths.load_splits(dataset_cfg.split_file, dataset_cfg.behave_dir)
137
+ # exit(0)
138
+
139
+ # split validation paths to only consider the selected batches
140
+ bs = cfg.dataloader.batch_size
141
+ num_batches_total = int(np.ceil(len(val_paths)/cfg.dataloader.batch_size))
142
+ end_idx = cfg.run.batch_end if cfg.run.batch_end is not None else num_batches_total
143
+ # print(cfg.run.batch_end, cfg.run.batch_start, end_idx)
144
+ val_paths = val_paths[cfg.run.batch_start*bs:end_idx*bs]
145
+
146
+ if cfg.dataset.type == 'behave':
147
+ train_type = BehaveDataset
148
+ val_datatype = BehaveDataset if 'ntu' not in dataset_cfg.split_file else NTUDataset
149
+ elif cfg.dataset.type == 'behave-test':
150
+ train_type = BehaveDataset
151
+ val_datatype = BehaveTestOnly
152
+ elif cfg.dataset.type == 'behave-objonly':
153
+ train_type = BehaveObjOnly
154
+ val_datatype = BehaveObjOnly
155
+ assert 'ntu' not in dataset_cfg.split_file, 'ntu not implemented!'
156
+ elif cfg.dataset.type == 'behave-humonly':
157
+ train_type = BehaveHumanOnly
158
+ val_datatype = BehaveHumanOnly
159
+ assert 'ntu' not in dataset_cfg.split_file, 'ntu not implemented!'
160
+ elif cfg.dataset.type == 'behave-hum-noscale':
161
+ train_type = BehaveHumNoscale
162
+ val_datatype = BehaveHumNoscale
163
+ elif cfg.dataset.type == 'behave-hum-pe':
164
+ train_type = BehaveHumanOnlyPosEnc
165
+ val_datatype = BehaveHumanOnlyPosEnc
166
+ elif cfg.dataset.type == 'behave-hum-surf':
167
+ train_type = BehaveHumanOnlySurfSample
168
+ val_datatype = BehaveHumanOnlySurfSample
169
+ elif cfg.dataset.type == 'behave-humonly-segm':
170
+ assert cfg.dataset.ho_segm_pred_path is not None, 'please specify predicted HO segmentation!'
171
+ train_type = BehaveHumanOnly
172
+ val_datatype = BehaveHumanOnlySegmInput
173
+ assert 'ntu' not in dataset_cfg.split_file, 'ntu not implemented!'
174
+ elif cfg.dataset.type == 'behave-objonly-segm':
175
+ assert cfg.dataset.ho_segm_pred_path is not None, 'please specify predicted HO segmentation!'
176
+ train_type = BehaveObjOnly
177
+ val_datatype = BehaveObjOnlySegmInput
178
+ assert 'ntu' not in dataset_cfg.split_file, 'ntu not implemented!'
179
+ elif cfg.dataset.type == 'behave-dtransl':
180
+ train_type = DirectTranslDataset
181
+ val_datatype = DirectTranslDataset
182
+ elif cfg.dataset.type == 'behave-attn':
183
+ train_type = BehaveCrossAttnDataset
184
+ val_datatype = BehaveCrossAttnDataset
185
+ elif cfg.dataset.type == 'behave-attn-test':
186
+ train_type = BehaveCrossAttnDataset
187
+ val_datatype = BehaveCrossAttnTest
188
+ elif cfg.dataset.type == 'behave-objv2v':
189
+ train_type = BehaveObjOnlyV2V
190
+ val_datatype = BehaveObjOnlyV2V
191
+ else:
192
+ raise NotImplementedError
193
+
194
+ dataset_train = train_type(train_paths, dataset_cfg.max_points, dataset_cfg.fix_sample,
195
+ (dataset_cfg.image_size, dataset_cfg.image_size),
196
+ split='train', sample_ratio_hum=dataset_cfg.sample_ratio_hum,
197
+ normalize_type=dataset_cfg.normalize_type, smpl_type='gt',
198
+ load_corr_points=dataset_cfg.load_corr_points,
199
+ uniform_obj_sample=dataset_cfg.uniform_obj_sample,
200
+ bkg_type=dataset_cfg.bkg_type,
201
+ bbox_params=dataset_cfg.bbox_params,
202
+ pred_binary=cfg.model.predict_binary,
203
+ ho_segm_pred_path=cfg.dataset.ho_segm_pred_path,
204
+ compute_closest_points=cfg.model.model_name=='pc2-diff-ho-tune-newloss',
205
+ use_gt_transl=cfg.dataset.use_gt_transl,
206
+ cam_noise_std=cfg.dataset.cam_noise_std,
207
+ sep_same_crop=cfg.dataset.sep_same_crop,
208
+ aug_blur=cfg.dataset.aug_blur,
209
+ std_coverage=cfg.dataset.std_coverage,
210
+ v2v_path=cfg.dataset.v2v_path)
211
+
212
+ dataset_val = val_datatype(val_paths, dataset_cfg.max_points, dataset_cfg.fix_sample,
213
+ (dataset_cfg.image_size, dataset_cfg.image_size),
214
+ split='val', sample_ratio_hum=dataset_cfg.sample_ratio_hum,
215
+ normalize_type=dataset_cfg.normalize_type, smpl_type=dataset_cfg.smpl_type,
216
+ load_corr_points=dataset_cfg.load_corr_points,
217
+ test_transl_type=dataset_cfg.test_transl_type,
218
+ uniform_obj_sample=dataset_cfg.uniform_obj_sample,
219
+ bkg_type=dataset_cfg.bkg_type,
220
+ bbox_params=dataset_cfg.bbox_params,
221
+ pred_binary=cfg.model.predict_binary,
222
+ ho_segm_pred_path=cfg.dataset.ho_segm_pred_path,
223
+ compute_closest_points=cfg.model.model_name=='pc2-diff-ho-tune-newloss',
224
+ use_gt_transl=cfg.dataset.use_gt_transl,
225
+ sep_same_crop=cfg.dataset.sep_same_crop,
226
+ std_coverage=cfg.dataset.std_coverage,
227
+ v2v_path=cfg.dataset.v2v_path)
228
+ # dataset_test = val_datatype(val_paths, dataset_cfg.max_points, dataset_cfg.fix_sample,
229
+ # (dataset_cfg.image_size, dataset_cfg.image_size),
230
+ # split='test', sample_ratio_hum=dataset_cfg.sample_ratio_hum,
231
+ # normalize_type=dataset_cfg.normalize_type, smpl_type=dataset_cfg.smpl_type,
232
+ # load_corr_points=dataset_cfg.load_corr_points,
233
+ # test_transl_type=dataset_cfg.test_transl_type,
234
+ # uniform_obj_sample=dataset_cfg.uniform_obj_sample,
235
+ # bkg_type=dataset_cfg.bkg_type,
236
+ # bbox_params=dataset_cfg.bbox_params,
237
+ # pred_binary=cfg.model.predict_binary,
238
+ # ho_segm_pred_path=cfg.dataset.ho_segm_pred_path,
239
+ # compute_closest_points=cfg.model.model_name=='pc2-diff-ho-tune-newloss',
240
+ # use_gt_transl=cfg.dataset.use_gt_transl,
241
+ # sep_same_crop=cfg.dataset.sep_same_crop)
242
+ dataloader_train = DataLoader(dataset_train, batch_size=cfg.dataloader.batch_size,
243
+ collate_fn=collate_batched_meshes,
244
+ num_workers=cfg.dataloader.num_workers, shuffle=True)
245
+ shuffle = cfg.run.job == 'train'
246
+ dataloader_val = DataLoader(dataset_val, batch_size=cfg.dataloader.batch_size,
247
+ collate_fn=collate_batched_meshes,
248
+ num_workers=cfg.dataloader.num_workers, shuffle=shuffle)
249
+ dataloader_vis = DataLoader(dataset_val, batch_size=cfg.dataloader.batch_size,
250
+ collate_fn=collate_batched_meshes,
251
+ num_workers=cfg.dataloader.num_workers, shuffle=shuffle)
252
+
253
+ # datasets = [BehaveDataset(p, dataset_cfg.max_points, dataset_cfg.fix_sample,
254
+ # (dataset_cfg.image_size, dataset_cfg.image_size),
255
+ # split=s, sample_ratio_hum=dataset_cfg.sample_ratio_hum,
256
+ # normalize_type=dataset_cfg.normalize_type) for p, s in zip([train_paths, val_paths, val_paths],
257
+ # ['train', 'val', 'test'])]
258
+ # dataloader_train = DataLoader(datasets[0], batch_size=cfg.dataloader.batch_size,
259
+ # collate_fn=collate_batched_meshes,
260
+ # num_workers=cfg.dataloader.num_workers, shuffle=True)
261
+ # dataloader_val = DataLoader(datasets[1], batch_size=cfg.dataloader.batch_size,
262
+ # collate_fn=collate_batched_meshes,
263
+ # num_workers=cfg.dataloader.num_workers, shuffle=False)
264
+ # dataloader_vis = DataLoader(datasets[2], batch_size=cfg.dataloader.batch_size,
265
+ # collate_fn=collate_batched_meshes,
266
+ # num_workers=cfg.dataloader.num_workers, shuffle=False)
267
+ elif cfg.dataset.type in ['shape']:
268
+ from .shape_dataset import ShapeDataset
269
+ from .behave_paths import DataPaths
270
+ from configs.structured import ShapeDatasetConfig
271
+ dataset_cfg: ShapeDatasetConfig = cfg.dataset
272
+
273
+ train_paths, _ = DataPaths.load_splits(dataset_cfg.split_file, dataset_cfg.behave_dir)
274
+ val_paths = train_paths # same as training, this is for overfitting
275
+ # split validation paths to only consider the selected batches
276
+ bs = cfg.dataloader.batch_size
277
+ num_batches_total = int(np.ceil(len(val_paths) / cfg.dataloader.batch_size))
278
+ end_idx = cfg.run.batch_end if cfg.run.batch_end is not None else num_batches_total
279
+ # print(cfg.run.batch_end, cfg.run.batch_start, end_idx)
280
+ val_paths = val_paths[cfg.run.batch_start * bs:end_idx * bs]
281
+
282
+ dataset_train = ShapeDataset(train_paths, dataset_cfg.max_points, dataset_cfg.fix_sample,
283
+ (dataset_cfg.image_size, dataset_cfg.image_size),
284
+ split='train', )
285
+ dataset_val = ShapeDataset(val_paths, dataset_cfg.max_points, dataset_cfg.fix_sample,
286
+ (dataset_cfg.image_size, dataset_cfg.image_size),
287
+ split='train', )
288
+ dataloader_train = DataLoader(dataset_train, batch_size=cfg.dataloader.batch_size,
289
+ collate_fn=collate_batched_meshes,
290
+ num_workers=cfg.dataloader.num_workers, shuffle=True)
291
+ shuffle = cfg.run.job == 'train'
292
+ dataloader_val = DataLoader(dataset_val, batch_size=cfg.dataloader.batch_size,
293
+ collate_fn=collate_batched_meshes,
294
+ num_workers=cfg.dataloader.num_workers, shuffle=shuffle)
295
+ dataloader_vis = DataLoader(dataset_val, batch_size=cfg.dataloader.batch_size,
296
+ collate_fn=collate_batched_meshes,
297
+ num_workers=cfg.dataloader.num_workers, shuffle=shuffle)
298
+ else:
299
+ raise NotImplementedError(cfg.dataset.type)
300
+
301
+ return dataloader_train, dataloader_val, dataloader_vis
dataset/base_data.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import path as osp
2
+
3
+ import cv2
4
+ import numpy as np
5
+ from torch.utils.data import Dataset
6
+
7
+ from dataset.img_utils import masks2bbox, resize, crop
8
+
9
+
10
+ class BaseDataset(Dataset):
11
+ def __init__(self, data_paths, input_size=(224, 224)):
12
+ self.data_paths = data_paths # RGB image files
13
+ self.input_size = input_size
14
+ opencv2py3d = np.eye(4)
15
+ opencv2py3d[0, 0] = opencv2py3d[1, 1] = -1
16
+ self.opencv2py3d = opencv2py3d
17
+
18
+ def __len__(self):
19
+ return len(self.data_paths)
20
+
21
+ def load_masks(self, rgb_file):
22
+ person_mask_file = rgb_file.replace('.color.jpg', ".person_mask.png")
23
+ if not osp.isfile(person_mask_file):
24
+ person_mask_file = rgb_file.replace('.color.jpg', ".person_mask.jpg")
25
+ obj_mask_file = None
26
+ for pat in [".obj_rend_mask.png", ".obj_rend_mask.jpg", ".obj_mask.png", ".obj_mask.jpg", ".object_rend.png"]:
27
+ obj_mask_file = rgb_file.replace('.color.jpg', pat)
28
+ if osp.isfile(obj_mask_file):
29
+ break
30
+ person_mask = cv2.imread(person_mask_file, cv2.IMREAD_GRAYSCALE)
31
+ obj_mask = cv2.imread(obj_mask_file, cv2.IMREAD_GRAYSCALE)
32
+
33
+ return person_mask, obj_mask
34
+
35
+ def get_crop_params(self, mask_hum, mask_obj, bbox_exp=1.0):
36
+ "compute bounding box based on masks"
37
+ bmin, bmax = masks2bbox([mask_hum, mask_obj])
38
+ crop_center = (bmin + bmax) // 2
39
+ # crop_size = np.max(bmax - bmin)
40
+ crop_size = int(np.max(bmax - bmin) * bbox_exp)
41
+ if crop_size % 2 == 1:
42
+ crop_size += 1 # make sure it is an even number
43
+ return bmax, bmin, crop_center, crop_size
44
+
45
+ def is_behave_dataset(self, image_width):
46
+ assert image_width in [2048, 1920, 1024, 960], f'unknwon image width {image_width}!'
47
+ if image_width in [2048, 1024]:
48
+ is_behave = True
49
+ else:
50
+ is_behave = False
51
+ return is_behave
52
+
53
+ def compute_K_roi(self, bbox_square,
54
+ image_width=2048,
55
+ image_height=1536,
56
+ fx=979.7844, fy=979.840,
57
+ cx=1018.952, cy=779.486):
58
+ "return results in ndc coordinate, this is correct!!!"
59
+ x, y, b, w = bbox_square
60
+ assert b == w
61
+ is_behave = self.is_behave_dataset(image_width)
62
+
63
+ if is_behave:
64
+ assert image_height / image_width == 0.75, f"invalid image aspect ratio: width={image_width}, height={image_height}"
65
+ # the image might be rendered at different size
66
+ ratio = image_width/2048.
67
+ fx, fy = 979.7844*ratio, 979.840*ratio
68
+ cx, cy = 1018.952*ratio, 779.486*ratio
69
+ else:
70
+ assert image_height / image_width == 9/16, f"invalid image aspect ratio: width={image_width}, height={image_height}"
71
+ # intercap camera
72
+ ratio = image_width/1920
73
+ fx, fy = 918.457763671875*ratio, 918.4373779296875*ratio
74
+ cx, cy = 956.9661865234375*ratio, 555.944580078125*ratio
75
+
76
+ cx, cy = cx - x, cy - y
77
+ scale = b/2.
78
+ # in ndc
79
+ cx_ = (scale - cx)/scale
80
+ cy_ = (scale - cy)/scale
81
+ fx_ = fx/scale
82
+ fy_ = fy/scale
83
+
84
+ K_roi = np.array([
85
+ [fx_, 0, cx_, 0],
86
+ [0., fy_, cy_, 0, ],
87
+ [0, 0, 0, 1.],
88
+ [0, 0, 1, 0]
89
+ ])
90
+ return K_roi
91
+
92
+ def crop_full_image(self, mask_hum, mask_obj, rgb_full, crop_masks, bbox_exp=1.0):
93
+ """
94
+ crop the image based on the given masks
95
+ :param mask_hum:
96
+ :param mask_obj:
97
+ :param rgb_full:
98
+ :param crop_masks: a list of masks used to do the crop
99
+ :return: Kroi, cropped human, object mask and RGB images (background masked out).
100
+ """
101
+ bmax, bmin, crop_center, crop_size = self.get_crop_params(*crop_masks, bbox_exp)
102
+ rgb = resize(crop(rgb_full, crop_center, crop_size), self.input_size) / 255.
103
+ person_mask = resize(crop(mask_hum, crop_center, crop_size), self.input_size) / 255.
104
+ obj_mask = resize(crop(mask_obj, crop_center, crop_size), self.input_size) / 255.
105
+ xywh = np.concatenate([crop_center - crop_size // 2, np.array([crop_size, crop_size])])
106
+ Kroi = self.compute_K_roi(xywh, rgb_full.shape[1], rgb_full.shape[0])
107
+ # mask bkg out
108
+ mask_comb = (person_mask > 0.5) | (obj_mask > 0.5)
109
+ rgb = rgb * np.expand_dims(mask_comb, -1)
110
+ return Kroi, obj_mask, person_mask, rgb
dataset/behave_paths.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os, re
3
+ import pickle as pkl
4
+ from os.path import join, basename, dirname, isfile
5
+ import os.path as osp
6
+
7
+ import cv2, json
8
+ import numpy as np
9
+
10
+ # PROCESSED_PATH = paths['PROCESSED_PATH']
11
+ BEHAVE_PATH = "/BS/xxie-5/static00/behave_release/sequences/"
12
+ RECON_PATH = "/BS/xxie-5/static00/behave-train"
13
+
14
+ class DataPaths:
15
+ """
16
+ class to handle path operations based on BEHAVE dataset structure
17
+ """
18
+ def __init__(self):
19
+ pass
20
+
21
+ @staticmethod
22
+ def load_splits(split_file, dataset_path=None):
23
+ assert os.path.exists(dataset_path), f'the given dataset path {dataset_path} does not exist, please check if your training data are placed over there!'
24
+ train, val = DataPaths.get_train_test_from_pkl(split_file)
25
+ return train, val
26
+ # print(train[:5], val[:5])
27
+ if isinstance(train[0], list):
28
+ # video data
29
+ train_full = [[join(dataset_path, seq[x]) for x in range(len(seq))] for seq in train]
30
+ val_full = [[join(dataset_path, seq[x]) for x in range(len(seq))] for seq in val]
31
+ else:
32
+ train_full = [join(dataset_path, x) for x in train] # full path to the training data
33
+ val_full = [join(dataset_path, x) for x in val] # full path to the validation data files
34
+ # print(train_full[:5], val_full[:5])
35
+ return train_full, val_full
36
+
37
+ @staticmethod
38
+ def load_splits_online(split_file, dataset_path=BEHAVE_PATH):
39
+ "load rgb file, smpl and object mesh paths"
40
+ keys = ['rgb', 'smpl', 'obj']
41
+ types = ['train', 'val']
42
+ splits = {}
43
+ data = pkl.load(open(split_file, 'rb'))
44
+ for type in types:
45
+ for key in keys:
46
+ k = f'{type}_{key}'
47
+ splits[k] = [join(dataset_path, x) for x in data[k]]
48
+ return splits
49
+
50
+ @staticmethod
51
+ def get_train_test_from_pkl(pkl_file):
52
+ data = pkl.load(open(pkl_file, 'rb'))
53
+ return data['train'], data['test']
54
+
55
+ @staticmethod
56
+ def get_image_paths_seq(seq, tid=1, check_occlusion=False, pat='t*.000'):
57
+ """
58
+ find all image paths in one sequence
59
+ :param seq: path to one behave sequence
60
+ :param tid: test on images from which camera
61
+ :param check_occlusion: whether to load full object mask and check occlusion ratio
62
+ :return: a list of paths to test image files
63
+ """
64
+ image_files = sorted(glob.glob(seq + f"/{pat}/k{tid}.color.jpg"))
65
+ # print(image_files, seq + f"/{pat}/k{tid}.color.jpg")
66
+ if not check_occlusion:
67
+ return image_files
68
+ # check object occlusion ratio
69
+ valid_files = []
70
+ count = 0
71
+ for img_file in image_files:
72
+ mask_file = img_file.replace('.color.jpg', '.obj_rend_mask.png')
73
+ if not os.path.isfile(mask_file):
74
+ mask_file = img_file.replace('.color.jpg', '.obj_rend_mask.jpg')
75
+ full_mask_file = img_file.replace('.color.jpg', '.obj_rend_full.png')
76
+ if not os.path.isfile(full_mask_file):
77
+ full_mask_file = img_file.replace('.color.jpg', '.obj_rend_full.jpg')
78
+ if not isfile(mask_file) or not isfile(full_mask_file):
79
+ continue
80
+
81
+ mask = np.sum(cv2.imread(mask_file, cv2.IMREAD_GRAYSCALE) > 127)
82
+ mask_full = np.sum(cv2.imread(full_mask_file, cv2.IMREAD_GRAYSCALE) > 127)
83
+ if mask_full == 0:
84
+ count += 1
85
+ continue
86
+
87
+ ratio = mask / mask_full
88
+ if ratio > 0.3:
89
+ valid_files.append(img_file)
90
+ else:
91
+ count += 1
92
+ print(f'{mask_file} occluded by {1 - ratio}!')
93
+ return valid_files
94
+
95
+ @staticmethod
96
+ def get_kinect_id(rgb_file):
97
+ "extract kinect id from the rgb file"
98
+ filename = osp.basename(rgb_file)
99
+ try:
100
+ kid = int(filename.split('.')[0][1])
101
+ assert kid in [0, 1, 2, 3, 4, 5], f'found invalid kinect id {kid} for file {rgb_file}'
102
+ return kid
103
+ except Exception as e:
104
+ print(rgb_file)
105
+ raise ValueError()
106
+
107
+ @staticmethod
108
+ def get_seq_date(rgb_file):
109
+ "date for the sequence"
110
+ seq_name = str(rgb_file).split(os.sep)[-3]
111
+ date = seq_name.split('_')[0]
112
+ assert date in ['Date01', 'Date02', 'Date03', 'Date04', 'Date05', 'Date06', 'Date07',
113
+ "ICapS01", "ICapS02", "ICapS03", "Date08", "Date09"], f"invalid date for {rgb_file}"
114
+ return date
115
+
116
+ @staticmethod
117
+ def rgb2obj_path(rgb_file:str, save_name='fit01-smooth'):
118
+ "convert an rgb file to a obj mesh file"
119
+ ss = rgb_file.split(os.sep)
120
+ seq_name = ss[-3]
121
+ obj_name = seq_name.split('_')[2]
122
+ real_name = obj_name
123
+ if 'chair' in obj_name:
124
+ real_name = 'chair'
125
+ if 'ball' in obj_name:
126
+ real_name = 'sports ball'
127
+
128
+ frame_folder = osp.dirname(rgb_file)
129
+ mesh_file = osp.join(frame_folder, real_name, save_name, f'{real_name}_fit.ply')
130
+
131
+ if not osp.isfile(mesh_file):
132
+ # synthetic data
133
+ mesh_file = osp.join(frame_folder, obj_name, save_name, f'{obj_name}_fit.ply')
134
+ return mesh_file
135
+
136
+ @staticmethod
137
+ def rgb2smpl_path(rgb_file:str, save_name='fit03'):
138
+ frame_folder = osp.dirname(rgb_file)
139
+ real_name = 'person'
140
+ mesh_file = osp.join(frame_folder, real_name, save_name, f'{real_name}_fit.ply')
141
+ return mesh_file
142
+
143
+ @staticmethod
144
+ def rgb2seq_frame(rgb_file:str):
145
+ "rgb file to seq_name, frame time"
146
+ ss = rgb_file.split(os.sep)
147
+ return ss[-3], ss[-2]
148
+
149
+ @staticmethod
150
+ def rgb2recon_folder(rgb_file, save_name, recon_path):
151
+ "convert rgb file to the subfolder"
152
+ dataset_path = osp.dirname(osp.dirname(osp.dirname(rgb_file)))
153
+ recon_folder = osp.join(osp.dirname(rgb_file.replace(dataset_path, recon_path)), save_name)
154
+ return recon_folder
155
+
156
+ @staticmethod
157
+ def get_seq_name(rgb_file):
158
+ return osp.basename(osp.dirname(osp.dirname(rgb_file)))
159
+
160
+ @staticmethod
161
+ def rgb2template_path(rgb_file):
162
+ "return the path to the object template"
163
+ from recon.opt_utils import get_template_path
164
+ # seq_name = DataPaths.get_seq_name(rgb_file)
165
+ # obj_name = seq_name.split('_')[2]
166
+ obj_name = DataPaths.rgb2object_name(rgb_file)
167
+ path = get_template_path(BEHAVE_PATH+"/../objects", obj_name)
168
+ return path
169
+
170
+ @staticmethod
171
+ def rgb2object_name(rgb_file):
172
+ seq_name = DataPaths.get_seq_name(rgb_file)
173
+ obj_name = seq_name.split('_')[2]
174
+ return obj_name
175
+
176
+ @staticmethod
177
+ def rgb2recon_frame(rgb_file, recon_path=RECON_PATH):
178
+ "return the frame folder in recon path"
179
+ ss = rgb_file.split(os.sep)
180
+ seq_name, frame = ss[-3], ss[-2]
181
+ return osp.join(recon_path, seq_name, frame)
182
+
183
+ @staticmethod
184
+ def rgb2gender(rgb_file):
185
+ "find the gender of this image"
186
+ seq_name = str(rgb_file).split(os.sep)[-3]
187
+ sub = seq_name.split('_')[1]
188
+ return _sub_gender[sub]
189
+
190
+ @staticmethod
191
+ def get_dataset_root(rgb_file):
192
+ "return the root path to all sequences"
193
+ from pathlib import Path
194
+ path = Path(rgb_file)
195
+ return str(path.parents[2])
196
+
197
+ @staticmethod
198
+ def seqname2gender(seq_name:str):
199
+ sub = seq_name.split('_')[1]
200
+ return _sub_gender[sub]
201
+
202
+ ICAP_PATH = "/BS/xxie-6/static00/InterCap" # assume same root folder
203
+ date_seqs = {
204
+ "Date01": BEHAVE_PATH + "/Date01_Sub01_backpack_back",
205
+ "Date02": BEHAVE_PATH + "/Date02_Sub02_backpack_back",
206
+ "Date03": BEHAVE_PATH + "/Date03_Sub03_backpack_back",
207
+ "Date04": BEHAVE_PATH + "/Date04_Sub05_backpack",
208
+ "Date05": BEHAVE_PATH + "/Date05_Sub05_backpack",
209
+ "Date06": BEHAVE_PATH + "/Date06_Sub07_backpack_back",
210
+ "Date07": BEHAVE_PATH + "/Date07_Sub04_backpack_back",
211
+ # "Date08": "/BS/xxie-6/static00/synthesize/Date08_Subxx_chairwood_synzv2-02",
212
+ "Date08": "/BS/xxie-6/static00/synz-backup/Date08_Subxx_chairwood_synzv2-02",
213
+ "Date09": "/BS/xxie-6/static00/synthesize/Date09_Subxx_obj01_icap", # InterCap sequence synz
214
+ "ICapS01": ICAP_PATH + "/ICapS01_sub01_obj01_Seg_0",
215
+ "ICapS02": ICAP_PATH + "/ICapS02_sub01_obj08_Seg_0",
216
+ "ICapS03": ICAP_PATH + "/ICapS03_sub07_obj05_Seg_0",
217
+ }
218
+
219
+ _sub_gender = {
220
+ "Sub01": 'male',
221
+ "Sub02": 'male',
222
+ "Sub03": 'male',
223
+ "Sub04": 'male',
224
+ "Sub05": 'male',
225
+ "Sub06": 'female',
226
+ "Sub07": 'female',
227
+ "Sub08": 'female',
228
+ }
dataset/demo_dataset.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import cv2
4
+ import torch
5
+
6
+ from .base_data import BaseDataset
7
+ from .behave_paths import DataPaths
8
+ from .img_utils import compute_translation, masks2bbox, crop
9
+
10
+
11
+ def padTo_4x3(rgb, person_mask, obj_mask, aspect_ratio=0.75):
12
+ """
13
+ pad images to have 4:3 aspect ratio
14
+ :param rgb: (H, W, 3)
15
+ :param person_mask:
16
+ :param obj_mask:
17
+ :return: all images at the given aspect ratio
18
+ """
19
+ h, w = rgb.shape[:2]
20
+ if w > h * 1/aspect_ratio:
21
+ # pad top
22
+ h_4x3 = int(w * aspect_ratio)
23
+ pad_top = h_4x3 - h
24
+ rgb_pad = np.pad(rgb, ((pad_top, 0), (0, 0), (0, 0)))
25
+ person_mask = np.pad(person_mask, ((pad_top, 0), (0, 0))) if person_mask is not None else None
26
+ obj_mask = np.pad(obj_mask, ((pad_top, 0), (0, 0))) if obj_mask is not None else None
27
+ else:
28
+ # pad two side
29
+ w_new = np.lcm.reduce([h * 2, 16]) # least common multiplier
30
+ h_4x3 = int(w_new * aspect_ratio)
31
+ pad_top = h_4x3 - h
32
+ pad_left = (w_new - w) // 2
33
+ pad_right = w_new - w - pad_left
34
+ rgb_pad = np.pad(rgb, ((pad_top, 0), (pad_left, pad_right), (0, 0)))
35
+ obj_mask = np.pad(obj_mask, ((pad_top, 0), (pad_left, pad_right))) if obj_mask is not None else None
36
+ person_mask = np.pad(person_mask, ((pad_top, 0), (pad_left, pad_right))) if person_mask is not None else None
37
+ return rgb_pad, obj_mask, person_mask
38
+
39
+
40
+ def recrop_input(rgb, person_mask, obj_mask, dataset_name='behave'):
41
+ "recrop input images"
42
+ exp_ratio = 1.42
43
+ if dataset_name == 'behave':
44
+ mean_center = np.array([1008, 995]) # mean RGB image crop center
45
+ behave_size = (2048, 1536)
46
+ new_size = (int(750 * exp_ratio), int(exp_ratio * 750))
47
+ else:
48
+ mean_center = np.array([904, 668]) # mean RGB image crop center for bottle sequences of ICAP
49
+ behave_size = (1920, 1080)
50
+ new_size = (int(593.925 * exp_ratio), int(exp_ratio * 593.925)) # mean width of bottle sequences
51
+ aspect_ratio = behave_size[1] / behave_size[0]
52
+ pad_top = mean_center[1] - new_size[0] // 2
53
+ pad_bottom = behave_size[1] - (mean_center[1] + new_size[0] // 2)
54
+ pad_left = mean_center[0] - new_size[0] // 2
55
+ pad_right = behave_size[0] - (mean_center[0] + new_size[0] // 2)
56
+
57
+ # First resize to the same aspect ratio
58
+ if rgb.shape[0] / rgb.shape[1] != aspect_ratio:
59
+ rgb, obj_mask, person_mask = padTo_4x3(rgb, person_mask, obj_mask, aspect_ratio)
60
+
61
+ # Resize to the same size as behave image, to have a comparable pixel size
62
+ rgb = cv2.resize(rgb, behave_size)
63
+ mask_ps = cv2.resize(person_mask, behave_size)
64
+ mask_obj = cv2.resize(obj_mask, behave_size)
65
+
66
+ # Crop and resize the human + object patch
67
+ bmin, bmax = masks2bbox([mask_ps, mask_obj])
68
+ center = (bmin + bmax) // 2
69
+ crop_size = int(np.max(bmax - bmin) * exp_ratio) # larger crop to have background
70
+ img_crop = cv2.resize(crop(rgb, center, crop_size), new_size)
71
+ mask_ps = cv2.resize(crop(mask_ps, center, crop_size), new_size)
72
+ mask_obj = cv2.resize(crop(mask_obj, center, crop_size), new_size)
73
+
74
+ # Pad back to have same shape as behave image
75
+ img_full = np.pad(img_crop, [[pad_top, pad_bottom], [pad_left, pad_right], [0, 0]])
76
+ mask_ps_full = np.pad(mask_ps, [[pad_top, pad_bottom], [pad_left, pad_right]])
77
+ mask_obj_full = np.pad(mask_obj, [[pad_top, pad_bottom], [pad_left, pad_right]])
78
+
79
+ # Make sure the image shape is the same
80
+ if img_full.shape[:2] != behave_size[::-1]:
81
+ img_full = cv2.resize(img_full, behave_size)
82
+ mask_ps_full = cv2.resize(mask_ps_full, behave_size)
83
+ mask_obj_full = cv2.resize(mask_obj_full, behave_size)
84
+ return img_full, mask_ps_full, mask_obj_full
85
+
86
+
87
+ class DemoDataset(BaseDataset):
88
+ def __init__(self, data_paths, input_size=(224, 224),
89
+ std_coverage=3.5, # used to estimate camera translation
90
+ ):
91
+ super().__init__(data_paths, input_size)
92
+ self.std_coverage = std_coverage
93
+
94
+ def __len__(self):
95
+ return len(self.data_paths)
96
+
97
+ def __getitem__(self, idx):
98
+ rgb_file = self.data_paths[idx]
99
+ mask_hum, mask_obj = self.load_masks(rgb_file)
100
+ rgb_full = cv2.imread(rgb_file)[:, :, ::-1]
101
+
102
+ return self.image2dict(mask_hum, mask_obj, rgb_full, rgb_file)
103
+
104
+ def image2dict(self, mask_hum, mask_obj, rgb_full, rgb_file=None):
105
+ "do all the necessary preprocessing for images"
106
+ if rgb_full.shape[:2] != mask_obj.shape[:2]:
107
+ raise ValueError(f"The given object mask shape {mask_obj.shape[:2]} does not match the RGB image shape {rgb_full.shape[:2]}")
108
+ if rgb_full.shape[:2] != mask_hum.shape[:2]:
109
+ raise ValueError(f"The given human mask shape {mask_hum.shape[:2]} does not match the RGB image shape {rgb_full.shape[:2]}")
110
+
111
+ if rgb_full.shape[:2] not in [(1080, 1920), (1536, 2048)]:
112
+ # crop and resize the image to behave image size
113
+ print(f"Recropping the input image and masks for {rgb_file}")
114
+ rgb_full, mask_hum, mask_obj = recrop_input(rgb_full, mask_hum, mask_obj)
115
+ color_h, color_w = rgb_full.shape[:2]
116
+ # Input to the first stage model: human + object crop
117
+ Kroi, objmask_fullcrop, psmask_fullcrop, rgb_fullcrop = self.crop_full_image(mask_hum.copy(),
118
+ mask_obj.copy(),
119
+ rgb_full.copy(),
120
+ [mask_hum, mask_obj],
121
+ 1.00)
122
+ # Input to the second stage model: human and object crops
123
+ Kroi_h, masko_hum, maskh_hum, rgb_hum = self.crop_full_image(mask_hum.copy(),
124
+ mask_obj.copy(),
125
+ rgb_full.copy(),
126
+ [mask_hum, mask_hum], 1.05)
127
+ Kroi_o, masko_obj, maskh_obj, rgb_obj = self.crop_full_image(mask_hum.copy(),
128
+ mask_obj.copy(),
129
+ rgb_full.copy(),
130
+ [mask_obj, mask_obj], 1.5)
131
+ # Estimate camera translation
132
+ cent_transform = np.eye(4) # the transform applied to the mesh that moves it back to kinect camera frame
133
+ bmin_ho, bmax_ho = masks2bbox([mask_hum, mask_obj])
134
+ crop_size_ho = int(np.max(bmax_ho - bmin_ho) * 1.0)
135
+ if crop_size_ho % 2 == 1:
136
+ crop_size_ho += 1 # make sure it is an even number
137
+ is_behave = self.is_behave_dataset(rgb_full.shape[1])
138
+ if rgb_full.shape[1] not in [2048, 1920]:
139
+ raise ValueError('the image is not normalized to BEHAVE or ICAP size!')
140
+ indices = np.indices(rgb_full.shape[:2])
141
+ if np.sum(mask_obj > 127) < 5:
142
+ raise ValueError(f'not enough object mask found for {rgb_file}')
143
+ pts_h = np.stack([indices[1][mask_hum > 127], indices[0][mask_hum > 127]], -1)
144
+ pts_o = np.stack([indices[1][mask_obj > 127], indices[0][mask_obj > 127]], -1)
145
+ proj_cent_est = (np.mean(pts_h, 0) + np.mean(pts_o, 0)) / 2. # heuristic to obtain 2d projection center
146
+ transl_estimate = compute_translation(proj_cent_est, crop_size_ho, is_behave, self.std_coverage)
147
+ cent_transform[:3, 3] = transl_estimate / 7.0
148
+ radius = 0.5 # don't do normalization anymore
149
+ cent = transl_estimate / 7.0
150
+ comb = np.matmul(self.opencv2py3d, cent_transform)
151
+ R = torch.from_numpy(comb[:3, :3]).float()
152
+ T = torch.from_numpy(comb[:3, 3]).float() / (radius * 2)
153
+ data_dict = {
154
+ "R": R,
155
+ "T": T,
156
+ "K": torch.from_numpy(Kroi).float(),
157
+ "T_ho": torch.from_numpy(cent).float(), # translation for H+O
158
+ "image_path": rgb_file,
159
+ "image_size_hw": torch.tensor(self.input_size),
160
+ "images": torch.from_numpy(rgb_fullcrop).float().permute(2, 0, 1),
161
+ "masks": torch.from_numpy(np.stack([psmask_fullcrop, objmask_fullcrop], 0)).float(),
162
+ 'orig_image_size': torch.tensor([color_h, color_w]),
163
+
164
+ # Human input to stage 2
165
+ "images_hum": torch.from_numpy(rgb_hum).float().permute(2, 0, 1),
166
+ "masks_hum": torch.from_numpy(np.stack([maskh_hum, masko_hum], 0)).float(),
167
+ "K_hum": torch.from_numpy(Kroi_h).float(),
168
+
169
+ # Object input to stage 2
170
+ "images_obj": torch.from_numpy(rgb_obj).float().permute(2, 0, 1),
171
+ "masks_obj": torch.from_numpy(np.stack([maskh_obj, masko_obj], 0)).float(),
172
+ "K_obj": torch.from_numpy(Kroi_o).float(),
173
+
174
+ # some normalization parameters
175
+ "gt_trans": cent,
176
+ 'radius': radius,
177
+ "estimated_trans": transl_estimate,
178
+ }
179
+ return data_dict
180
+
181
+ def image2batch(self, rgb, mask_hum, mask_obj):
182
+ """
183
+ given input image, convert it into a batch object ready for model inference
184
+ :param rgb: (h, w, 3), np array
185
+ :param mask_hum: (h, w, 3), np array
186
+ :param mask_obj: (h, w, 3), np array
187
+ :return:
188
+ """
189
+ mask_hum = np.mean(mask_hum, -1)
190
+ mask_obj = np.mean(mask_obj, -1)
191
+
192
+ data_dict = self.image2dict(mask_hum, mask_obj, rgb, 'input image')
193
+ # convert dict to list
194
+ new_dict = {k:[v] for k, v in data_dict.items()}
195
+
196
+ return new_dict
197
+
198
+
dataset/img_utils.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ common functions for image operations
3
+ """
4
+
5
+ import cv2
6
+ import numpy as np
7
+
8
+
9
+ def crop(img, center, crop_size):
10
+ """
11
+ crop image around the given center, pad zeros for borders
12
+ :param img:
13
+ :param center: np array
14
+ :param crop_size: np array or a float size of the resulting crop
15
+ :return: a square crop around the center
16
+ """
17
+ assert isinstance(img, np.ndarray)
18
+ h, w = img.shape[:2]
19
+ topleft = np.round(center - crop_size / 2).astype(int)
20
+ bottom_right = np.round(center + crop_size / 2).astype(int)
21
+
22
+ x1 = max(0, topleft[0])
23
+ y1 = max(0, topleft[1])
24
+ x2 = min(w - 1, bottom_right[0])
25
+ y2 = min(h - 1, bottom_right[1])
26
+ cropped = img[y1:y2, x1:x2]
27
+
28
+ p1 = max(0, -topleft[0]) # padding in x, top
29
+ p2 = max(0, -topleft[1]) # padding in y, top
30
+ p3 = max(0, bottom_right[0] - w + 1) # padding in x, bottom
31
+ p4 = max(0, bottom_right[1] - h + 1) # padding in y, bottom
32
+
33
+ dim = len(img.shape)
34
+ if dim == 3:
35
+ padded = np.pad(cropped, [[p2, p4], [p1, p3], [0, 0]])
36
+ elif dim == 2:
37
+ padded = np.pad(cropped, [[p2, p4], [p1, p3]])
38
+ else:
39
+ raise NotImplemented
40
+ return padded
41
+
42
+
43
+ def resize(img, img_size, mode=cv2.INTER_LINEAR):
44
+ """
45
+ resize image to the input
46
+ :param img:
47
+ :param img_size: (width, height) of the target image size
48
+ :param mode:
49
+ :return:
50
+ """
51
+ h, w = img.shape[:2]
52
+ load_ratio = 1.0 * w / h
53
+ netin_ratio = 1.0 * img_size[0] / img_size[1]
54
+ assert load_ratio == netin_ratio, "image aspect ration not matching, given image: {}, net input: {}".format(
55
+ img.shape, img_size)
56
+ resized = cv2.resize(img, img_size, interpolation=mode)
57
+ return resized
58
+
59
+
60
+ def masks2bbox(masks, threshold=127):
61
+ """
62
+
63
+ :param masks:
64
+ :param threshold:
65
+ :return: bounding box corner coordinate
66
+ """
67
+ mask_comb = np.zeros_like(masks[0], dtype=bool)
68
+ for m in masks:
69
+ mask_comb = mask_comb | (m > threshold)
70
+
71
+ yid, xid = np.where(mask_comb)
72
+ bmin = np.array([xid.min(), yid.min()])
73
+ bmax = np.array([xid.max(), yid.max()])
74
+ return bmin, bmax
75
+
76
+
77
+ def compute_translation(crop_center, crop_size, is_behave=True, std_coverage=3.5):
78
+ """
79
+ solve for an optimal translation that project gaussian in origin to the crop
80
+ Parameters
81
+ ----------
82
+ crop_center: (x, y) of the crop center
83
+ crop_size: float, the size of the square crop
84
+ std_coverage: which edge point should be projected back to the edge of the 2d crop
85
+
86
+ Returns
87
+ -------
88
+ the estimated translation
89
+
90
+ """
91
+ x0, y0 = crop_center
92
+ x1, y1 = x0 + crop_size/2, y0
93
+ x2, y2 = x0 - crop_size/2, y0
94
+ x3, y3 = x0, y0 + crop_size/2.
95
+ # predefined kinect intrinsics
96
+ if is_behave:
97
+ fx = 979.7844
98
+ fy = 979.840
99
+ cx = 1018.952
100
+ cy = 779.486
101
+ else:
102
+ # intercap camera
103
+ fx, fy = 918.457763671875, 918.4373779296875
104
+ cx, cy = 956.9661865234375, 555.944580078125
105
+
106
+ # construct the matrix
107
+ # A = np.array([
108
+ # [fx, 0, cx-x0, cx-x0, 0, 0],
109
+ # [0, fy, cy-y0, cy-y0, 0, 0],
110
+ # [fx, 0, cx-x1, 0, cx-x1, 0],
111
+ # [0, fy, cy-y1, 0, cy-y1, 0],
112
+ # [fx, 0, cx-x2, 0, 0, cx-x2],
113
+ # [0, fy, cy-y2, 0, 0, cy-y2]
114
+ # ]) # this matrix is low-rank because columns are linearly dependent: col3 - col4 = col5 + col6
115
+ # # find linearly dependent rows
116
+ # lambdas, V = np.linalg.eig(A)
117
+ # # print()
118
+ # # The linearly dependent row vectors
119
+ # print(lambdas == 0, np.linalg.det(A), A[lambdas == 0, :]) # some have determinant zero, some don't??
120
+ # print(np.linalg.inv(A))
121
+
122
+ # A = np.array([
123
+ # [fx, 0, cx - x0, cx - x0, 0, 0],
124
+ # [0, fy, cy - y0, cy - y0, 0, 0],
125
+ # [fx, 0, cx - x1, 0, cx - x1, 0],
126
+ # [0, fy, cy - y1, 0, cy - y1, 0],
127
+ # [fx, 0, cx - x3, 0, 0, cx - x3],
128
+ # [0, fy, cy - y3, 0, 0, cy - y3]
129
+ # ]) # this is also low rank!
130
+ # b = np.array([0, 0, -3*fx, 0, 0, -3*fy]).reshape((-1, 1))
131
+ # print("rank of the coefficient matrix:", np.linalg.matrix_rank(A)) # rank is 5! underconstrained matrix!
132
+ # x = np.matmul(np.linalg.inv(A), b)
133
+
134
+ # fix z0 as 0, then A is a full-rank matrix
135
+ # first two equations: origin (0, 0, 0) is projected to the crop center
136
+ # last two equations: edge point (3.5, 0, z) is projected to the edge of crop
137
+ A = np.array([
138
+ [fx, 0, cx-x0, cx-x0],
139
+ [0, fy, cy-y0, cy-y0],
140
+ [fx, 0, fx-x1, 0],
141
+ [0, fy, cy-y1, 0]
142
+ ])
143
+ # b = np.array([0, 0, -3.5*fx, 0]).reshape((-1, 1)) # 3.5->half of 7.0
144
+ b = np.array([0, 0, -std_coverage * fx, 0]).reshape((-1, 1)) # 3.5->half of 7.0
145
+ x = np.matmul(np.linalg.inv(A), b) # use 4 or 5 does not really matter, same results
146
+
147
+ # A is always a full-rank matrix
148
+
149
+ return x.flatten()[:3]
demo.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Demo for template-free reconstruction
3
+
4
+ python demo.py model=ho-attn run.image_path=/BS/xxie-2/work/HDM/outputs/000000017450/k1.color.jpg run.job=sample model.predict_binary=True dataset.std_coverage=3.0
5
+ """
6
+ import pickle as pkl
7
+ import sys, os
8
+ import os.path as osp
9
+ from typing import Iterable, Optional
10
+
11
+ import cv2
12
+ from accelerate import Accelerator
13
+ from tqdm import tqdm
14
+ from glob import glob
15
+
16
+ sys.path.append(os.getcwd())
17
+ import hydra
18
+ import torch
19
+ import numpy as np
20
+ import imageio
21
+ from torch.utils.data import DataLoader
22
+ from pytorch3d.datasets import R2N2, collate_batched_meshes
23
+ from pytorch3d.structures import Pointclouds
24
+ from pytorch3d.renderer import PerspectiveCameras, look_at_view_transform
25
+ from pytorch3d.io import IO
26
+ import torchvision.transforms.functional as TVF
27
+ from huggingface_hub import hf_hub_download
28
+
29
+ import training_utils
30
+ from configs.structured import ProjectConfig
31
+ from dataset.demo_dataset import DemoDataset
32
+ from model import CrossAttenHODiffusionModel, ConditionalPCDiffusionSeparateSegm
33
+ from render.pyt3d_wrapper import PcloudRenderer
34
+
35
+
36
+ class DemoRunner:
37
+ def __init__(self, cfg: ProjectConfig):
38
+ cfg.model.model_name, cfg.model.predict_binary = 'pc2-diff-ho-sepsegm', True
39
+ model_stage1 = ConditionalPCDiffusionSeparateSegm(**cfg.model)
40
+ cfg.model.model_name, cfg.model.predict_binary = 'diff-ho-attn', False # stage 2 does not predict segmentation
41
+ model_stage2 = CrossAttenHODiffusionModel(**cfg.model)
42
+
43
+ # Load from checkpoint
44
+ # ckpt_file1 = os.path.join(cfg.run.code_dir_abs, f'outputs/{cfg.run.stage1_name}/single/checkpoint-latest.pth')
45
+ # self.load_checkpoint(ckpt_file1, model_stage1)
46
+ # ckpt_file2 = os.path.join(cfg.run.code_dir_abs, f'outputs/{cfg.run.stage2_name}/single/checkpoint-latest.pth')
47
+ # self.load_checkpoint(ckpt_file2, model_stage2)
48
+ # Load ckpt from hf
49
+ ckpt_file1 = hf_hub_download("xiexh20/HDM-models", f'{cfg.run.stage1_name}.pth')
50
+ self.load_checkpoint(ckpt_file1, model_stage1)
51
+ ckpt_file2 = hf_hub_download("xiexh20/HDM-models", f'{cfg.run.stage2_name}.pth')
52
+ self.load_checkpoint(ckpt_file2, model_stage2)
53
+
54
+ self.model_stage1, self.model_stage2 = model_stage1, model_stage2
55
+ self.model_stage1.eval()
56
+ self.model_stage2.eval()
57
+ self.model_stage1.to('cuda')
58
+ self.model_stage2.to('cuda')
59
+
60
+ self.cfg = cfg
61
+ self.io_pc = IO()
62
+
63
+ # For visualization
64
+ self.renderer = PcloudRenderer(image_size=cfg.dataset.image_size, radius=0.0075)
65
+ self.rend_size = cfg.dataset.image_size
66
+ self.device = 'cuda'
67
+
68
+ def load_checkpoint(self, ckpt_file1, model_stage1):
69
+ checkpoint = torch.load(ckpt_file1, map_location='cpu')
70
+ state_dict, key = checkpoint['model'], 'model'
71
+ if any(k.startswith('module.') for k in state_dict.keys()):
72
+ state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
73
+ print('Removed "module." from checkpoint state dict')
74
+ missing_keys, unexpected_keys = model_stage1.load_state_dict(state_dict, strict=False)
75
+ print(f'Loaded model checkpoint {key} from {ckpt_file1}')
76
+ if len(missing_keys):
77
+ print(f' - Missing_keys: {missing_keys}')
78
+ if len(unexpected_keys):
79
+ print(f' - Unexpected_keys: {unexpected_keys}')
80
+
81
+ @torch.no_grad()
82
+ def run(self):
83
+ "simply run the demo on given images, and save the results"
84
+ # Set random seed
85
+ training_utils.set_seed(self.cfg.run.seed)
86
+
87
+ outdir = osp.join(self.cfg.run.code_dir_abs, 'outputs/demo')
88
+ os.makedirs(outdir, exist_ok=True)
89
+ cfg = self.cfg
90
+
91
+ # Init data
92
+ image_files = sorted(glob(cfg.run.image_path))
93
+ data = DemoDataset(image_files,
94
+ (cfg.dataset.image_size, cfg.dataset.image_size),
95
+ cfg.dataset.std_coverage)
96
+ dataloader = DataLoader(data, batch_size=cfg.dataloader.batch_size,
97
+ collate_fn=collate_batched_meshes,
98
+ num_workers=1, shuffle=False)
99
+ dataloader = dataloader
100
+ progress_bar = tqdm(dataloader)
101
+ for batch_idx, batch in enumerate(progress_bar):
102
+ progress_bar.set_description(f'Processing batch {batch_idx:4d} / {len(dataloader):4d}')
103
+
104
+ out_stage1, out_stage2 = self.forward_batch(batch, cfg)
105
+
106
+ bs = len(out_stage1)
107
+ camera_full = PerspectiveCameras(
108
+ R=torch.stack(batch['R']),
109
+ T=torch.stack(batch['T']),
110
+ K=torch.stack(batch['K']),
111
+ device='cuda',
112
+ in_ndc=True)
113
+
114
+ # save output
115
+ for i in range(bs):
116
+ image_path = str(batch['image_path'])
117
+ folder, fname = osp.basename(osp.dirname(image_path)), osp.splitext(osp.basename(image_path))[0]
118
+ out_i = osp.join(outdir, folder)
119
+ os.makedirs(out_i, exist_ok=True)
120
+ self.io_pc.save_pointcloud(data=out_stage1[i],
121
+ path=osp.join(out_i, f'{fname}_stage1.ply'))
122
+ self.io_pc.save_pointcloud(data=out_stage2[i],
123
+ path=osp.join(out_i, f'{fname}_stage2.ply'))
124
+ TVF.to_pil_image(batch['images'][i]).save(osp.join(out_i, f'{fname}_input.png'))
125
+
126
+ # Save metadata as well
127
+ metadata = dict(index=i,
128
+ camera=camera_full[i],
129
+ image_size_hw=batch['image_size_hw'][i],
130
+ image_path=batch['image_path'][i])
131
+ torch.save(metadata, osp.join(out_i, f'{fname}_meta.pth'))
132
+
133
+ # Visualize
134
+ # front_camera = camera_full[i]
135
+ pc_comb = Pointclouds([out_stage1[i].points_packed(), out_stage2[i].points_packed()],
136
+ features=[out_stage1[i].features_packed(), out_stage2[i].features_packed()])
137
+ video_file = osp.join(out_i, f'{fname}_360view.mp4')
138
+ video_writer = imageio.get_writer(video_file, format='FFMPEG', mode='I', fps=1)
139
+
140
+ # first render front view
141
+ rend_stage1, _ = self.renderer.render(out_stage1[i], camera_full[i], mode='mask')
142
+ rend_stage2, _ = self.renderer.render(out_stage2[i], camera_full[i], mode='mask')
143
+ comb = np.concatenate([batch['images'][i].permute(1, 2, 0).cpu().numpy(), rend_stage1, rend_stage2], 1)
144
+ video_writer.append_data((comb*255).astype(np.uint8))
145
+
146
+ for azim in range(180, 180+360, 30):
147
+ R, T = look_at_view_transform(1.7, 0, azim, up=((0, -1, 0),), )
148
+ side_camera = PerspectiveCameras(image_size=((self.rend_size, self.rend_size),),
149
+ device=self.device,
150
+ R=R.repeat(2, 1, 1), T=T.repeat(2, 1),
151
+ focal_length=self.rend_size * 1.5,
152
+ principal_point=((self.rend_size / 2., self.rend_size / 2.),),
153
+ in_ndc=False)
154
+ rend, mask = self.renderer.render(pc_comb, side_camera, mode='mask')
155
+
156
+ imgs = [batch['images'][i].permute(1, 2, 0).cpu().numpy()]
157
+ imgs.extend([rend[0], rend[1]])
158
+ video_writer.append_data((np.concatenate(imgs, 1)*255).astype(np.uint8))
159
+ print(f"Visualization saved to {out_i}")
160
+
161
+ @torch.no_grad()
162
+ def forward_batch(self, batch, cfg):
163
+ """
164
+ forward one batch
165
+ :param batch:
166
+ :param cfg:
167
+ :return: predicted point clouds of stage 1 and 2
168
+ """
169
+ camera_full = PerspectiveCameras(
170
+ R=torch.stack(batch['R']),
171
+ T=torch.stack(batch['T']),
172
+ K=torch.stack(batch['K']),
173
+ device='cuda',
174
+ in_ndc=True)
175
+ out_stage1 = self.model_stage1.forward_sample(num_points=cfg.dataset.max_points,
176
+ camera=camera_full,
177
+ image_rgb=torch.stack(batch['images']).to('cuda'),
178
+ mask=torch.stack(batch['masks']).to('cuda'),
179
+ scheduler=cfg.run.diffusion_scheduler,
180
+ num_inference_steps=cfg.run.num_inference_steps,
181
+ )
182
+ # segment and normalize human/object
183
+ bs = len(out_stage1)
184
+ pred_hum, pred_obj = [], [] # predicted human/object points
185
+ cent_hum_pred, cent_obj_pred = [], []
186
+ radius_hum_pred, radius_obj_pred = [], []
187
+ T_hum, T_obj = [], []
188
+ num_samples = int(cfg.dataset.max_points / 2)
189
+ for i in range(bs):
190
+ pc: Pointclouds = out_stage1[i]
191
+ vc = pc.features_packed().cpu() # (P, 3), human is light blue [0.1, 1.0, 1.0], object light green [0.5, 1.0, 0]
192
+ points = pc.points_packed().cpu() # (P, 3)
193
+ mask_hum = vc[:, 2] > 0.5
194
+ pc_hum, pc_obj = points[mask_hum], points[~mask_hum]
195
+ # Up/Down-sample the points
196
+ pc_obj = self.upsample_predicted_pc(num_samples, pc_obj)
197
+ pc_hum = self.upsample_predicted_pc(num_samples, pc_hum)
198
+
199
+ # Normalize
200
+ cent_hum, cent_obj = torch.mean(pc_hum, 0, keepdim=True), torch.mean(pc_obj, 0, keepdim=True)
201
+ scale_hum = torch.sqrt(torch.sum((pc_hum - cent_hum) ** 2, -1).max())
202
+ scale_obj = torch.sqrt(torch.sum((pc_obj - cent_obj) ** 2, -1).max())
203
+ pc_hum = (pc_hum - cent_hum) / (2 * scale_hum)
204
+ pc_obj = (pc_obj - cent_obj) / (2 * scale_obj)
205
+ # Also update camera parameters for separate human + object
206
+ T_hum_scaled = (batch['T_ho'][i] + cent_hum.squeeze(0)) / (2 * scale_hum)
207
+ T_obj_scaled = (batch['T_ho'][i] + cent_obj.squeeze(0)) / (2 * scale_obj)
208
+
209
+ pred_hum.append(pc_hum)
210
+ pred_obj.append(pc_obj)
211
+ cent_hum_pred.append(cent_hum.squeeze(0))
212
+ cent_obj_pred.append(cent_obj.squeeze(0))
213
+ T_hum.append(T_hum_scaled * torch.tensor([-1, -1, 1])) # apply opencv to pytorch3d transform: flip x and y
214
+ T_obj.append(T_obj_scaled * torch.tensor([-1, -1, 1]))
215
+ radius_hum_pred.append(scale_hum)
216
+ radius_obj_pred.append(scale_obj)
217
+ # Pack data into a new batch dict
218
+ camera_hum = PerspectiveCameras(
219
+ R=torch.stack(batch['R']),
220
+ T=torch.stack(T_hum),
221
+ K=torch.stack(batch['K_hum']),
222
+ device='cuda',
223
+ in_ndc=True
224
+ )
225
+ camera_obj = PerspectiveCameras(
226
+ R=torch.stack(batch['R']),
227
+ T=torch.stack(T_obj),
228
+ K=torch.stack(batch['K_obj']), # the camera should be human/object specific!!!
229
+ device='cuda',
230
+ in_ndc=True
231
+ )
232
+ # use pc from predicted
233
+ pc_hum = Pointclouds([x.to('cuda') for x in pred_hum])
234
+ pc_obj = Pointclouds([x.to('cuda') for x in pred_obj])
235
+ # use center and radius from predicted
236
+ cent_hum = torch.stack(cent_hum_pred, 0).to('cuda')
237
+ cent_obj = torch.stack(cent_obj_pred, 0).to('cuda') # B, 3
238
+ radius_hum = torch.stack(radius_hum_pred, 0).to('cuda') # B, 1
239
+ radius_obj = torch.stack(radius_obj_pred, 0).to('cuda')
240
+ out_stage2: Pointclouds = self.model_stage2.forward_sample(
241
+ num_points=num_samples,
242
+ camera=camera_hum,
243
+ image_rgb=torch.stack(batch['images_hum'], 0).to('cuda'),
244
+ mask=torch.stack(batch['masks_hum'], 0).to('cuda'),
245
+ gt_pc=pc_hum,
246
+ rgb_obj=torch.stack(batch['images_obj'], 0).to('cuda'),
247
+ mask_obj=torch.stack(batch['masks_obj'], 0).to('cuda'),
248
+ pc_obj=pc_obj,
249
+ camera_obj=camera_obj,
250
+ cent_hum=cent_hum,
251
+ cent_obj=cent_obj,
252
+ radius_hum=radius_hum.unsqueeze(-1),
253
+ radius_obj=radius_obj.unsqueeze(-1),
254
+ sample_from_interm=True,
255
+ noise_step=cfg.run.sample_noise_step)
256
+ return out_stage1, out_stage2
257
+
258
+ def upsample_predicted_pc(self, num_samples, pc_obj):
259
+ """
260
+ Up/Downsample the points to given number
261
+ :param num_samples: the target number
262
+ :param pc_obj: (N, 3)
263
+ :return: (num_samples, 3)
264
+ """
265
+ if len(pc_obj) > num_samples:
266
+ ind_obj = np.random.choice(len(pc_obj), num_samples)
267
+ else:
268
+ ind_obj = np.concatenate([np.arange(len(pc_obj)), np.random.choice(len(pc_obj), num_samples - len(pc_obj))])
269
+ pc_obj = pc_obj.clone()[torch.from_numpy(ind_obj).long().to(pc_obj.device)]
270
+ return pc_obj
271
+
272
+
273
+ @hydra.main(config_path='configs', config_name='configs', version_base='1.1')
274
+ def main(cfg: ProjectConfig):
275
+ runner = DemoRunner(cfg)
276
+ runner.run()
277
+
278
+
279
+ if __name__ == '__main__':
280
+ main()
diffusion_utils.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional, Sequence, Union
3
+
4
+ import imageio
5
+ import logging
6
+ import numpy as np
7
+ import torch
8
+ import torch.utils.data
9
+ from PIL import Image
10
+ from torch.distributions import Normal
11
+ from torchvision.transforms.functional import to_pil_image
12
+ from torchvision.utils import make_grid
13
+ from tqdm import tqdm, trange
14
+ from pytorch3d.renderer import (
15
+ AlphaCompositor,
16
+ NormWeightedCompositor,
17
+ OrthographicCameras,
18
+ PointsRasterizationSettings,
19
+ PointsRasterizer,
20
+ PointsRenderer,
21
+ look_at_view_transform)
22
+ from pytorch3d.renderer.cameras import CamerasBase
23
+ from pytorch3d.structures import Pointclouds
24
+ from pytorch3d.structures.pointclouds import join_pointclouds_as_batch
25
+
26
+
27
+ # Disable unnecessary imageio logging
28
+ logging.getLogger("imageio_ffmpeg").setLevel(logging.ERROR)
29
+
30
+
31
+ def rotation_matrix(axis, theta):
32
+ """
33
+ Return the rotation matrix associated with counterclockwise rotation about
34
+ the given axis by theta radians.
35
+ """
36
+ axis = np.asarray(axis)
37
+ axis = axis / np.sqrt(np.dot(axis, axis))
38
+ a = np.cos(theta / 2.0)
39
+ b, c, d = -axis * np.sin(theta / 2.0)
40
+ aa, bb, cc, dd = a * a, b * b, c * c, d * d
41
+ bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d
42
+ return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)],
43
+ [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)],
44
+ [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]])
45
+
46
+
47
+ def rotate(vertices, faces):
48
+ '''
49
+ vertices: [numpoints, 3]
50
+ '''
51
+ M = rotation_matrix([0, 1, 0], np.pi / 2).transpose()
52
+ N = rotation_matrix([1, 0, 0], -np.pi / 4).transpose()
53
+ K = rotation_matrix([0, 0, 1], np.pi).transpose()
54
+
55
+ v, f = vertices[:, [1, 2, 0]].dot(M).dot(N).dot(K), faces[:, [1, 2, 0]]
56
+ return v, f
57
+
58
+
59
+ def norm(v, f):
60
+ v = (v - v.min()) / (v.max() - v.min()) - 0.5
61
+
62
+ return v, f
63
+
64
+
65
+ def getGradNorm(net):
66
+ pNorm = torch.sqrt(sum(torch.sum(p ** 2) for p in net.parameters()))
67
+ gradNorm = torch.sqrt(sum(torch.sum(p.grad ** 2) for p in net.parameters()))
68
+ return pNorm, gradNorm
69
+
70
+
71
+ def weights_init(m):
72
+ classname = m.__class__.__name__
73
+ if classname.find('Conv') != -1 and m.weight is not None:
74
+ torch.nn.init.xavier_normal_(m.weight)
75
+ elif classname.find('BatchNorm') != -1:
76
+ m.weight.data.normal_()
77
+ m.bias.data.fill_(0)
78
+
79
+
80
+ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
81
+ # Assumes data is integers [0, 1]
82
+ assert x.shape == means.shape == log_scales.shape
83
+ px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
84
+
85
+ centered_x = x - means
86
+ inv_stdv = torch.exp(-log_scales)
87
+ plus_in = inv_stdv * (centered_x + 0.5)
88
+ cdf_plus = px0.cdf(plus_in)
89
+ min_in = inv_stdv * (centered_x - .5)
90
+ cdf_min = px0.cdf(min_in)
91
+ log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus) * 1e-12))
92
+ log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min) * 1e-12))
93
+ cdf_delta = cdf_plus - cdf_min
94
+
95
+ log_probs = torch.where(
96
+ x < 0.001, log_cdf_plus,
97
+ torch.where(x > 0.999, log_one_minus_cdf_min,
98
+ torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta) * 1e-12))))
99
+ assert log_probs.shape == x.shape
100
+ return log_probs
101
+
102
+
103
+ def fig2img(fig):
104
+ """Convert a Matplotlib figure to a PIL Image and return it"""
105
+ import io
106
+ buf = io.BytesIO()
107
+ fig.savefig(buf)
108
+ buf.seek(0)
109
+ img = Image.open(buf)
110
+ return img
111
+
112
+
113
+ @torch.no_grad()
114
+ def visualize_distance_transform(
115
+ path_stem: str,
116
+ images: torch.Tensor,
117
+ ) -> str:
118
+ output_file_image = f'{path_stem}.png'
119
+ if images.shape[3] in [1, 3]: # convert to (B, C, H, W)
120
+ images = images.permute(0, 3, 1, 2)
121
+ images = images[:, -1:] # (B, 1, H, W) # get only distances (not vectors for now, for simplicity)
122
+ image_grid = make_grid(images, nrow=int(math.sqrt(len(images))), pad_value=1, normalize=True)
123
+ to_pil_image(image_grid).save(output_file_image)
124
+ return output_file_image
125
+
126
+
127
+ @torch.no_grad()
128
+ def visualize_image(
129
+ path_stem: str,
130
+ images: torch.Tensor,
131
+ mean: Union[torch.Tensor, float] = 0.5,
132
+ std: Union[torch.Tensor, float] = 0.5,
133
+ ) -> str:
134
+ output_file_image = f'{path_stem}.png'
135
+ if images.shape[3] in [1, 3, 4]: # convert to (B, C, H, W)
136
+ images = images.permute(0, 3, 1, 2)
137
+ if images.shape[1] in [3, 4]: # normalize (single-channel images are not normalized)
138
+ images[:, :3] = images[:, :3] * std + mean # denormalize (color channels only, not alpha channel)
139
+ if images.shape[1] == 4: # normalize (single-channel images are not normalized)
140
+ image_alpha = images[:, 3:] # (B, 1, H, W)
141
+ bg_color = torch.tensor([230, 220, 250], device=images.device).reshape(1, 3, 1, 1) / 255
142
+ images = images[:, :3] * image_alpha + bg_color * (1 - image_alpha) # (B, 3, H, W)
143
+ image_grid = make_grid(images, nrow=int(math.sqrt(len(images))), pad_value=1)
144
+ to_pil_image(image_grid).save(output_file_image)
145
+ return output_file_image
146
+
147
+
148
+ def ensure_point_cloud_has_colors(pointcloud: Pointclouds):
149
+ if pointcloud.features_padded() is None:
150
+ pointcloud = type(pointcloud)(points=pointcloud.points_padded(),
151
+ normals=pointcloud.normals_padded(), features=torch.zeros_like(pointcloud.points_padded()))
152
+ return pointcloud
153
+
154
+
155
+ @torch.no_grad()
156
+ def render_pointcloud_batch_pytorch3d(
157
+ cameras: CamerasBase,
158
+ pointclouds: Pointclouds,
159
+ image_size: int = 224,
160
+ radius: float = 0.01,
161
+ points_per_pixel: int = 10,
162
+ background_color: Sequence[float] = (0.78431373, 0.78431373, 0.78431373),
163
+ compositor: str = 'norm_weighted'
164
+ ):
165
+ # Define the settings for rasterization and shading. Here we set the output image to be of size
166
+ # 512x512. As we are rendering images for visualization purposes only we will set faces_per_pixel=1
167
+ # and blur_radius=0.0. Refer to rasterize_points.py for explanations of these parameters.
168
+ raster_settings = PointsRasterizationSettings(
169
+ image_size=image_size,
170
+ radius=radius,
171
+ points_per_pixel=points_per_pixel,
172
+ )
173
+
174
+ # Rasterizer
175
+ rasterizer = PointsRasterizer(cameras=cameras, raster_settings=raster_settings)
176
+
177
+ # Compositor
178
+ if compositor == 'alpha':
179
+ compositor = AlphaCompositor(background_color=background_color)
180
+ elif compositor == 'norm_weighted':
181
+ compositor = NormWeightedCompositor(background_color=background_color)
182
+ else:
183
+ raise ValueError(compositor)
184
+
185
+ # Create a points renderer by compositing points using an weighted compositor (3D points are
186
+ # weighted according to their distance to a pixel and accumulated using a weighted sum)
187
+ renderer = PointsRenderer(rasterizer=rasterizer, compositor=compositor)
188
+
189
+ # We cannot render a point cloud without colors, so add them if the pointcloud does
190
+ # not already have them
191
+ pointclouds = ensure_point_cloud_has_colors(pointclouds)
192
+
193
+ # Render batch of image
194
+ images = renderer(pointclouds)
195
+
196
+ return images
197
+
198
+
199
+ @torch.no_grad()
200
+ def visualize_pointcloud_batch_pytorch3d(
201
+ pointclouds: Pointclouds,
202
+ output_file_video: Optional[str] = None,
203
+ output_file_image: Optional[str] = None,
204
+ cameras: Optional[CamerasBase] = None, # if None, we rotate
205
+ scale_factor: float = 1.0,
206
+ num_frames: int = 1, # note that it takes a while with 30 * batch_size frames
207
+ elev: int = 30,
208
+ ):
209
+ """Saves a video and a single image of a point cloud"""
210
+ assert 360 % num_frames == 0, 'please select a better number of frames'
211
+
212
+ # Sizes
213
+ B, N, C, F = *(pointclouds.points_padded().shape), num_frames
214
+ device = pointclouds.device
215
+
216
+ # If a camera has not been provided, we render from a rotating view around an image
217
+ if cameras is None:
218
+
219
+ # Create view transforms - R is (F, 3, 3) and T is (F, 3)
220
+ R, T = look_at_view_transform(dist=10.0, elev=elev, azim=list(range(0, 360, 360 // F)), degrees=True, device=device)
221
+
222
+ # Repeat
223
+ R = R.repeat_interleave(B, dim=0) # (F * B, 3, 3)
224
+ T = T.repeat_interleave(B, dim=0) # (F * B, 3)
225
+ points = pointclouds.points_padded().tile(F, 1, 1) # (F * B, num_points, 3)
226
+ colors = (torch.zeros_like(points) if pointclouds.features_padded() is None else
227
+ pointclouds.features_padded().tile(F, 1, 1)) # (F * B, num_points, 3)
228
+
229
+ # Initialize batch of cameras
230
+ cameras = OrthographicCameras(focal_length=(0.25 * scale_factor), device=device, R=R, T=T)
231
+
232
+ # Wrap in Pointclouds (with color, even if the original point cloud had no color)
233
+ pointclouds = Pointclouds(points=points, features=colors).to(device)
234
+
235
+ # Render image
236
+ images = render_pointcloud_batch_pytorch3d(cameras, pointclouds)
237
+
238
+ # Convert images into grid
239
+ image_grids = []
240
+ images_for_grids = images.reshape(F, B, *images.shape[1:]).permute(0, 1, 4, 2, 3)
241
+ for image_for_grids in images_for_grids:
242
+ image_grid = make_grid(image_for_grids, nrow=int(math.sqrt(B)), pad_value=1)
243
+ image_grids.append(image_grid)
244
+ image_grids = torch.stack(image_grids, dim=0)
245
+ image_grids = image_grids.detach().cpu()
246
+
247
+ # Save image
248
+ if output_file_image is not None:
249
+ to_pil_image(image_grids[0]).save(output_file_image)
250
+
251
+ # Save video
252
+ if output_file_video:
253
+ video = (image_grids * 255).permute(0, 2, 3, 1).to(torch.uint8).numpy()
254
+ imageio.mimwrite(output_file_video, video, fps=10)
255
+
256
+
257
+ @torch.no_grad()
258
+ def visualize_pointcloud_evolution_pytorch3d(
259
+ pointclouds: Pointclouds,
260
+ output_file_video: str,
261
+ camera: Optional[CamerasBase] = None, # if None, we rotate
262
+ scale_factor: float = 1.0,
263
+ ):
264
+
265
+ # Device
266
+ B, device = len(pointclouds), pointclouds.device
267
+
268
+ # Cameras
269
+ if camera is None:
270
+ R, T = look_at_view_transform(dist=10.0, elev=30, azim=0, device=device)
271
+ camera = OrthographicCameras(focal_length=(0.25 * scale_factor), device=device, R=R, T=T)
272
+
273
+ # Render
274
+ frames = render_pointcloud_batch_pytorch3d(camera, pointclouds)
275
+
276
+ # Save video
277
+ video = (frames.detach().cpu() * 255).to(torch.uint8).numpy()
278
+ imageio.mimwrite(output_file_video, video, fps=10)
279
+
280
+
281
+ def get_camera_index(cameras: CamerasBase, index: Optional[int] = None):
282
+ if index is None:
283
+ return cameras
284
+ kwargs = dict(
285
+ R=cameras.R[index].unsqueeze(0),
286
+ T=cameras.T[index].unsqueeze(0),
287
+ K=cameras.K[index].unsqueeze(0) if cameras.K is not None else None,
288
+ )
289
+ if hasattr(cameras, 'focal_length'):
290
+ kwargs['focal_length'] = cameras.focal_length[index].unsqueeze(0)
291
+ if hasattr(cameras, 'principal_point'):
292
+ kwargs['principal_point'] = cameras.principal_point[index].unsqueeze(0)
293
+ return type(cameras)(**kwargs).to(cameras.device)
294
+
295
+
296
+ def get_metadata(item) -> str:
297
+ s = '-------------\n'
298
+ for key in item.keys():
299
+ value = item[key]
300
+ if torch.is_tensor(value) and value.numel() < 25:
301
+ value_str = value
302
+ elif torch.is_tensor(value):
303
+ value_str = value.shape
304
+ elif isinstance(value, str):
305
+ value_str = value
306
+ elif isinstance(value, list) and 0 < len(value) and len(value) < 25 and isinstance(value[0], str):
307
+ value_str = value
308
+ elif isinstance(value, dict):
309
+ value_str = str({k: type(v) for k, v in value.items()})
310
+ else:
311
+ value_str = type(value)
312
+ s += f"{key:<30} {value_str}\n"
313
+ return s
examples/017450/k1.color.jpg ADDED
examples/017450/k1.obj_rend_mask.png ADDED
examples/017450/k1.person_mask.png ADDED
model/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from configs.structured import ProjectConfig
2
+ from .model import ConditionalPointCloudDiffusionModel
3
+ from .model_coloring import PointCloudColoringModel
4
+ from .model_utils import set_requires_grad
5
+ from .model_diff_data import ConditionalPCDiffusionSeparateSegm
6
+ from .model_hoattn import CrossAttenHODiffusionModel
7
+
8
+ def get_model(cfg: ProjectConfig):
9
+ if cfg.model.model_name == 'pc2-diff':
10
+ model = ConditionalPointCloudDiffusionModel(**cfg.model)
11
+ elif cfg.model.model_name == 'pc2-diff-ho-sepsegm':
12
+ model = ConditionalPCDiffusionSeparateSegm(**cfg.model)
13
+ print("Using a separate model to predict segmentation label")
14
+ elif cfg.model.model_name == 'diff-ho-attn':
15
+ model = CrossAttenHODiffusionModel(**cfg.model)
16
+ print("Using separate model for human + object with cross attention.")
17
+ else:
18
+ raise NotImplementedError
19
+ if cfg.run.freeze_feature_model:
20
+ set_requires_grad(model.feature_model, False)
21
+ return model
22
+
23
+
24
+ def get_coloring_model(cfg: ProjectConfig):
25
+ model = PointCloudColoringModel(**cfg.model)
26
+ if cfg.run.freeze_feature_model:
27
+ set_requires_grad(model.feature_model, False)
28
+ return model
model/feature_model.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
7
+ from diffusers import ModelMixin
8
+ from timm.models.vision_transformer import VisionTransformer, resize_pos_embed
9
+ from torch import Tensor
10
+ from torchvision.transforms import functional as TVF
11
+
12
+
13
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
14
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
15
+
16
+ MODEL_URLS = {
17
+ 'vit_base_patch16_224_mae': 'https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth',
18
+ 'vit_small_patch16_224_msn': 'https://dl.fbaipublicfiles.com/msn/vits16_800ep.pth.tar',
19
+ 'vit_large_patch7_224_msn': 'https://dl.fbaipublicfiles.com/msn/vitl7_200ep.pth.tar',
20
+ }
21
+
22
+ NORMALIZATION = {
23
+ 'vit_base_patch16_224_mae': (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
24
+ 'vit_small_patch16_224_msn': (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
25
+ 'vit_large_patch7_224_msn': (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
26
+ }
27
+
28
+ MODEL_KWARGS = {
29
+ 'vit_base_patch16_224_mae': dict(
30
+ patch_size=16, embed_dim=768, depth=12, num_heads=12,
31
+ ),
32
+ 'vit_small_patch16_224_msn': dict(
33
+ patch_size=16, embed_dim=384, depth=12, num_heads=6,
34
+ ),
35
+ 'vit_large_patch7_224_msn': dict(
36
+ patch_size=7, embed_dim=1024, depth=24, num_heads=16,
37
+ )
38
+ }
39
+
40
+
41
+ class FeatureModel(ModelMixin, ConfigMixin):
42
+
43
+ @register_to_config
44
+ def __init__(
45
+ self,
46
+ image_size: int = 224,
47
+ model_name: str = 'vit_small_patch16_224_mae',
48
+ global_pool: str = '', # '' or 'token'
49
+ ) -> None:
50
+ super().__init__()
51
+ self.model_name = model_name
52
+
53
+ # Identity
54
+ if self.model_name == 'identity':
55
+ return
56
+
57
+ # Create model
58
+ self.model = VisionTransformer(
59
+ img_size=image_size, num_classes=0, global_pool=global_pool,
60
+ **MODEL_KWARGS[model_name])
61
+
62
+ # Model properties
63
+ self.feature_dim = self.model.embed_dim
64
+ self.mean, self.std = NORMALIZATION[model_name]
65
+
66
+ # # Modify MSN model with output head from training
67
+ # if model_name.endswith('msn'):
68
+ # use_bn = True
69
+ # emb_dim = (192 if 'tiny' in model_name else 384 if 'small' in model_name else
70
+ # 768 if 'base' in model_name else 1024 if 'large' in model_name else 1280)
71
+ # hidden_dim = 2048
72
+ # output_dim = 256
73
+ # self.model.fc = None
74
+ # fc = OrderedDict([])
75
+ # fc['fc1'] = torch.nn.Linear(emb_dim, hidden_dim)
76
+ # if use_bn:
77
+ # fc['bn1'] = torch.nn.BatchNorm1d(hidden_dim)
78
+ # fc['gelu1'] = torch.nn.GELU()
79
+ # fc['fc2'] = torch.nn.Linear(hidden_dim, hidden_dim)
80
+ # if use_bn:
81
+ # fc['bn2'] = torch.nn.BatchNorm1d(hidden_dim)
82
+ # fc['gelu2'] = torch.nn.GELU()
83
+ # fc['fc3'] = torch.nn.Linear(hidden_dim, output_dim)
84
+ # self.model.fc = torch.nn.Sequential(fc)
85
+
86
+ # Load pretrained checkpoint
87
+ checkpoint = torch.hub.load_state_dict_from_url(MODEL_URLS[model_name])
88
+ if 'model' in checkpoint:
89
+ state_dict = checkpoint['model']
90
+ elif 'target_encoder' in checkpoint:
91
+ state_dict = checkpoint['target_encoder']
92
+ state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
93
+ # NOTE: Comment the line below if using the projection head, uncomment if not using it
94
+ # See https://github.com/facebookresearch/msn/blob/81cb855006f41cd993fbaad4b6a6efbb486488e6/src/msn_train.py#L490-L502
95
+ # for more info about the projection head
96
+ state_dict = {k: v for k, v in state_dict.items() if not k.startswith('fc.')}
97
+ else:
98
+ raise NotImplementedError()
99
+ state_dict['pos_embed'] = resize_pos_embed(state_dict['pos_embed'], self.model.pos_embed)
100
+ self.model.load_state_dict(state_dict)
101
+ self.model.eval()
102
+
103
+ # # Modify MSN model with output head from training
104
+ # if model_name.endswith('msn'):
105
+ # self.fc = self.model.fc
106
+ # del self.model.fc
107
+ # else:
108
+ # self.fc = nn.Identity()
109
+
110
+ # NOTE: I've disabled the whole projection head stuff for simplicity for now
111
+ self.fc = nn.Identity()
112
+
113
+ def denormalize(self, img: Tensor):
114
+ img = TVF.normalize(img, mean=[-m/s for m, s in zip(self.mean, self.std)], std=[1/s for s in self.std])
115
+ return torch.clip(img, 0, 1)
116
+
117
+ def normalize(self, img: Tensor):
118
+ return TVF.normalize(img, mean=self.mean, std=self.std)
119
+
120
+ def forward(
121
+ self,
122
+ x: Tensor,
123
+ return_type: str = 'features',
124
+ return_upscaled_features: bool = True,
125
+ return_projection_head_output: bool = False,
126
+ ):
127
+ """Normalizes the input `x` and runs it through `model` to obtain features"""
128
+ assert return_type in {'cls_token', 'features', 'all'}
129
+
130
+ # Identity
131
+ if self.model_name == 'identity':
132
+ return x
133
+
134
+ # Normalize and forward
135
+ B, C, H, W = x.shape
136
+ x = self.normalize(x)
137
+ feats = self.model(x)
138
+
139
+ # Reshape to image-like size
140
+ if return_type in {'features', 'all'}:
141
+ B, T, D = feats.shape
142
+ assert math.sqrt(T - 1).is_integer()
143
+ HW_down = int(math.sqrt(T - 1)) # subtract one for CLS token
144
+ output_feats: Tensor = feats[:, 1:, :].reshape(B, HW_down, HW_down, D).permute(0, 3, 1, 2) # (B, D, H_down, W_down)
145
+ if return_upscaled_features:
146
+ output_feats = F.interpolate(output_feats, size=(H, W), mode='bilinear',
147
+ align_corners=False) # (B, D, H_orig, W_orig)
148
+
149
+ # Head for MSN
150
+ output_cls = feats[:, 0]
151
+ if return_projection_head_output and return_type in {'cls_token', 'all'}:
152
+ output_cls = self.fc(output_cls)
153
+
154
+ # Return
155
+ if return_type == 'cls_token':
156
+ return output_cls
157
+ elif return_type == 'features':
158
+ return output_feats
159
+ else:
160
+ return output_cls, output_feats
model/model.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import random
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
8
+ from diffusers.schedulers.scheduling_ddim import DDIMScheduler
9
+ from diffusers.schedulers.scheduling_pndm import PNDMScheduler
10
+ from pytorch3d.implicitron.dataset.data_loader_map_provider import FrameData
11
+ from pytorch3d.renderer.cameras import CamerasBase
12
+ from pytorch3d.structures import Pointclouds
13
+ from torch import Tensor
14
+ from tqdm import tqdm
15
+
16
+ from .model_utils import get_num_points, get_custom_betas
17
+ from .point_cloud_model import PointCloudModel
18
+ from .projection_model import PointCloudProjectionModel
19
+
20
+
21
+ class ConditionalPointCloudDiffusionModel(PointCloudProjectionModel):
22
+
23
+ def __init__(
24
+ self,
25
+ beta_start: float,
26
+ beta_end: float,
27
+ beta_schedule: str,
28
+ point_cloud_model: str,
29
+ point_cloud_model_embed_dim: int,
30
+ **kwargs, # projection arguments
31
+ ):
32
+ super().__init__(**kwargs)
33
+
34
+ # Checks
35
+ if not self.predict_shape:
36
+ raise NotImplementedError('Must predict shape if performing diffusion.')
37
+
38
+ # Create diffusion model schedulers which define the sampling timesteps
39
+ self.dm_pred_type = kwargs.get('dm_pred_type', "epsilon")
40
+ assert self.dm_pred_type in ['epsilon','sample']
41
+ scheduler_kwargs = {"prediction_type": self.dm_pred_type}
42
+ if beta_schedule == 'custom':
43
+ scheduler_kwargs.update(dict(trained_betas=get_custom_betas(beta_start=beta_start, beta_end=beta_end)))
44
+ else:
45
+ scheduler_kwargs.update(dict(beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule))
46
+ self.schedulers_map = {
47
+ 'ddpm': DDPMScheduler(**scheduler_kwargs, clip_sample=False),
48
+ 'ddim': DDIMScheduler(**scheduler_kwargs, clip_sample=False),
49
+ 'pndm': PNDMScheduler(**scheduler_kwargs),
50
+ }
51
+ self.scheduler = self.schedulers_map['ddpm'] # this can be changed for inference
52
+
53
+ # Create point cloud model for processing point cloud at each diffusion step
54
+ self.init_pcloud_model(kwargs, point_cloud_model, point_cloud_model_embed_dim)
55
+
56
+ self.load_sample_init = kwargs.get('load_sample_init', False)
57
+ self.sample_init_scale = kwargs.get('sample_init_scale', 1.0)
58
+ self.test_init_with_gtpc = kwargs.get('test_init_with_gtpc', False)
59
+
60
+ self.consistent_center = kwargs.get('consistent_center', False)
61
+ self.cam_noise_std = kwargs.get('cam_noise_std', 0.0) # add noise to camera based on timestamps
62
+
63
+ def init_pcloud_model(self, kwargs, point_cloud_model, point_cloud_model_embed_dim):
64
+ self.point_cloud_model = PointCloudModel(
65
+ model_type=point_cloud_model,
66
+ embed_dim=point_cloud_model_embed_dim,
67
+ in_channels=self.in_channels,
68
+ out_channels=self.out_channels, # voxel resolution multiplier is 1.
69
+ voxel_resolution_multiplier=kwargs.get('voxel_resolution_multiplier', 1)
70
+ )
71
+
72
+ def forward_train(
73
+ self,
74
+ pc: Pointclouds,
75
+ camera: Optional[CamerasBase],
76
+ image_rgb: Optional[Tensor],
77
+ mask: Optional[Tensor],
78
+ return_intermediate_steps: bool = False,
79
+ **kwargs
80
+ ):
81
+
82
+ # Normalize colors and convert to tensor
83
+ x_0 = self.point_cloud_to_tensor(pc, normalize=True, scale=True) # this will not pack the point colors
84
+ B, N, D = x_0.shape
85
+
86
+ # Sample random noise
87
+ noise = torch.randn_like(x_0)
88
+ if self.consistent_center:
89
+ # modification suggested by https://arxiv.org/pdf/2308.07837.pdf
90
+ noise = noise - torch.mean(noise, dim=1, keepdim=True)
91
+
92
+ # Sample random timesteps for each point_cloud
93
+ timestep = torch.randint(0, self.scheduler.num_train_timesteps, (B,),
94
+ device=self.device, dtype=torch.long)
95
+
96
+ # Add noise to points
97
+ x_t = self.scheduler.add_noise(x_0, noise, timestep) # diffusion noisy adding, only add to the coordinate, not features
98
+
99
+ # add noise to the camera pose, based on timestamps
100
+ if self.cam_noise_std > 0.000001:
101
+ # the noise is very different
102
+ camera = camera.clone()
103
+ camT = camera.T # (B, 3)
104
+ dist = torch.sqrt(torch.sum(camT**2, -1, keepdim=True))
105
+ nratio = timestep[:, None] / self.scheduler.num_train_timesteps # time-dependent noise
106
+ tnoise = torch.randn(B, 3).to(dist.device)/3. * dist * self.cam_noise_std * nratio
107
+ camera.T = camera.T + tnoise
108
+
109
+ # Conditioning, the pixel-aligned feature is based on points with noise (new points)
110
+ x_t_input = self.get_diffu_input(camera, image_rgb, mask, timestep, x_t, **kwargs)
111
+
112
+ # Forward
113
+ loss, noise_pred = self.compute_loss(noise, timestep, x_0, x_t_input)
114
+
115
+ # Whether to return intermediate steps
116
+ if return_intermediate_steps:
117
+ return loss, (x_0, x_t, noise, noise_pred)
118
+
119
+ return loss
120
+
121
+ def compute_loss(self, noise, timestep, x_0, x_t_input):
122
+ x_pred = torch.zeros_like(x_0)
123
+ if self.self_conditioning:
124
+ # self conditioning, from https://openreview.net/pdf?id=3itjR9QxFw
125
+ if random.uniform(0, 1.) > 0.5:
126
+ with torch.no_grad():
127
+ x_pred = self.point_cloud_model(torch.cat([x_t_input, x_pred], -1), timestep)
128
+ noise_pred = self.point_cloud_model(torch.cat([x_t_input, x_pred], -1), timestep)
129
+ else:
130
+ noise_pred = self.point_cloud_model(x_t_input, timestep)
131
+ # Check
132
+ if not noise_pred.shape == noise.shape:
133
+ raise ValueError(f'{noise_pred.shape=} and {noise.shape=}')
134
+ # Loss
135
+ if self.dm_pred_type == 'epsilon':
136
+ loss = F.mse_loss(noise_pred, noise)
137
+ elif self.dm_pred_type == 'sample':
138
+ loss = F.mse_loss(noise_pred, x_0) # predicting sample
139
+ else:
140
+ raise NotImplementedError
141
+ return loss, noise_pred
142
+
143
+ def get_diffu_input(self, camera, image_rgb, mask, timestep, x_t, **kwargs):
144
+ "return: (B, N, D), the exact input to the diffusion model, x_t: (B, N, 3)"
145
+ x_t_input = self.get_input_with_conditioning(x_t, camera=camera,
146
+ image_rgb=image_rgb, mask=mask, t=timestep)
147
+ return x_t_input
148
+
149
+ @torch.no_grad()
150
+ def forward_sample(
151
+ self,
152
+ num_points: int,
153
+ camera: Optional[CamerasBase],
154
+ image_rgb: Optional[Tensor],
155
+ mask: Optional[Tensor],
156
+ # Optional overrides
157
+ scheduler: Optional[str] = 'ddpm',
158
+ # Inference parameters
159
+ num_inference_steps: Optional[int] = 1000,
160
+ eta: Optional[float] = 0.0, # for DDIM
161
+ # Whether to return all the intermediate steps in generation
162
+ return_sample_every_n_steps: int = -1,
163
+ # Whether to disable tqdm
164
+ disable_tqdm: bool = False,
165
+ gt_pc: Pointclouds = None,
166
+ **kwargs
167
+ ):
168
+
169
+ # Get scheduler from mapping, or use self.scheduler if None
170
+ scheduler = self.scheduler if scheduler is None else self.schedulers_map[scheduler]
171
+
172
+ # Get the size of the noise
173
+ N = num_points
174
+ B = 1 if image_rgb is None else image_rgb.shape[0]
175
+ D = self.get_x_T_channel()
176
+ device = self.device if image_rgb is None else image_rgb.device
177
+
178
+ sample_from_interm = kwargs.get('sample_from_interm', False)
179
+ interm_steps = kwargs.get('noise_step') if sample_from_interm else -1
180
+ x_t = self.initialize_x_T(device, gt_pc, (B, N, D), interm_steps, scheduler)
181
+ x_pred = torch.zeros_like(x_t)
182
+
183
+ # Set timesteps
184
+ extra_step_kwargs = self.setup_reverse_process(eta, num_inference_steps, scheduler)
185
+
186
+ # Loop over timesteps
187
+ all_outputs = []
188
+ return_all_outputs = (return_sample_every_n_steps > 0)
189
+ progress_bar = tqdm(scheduler.timesteps.to(device), desc=f'Sampling ({x_t.shape})', disable=disable_tqdm)
190
+
191
+ for i, t in enumerate(progress_bar):
192
+ add_interm_output = (return_all_outputs and (
193
+ i % return_sample_every_n_steps == 0 or i == len(scheduler.timesteps) - 1))
194
+ # Conditioning
195
+ x_t_input = self.get_diffu_input(camera, image_rgb, mask, t, x_t, **kwargs)
196
+ if self.self_conditioning:
197
+ x_t_input = torch.cat([x_t_input, x_pred], -1) # add self-conditioning
198
+ inference_binary = (i == len(progress_bar) - 1) | add_interm_output
199
+ # One reverse step with conditioning
200
+ x_t = self.reverse_step(extra_step_kwargs, scheduler, t, x_t, x_t_input,
201
+ inference_binary=inference_binary) # (B, N, D), D=3 or 4
202
+ x_pred = x_t # for next iteration self conditioning
203
+
204
+ # Append to output list if desired
205
+ if add_interm_output:
206
+ all_outputs.append(x_t)
207
+
208
+ # Convert output back into a point cloud, undoing normalization and scaling
209
+ output = self.tensor_to_point_cloud(x_t, denormalize=True, unscale=True) # this convert the points back to original scale
210
+ if return_all_outputs:
211
+ all_outputs = torch.stack(all_outputs, dim=1) # (B, sample_steps, N, D)
212
+ all_outputs = [self.tensor_to_point_cloud(o, denormalize=True, unscale=True) for o in all_outputs]
213
+
214
+ return (output, all_outputs) if return_all_outputs else output
215
+
216
+ def get_x_T_channel(self):
217
+ D = 3 + (self.color_channels if self.predict_color else 0)
218
+ return D
219
+
220
+ def initialize_x_T(self, device, gt_pc, shape, interm_steps:int=-1, scheduler=None):
221
+ B, N, D = shape
222
+ # Sample noise initialization
223
+ if interm_steps > 0:
224
+ # Sample from some intermediate steps
225
+ x_0 = self.point_cloud_to_tensor(gt_pc, normalize=True, scale=True)
226
+ noise = torch.randn(B, N, D, device=device)
227
+
228
+ # always make sure the noise does not change the pc center, this is important to reduce 0.1cm CD!
229
+ noise = noise - torch.mean(noise, dim=1, keepdim=True)
230
+
231
+ x_t = scheduler.add_noise(x_0, noise, torch.tensor([interm_steps - 1] * B).long().to(device)) # Add noise
232
+ else:
233
+ # Sample from random Gaussian
234
+ x_t = torch.randn(B, N, D, device=device)
235
+
236
+ x_t = x_t * self.sample_init_scale # for test
237
+ if self.consistent_center:
238
+ x_t = x_t - torch.mean(x_t, dim=1, keepdim=True)
239
+ return x_t
240
+
241
+ def reverse_step(self, extra_step_kwargs, scheduler, t, x_t, x_t_input, **kwargs):
242
+ """
243
+ run one reverse step to compute x_t
244
+ :param extra_step_kwargs:
245
+ :param scheduler:
246
+ :param t: [1], diffusion time step
247
+ :param x_t: (B, N, 3)
248
+ :param x_t_input: conditional features (B, N, F)
249
+ :param kwargs: other configurations to run diffusion step
250
+ :return: denoised x_t
251
+ """
252
+ B = x_t.shape[0]
253
+ # Forward
254
+ noise_pred = self.point_cloud_model(x_t_input, t.reshape(1).expand(B))
255
+ if self.consistent_center:
256
+ assert self.dm_pred_type != 'sample', 'incompatible dm predition type for CCD!'
257
+ # suggested by the CCD-3DR paper
258
+ noise_pred = noise_pred - torch.mean(noise_pred, dim=1, keepdim=True)
259
+ # Step
260
+ x_t = scheduler.step(noise_pred, t, x_t, **extra_step_kwargs).prev_sample
261
+ if self.consistent_center:
262
+ x_t = x_t - torch.mean(x_t, dim=1, keepdim=True)
263
+ return x_t
264
+
265
+ def setup_reverse_process(self, eta, num_inference_steps, scheduler):
266
+ """
267
+ setup diffusion chain, and others.
268
+ """
269
+ accepts_offset = "offset" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
270
+ extra_set_kwargs = {"offset": 1} if accepts_offset else {}
271
+ scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
272
+ # Prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
273
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
274
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
275
+ # and should be between [0, 1]
276
+ accepts_eta = "eta" in set(inspect.signature(scheduler.step).parameters.keys())
277
+ extra_step_kwargs = {"eta": eta} if accepts_eta else {}
278
+ return extra_step_kwargs
279
+
280
+ def forward(self, batch: FrameData, mode: str = 'train', **kwargs):
281
+ """
282
+ A wrapper around the forward method for training and inference
283
+ """
284
+ if isinstance(batch, dict): # fixes a bug with multiprocessing where batch becomes a dict
285
+ batch = FrameData(**batch) # it really makes no sense, I do not understand it
286
+
287
+ if mode == 'train':
288
+ return self.forward_train(
289
+ pc=batch.sequence_point_cloud,
290
+ camera=batch.camera,
291
+ image_rgb=batch.image_rgb,
292
+ mask=batch.fg_probability,
293
+ **kwargs)
294
+ elif mode == 'sample':
295
+ num_points = kwargs.pop('num_points', get_num_points(batch.sequence_point_cloud))
296
+ return self.forward_sample(
297
+ num_points=num_points,
298
+ camera=batch.camera,
299
+ image_rgb=batch.image_rgb,
300
+ mask=batch.fg_probability,
301
+ **kwargs)
302
+ else:
303
+ raise NotImplementedError()
model/model_coloring.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from pytorch3d.implicitron.dataset.data_loader_map_provider import FrameData
6
+ from pytorch3d.renderer.cameras import CamerasBase
7
+ from pytorch3d.structures import Pointclouds
8
+ from torch import Tensor
9
+
10
+ from .point_cloud_transformer_model import PointCloudTransformerModel
11
+ from .projection_model import PointCloudProjectionModel
12
+
13
+ class PointCloudColoringModel(PointCloudProjectionModel):
14
+
15
+ def __init__(
16
+ self,
17
+ point_cloud_model: str,
18
+ point_cloud_model_layers: int,
19
+ point_cloud_model_embed_dim: int,
20
+ **kwargs, # projection arguments
21
+ ):
22
+ super().__init__(**kwargs)
23
+
24
+ # Checks
25
+ if self.predict_shape or not self.predict_color:
26
+ raise NotImplementedError('Must predict color, not shape, for coloring')
27
+
28
+ # Create point cloud model for processing point cloud
29
+ self.point_cloud_model = PointCloudTransformerModel(
30
+ num_layers=point_cloud_model_layers,
31
+ model_type=point_cloud_model,
32
+ embed_dim=point_cloud_model_embed_dim,
33
+ in_channels=self.in_channels,
34
+ out_channels=self.out_channels,
35
+ ) # why use transformer instead???
36
+
37
+ def _forward(
38
+ self,
39
+ pc: Pointclouds,
40
+ camera: Optional[CamerasBase],
41
+ image_rgb: Optional[Tensor],
42
+ mask: Optional[Tensor],
43
+ return_point_cloud: bool = False,
44
+ noise_std: float = 0.0,
45
+ ):
46
+
47
+ # Normalize colors and convert to tensor
48
+ x = self.point_cloud_to_tensor(pc, normalize=True, scale=True)
49
+ x_points, x_colors = x[:, :, :3], x[:, :, 3:]
50
+
51
+ # Add noise to points. TODO: Add to configs.
52
+ x_input = x_points + torch.randn_like(x_points) * noise_std # simulate noise of the predicted pc?
53
+
54
+ # Conditioning
55
+ # x_input = self.get_input_with_conditioning(x_input, camera=camera,
56
+ # image_rgb=image_rgb, mask=mask)
57
+ # XH: edit to run
58
+ x_input = self.get_input_with_conditioning(x_input, camera=camera,
59
+ image_rgb=image_rgb, mask=mask, t=None)
60
+
61
+ # Forward
62
+ pred_colors = self.point_cloud_model(x_input)
63
+
64
+ # During inference, we return the point cloud with the predicted colors
65
+ if return_point_cloud:
66
+ pred_pointcloud = self.tensor_to_point_cloud(
67
+ torch.cat((x_points, pred_colors), dim=2), denormalize=True, unscale=True)
68
+ return pred_pointcloud
69
+
70
+ # During training, we have ground truth colors and return the loss
71
+ loss = F.mse_loss(pred_colors, x_colors)
72
+ return loss
73
+
74
+ def forward(self, batch: FrameData, **kwargs):
75
+ """A wrapper around the forward method"""
76
+ if isinstance(batch, dict): # fixes a bug with multiprocessing where batch becomes a dict
77
+ batch = FrameData(**batch) # it really makes no sense, I do not understand it
78
+ return self._forward(
79
+ pc=batch.sequence_point_cloud,
80
+ camera=batch.camera,
81
+ image_rgb=batch.image_rgb,
82
+ mask=batch.fg_probability,
83
+ **kwargs,
84
+ )
model/model_diff_data.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ model to deal with shapenet inputs and other datasets such as Behave and ProciGen
3
+ the model takes a different data dictionary in forward function
4
+ """
5
+ import inspect
6
+ from typing import Optional
7
+ import numpy as np
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
12
+ from diffusers.schedulers.scheduling_ddim import DDIMScheduler
13
+ from diffusers.schedulers.scheduling_pndm import PNDMScheduler
14
+ from pytorch3d.implicitron.dataset.data_loader_map_provider import FrameData
15
+ from pytorch3d.renderer.cameras import CamerasBase
16
+ from pytorch3d.structures import Pointclouds
17
+ from torch import Tensor
18
+ from tqdm import tqdm
19
+ from pytorch3d.renderer import PerspectiveCameras
20
+ from pytorch3d.datasets.r2n2.utils import BlenderCamera
21
+
22
+
23
+ from .model import ConditionalPointCloudDiffusionModel
24
+ from .model_utils import get_num_points
25
+
26
+
27
+ class ConditionalPCDiffusionShapenet(ConditionalPointCloudDiffusionModel):
28
+ def forward(self, batch, mode: str = 'train', **kwargs):
29
+ """
30
+ take a batch of data from ShapeNet
31
+ """
32
+ images = torch.stack(batch['images'], 0).to('cuda')
33
+ masks = torch.stack(batch['masks'], 0).to('cuda')
34
+ pc = Pointclouds([x.to('cuda') for x in batch['pclouds']])
35
+ camera = BlenderCamera(
36
+ torch.stack(batch['R']),
37
+ torch.stack(batch['T']),
38
+ torch.stack(batch['K']), device='cuda'
39
+ )
40
+
41
+ if mode == 'train':
42
+ return self.forward_train(
43
+ pc=pc,
44
+ camera=camera,
45
+ image_rgb=images,
46
+ mask=masks,
47
+
48
+ **kwargs)
49
+ elif mode == 'sample':
50
+ num_points = kwargs.pop('num_points', get_num_points(pc))
51
+ return self.forward_sample(
52
+ num_points=num_points,
53
+ camera=camera,
54
+ image_rgb=images,
55
+ mask=masks,
56
+ gt_pc=pc,
57
+ **kwargs)
58
+ else:
59
+ raise NotImplementedError()
60
+
61
+
62
+ class ConditionalPCDiffusionBehave(ConditionalPointCloudDiffusionModel):
63
+ "diffusion model for Behave dataset"
64
+ def forward(self, batch, mode: str = 'train', **kwargs):
65
+ images = torch.stack(batch['images'], 0).to('cuda')
66
+ masks = torch.stack(batch['masks'], 0).to('cuda')
67
+ pc = self.get_input_pc(batch)
68
+ camera = PerspectiveCameras(
69
+ R=torch.stack(batch['R']),
70
+ T=torch.stack(batch['T']),
71
+ K=torch.stack(batch['K']),
72
+ device='cuda',
73
+ in_ndc=True
74
+ )
75
+ grid_df = torch.stack(batch['grid_df'], 0).to('cuda') if 'grid_df' in batch else None
76
+ num_points = kwargs.pop('num_points', get_num_points(pc))
77
+ if mode == 'train':
78
+ return self.forward_train(
79
+ pc=pc,
80
+ camera=camera,
81
+ image_rgb=images,
82
+ mask=masks,
83
+ grid_df=grid_df,
84
+ **kwargs)
85
+ elif mode == 'sample':
86
+ return self.forward_sample(
87
+ num_points=num_points,
88
+ camera=camera,
89
+ image_rgb=images,
90
+ mask=masks,
91
+ gt_pc=pc,
92
+ **kwargs)
93
+ else:
94
+ raise NotImplementedError()
95
+
96
+ def get_input_pc(self, batch):
97
+ pc = Pointclouds([x.to('cuda') for x in batch['pclouds']])
98
+ return pc
99
+
100
+
101
+ class ConditionalPCDiffusionSeparateSegm(ConditionalPCDiffusionBehave):
102
+ "a separate model to predict binary labels, the final segmentation model"
103
+ def __init__(self,
104
+ beta_start: float,
105
+ beta_end: float,
106
+ beta_schedule: str,
107
+ point_cloud_model: str,
108
+ point_cloud_model_embed_dim: int,
109
+ **kwargs, # projection arguments
110
+ ):
111
+ super(ConditionalPCDiffusionSeparateSegm, self).__init__(beta_start, beta_end, beta_schedule,
112
+ point_cloud_model,
113
+ point_cloud_model_embed_dim, **kwargs)
114
+ # add a separate model to predict binary label
115
+ from .point_cloud_transformer_model import PointCloudTransformerModel, PointCloudModel
116
+
117
+ self.binary_model = PointCloudTransformerModel(
118
+ num_layers=1, # XH: use the default color model number of layers
119
+ model_type=point_cloud_model, # pvcnn
120
+ embed_dim=point_cloud_model_embed_dim, # save as pc shape model
121
+ in_channels=self.in_channels,
122
+ out_channels=1,
123
+ )
124
+ self.binary_training_noise_std = kwargs.get("binary_training_noise_std", 0.1)
125
+
126
+ # re-initialize point cloud model
127
+ assert self.predict_binary
128
+ self.point_cloud_model = PointCloudModel(
129
+ model_type=point_cloud_model,
130
+ embed_dim=point_cloud_model_embed_dim,
131
+ in_channels=self.in_channels,
132
+ out_channels=self.out_channels - 1, # not predicting binary from this anymore
133
+ voxel_resolution_multiplier=kwargs.get('voxel_resolution_multiplier', 1)
134
+ )
135
+
136
+ def forward_train(
137
+ self,
138
+ pc: Pointclouds,
139
+ camera: Optional[CamerasBase],
140
+ image_rgb: Optional[Tensor],
141
+ mask: Optional[Tensor],
142
+ return_intermediate_steps: bool = False,
143
+ **kwargs
144
+ ):
145
+ # first run shape forward, then binary label forward
146
+ assert not return_intermediate_steps
147
+ assert self.predict_binary
148
+ loss_shape = super(ConditionalPCDiffusionSeparateSegm, self).forward_train(pc,
149
+ camera,
150
+ image_rgb,
151
+ mask,
152
+ return_intermediate_steps,
153
+ **kwargs)
154
+
155
+ # binary label forward
156
+ x_0 = self.point_cloud_to_tensor(pc, normalize=True, scale=True)
157
+ x_points, x_colors = x_0[:, :, :3], x_0[:, :, 3:]
158
+
159
+ # Add noise to points.
160
+ x_input = x_points + torch.randn_like(x_points) * self.binary_training_noise_std # std=0.1
161
+ x_input = self.get_input_with_conditioning(x_input, camera=camera,
162
+ image_rgb=image_rgb, mask=mask, t=None)
163
+
164
+ # Forward
165
+ pred_segm = self.binary_model(x_input)
166
+
167
+ # use compressed bits
168
+ df_grid = kwargs.get('grid_df', None).unsqueeze(1) # (B, 1, resz, resy, resx)
169
+ points = x_points.clone().detach() / self.scale_factor * 2 # , normalize to [-1, 1]
170
+ points[:, :, 0], points[:, :, 2] = points[:, :, 2].clone(), points[:, :,0].clone() # swap, make sure clone is used!
171
+ points = points.unsqueeze(1).unsqueeze(1) # (B,1, 1, N, 3)
172
+ with torch.no_grad():
173
+ df_interp = F.grid_sample(df_grid, points, padding_mode='border', align_corners=True).squeeze(1).squeeze(1) # (B, 1, 1, 1, N)
174
+ binary_label = df_interp[:, 0] > 0.5 # (B, 1, N)
175
+
176
+ binary_pred = torch.sigmoid(pred_segm.squeeze(-1)) # add a sigmoid layer
177
+ loss_binary = F.mse_loss(binary_pred, binary_label.float().squeeze(1).squeeze(1)) * self.lw_binary
178
+ loss = loss_shape + loss_binary
179
+
180
+ return loss, torch.tensor([loss_shape, loss_binary])
181
+
182
+ def reverse_step(self, extra_step_kwargs, scheduler, t, x_t, x_t_input, **kwargs):
183
+ "return (B, N, 4), the 4-th channel is binary label"
184
+ B = x_t.shape[0]
185
+ # Forward
186
+ noise_pred = self.point_cloud_model(x_t_input, t.reshape(1).expand(B))
187
+ if self.consistent_center:
188
+ assert self.dm_pred_type != 'sample', 'incompatible dm predition type!'
189
+ # suggested by the CCD-3DR paper
190
+ noise_pred = noise_pred - torch.mean(noise_pred, dim=1, keepdim=True)
191
+ # Step: make sure only update the shape (first 3 channels)
192
+ x_t = scheduler.step(noise_pred, t, x_t[:, :, :3], **extra_step_kwargs).prev_sample
193
+ if self.consistent_center:
194
+ x_t = x_t - torch.mean(x_t, dim=1, keepdim=True)
195
+
196
+ # also add binary prediction
197
+ if kwargs.get('inference_binary', False):
198
+ pred_segm = self.binary_model(x_t_input)
199
+ else:
200
+ pred_segm = torch.zeros_like(x_t[:, :, 0:1])
201
+
202
+ x_t = torch.cat([x_t, torch.sigmoid(pred_segm)], -1)
203
+
204
+ return x_t
205
+
206
+ def get_coord_feature(self, x_t):
207
+ x_t_input = [x_t[:, :, :3]]
208
+ return x_t_input
209
+
210
+ def tensor_to_point_cloud(self, x: Tensor, /, denormalize: bool = False, unscale: bool = False):
211
+ """
212
+ take binary label into account
213
+ :param self:
214
+ :param x: (B, N, 4), the 4th channel is the binary segmentation, 1-human, 0-object
215
+ :param denormalize: denormalize the per-point colors, from pc2
216
+ :param unscale: undo point scaling, from pc2
217
+ :return: pc with point colors if predict binary label or per-point color
218
+ """
219
+ points = x[:, :, :3] / (self.scale_factor if unscale else 1)
220
+ if self.predict_color:
221
+ colors = self.denormalize(x[:, :, 3:]) if denormalize else x[:, :, 3:]
222
+ return Pointclouds(points=points, features=colors)
223
+ else:
224
+ if self.predict_binary:
225
+ assert x.shape[2] == 4
226
+ # add color to predicted binary labels
227
+ is_hum = x[:, :, 3] > 0.5
228
+ features = []
229
+ for mask in is_hum:
230
+ color = torch.zeros_like(x[0, :, :3]) + torch.tensor([0.5, 1.0, 0]).to(x.device)
231
+ color[mask, :] = torch.tensor([0.05, 1.0, 1.0]).to(x.device) # human is light blue, object light green
232
+ features.append(color)
233
+ else:
234
+ assert x.shape[2] == 3
235
+ features = None
236
+ return Pointclouds(points=points, features=features)
237
+
238
+
model/model_hoattn.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ model that use cross attention to predict human + object
3
+ """
4
+
5
+ import inspect
6
+ import random
7
+ from typing import Optional
8
+ from torch import Tensor
9
+ import torch
10
+ import numpy as np
11
+
12
+ from pytorch3d.structures import Pointclouds
13
+ from pytorch3d.renderer import CamerasBase
14
+ from .model_diff_data import ConditionalPCDiffusionBehave
15
+ from .pvcnn.pvcnn_ho import PVCNN2HumObj
16
+ import torch.nn.functional as F
17
+ from pytorch3d.renderer import PerspectiveCameras
18
+ from .model_utils import get_num_points
19
+ from tqdm import tqdm
20
+
21
+
22
+ class CrossAttenHODiffusionModel(ConditionalPCDiffusionBehave):
23
+ def init_pcloud_model(self, kwargs, point_cloud_model, point_cloud_model_embed_dim):
24
+ """use cross attention model"""
25
+ if point_cloud_model == 'pvcnn':
26
+ self.point_cloud_model = PVCNN2HumObj(embed_dim=point_cloud_model_embed_dim,
27
+ num_classes=self.out_channels,
28
+ extra_feature_channels=(self.in_channels - 3),
29
+ voxel_resolution_multiplier=kwargs.get('voxel_resolution_multiplier', 1),
30
+ attn_type=kwargs.get('attn_type', 'simple-cross'),
31
+ attn_weight=kwargs.get("attn_weight", 1.0)
32
+ )
33
+ else:
34
+ raise ValueError(f"Unknown point cloud model {point_cloud_model}!")
35
+ self.point_visible_test = kwargs.get("point_visible_test", 'single') # when doing point visibility test, use only human points or human + object?
36
+ assert self.point_visible_test in ['single', 'combine'], f'invalide point visible test option {self.point_visible_test}'
37
+ # print(f"Point visibility test is based on {self.point_visible_test} point clouds!")
38
+
39
+ def forward_train(
40
+ self,
41
+ pc: Pointclouds,
42
+ camera: Optional[CamerasBase],
43
+ image_rgb: Optional[Tensor],
44
+ mask: Optional[Tensor],
45
+ return_intermediate_steps: bool = False,
46
+ **kwargs
47
+ ):
48
+ "additional input (RGB, mask, camera, and pc) for object is read from kwargs"
49
+ # assert not self.consistent_center
50
+ assert not self.self_conditioning
51
+
52
+ # Normalize colors and convert to tensor
53
+ x0_h = self.point_cloud_to_tensor(pc, normalize=True, scale=True) # this will not pack the point colors
54
+ x0_o = self.point_cloud_to_tensor(kwargs.get('pc_obj'), normalize=True, scale=True)
55
+ B, N, D = x0_h.shape
56
+
57
+ # Sample random noise
58
+ noise = torch.randn_like(x0_h)
59
+ if self.consistent_center:
60
+ # modification suggested by https://arxiv.org/pdf/2308.07837.pdf
61
+ noise = noise - torch.mean(noise, dim=1, keepdim=True)
62
+
63
+ # Sample random timesteps for each point_cloud
64
+ timestep = torch.randint(0, self.scheduler.num_train_timesteps, (B,),
65
+ device=self.device, dtype=torch.long)
66
+ # timestep = torch.randint(0, 1, (B,),
67
+ # device=self.device, dtype=torch.long)
68
+
69
+ # Add noise to points
70
+ xt_h = self.scheduler.add_noise(x0_h, noise, timestep)
71
+ xt_o = self.scheduler.add_noise(x0_o, noise, timestep)
72
+ norm_parms = self.pack_norm_params(kwargs) # (2, B, 4)
73
+
74
+ # get input conditioning
75
+ x_t_input_h, x_t_input_o = self.get_image_conditioning(camera, image_rgb, kwargs, mask, norm_parms, timestep,
76
+ xt_h, xt_o)
77
+
78
+ # Diffusion prediction
79
+ noise_pred_h, noise_pred_o = self.point_cloud_model(x_t_input_h, x_t_input_o, timestep, norm_parms)
80
+
81
+ # Check
82
+ if not noise_pred_h.shape == noise.shape:
83
+ raise ValueError(f'{noise_pred_h.shape=} and {noise.shape=}')
84
+ if not noise_pred_o.shape == noise.shape:
85
+ raise ValueError(f'{noise_pred_o.shape=} and {noise.shape=}')
86
+
87
+ # Loss
88
+ loss_h = F.mse_loss(noise_pred_h, noise)
89
+ loss_o = F.mse_loss(noise_pred_o, noise)
90
+
91
+ loss = loss_h + loss_o
92
+
93
+ # Whether to return intermediate steps
94
+ if return_intermediate_steps:
95
+ return loss, (x0_h, xt_h, noise, noise_pred_h)
96
+
97
+ return loss, torch.tensor([loss_h, loss_o])
98
+
99
+ def get_image_conditioning(self, camera, image_rgb, kwargs, mask, norm_parms, timestep, xt_h, xt_o):
100
+ """
101
+ compute image features for each point
102
+ :param camera:
103
+ :param image_rgb:
104
+ :param kwargs:
105
+ :param mask:
106
+ :param norm_parms:
107
+ :param timestep:
108
+ :param xt_h:
109
+ :param xt_o:
110
+ :return:
111
+ """
112
+ if self.point_visible_test == 'single':
113
+ # Visibility test is down independently for human and object
114
+ x_t_input_h = self.get_input_with_conditioning(xt_h, camera=camera,
115
+ image_rgb=image_rgb, mask=mask, t=timestep)
116
+ x_t_input_o = self.get_input_with_conditioning(xt_o, camera=kwargs.get('camera_obj'),
117
+ image_rgb=kwargs.get('rgb_obj'),
118
+ mask=kwargs.get('mask_obj'), t=timestep)
119
+ elif self.point_visible_test == 'combine':
120
+ # Combine human + object points to do visibility test and obtain features
121
+ B, N = xt_h.shape[:2] # (B, N, 3)
122
+ # for human: transform object points first to H+O space, then to human space
123
+ xt_o_in_ho = xt_o * 2 * norm_parms[1, :, 3:].unsqueeze(1) + norm_parms[1, :, :3].unsqueeze(1)
124
+ xt_o_in_hum = (xt_o_in_ho - norm_parms[0, :, :3].unsqueeze(1)) / (2 * norm_parms[0, :, 3:].unsqueeze(1))
125
+ # compute features for all points, take only first half feature for human
126
+ x_t_input_h = self.get_input_with_conditioning(torch.cat([xt_h, xt_o_in_hum], 1), camera=camera,
127
+ image_rgb=image_rgb, mask=mask, t=timestep)[:,:N]
128
+ # for object: transform human points to H+O space, then to object space
129
+ xt_h_in_ho = xt_h * 2 * norm_parms[0, :, 3:].unsqueeze(1) + norm_parms[0, :, :3].unsqueeze(1)
130
+ xt_h_in_obj = (xt_h_in_ho - norm_parms[1, :, :3].unsqueeze(1)) / (2 * norm_parms[1, :, 3:].unsqueeze(1))
131
+ x_t_input_o = self.get_input_with_conditioning(torch.cat([xt_o, xt_h_in_obj], 1),
132
+ camera=kwargs.get('camera_obj'),
133
+ image_rgb=kwargs.get('rgb_obj'),
134
+ mask=kwargs.get('mask_obj'), t=timestep)[:, :N]
135
+ else:
136
+ raise NotImplementedError
137
+ return x_t_input_h, x_t_input_o
138
+
139
+ def forward(self, batch, mode: str = 'train', **kwargs):
140
+ """"""
141
+ images = torch.stack(batch['images'], 0).to('cuda')
142
+ masks = torch.stack(batch['masks'], 0).to('cuda')
143
+ pc = self.get_input_pc(batch)
144
+ camera = PerspectiveCameras(
145
+ R=torch.stack(batch['R']),
146
+ T=torch.stack(batch['T_hum']),
147
+ K=torch.stack(batch['K_hum']),
148
+ device='cuda',
149
+ in_ndc=True
150
+ )
151
+ grid_df = torch.stack(batch['grid_df'], 0).to('cuda') if 'grid_df' in batch else None
152
+ num_points = kwargs.pop('num_points', get_num_points(pc))
153
+
154
+ rgb_obj = torch.stack(batch['images_obj'], 0).to('cuda')
155
+ masks_obj = torch.stack(batch['masks_obj'], 0).to('cuda')
156
+ pc_obj = Pointclouds([x.to('cuda') for x in batch['pclouds_obj']])
157
+ camera_obj = PerspectiveCameras(
158
+ R=torch.stack(batch['R']),
159
+ T=torch.stack(batch['T_obj']),
160
+ K=torch.stack(batch['K_obj']),
161
+ device='cuda',
162
+ in_ndc=True
163
+ )
164
+
165
+ # normalization parameters
166
+ cent_hum = torch.stack(batch['cent_hum'], 0).to('cuda')
167
+ cent_obj = torch.stack(batch['cent_obj'], 0).to('cuda') # B, 3
168
+ radius_hum = torch.stack(batch['radius_hum'], 0).to('cuda') # B, 1
169
+ radius_obj = torch.stack(batch['radius_obj'], 0).to('cuda')
170
+
171
+ # print(batch['image_path'])
172
+
173
+ if mode == 'train':
174
+ return self.forward_train(
175
+ pc=pc,
176
+ camera=camera,
177
+ image_rgb=images,
178
+ mask=masks,
179
+ grid_df=grid_df,
180
+ rgb_obj=rgb_obj,
181
+ mask_obj=masks_obj,
182
+ pc_obj=pc_obj,
183
+ camera_obj=camera_obj,
184
+ cent_hum=cent_hum,
185
+ cent_obj=cent_obj,
186
+ radius_hum=radius_hum,
187
+ radius_obj=radius_obj,
188
+ )
189
+ elif mode == 'sample':
190
+ # this use GT centers to do projection
191
+ return self.forward_sample(
192
+ num_points=num_points,
193
+ camera=camera,
194
+ image_rgb=images,
195
+ mask=masks,
196
+ gt_pc=pc,
197
+ rgb_obj=rgb_obj,
198
+ mask_obj=masks_obj,
199
+ pc_obj=pc_obj,
200
+ camera_obj=camera_obj,
201
+ cent_hum=cent_hum,
202
+ cent_obj=cent_obj,
203
+ radius_hum=radius_hum,
204
+ radius_obj=radius_obj,
205
+ **kwargs)
206
+ elif mode == 'interm-gt':
207
+ return self.forward_sample(
208
+ num_points=num_points,
209
+ camera=camera,
210
+ image_rgb=images,
211
+ mask=masks,
212
+ gt_pc=pc,
213
+ rgb_obj=rgb_obj,
214
+ mask_obj=masks_obj,
215
+ pc_obj=pc_obj,
216
+ camera_obj=camera_obj,
217
+ cent_hum=cent_hum,
218
+ cent_obj=cent_obj,
219
+ radius_hum=radius_hum,
220
+ radius_obj=radius_obj,
221
+ sample_from_interm=True,
222
+ **kwargs)
223
+ elif mode == 'interm-pred':
224
+ # use camera from predicted
225
+ camera = PerspectiveCameras(
226
+ R=torch.stack(batch['R']),
227
+ T=torch.stack(batch['T_hum_scaled']),
228
+ K=torch.stack(batch['K_hum']),
229
+ device='cuda',
230
+ in_ndc=True
231
+ )
232
+ camera_obj = PerspectiveCameras(
233
+ R=torch.stack(batch['R']),
234
+ T=torch.stack(batch['T_obj_scaled']),
235
+ K=torch.stack(batch['K_obj']), # the camera should be human/object specific!!!
236
+ device='cuda',
237
+ in_ndc=True
238
+ )
239
+ # use pc from predicted
240
+ pc = Pointclouds([x.to('cuda') for x in batch['pred_hum']])
241
+ pc_obj = Pointclouds([x.to('cuda') for x in batch['pred_obj']])
242
+ # use center and radius from predicted
243
+ cent_hum = torch.stack(batch['cent_hum_pred'], 0).to('cuda')
244
+ cent_obj = torch.stack(batch['cent_obj_pred'], 0).to('cuda') # B, 3
245
+ radius_hum = torch.stack(batch['radius_hum_pred'], 0).to('cuda') # B, 1
246
+ radius_obj = torch.stack(batch['radius_obj_pred'], 0).to('cuda')
247
+
248
+ return self.forward_sample(
249
+ num_points=num_points,
250
+ camera=camera,
251
+ image_rgb=images,
252
+ mask=masks,
253
+ gt_pc=pc,
254
+ rgb_obj=rgb_obj,
255
+ mask_obj=masks_obj,
256
+ pc_obj=pc_obj,
257
+ camera_obj=camera_obj,
258
+ cent_hum=cent_hum,
259
+ cent_obj=cent_obj,
260
+ radius_hum=radius_hum,
261
+ radius_obj=radius_obj,
262
+ sample_from_interm=True,
263
+ **kwargs)
264
+ elif mode == 'interm-pred-ts':
265
+ # use only estimate translation and scale, but sample from gaussian
266
+ # this works, the camera is GT!!!
267
+ pc = Pointclouds([x.to('cuda') for x in batch['pred_hum']])
268
+ pc_obj = Pointclouds([x.to('cuda') for x in batch['pred_obj']])
269
+ # use center and radius from predicted
270
+ cent_hum = torch.stack(batch['cent_hum_pred'], 0).to('cuda')
271
+ cent_obj = torch.stack(batch['cent_obj_pred'], 0).to('cuda') # B, 3
272
+ radius_hum = torch.stack(batch['radius_hum_pred'], 0).to('cuda') # B, 1
273
+ radius_obj = torch.stack(batch['radius_obj_pred'], 0).to('cuda')
274
+ # print(cent_hum[0], radius_hum[0], cent_obj[0], radius_obj[0])
275
+
276
+ return self.forward_sample(
277
+ num_points=num_points,
278
+ camera=camera,
279
+ image_rgb=images,
280
+ mask=masks,
281
+ gt_pc=pc,
282
+ rgb_obj=rgb_obj,
283
+ mask_obj=masks_obj,
284
+ pc_obj=pc_obj,
285
+ camera_obj=camera_obj,
286
+ cent_hum=cent_hum,
287
+ cent_obj=cent_obj,
288
+ radius_hum=radius_hum,
289
+ radius_obj=radius_obj,
290
+ sample_from_interm=False,
291
+ **kwargs)
292
+ else:
293
+ raise NotImplementedError
294
+
295
+ def forward_sample(
296
+ self,
297
+ num_points: int,
298
+ camera: Optional[CamerasBase],
299
+ image_rgb: Optional[Tensor],
300
+ mask: Optional[Tensor],
301
+ # Optional overrides
302
+ scheduler: Optional[str] = 'ddpm',
303
+ # Inference parameters
304
+ num_inference_steps: Optional[int] = 1000,
305
+ eta: Optional[float] = 0.0, # for DDIM
306
+ # Whether to return all the intermediate steps in generation
307
+ return_sample_every_n_steps: int = -1,
308
+ # Whether to disable tqdm
309
+ disable_tqdm: bool = False,
310
+ gt_pc: Pointclouds = None,
311
+ **kwargs
312
+ ):
313
+ "use two models to run diffusion forward, and also use translation and scale to put them back"
314
+ assert not self.self_conditioning
315
+ # Get scheduler from mapping, or use self.scheduler if None
316
+ scheduler = self.scheduler if scheduler is None else self.schedulers_map[scheduler]
317
+
318
+ # Get the size of the noise
319
+ N = num_points
320
+ B = 1 if image_rgb is None else image_rgb.shape[0]
321
+ D = self.get_x_T_channel()
322
+ device = self.device if image_rgb is None else image_rgb.device
323
+
324
+ # sample from full steps or only a few steps
325
+ sample_from_interm = kwargs.get('sample_from_interm', False)
326
+ interm_steps = kwargs.get('noise_step') if sample_from_interm else -1
327
+
328
+ xt_h = self.initialize_x_T(device, gt_pc, (B, N, D), interm_steps, scheduler)
329
+ xt_o = self.initialize_x_T(device, kwargs.get('pc_obj', None), (B, N, D), interm_steps, scheduler)
330
+
331
+ # the segmentation mask
332
+ segm_mask = torch.zeros(B, 2*N, 1).to(device)
333
+ segm_mask[:, :N] = 1.0
334
+
335
+ # Set timesteps
336
+ extra_step_kwargs = self.setup_reverse_process(eta, num_inference_steps, scheduler)
337
+
338
+ # Loop over timesteps
339
+ all_outputs = []
340
+ return_all_outputs = (return_sample_every_n_steps > 0)
341
+ progress_bar = tqdm(self.get_reverse_timesteps(scheduler, interm_steps),
342
+ desc=f'Sampling ({xt_h.shape})', disable=disable_tqdm)
343
+
344
+ # print("Camera T:", camera.T[0], camera.R[0])
345
+ # print("Camera_obj T:", kwargs.get('camera_obj').T[0], kwargs.get('camera_obj').R[0])
346
+
347
+ norm_parms = self.pack_norm_params(kwargs)
348
+ for i, t in enumerate(progress_bar):
349
+ x_t_input_h, x_t_input_o = self.get_image_conditioning(camera, image_rgb,
350
+ kwargs, mask,
351
+ norm_parms,
352
+ t,
353
+ xt_h, xt_o)
354
+
355
+ # One reverse step with conditioning
356
+ xt_h, xt_o = self.reverse_step(extra_step_kwargs, scheduler, t, torch.stack([xt_h, xt_o], 0),
357
+ torch.stack([x_t_input_h, x_t_input_o], 0), **kwargs) # (B, N, D), D=3
358
+
359
+ if (return_all_outputs and (i % return_sample_every_n_steps == 0 or i == len(scheduler.timesteps) - 1)):
360
+ # print(xt_h.shape, kwargs.get('cent_hum').shape, kwargs.get('radius_hum').shape)
361
+ x_t = torch.cat([self.denormalize_pclouds(xt_h, kwargs.get('cent_hum'), kwargs.get('radius_hum')),
362
+ self.denormalize_pclouds(xt_o, kwargs.get('cent_obj'), kwargs.get('radius_obj'))], 1)
363
+ # print(x_t.shape, xt_o.shape)
364
+ all_outputs.append(torch.cat([x_t, segm_mask], -1))
365
+ # print("Updating intermediate...")
366
+
367
+ # Convert output back into a point cloud, undoing normalization and scaling
368
+ x_t = torch.cat([self.denormalize_pclouds(xt_h, kwargs.get('cent_hum'), kwargs.get('radius_hum')),
369
+ self.denormalize_pclouds(xt_o, kwargs.get('cent_obj'), kwargs.get('radius_obj'))], 1)
370
+ x_t = torch.cat([x_t, segm_mask], -1)
371
+ output = self.tensor_to_point_cloud(x_t, denormalize=False, unscale=False) # this convert the points back to original scale
372
+ if return_all_outputs:
373
+ all_outputs = torch.stack(all_outputs, dim=1) # (B, sample_steps, N, D)
374
+ all_outputs = [self.tensor_to_point_cloud(o, denormalize=False, unscale=False) for o in all_outputs]
375
+
376
+ return (output, all_outputs) if return_all_outputs else output
377
+
378
+ def get_reverse_timesteps(self, scheduler, interm_steps:int):
379
+ """
380
+
381
+ :param scheduler:
382
+ :param interm_steps: start from some intermediate steps
383
+ :return:
384
+ """
385
+ if interm_steps > 0:
386
+ timesteps = torch.from_numpy(np.arange(0, interm_steps)[::-1].copy()).to(self.device)
387
+ else:
388
+ timesteps = scheduler.timesteps.to(self.device)
389
+ return timesteps
390
+
391
+ def pack_norm_params(self, kwargs:dict, scale=True):
392
+ scale_factor = self.scale_factor if scale else 1.0
393
+ hum = torch.cat([kwargs.get('cent_hum')*scale_factor, kwargs.get('radius_hum')], -1)
394
+ obj = torch.cat([kwargs.get('cent_obj')*scale_factor, kwargs.get('radius_obj')], -1)
395
+ return torch.stack([hum, obj], 0) # (2, B, 4)
396
+
397
+ def reverse_step(self, extra_step_kwargs, scheduler, t, x_t, x_t_input, **kwargs):
398
+ "x_t: (2, B, D, N), x_t_input: (2, B, D, N)"
399
+ norm_parms = self.pack_norm_params(kwargs) # (2, B, 4)
400
+ B = x_t.shape[1]
401
+ # print(f"Step {t} Norm params:", norm_parms[:, 0, :])
402
+ noise_pred_h, noise_pred_o = self.point_cloud_model(x_t_input[0], x_t_input[1], t.reshape(1).expand(B),
403
+ norm_parms)
404
+ if self.consistent_center:
405
+ assert self.dm_pred_type != 'sample', 'incompatible dm predition type!'
406
+ noise_pred_h = noise_pred_h - torch.mean(noise_pred_h, dim=1, keepdim=True)
407
+ noise_pred_o = noise_pred_o - torch.mean(noise_pred_o, dim=1, keepdim=True)
408
+
409
+ xt_h = scheduler.step(noise_pred_h, t, x_t[0], **extra_step_kwargs).prev_sample
410
+ xt_o = scheduler.step(noise_pred_o, t, x_t[1], **extra_step_kwargs).prev_sample
411
+
412
+ if self.consistent_center:
413
+ xt_h = xt_h - torch.mean(xt_h, dim=1, keepdim=True)
414
+ xt_o = xt_o - torch.mean(xt_o, dim=1, keepdim=True)
415
+
416
+ return xt_h, xt_o
417
+
418
+ def denormalize_pclouds(self, x: Tensor, cent, radius, unscale: bool = True):
419
+ """
420
+ first denormalize, then apply center and scale to original H+O coordinate
421
+ :param x:
422
+ :param cent: (B, 3)
423
+ :param radius: (B, 1)
424
+ :param unscale:
425
+ :return:
426
+ """
427
+ # denormalize: scale down.
428
+ points = x[:, :, :3] / (self.scale_factor if unscale else 1)
429
+ # translation and scale back to H+O coordinate
430
+ points = points * 2 * radius.unsqueeze(-1) + cent.unsqueeze(1)
431
+ return points
432
+
433
+ def tensor_to_point_cloud(self, x: Tensor, /, denormalize: bool = False, unscale: bool = False):
434
+ """
435
+ take binary into account
436
+ :param self:
437
+ :param x: (B, N, 4)
438
+ :param denormalize:
439
+ :param unscale:
440
+ :return:
441
+ """
442
+ points = x[:, :, :3] / (self.scale_factor if unscale else 1)
443
+ if self.predict_color:
444
+ colors = self.denormalize(x[:, :, 3:]) if denormalize else x[:, :, 3:]
445
+ return Pointclouds(points=points, features=colors)
446
+ else:
447
+ assert x.shape[2] == 4
448
+ # add color to predicted binary labels
449
+ is_hum = x[:, :, 3] > 0.5
450
+ features = []
451
+ for mask in is_hum:
452
+ color = torch.zeros_like(x[0, :, :3]) + torch.tensor([0.5, 1.0, 0]).to(x.device)
453
+ color[mask, :] = torch.tensor([0.05, 1.0, 1.0]).to(x.device) # human is light blue, object light green
454
+ features.append(color)
455
+ return Pointclouds(points=points, features=features)
456
+
457
+
model/model_utils.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ from pytorch3d.structures import Pointclouds
6
+
7
+
8
+ def set_requires_grad(module: nn.Module, requires_grad: bool):
9
+ for p in module.parameters():
10
+ p.requires_grad_(requires_grad)
11
+
12
+
13
+ def compute_distance_transform(mask: torch.Tensor):
14
+ """
15
+
16
+ Parameters
17
+ ----------
18
+ mask (B, 1, H, W) or (B, 2, H, W) true for foreground
19
+
20
+ Returns
21
+ -------
22
+ the vector to the closest foreground pixel, zero if inside mask
23
+
24
+ """
25
+ C = mask.shape[1]
26
+ assert C in [1, 2], f'invalid mask shape {mask.shape} found!'
27
+
28
+ image_size = mask.shape[-1]
29
+
30
+ dts = []
31
+ for i in range(C):
32
+ distance_transform = torch.stack([
33
+ torch.from_numpy(cv2.distanceTransform(
34
+ (1 - m), distanceType=cv2.DIST_L2, maskSize=cv2.DIST_MASK_3
35
+ ) / (image_size / 2))
36
+ for m in mask[:, i:i+1].squeeze(1).detach().cpu().numpy().astype(np.uint8)
37
+ ]).unsqueeze(1).clip(0, 1).to(mask.device)
38
+ dts.append(distance_transform)
39
+ return torch.cat(dts, 1)
40
+
41
+
42
+ def default(x, d):
43
+ return d if x is None else x
44
+
45
+
46
+ def get_num_points(x: Pointclouds, /):
47
+ return x.points_padded().shape[1]
48
+
49
+
50
+ def get_custom_betas(beta_start: float, beta_end: float, warmup_frac: float = 0.3, num_train_timesteps: int = 1000):
51
+ """Custom beta schedule"""
52
+ betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
53
+ warmup_frac = 0.3
54
+ warmup_time = int(num_train_timesteps * warmup_frac)
55
+ warmup_steps = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
56
+ warmup_time = min(warmup_time, num_train_timesteps)
57
+ betas[:warmup_time] = warmup_steps[:warmup_time]
58
+ return betas
model/point_cloud_model.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import nullcontext
2
+
3
+ import torch
4
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
5
+ from diffusers import ModelMixin
6
+ from torch import Tensor
7
+
8
+ from .pvcnn.pvcnn import PVCNN2
9
+ from .pvcnn.pvcnn_plus_plus import PVCNN2PlusPlus
10
+ from .simple.simple_model import SimplePointModel
11
+
12
+
13
+ class PointCloudModel(ModelMixin, ConfigMixin):
14
+ @register_to_config
15
+ def __init__(
16
+ self,
17
+ model_type: str = 'pvcnn',
18
+ in_channels: int = 3,
19
+ out_channels: int = 3,
20
+ embed_dim: int = 64,
21
+ dropout: float = 0.1,
22
+ width_multiplier: int = 1,
23
+ voxel_resolution_multiplier: int = 1,
24
+ ):
25
+ super().__init__()
26
+ self.model_type = model_type
27
+ if self.model_type == 'pvcnn':
28
+ self.autocast_context = torch.autocast('cuda', dtype=torch.float32)
29
+ self.model = PVCNN2(
30
+ embed_dim=embed_dim,
31
+ num_classes=out_channels,
32
+ extra_feature_channels=(in_channels - 3),
33
+ dropout=dropout, width_multiplier=width_multiplier,
34
+ voxel_resolution_multiplier=voxel_resolution_multiplier
35
+ )
36
+ self.model.classifier[-1].bias.data.normal_(0, 1e-6)
37
+ self.model.classifier[-1].weight.data.normal_(0, 1e-6)
38
+ elif self.model_type == 'pvcnnplusplus':
39
+ self.autocast_context = torch.autocast('cuda', dtype=torch.float32)
40
+ self.model = PVCNN2PlusPlus(
41
+ embed_dim=embed_dim,
42
+ num_classes=out_channels,
43
+ extra_feature_channels=(in_channels - 3),
44
+ )
45
+ self.model.output_projection[-1].bias.data.normal_(0, 1e-6)
46
+ self.model.output_projection[-1].weight.data.normal_(0, 1e-6)
47
+ elif self.model_type == 'simple':
48
+ self.autocast_context = nullcontext()
49
+ self.model = SimplePointModel(
50
+ embed_dim=embed_dim,
51
+ num_classes=out_channels,
52
+ extra_feature_channels=(in_channels - 3),
53
+ )
54
+ self.model.output_projection.bias.data.normal_(0, 1e-6)
55
+ self.model.output_projection.weight.data.normal_(0, 1e-6)
56
+ else:
57
+ raise NotImplementedError()
58
+
59
+ def forward(self, inputs: Tensor, t: Tensor, ret_feats=False) -> Tensor:
60
+ """ Receives input of shape (B, N, in_channels) and returns output
61
+ of shape (B, N, out_channels) """
62
+ with self.autocast_context:
63
+ if not ret_feats:
64
+ return self.model(inputs.transpose(1, 2), t, ret_feats=False).transpose(1, 2)
65
+ else:
66
+ pred, feats = self.model(inputs.transpose(1, 2), t, ret_feats=True)
67
+ return pred.transpose(1, 2), feats
model/point_cloud_transformer_model.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
6
+ from diffusers import ModelMixin
7
+ from torch import Tensor
8
+ from timm.models.vision_transformer import Attention, LayerScale, DropPath, Mlp
9
+
10
+ from .point_cloud_model import PointCloudModel
11
+
12
+
13
+ class PointCloudModelBlock(nn.Module):
14
+
15
+ def __init__(
16
+ self,
17
+ *,
18
+ # Point cloud model
19
+ dim: int,
20
+ model_type: str = 'pvcnn',
21
+ dropout: float = 0.1,
22
+ width_multiplier: int = 1,
23
+ voxel_resolution_multiplier: int = 1,
24
+ # Transformer model
25
+ num_heads=6, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
26
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_attn=False
27
+ ):
28
+ super().__init__()
29
+
30
+ # Point cloud model
31
+ self.norm0 = norm_layer(dim)
32
+ self.point_cloud_model = PointCloudModel(model_type=model_type,
33
+ in_channels=dim, out_channels=dim, embed_dim=dim, dropout=dropout,
34
+ width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier)
35
+ self.ls0 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
36
+ self.drop_path0 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
37
+
38
+ # Attention
39
+ self.use_attn = use_attn
40
+ if self.use_attn:
41
+ self.norm1 = norm_layer(dim)
42
+ self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
43
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
44
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
45
+
46
+ # MLP
47
+ self.norm2 = norm_layer(dim)
48
+ self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
49
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
50
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
51
+
52
+ def apply_point_cloud_model(self, x: Tensor, t: Optional[Tensor] = None) -> Tensor:
53
+ t = t if t is not None else torch.zeros(len(x), device=x.device, dtype=torch.long)
54
+ return self.point_cloud_model(x, t)
55
+
56
+ def forward(self, x: Tensor):
57
+ x = x + self.drop_path0(self.ls0(self.apply_point_cloud_model(self.norm0(x))))
58
+ if self.use_attn:
59
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
60
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
61
+ return x
62
+
63
+
64
+ class PointCloudTransformerModel(ModelMixin, ConfigMixin):
65
+ @register_to_config
66
+ def __init__(self, num_layers: int, in_channels: int = 3, out_channels: int = 3, embed_dim: int = 64, **kwargs):
67
+ super().__init__()
68
+ self.num_layers = num_layers
69
+ self.input_projection = nn.Linear(in_channels, embed_dim)
70
+ self.blocks = nn.Sequential(*[PointCloudModelBlock(dim=embed_dim, **kwargs) for i in range(self.num_layers)])
71
+ self.norm = nn.LayerNorm(embed_dim)
72
+ self.output_projection = nn.Linear(embed_dim, out_channels)
73
+
74
+ def forward(self, inputs: Tensor) -> Tensor:
75
+ """ Receives input of shape (B, N, in_channels) and returns output
76
+ of shape (B, N, out_channels) """
77
+ x = self.input_projection(inputs)
78
+ x = self.blocks(x)
79
+ x = self.output_projection(x)
80
+ return x
model/projection_model.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+
3
+ import torch
4
+ from diffusers.schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler
5
+ from diffusers.schedulers.scheduling_lms_discrete import LMSDiscreteScheduler
6
+ from diffusers import ModelMixin
7
+ from pytorch3d.implicitron.dataset.data_loader_map_provider import FrameData
8
+ from pytorch3d.renderer import PointsRasterizationSettings, PointsRasterizer
9
+ from pytorch3d.renderer.cameras import CamerasBase
10
+ from pytorch3d.structures import Pointclouds
11
+ from torch import Tensor
12
+
13
+ from .feature_model import FeatureModel
14
+ from .model_utils import compute_distance_transform
15
+
16
+ SchedulerClass = Union[DDPMScheduler, DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
17
+
18
+
19
+ class PointCloudProjectionModel(ModelMixin):
20
+
21
+ def __init__(
22
+ self,
23
+ image_size: int,
24
+ image_feature_model: str,
25
+ use_local_colors: bool = True,
26
+ use_local_features: bool = True,
27
+ use_global_features: bool = False,
28
+ use_mask: bool = True,
29
+ use_distance_transform: bool = True,
30
+ predict_shape: bool = True,
31
+ predict_color: bool = False,
32
+ process_color: bool = False,
33
+ image_color_channels: int = 3, # for the input image, not the points
34
+ color_channels: int = 3, # for the points, not the input image
35
+ colors_mean: float = 0.5,
36
+ colors_std: float = 0.5,
37
+ scale_factor: float = 1.0,
38
+ # Rasterization settings
39
+ raster_point_radius: float = 0.0075, # point size
40
+ raster_points_per_pixel: int = 1, # a single point per pixel, for now
41
+ bin_size: int = 0,
42
+ model_name=None,
43
+ # additional arguments added by XH
44
+ load_sample_init=False,
45
+ sample_init_scale=1.0,
46
+ test_init_with_gtpc=False,
47
+ consistent_center=False, # from https://arxiv.org/pdf/2308.07837.pdf
48
+ voxel_resolution_multiplier: int=1,
49
+ predict_binary: bool=False, # predict a binary class label
50
+ lw_binary: float=1.0,
51
+ binary_training_noise_std: float=0.1,
52
+ dm_pred_type: str='epsilon', # diffusion prediction type
53
+ self_conditioning=False,
54
+ **kwargs,
55
+
56
+ ):
57
+ super().__init__()
58
+ self.image_size = image_size
59
+ self.scale_factor = scale_factor
60
+ self.use_local_colors = use_local_colors
61
+ self.use_local_features = use_local_features
62
+ self.use_global_features = use_global_features
63
+ self.use_mask = use_mask
64
+ self.use_distance_transform = use_distance_transform
65
+ self.predict_shape = predict_shape # default False
66
+ self.predict_color = predict_color # default True
67
+ self.process_color = process_color
68
+ self.image_color_channels = image_color_channels
69
+ self.color_channels = color_channels
70
+ self.colors_mean = colors_mean
71
+ self.colors_std = colors_std
72
+ self.model_name = model_name
73
+ print("PointCloud Model scale factor:", self.scale_factor, 'Model name:', self.model_name)
74
+ self.predict_binary = predict_binary
75
+ self.lw_binary = lw_binary
76
+ self.self_conditioning = self_conditioning
77
+
78
+ # Types of conditioning that are used
79
+ self.use_local_conditioning = self.use_local_colors or self.use_local_features or self.use_mask
80
+ self.use_global_conditioning = self.use_global_features
81
+ self.kwargs = kwargs
82
+
83
+ # Create feature model
84
+ self.feature_model = FeatureModel(image_size, image_feature_model)
85
+
86
+ # Input size
87
+ self.in_channels = 3 # 3 for 3D point positions
88
+ if self.use_local_colors: # whether color should be an input
89
+ self.in_channels += self.image_color_channels
90
+ if self.use_local_features:
91
+ self.in_channels += self.feature_model.feature_dim
92
+ if self.use_global_features:
93
+ self.in_channels += self.feature_model.feature_dim
94
+ if self.use_mask:
95
+ self.in_channels += 2 if self.use_distance_transform else 1
96
+ if self.process_color:
97
+ self.in_channels += self.color_channels # point color added to input or not, default False
98
+ if self.self_conditioning:
99
+ self.in_channels += 3 # add self conditioning
100
+
101
+ self.in_channels = self.add_extra_input_chennels(self.in_channels)
102
+
103
+ if self.model_name in ['pc2-diff-ho-sepsegm', 'diff-ho-attn']:
104
+ self.in_channels += 2 if self.use_distance_transform else 1
105
+
106
+ # Output size
107
+ self.out_channels = 0
108
+ if self.predict_shape:
109
+ self.out_channels += 3
110
+ if self.predict_color:
111
+ self.out_channels += self.color_channels
112
+ if self.predict_binary:
113
+ print("Output binary classification score!")
114
+ self.out_channels += 1
115
+
116
+ # Save rasterization settings
117
+ self.raster_settings = PointsRasterizationSettings(
118
+ image_size=(image_size, image_size),
119
+ radius=raster_point_radius,
120
+ points_per_pixel=raster_points_per_pixel,
121
+ bin_size=bin_size,
122
+ )
123
+
124
+ def add_extra_input_chennels(self, input_channels):
125
+ return input_channels
126
+
127
+ def denormalize(self, x: Tensor, /, clamp: bool = True):
128
+ x = x * self.colors_std + self.colors_mean
129
+ return torch.clamp(x, 0, 1) if clamp else x
130
+
131
+ def normalize(self, x: Tensor, /):
132
+ x = (x - self.colors_mean) / self.colors_std
133
+ return x
134
+
135
+ def get_global_conditioning(self, image_rgb: Tensor):
136
+ global_conditioning = []
137
+ if self.use_global_features:
138
+ global_conditioning.append(self.feature_model(image_rgb,
139
+ return_cls_token_only=True)) # (B, D)
140
+ global_conditioning = torch.cat(global_conditioning, dim=1) # (B, D_cond)
141
+ return global_conditioning
142
+
143
+ def get_local_conditioning(self, image_rgb: Tensor, mask: Tensor):
144
+ """
145
+ compute per-point conditioning
146
+ Parameters
147
+ ----------
148
+ image_rgb: (B, 3, 224, 224), values normalized to 0-1, background is masked by the given mask
149
+ mask: (B, 1, 224, 224), or (B, 2, 224, 224) for h+o
150
+ """
151
+ local_conditioning = []
152
+ # import pdb; pdb.set_trace()
153
+
154
+ if self.use_local_colors: # XH: default True
155
+ local_conditioning.append(self.normalize(image_rgb))
156
+ if self.use_local_features: # XH: default True
157
+ local_conditioning.append(self.feature_model(image_rgb)) # I guess no mask here? feature model: 'vit_small_patch16_224_mae'
158
+ if self.use_mask: # default True
159
+ local_conditioning.append(mask.float())
160
+ if self.use_distance_transform: # default True
161
+ if not self.use_mask:
162
+ raise ValueError('No mask for distance transform?')
163
+ if mask.is_floating_point():
164
+ mask = mask > 0.5
165
+ local_conditioning.append(compute_distance_transform(mask))
166
+ local_conditioning = torch.cat(local_conditioning, dim=1) # (B, D_cond, H, W)
167
+ return local_conditioning
168
+
169
+ @torch.autocast('cuda', dtype=torch.float32)
170
+ def surface_projection(
171
+ self, points: Tensor, camera: CamerasBase, local_features: Tensor,
172
+ ):
173
+ B, C, H, W, device = *local_features.shape, local_features.device
174
+ R = self.raster_settings.points_per_pixel
175
+ N = points.shape[1]
176
+
177
+ # Scale camera by scaling T. ASSUMES CAMERA IS LOOKING AT ORIGIN!
178
+ camera = camera.clone()
179
+ camera.T = camera.T * self.scale_factor
180
+
181
+ # Create rasterizer
182
+ rasterizer = PointsRasterizer(cameras=camera, raster_settings=self.raster_settings)
183
+
184
+ # Associate points with features via rasterization
185
+ fragments = rasterizer(Pointclouds(points)) # (B, H, W, R)
186
+ fragments_idx: Tensor = fragments.idx.long()
187
+ visible_pixels = (fragments_idx > -1) # (B, H, W, R)
188
+ points_to_visible_pixels = fragments_idx[visible_pixels]
189
+
190
+ # Reshape local features to (B, H, W, R, C)
191
+ local_features = local_features.permute(0, 2, 3, 1).unsqueeze(-2).expand(-1, -1, -1, R, -1) # (B, H, W, R, C)
192
+
193
+ # Get local features corresponding to visible points
194
+ local_features_proj = torch.zeros(B * N, C, device=device)
195
+ # local feature includes: raw RGB color, image features, mask, distance transform
196
+ local_features_proj[points_to_visible_pixels] = local_features[visible_pixels]
197
+ local_features_proj = local_features_proj.reshape(B, N, C)
198
+
199
+ return local_features_proj
200
+
201
+ def point_cloud_to_tensor(self, pc: Pointclouds, /, normalize: bool = False, scale: bool = False):
202
+ """Converts a point cloud to a tensor, with color if and only if self.predict_color"""
203
+ points = pc.points_padded() * (self.scale_factor if scale else 1)
204
+ if self.predict_color and pc.features_padded() is not None: # normalize color, not point locations
205
+ colors = self.normalize(pc.features_padded()) if normalize else pc.features_padded()
206
+ return torch.cat((points, colors), dim=2)
207
+ else:
208
+ return points
209
+
210
+ def tensor_to_point_cloud(self, x: Tensor, /, denormalize: bool = False, unscale: bool = False):
211
+ points = x[:, :, :3] / (self.scale_factor if unscale else 1)
212
+ if self.predict_color:
213
+ colors = self.denormalize(x[:, :, 3:]) if denormalize else x[:, :, 3:]
214
+ return Pointclouds(points=points, features=colors)
215
+ else:
216
+ assert x.shape[2] == 3
217
+ return Pointclouds(points=points)
218
+
219
+ def get_input_with_conditioning(
220
+ self,
221
+ x_t: Tensor,
222
+ camera: Optional[CamerasBase],
223
+ image_rgb: Optional[Tensor],
224
+ mask: Optional[Tensor],
225
+ t: Optional[Tensor],
226
+ ):
227
+ """ Extracts local features from the input image and projects them onto the points
228
+ in the point cloud to obtain the input to the model. Then extracts global
229
+ features, replicates them across points, and concats them to the input.
230
+ image_rgb: masked background
231
+ XH: why there is no positional encoding as described by the supp??
232
+ """
233
+ B, N = x_t.shape[:2]
234
+
235
+ # Initial input is the point locations (and colors if and only if predicting color)
236
+ x_t_input = self.get_coord_feature(x_t)
237
+
238
+ # Local conditioning
239
+ if self.use_local_conditioning:
240
+
241
+ # Get local features and check that they are the same size as the input image
242
+ local_features = self.get_local_conditioning(image_rgb=image_rgb, mask=mask) # concatenate RGB + mask + RGB feature + distance transform
243
+ if local_features.shape[-2:] != image_rgb.shape[-2:]:
244
+ raise ValueError(f'{local_features.shape=} and {image_rgb.shape=}')
245
+
246
+ # Project local features. Here that we only need the point locations, not colors
247
+ local_features_proj = self.surface_projection(points=x_t[:, :, :3],
248
+ camera=camera, local_features=local_features) # (B, N, D_local)
249
+
250
+ x_t_input.append(local_features_proj)
251
+
252
+ # Global conditioning
253
+ if self.use_global_conditioning: # False
254
+
255
+ # Get and repeat global features
256
+ global_features = self.get_global_conditioning(image_rgb=image_rgb) # (B, D_global)
257
+ global_features = global_features.unsqueeze(1).expand(-1, N, -1) # (B, D_global, N)
258
+
259
+ x_t_input.append(global_features)
260
+
261
+ # Concatenate together all the pointwise features
262
+ x_t_input = torch.cat(x_t_input, dim=2) # (B, N, D)
263
+
264
+ return x_t_input
265
+
266
+ def get_coord_feature(self, x_t):
267
+ """get coordinate feature, for model that uses separate model to predict binary, we use first 3 channels only"""
268
+ x_t_input = [x_t]
269
+ return x_t_input
270
+
271
+ def forward(self, batch: FrameData, mode: str = 'train', **kwargs):
272
+ """ The forward method may be defined differently for different models. """
273
+ raise NotImplementedError()
model/pvcnn/__init__.py ADDED
File without changes
model/pvcnn/modules/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .ball_query import BallQuery, BallQueryHO
2
+ from .frustum import FrustumPointNetLoss
3
+ from .loss import KLLoss
4
+ from .pointnet import PointNetAModule, PointNetSAModule, PointNetFPModule
5
+ from .pvconv import PVConv, Attention, Swish, PVConvReLU
6
+ from .se import SE3d
7
+ from .shared_mlp import SharedMLP
8
+ from .voxelization import Voxelization
model/pvcnn/modules/ball_query.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from . import functional as F
5
+
6
+ __all__ = ['BallQuery']
7
+
8
+
9
+ class BallQuery(nn.Module):
10
+ def __init__(self, radius, num_neighbors, include_coordinates=True):
11
+ super().__init__()
12
+ self.radius = radius
13
+ self.num_neighbors = num_neighbors
14
+ self.include_coordinates = include_coordinates
15
+
16
+ def forward(self, points_coords, centers_coords, temb, points_features=None):
17
+ points_coords = points_coords.contiguous()
18
+ centers_coords = centers_coords.contiguous()
19
+ neighbor_indices = F.ball_query(centers_coords, points_coords, self.radius, self.num_neighbors)
20
+ neighbor_coordinates = F.grouping(points_coords, neighbor_indices)
21
+ neighbor_coordinates = neighbor_coordinates - centers_coords.unsqueeze(-1)
22
+
23
+ if points_features is None:
24
+ assert self.include_coordinates, 'No Features For Grouping'
25
+ neighbor_features = neighbor_coordinates
26
+ else:
27
+ neighbor_features = F.grouping(points_features, neighbor_indices) # return [B, C, M, U] C=feat dim, M=# centers, U=# neighbours
28
+ if self.include_coordinates:
29
+ neighbor_features = torch.cat([neighbor_coordinates, neighbor_features], dim=1)
30
+ return neighbor_features, F.grouping(temb, neighbor_indices)
31
+
32
+ def extra_repr(self):
33
+ return 'radius={}, num_neighbors={}{}'.format(
34
+ self.radius, self.num_neighbors, ', include coordinates' if self.include_coordinates else '')
35
+
36
+
37
+ class BallQueryHO(nn.Module):
38
+ "no point feature, but only relative and abs coordinate"
39
+ def __init__(self, radius, num_neighbors, include_relative=False):
40
+ super().__init__()
41
+ self.radius = radius
42
+ self.num_neighbors = num_neighbors
43
+ self.include_relative = include_relative
44
+
45
+ def forward(self, points_coords, centers_coords, points_features=None):
46
+ """
47
+ if not enough points inside the given radius, the entries will be zero
48
+ if too many points inside the radius, the order is random??? (not sure)
49
+ :param points_coords: (B, 3, N)
50
+ :param centers_coords: (B, 3, M)
51
+ :param points_features: None
52
+ :return:
53
+ """
54
+ points_coords = points_coords.contiguous()
55
+ centers_coords = centers_coords.contiguous()
56
+ neighbor_indices = F.ball_query(centers_coords, points_coords, self.radius, self.num_neighbors)
57
+ neighbor_coordinates = F.grouping(points_coords, neighbor_indices) # (B, 3, M, U)
58
+ if self.include_relative:
59
+ neighbor_coordinates_rela = neighbor_coordinates - centers_coords.unsqueeze(-1)
60
+ neighbor_coordinates = torch.cat([neighbor_coordinates, neighbor_coordinates_rela], 1) # (B, 6, M, U)
61
+ # flatten the coordinate
62
+ neighbor_coordinates = neighbor_coordinates.permute(0, 1, 3, 2) # (B, 3/6, U, M)
63
+ neighbor_coordinates = torch.flatten(neighbor_coordinates, 1, 2) # (B, 3*U, M)
64
+ return neighbor_coordinates
65
+
66
+ def extra_repr(self):
67
+ return 'radius={}, num_neighbors={}{}'.format(
68
+ self.radius, self.num_neighbors, ', include relative' if self.include_relative else '')
69
+
model/pvcnn/modules/frustum.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from . import functional as F
7
+
8
+ __all__ = ['FrustumPointNetLoss', 'get_box_corners_3d']
9
+
10
+
11
+ class FrustumPointNetLoss(nn.Module):
12
+ def __init__(self, num_heading_angle_bins, num_size_templates, size_templates, box_loss_weight=1.0,
13
+ corners_loss_weight=10.0, heading_residual_loss_weight=20.0, size_residual_loss_weight=20.0):
14
+ super().__init__()
15
+ self.box_loss_weight = box_loss_weight
16
+ self.corners_loss_weight = corners_loss_weight
17
+ self.heading_residual_loss_weight = heading_residual_loss_weight
18
+ self.size_residual_loss_weight = size_residual_loss_weight
19
+
20
+ self.num_heading_angle_bins = num_heading_angle_bins
21
+ self.num_size_templates = num_size_templates
22
+ self.register_buffer('size_templates', size_templates.view(self.num_size_templates, 3))
23
+ self.register_buffer(
24
+ 'heading_angle_bin_centers', torch.arange(0, 2 * np.pi, 2 * np.pi / self.num_heading_angle_bins)
25
+ )
26
+
27
+ def forward(self, inputs, targets):
28
+ mask_logits = inputs['mask_logits'] # (B, 2, N)
29
+ center_reg = inputs['center_reg'] # (B, 3)
30
+ center = inputs['center'] # (B, 3)
31
+ heading_scores = inputs['heading_scores'] # (B, NH)
32
+ heading_residuals_normalized = inputs['heading_residuals_normalized'] # (B, NH)
33
+ heading_residuals = inputs['heading_residuals'] # (B, NH)
34
+ size_scores = inputs['size_scores'] # (B, NS)
35
+ size_residuals_normalized = inputs['size_residuals_normalized'] # (B, NS, 3)
36
+ size_residuals = inputs['size_residuals'] # (B, NS, 3)
37
+
38
+ mask_logits_target = targets['mask_logits'] # (B, N)
39
+ center_target = targets['center'] # (B, 3)
40
+ heading_bin_id_target = targets['heading_bin_id'] # (B, )
41
+ heading_residual_target = targets['heading_residual'] # (B, )
42
+ size_template_id_target = targets['size_template_id'] # (B, )
43
+ size_residual_target = targets['size_residual'] # (B, 3)
44
+
45
+ batch_size = center.size(0)
46
+ batch_id = torch.arange(batch_size, device=center.device)
47
+
48
+ # Basic Classification and Regression losses
49
+ mask_loss = F.cross_entropy(mask_logits, mask_logits_target)
50
+ heading_loss = F.cross_entropy(heading_scores, heading_bin_id_target)
51
+ size_loss = F.cross_entropy(size_scores, size_template_id_target)
52
+ center_loss = PF.huber_loss(torch.norm(center_target - center, dim=-1), delta=2.0)
53
+ center_reg_loss = PF.huber_loss(torch.norm(center_target - center_reg, dim=-1), delta=1.0)
54
+
55
+ # Refinement losses for size/heading
56
+ heading_residuals_normalized = heading_residuals_normalized[batch_id, heading_bin_id_target] # (B, )
57
+ heading_residual_normalized_target = heading_residual_target / (np.pi / self.num_heading_angle_bins)
58
+ heading_residual_normalized_loss = PF.huber_loss(
59
+ heading_residuals_normalized - heading_residual_normalized_target, delta=1.0
60
+ )
61
+ size_residuals_normalized = size_residuals_normalized[batch_id, size_template_id_target] # (B, 3)
62
+ size_residual_normalized_target = size_residual_target / self.size_templates[size_template_id_target]
63
+ size_residual_normalized_loss = PF.huber_loss(
64
+ torch.norm(size_residual_normalized_target - size_residuals_normalized, dim=-1), delta=1.0
65
+ )
66
+
67
+ # Bounding box losses
68
+ heading = (heading_residuals[batch_id, heading_bin_id_target]
69
+ + self.heading_angle_bin_centers[heading_bin_id_target]) # (B, )
70
+ # Warning: in origin code, size_residuals are added twice (issue #43 and #49 in charlesq34/frustum-pointnets)
71
+ size = (size_residuals[batch_id, size_template_id_target]
72
+ + self.size_templates[size_template_id_target]) # (B, 3)
73
+ corners = get_box_corners_3d(centers=center, headings=heading, sizes=size, with_flip=False) # (B, 3, 8)
74
+ heading_target = self.heading_angle_bin_centers[heading_bin_id_target] + heading_residual_target # (B, )
75
+ size_target = self.size_templates[size_template_id_target] + size_residual_target # (B, 3)
76
+ corners_target, corners_target_flip = get_box_corners_3d(centers=center_target, headings=heading_target,
77
+ sizes=size_target, with_flip=True) # (B, 3, 8)
78
+ corners_loss = PF.huber_loss(torch.min(
79
+ torch.norm(corners - corners_target, dim=1), torch.norm(corners - corners_target_flip, dim=1)
80
+ ), delta=1.0)
81
+ # Summing up
82
+ loss = mask_loss + self.box_loss_weight * (
83
+ center_loss + center_reg_loss + heading_loss + size_loss
84
+ + self.heading_residual_loss_weight * heading_residual_normalized_loss
85
+ + self.size_residual_loss_weight * size_residual_normalized_loss
86
+ + self.corners_loss_weight * corners_loss
87
+ )
88
+
89
+ return loss
90
+
91
+
92
+ def get_box_corners_3d(centers, headings, sizes, with_flip=False):
93
+ """
94
+ :param centers: coords of box centers, FloatTensor[N, 3]
95
+ :param headings: heading angles, FloatTensor[N, ]
96
+ :param sizes: box sizes, FloatTensor[N, 3]
97
+ :param with_flip: bool, whether to return flipped box (headings + np.pi)
98
+ :return:
99
+ coords of box corners, FloatTensor[N, 3, 8]
100
+ NOTE: corner points are in counter clockwise order, e.g.,
101
+ 2--1
102
+ 3--0 5
103
+ 7--4
104
+ """
105
+ l = sizes[:, 0] # (N,)
106
+ w = sizes[:, 1] # (N,)
107
+ h = sizes[:, 2] # (N,)
108
+ x_corners = torch.stack([l/2, l/2, -l/2, -l/2, l/2, l/2, -l/2, -l/2], dim=1) # (N, 8)
109
+ y_corners = torch.stack([h/2, h/2, h/2, h/2, -h/2, -h/2, -h/2, -h/2], dim=1) # (N, 8)
110
+ z_corners = torch.stack([w/2, -w/2, -w/2, w/2, w/2, -w/2, -w/2, w/2], dim=1) # (N, 8)
111
+
112
+ c = torch.cos(headings) # (N,)
113
+ s = torch.sin(headings) # (N,)
114
+ o = torch.ones_like(headings) # (N,)
115
+ z = torch.zeros_like(headings) # (N,)
116
+
117
+ centers = centers.unsqueeze(-1) # (B, 3, 1)
118
+ corners = torch.stack([x_corners, y_corners, z_corners], dim=1) # (N, 3, 8)
119
+ R = torch.stack([c, z, s, z, o, z, -s, z, c], dim=1).view(-1, 3, 3) # roty matrix: (N, 3, 3)
120
+ if with_flip:
121
+ R_flip = torch.stack([-c, z, -s, z, o, z, s, z, -c], dim=1).view(-1, 3, 3)
122
+ return torch.matmul(R, corners) + centers, torch.matmul(R_flip, corners) + centers
123
+ else:
124
+ return torch.matmul(R, corners) + centers
125
+
126
+ # centers = centers.unsqueeze(1) # (B, 1, 3)
127
+ # corners = torch.stack([x_corners, y_corners, z_corners], dim=-1) # (N, 8, 3)
128
+ # RT = torch.stack([c, z, -s, z, o, z, s, z, c], dim=1).view(-1, 3, 3) # (N, 3, 3)
129
+ # if with_flip:
130
+ # RT_flip = torch.stack([-c, z, s, z, o, z, -s, z, -c], dim=1).view(-1, 3, 3) # (N, 3, 3)
131
+ # return torch.matmul(corners, RT) + centers, torch.matmul(corners, RT_flip) + centers # (N, 8, 3)
132
+ # else:
133
+ # return torch.matmul(corners, RT) + centers # (N, 8, 3)
134
+
135
+ # corners = torch.stack([x_corners, y_corners, z_corners], dim=1) # (N, 3, 8)
136
+ # R = torch.stack([c, z, s, z, o, z, -s, z, c], dim=1).view(-1, 3, 3) # (N, 3, 3)
137
+ # corners = torch.matmul(R, corners) + centers.unsqueeze(2) # (N, 3, 8)
138
+ # corners = corners.transpose(1, 2) # (N, 8, 3)
model/pvcnn/modules/functional/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .ball_query import ball_query
2
+ from .devoxelization import trilinear_devoxelize
3
+ from .grouping import grouping
4
+ from .interpolatation import nearest_neighbor_interpolate
5
+ from .loss import kl_loss, huber_loss
6
+ from .sampling import gather, furthest_point_sample, logits_mask
7
+ from .voxelization import avg_voxelize
model/pvcnn/modules/functional/backend.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ from torch.utils.cpp_extension import load
5
+
6
+
7
+ gcc_path = os.getenv('CC', default='/usr/bin/gcc')
8
+ if not Path(gcc_path).is_file():
9
+ raise ValueError('Could not find your gcc, please replace it here.')
10
+
11
+ _src_path = os.path.dirname(os.path.abspath(__file__))
12
+ _backend = load(
13
+ name='_pvcnn_backend',
14
+ extra_cflags=['-O3', '-std=c++17'],
15
+ extra_cuda_cflags=[f'--compiler-bindir={gcc_path}'],
16
+ sources=[os.path.join(_src_path,'src', f) for f in [
17
+ 'ball_query/ball_query.cpp',
18
+ 'ball_query/ball_query.cu',
19
+ 'grouping/grouping.cpp',
20
+ 'grouping/grouping.cu',
21
+ 'interpolate/neighbor_interpolate.cpp',
22
+ 'interpolate/neighbor_interpolate.cu',
23
+ 'interpolate/trilinear_devox.cpp',
24
+ 'interpolate/trilinear_devox.cu',
25
+ 'sampling/sampling.cpp',
26
+ 'sampling/sampling.cu',
27
+ 'voxelization/vox.cpp',
28
+ 'voxelization/vox.cu',
29
+ 'bindings.cpp',
30
+ ]]
31
+ )
32
+
33
+ __all__ = ['_backend']
model/pvcnn/modules/functional/ball_query.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.autograd import Function
2
+
3
+ from .backend import _backend
4
+
5
+ __all__ = ['ball_query']
6
+
7
+
8
+ def ball_query(centers_coords, points_coords, radius, num_neighbors):
9
+ """
10
+ :param centers_coords: coordinates of centers, FloatTensor[B, 3, M]
11
+ :param points_coords: coordinates of points, FloatTensor[B, 3, N]
12
+ :param radius: float, radius of ball query
13
+ :param num_neighbors: int, maximum number of neighbors
14
+ :return:
15
+ neighbor_indices: indices of neighbors, IntTensor[B, M, U]
16
+ """
17
+ centers_coords = centers_coords.contiguous()
18
+ points_coords = points_coords.contiguous()
19
+ return _backend.ball_query(centers_coords, points_coords, radius, num_neighbors)
model/pvcnn/modules/functional/devoxelization.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.autograd import Function
2
+
3
+ from .backend import _backend
4
+
5
+ __all__ = ['trilinear_devoxelize']
6
+
7
+
8
+ class TrilinearDevoxelization(Function):
9
+ @staticmethod
10
+ def forward(ctx, features, coords, resolution, is_training=True):
11
+ """
12
+ :param ctx:
13
+ :param coords: the coordinates of points, FloatTensor[B, 3, N]
14
+ :param features: FloatTensor[B, C, R, R, R]
15
+ :param resolution: int, the voxel resolution
16
+ :param is_training: bool, training mode
17
+ :return:
18
+ FloatTensor[B, C, N]
19
+ """
20
+ B, C = features.shape[:2]
21
+ features = features.contiguous().view(B, C, -1)
22
+ coords = coords.contiguous()
23
+ outs, inds, wgts = _backend.trilinear_devoxelize_forward(resolution, is_training, coords, features)
24
+ if is_training:
25
+ ctx.save_for_backward(inds, wgts)
26
+ ctx.r = resolution
27
+ return outs
28
+
29
+ @staticmethod
30
+ def backward(ctx, grad_output):
31
+ """
32
+ :param ctx:
33
+ :param grad_output: gradient of outputs, FloatTensor[B, C, N]
34
+ :return:
35
+ gradient of inputs, FloatTensor[B, C, R, R, R]
36
+ """
37
+ inds, wgts = ctx.saved_tensors
38
+ grad_inputs = _backend.trilinear_devoxelize_backward(grad_output.contiguous(), inds, wgts, ctx.r)
39
+ return grad_inputs.view(grad_output.size(0), grad_output.size(1), ctx.r, ctx.r, ctx.r), None, None, None
40
+
41
+
42
+ trilinear_devoxelize = TrilinearDevoxelization.apply
model/pvcnn/modules/functional/grouping.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.autograd import Function
2
+
3
+ from .backend import _backend
4
+
5
+ __all__ = ['grouping']
6
+
7
+
8
+ class Grouping(Function):
9
+ @staticmethod
10
+ def forward(ctx, features, indices):
11
+ """
12
+ :param ctx:
13
+ :param features: features of points, FloatTensor[B, C, N]
14
+ :param indices: neighbor indices of centers, IntTensor[B, M, U], M is #centers, U is #neighbors
15
+ :return:
16
+ grouped_features: grouped features, FloatTensor[B, C, M, U]
17
+ """
18
+ features = features.contiguous()
19
+ indices = indices.contiguous()
20
+ ctx.save_for_backward(indices)
21
+ ctx.num_points = features.size(-1)
22
+ # print(features.dtype, features.shape)
23
+ return _backend.grouping_forward(features, indices)
24
+
25
+ @staticmethod
26
+ def backward(ctx, grad_output):
27
+ indices, = ctx.saved_tensors
28
+ grad_features = _backend.grouping_backward(grad_output.contiguous(), indices, ctx.num_points)
29
+ return grad_features, None
30
+
31
+
32
+ grouping = Grouping.apply
model/pvcnn/modules/functional/interpolatation.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.autograd import Function
2
+
3
+ from .backend import _backend
4
+
5
+ __all__ = ['nearest_neighbor_interpolate']
6
+
7
+
8
+ class NeighborInterpolation(Function):
9
+ @staticmethod
10
+ def forward(ctx, points_coords, centers_coords, centers_features):
11
+ """
12
+ :param ctx:
13
+ :param points_coords: coordinates of points, FloatTensor[B, 3, N]
14
+ :param centers_coords: coordinates of centers, FloatTensor[B, 3, M]
15
+ :param centers_features: features of centers, FloatTensor[B, C, M]
16
+ :return:
17
+ points_features: features of points, FloatTensor[B, C, N]
18
+ """
19
+ centers_coords = centers_coords.contiguous()
20
+ points_coords = points_coords.contiguous()
21
+ centers_features = centers_features.contiguous()
22
+ points_features, indices, weights = _backend.three_nearest_neighbors_interpolate_forward(
23
+ points_coords, centers_coords, centers_features
24
+ )
25
+ ctx.save_for_backward(indices, weights)
26
+ ctx.num_centers = centers_coords.size(-1)
27
+ return points_features
28
+
29
+ @staticmethod
30
+ def backward(ctx, grad_output):
31
+ indices, weights = ctx.saved_tensors
32
+ grad_centers_features = _backend.three_nearest_neighbors_interpolate_backward(
33
+ grad_output.contiguous(), indices, weights, ctx.num_centers
34
+ )
35
+ return None, None, grad_centers_features
36
+
37
+
38
+ nearest_neighbor_interpolate = NeighborInterpolation.apply
model/pvcnn/modules/functional/loss.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ __all__ = ['kl_loss', 'huber_loss']
5
+
6
+
7
+ def kl_loss(x, y):
8
+ x = F.softmax(x.detach(), dim=1)
9
+ y = F.log_softmax(y, dim=1)
10
+ return torch.mean(torch.sum(x * (torch.log(x) - y), dim=1))
11
+
12
+
13
+ def huber_loss(error, delta):
14
+ abs_error = torch.abs(error)
15
+ quadratic = torch.min(abs_error, torch.full_like(abs_error, fill_value=delta))
16
+ losses = 0.5 * (quadratic ** 2) + delta * (abs_error - quadratic)
17
+ return torch.mean(losses)
model/pvcnn/modules/functional/sampling.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch.autograd import Function
4
+
5
+ from .backend import _backend
6
+
7
+ __all__ = ['gather', 'furthest_point_sample', 'logits_mask']
8
+
9
+
10
+ class Gather(Function):
11
+ @staticmethod
12
+ def forward(ctx, features, indices):
13
+ """
14
+ Gather
15
+ :param ctx:
16
+ :param features: features of points, FloatTensor[B, C, N]
17
+ :param indices: centers' indices in points, IntTensor[b, m]
18
+ :return:
19
+ centers_coords: coordinates of sampled centers, FloatTensor[B, C, M]
20
+ """
21
+ features = features.contiguous()
22
+ indices = indices.int().contiguous()
23
+ ctx.save_for_backward(indices)
24
+ ctx.num_points = features.size(-1)
25
+ return _backend.gather_features_forward(features, indices)
26
+
27
+ @staticmethod
28
+ def backward(ctx, grad_output):
29
+ indices, = ctx.saved_tensors
30
+ grad_features = _backend.gather_features_backward(grad_output.contiguous(), indices, ctx.num_points)
31
+ return grad_features, None
32
+
33
+
34
+ gather = Gather.apply
35
+
36
+
37
+ def furthest_point_sample(coords, num_samples):
38
+ """
39
+ Uses iterative furthest point sampling to select a set of npoint features that have the largest
40
+ minimum distance to the sampled point set
41
+ :param coords: coordinates of points, FloatTensor[B, 3, N]
42
+ :param num_samples: int, M
43
+ :return:
44
+ centers_coords: coordinates of sampled centers, FloatTensor[B, 3, M]
45
+ """
46
+ coords = coords.contiguous()
47
+ indices = _backend.furthest_point_sampling(coords, num_samples)
48
+ return gather(coords, indices)
49
+
50
+
51
+ def logits_mask(coords, logits, num_points_per_object):
52
+ """
53
+ Use logits to sample points
54
+ :param coords: coords of points, FloatTensor[B, 3, N]
55
+ :param logits: binary classification logits, FloatTensor[B, 2, N]
56
+ :param num_points_per_object: M, #points per object after masking, int
57
+ :return:
58
+ selected_coords: FloatTensor[B, 3, M]
59
+ masked_coords_mean: mean coords of selected points, FloatTensor[B, 3]
60
+ mask: mask to select points, BoolTensor[B, N]
61
+ """
62
+ batch_size, _, num_points = coords.shape
63
+ mask = torch.lt(logits[:, 0, :], logits[:, 1, :]) # [B, N]
64
+ num_candidates = torch.sum(mask, dim=-1, keepdim=True) # [B, 1]
65
+ masked_coords = coords * mask.view(batch_size, 1, num_points) # [B, C, N]
66
+ masked_coords_mean = torch.sum(masked_coords, dim=-1) / torch.max(num_candidates,
67
+ torch.ones_like(num_candidates)).float() # [B, C]
68
+ selected_indices = torch.zeros((batch_size, num_points_per_object), device=coords.device, dtype=torch.int32)
69
+ for i in range(batch_size):
70
+ current_mask = mask[i] # [N]
71
+ current_candidates = current_mask.nonzero().view(-1)
72
+ current_num_candidates = current_candidates.numel()
73
+ if current_num_candidates >= num_points_per_object:
74
+ choices = np.random.choice(current_num_candidates, num_points_per_object, replace=False)
75
+ selected_indices[i] = current_candidates[choices]
76
+ elif current_num_candidates > 0:
77
+ choices = np.concatenate([
78
+ np.arange(current_num_candidates).repeat(num_points_per_object // current_num_candidates),
79
+ np.random.choice(current_num_candidates, num_points_per_object % current_num_candidates, replace=False)
80
+ ])
81
+ np.random.shuffle(choices)
82
+ selected_indices[i] = current_candidates[choices]
83
+ selected_coords = gather(masked_coords - masked_coords_mean.view(batch_size, -1, 1), selected_indices)
84
+ return selected_coords, masked_coords_mean, mask
model/pvcnn/modules/functional/src/ball_query/ball_query.cpp ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ball_query.hpp"
2
+ #include "ball_query.cuh"
3
+
4
+ #include "../utils.hpp"
5
+
6
+ at::Tensor ball_query_forward(at::Tensor centers_coords,
7
+ at::Tensor points_coords, const float radius,
8
+ const int num_neighbors) {
9
+ CHECK_CUDA(centers_coords);
10
+ CHECK_CUDA(points_coords);
11
+ CHECK_CONTIGUOUS(centers_coords);
12
+ CHECK_CONTIGUOUS(points_coords);
13
+ CHECK_IS_FLOAT(centers_coords);
14
+ CHECK_IS_FLOAT(points_coords);
15
+
16
+ int b = centers_coords.size(0);
17
+ int m = centers_coords.size(2);
18
+ int n = points_coords.size(2);
19
+
20
+ at::Tensor neighbors_indices = torch::zeros(
21
+ {b, m, num_neighbors},
22
+ at::device(centers_coords.device()).dtype(at::ScalarType::Int));
23
+
24
+ ball_query(b, n, m, radius * radius, num_neighbors,
25
+ centers_coords.data_ptr<float>(),
26
+ points_coords.data_ptr<float>(),
27
+ neighbors_indices.data_ptr<int>());
28
+
29
+ return neighbors_indices;
30
+ }
model/pvcnn/modules/functional/src/ball_query/ball_query.cu ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <math.h>
2
+ #include <stdio.h>
3
+ #include <stdlib.h>
4
+
5
+ #include "../cuda_utils.cuh"
6
+
7
+ /*
8
+ Function: ball query
9
+ Args:
10
+ b : batch size
11
+ n : number of points in point clouds
12
+ m : number of query centers
13
+ r2 : ball query radius ** 2
14
+ u : maximum number of neighbors
15
+ centers_coords: coordinates of centers, FloatTensor[b, 3, m]
16
+ points_coords : coordinates of points, FloatTensor[b, 3, n]
17
+ neighbors_indices : neighbor indices in points, IntTensor[b, m, u]
18
+ */
19
+ __global__ void ball_query_kernel(int b, int n, int m, float r2, int u,
20
+ const float *__restrict__ centers_coords,
21
+ const float *__restrict__ points_coords,
22
+ int *__restrict__ neighbors_indices) {
23
+ int batch_index = blockIdx.x;
24
+ int index = threadIdx.x;
25
+ int stride = blockDim.x;
26
+ points_coords += batch_index * n * 3;
27
+ centers_coords += batch_index * m * 3;
28
+ neighbors_indices += batch_index * m * u;
29
+
30
+ for (int j = index; j < m; j += stride) {
31
+ float center_x = centers_coords[j];
32
+ float center_y = centers_coords[j + m];
33
+ float center_z = centers_coords[j + m + m];
34
+ for (int k = 0, cnt = 0; k < n && cnt < u; ++k) {
35
+ float dx = center_x - points_coords[k];
36
+ float dy = center_y - points_coords[k + n];
37
+ float dz = center_z - points_coords[k + n + n];
38
+ float d2 = dx * dx + dy * dy + dz * dz;
39
+ if (d2 < r2) {
40
+ if (cnt == 0) {
41
+ for (int v = 0; v < u; ++v) {
42
+ neighbors_indices[j * u + v] = k;
43
+ }
44
+ }
45
+ neighbors_indices[j * u + cnt] = k;
46
+ ++cnt;
47
+ }
48
+ }
49
+ }
50
+ }
51
+
52
+ void ball_query(int b, int n, int m, float r2, int u,
53
+ const float *centers_coords, const float *points_coords,
54
+ int *neighbors_indices) {
55
+ ball_query_kernel<<<b, optimal_num_threads(m), 0,
56
+ at::cuda::getCurrentCUDAStream()>>>(
57
+ b, n, m, r2, u, centers_coords, points_coords, neighbors_indices);
58
+ CUDA_CHECK_ERRORS();
59
+ }
model/pvcnn/modules/functional/src/ball_query/ball_query.cuh ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ #ifndef _BALL_QUERY_CUH
2
+ #define _BALL_QUERY_CUH
3
+
4
+ void ball_query(int b, int n, int m, float r2, int u,
5
+ const float *centers_coords, const float *points_coords,
6
+ int *neighbors_indices);
7
+
8
+ #endif
model/pvcnn/modules/functional/src/ball_query/ball_query.hpp ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef _BALL_QUERY_HPP
2
+ #define _BALL_QUERY_HPP
3
+
4
+ #include <torch/extension.h>
5
+
6
+ at::Tensor ball_query_forward(at::Tensor centers_coords,
7
+ at::Tensor points_coords, const float radius,
8
+ const int num_neighbors);
9
+
10
+ #endif
model/pvcnn/modules/functional/src/bindings.cpp ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <pybind11/pybind11.h>
2
+
3
+ #include "ball_query/ball_query.hpp"
4
+ #include "grouping/grouping.hpp"
5
+ #include "interpolate/neighbor_interpolate.hpp"
6
+ #include "interpolate/trilinear_devox.hpp"
7
+ #include "sampling/sampling.hpp"
8
+ #include "voxelization/vox.hpp"
9
+
10
+ PYBIND11_MODULE(_pvcnn_backend, m) {
11
+ m.def("gather_features_forward", &gather_features_forward,
12
+ "Gather Centers' Features forward (CUDA)");
13
+ m.def("gather_features_backward", &gather_features_backward,
14
+ "Gather Centers' Features backward (CUDA)");
15
+ m.def("furthest_point_sampling", &furthest_point_sampling_forward,
16
+ "Furthest Point Sampling (CUDA)");
17
+ m.def("ball_query", &ball_query_forward, "Ball Query (CUDA)");
18
+ m.def("grouping_forward", &grouping_forward,
19
+ "Grouping Features forward (CUDA)");
20
+ m.def("grouping_backward", &grouping_backward,
21
+ "Grouping Features backward (CUDA)");
22
+ m.def("three_nearest_neighbors_interpolate_forward",
23
+ &three_nearest_neighbors_interpolate_forward,
24
+ "3 Nearest Neighbors Interpolate forward (CUDA)");
25
+ m.def("three_nearest_neighbors_interpolate_backward",
26
+ &three_nearest_neighbors_interpolate_backward,
27
+ "3 Nearest Neighbors Interpolate backward (CUDA)");
28
+
29
+ m.def("trilinear_devoxelize_forward", &trilinear_devoxelize_forward,
30
+ "Trilinear Devoxelization forward (CUDA)");
31
+ m.def("trilinear_devoxelize_backward", &trilinear_devoxelize_backward,
32
+ "Trilinear Devoxelization backward (CUDA)");
33
+ m.def("avg_voxelize_forward", &avg_voxelize_forward,
34
+ "Voxelization forward with average pooling (CUDA)");
35
+ m.def("avg_voxelize_backward", &avg_voxelize_backward,
36
+ "Voxelization backward (CUDA)");
37
+ }
model/pvcnn/modules/functional/src/cuda_utils.cuh ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef _CUDA_UTILS_H
2
+ #define _CUDA_UTILS_H
3
+
4
+ #include <ATen/ATen.h>
5
+ #include <ATen/cuda/CUDAContext.h>
6
+ #include <cmath>
7
+
8
+ #include <cuda.h>
9
+ #include <cuda_runtime.h>
10
+
11
+ #include <vector>
12
+
13
+ #define MAXIMUM_THREADS 512
14
+
15
+ inline int optimal_num_threads(int work_size) {
16
+ const int pow_2 = std::log2(static_cast<double>(work_size));
17
+ return max(min(1 << pow_2, MAXIMUM_THREADS), 1);
18
+ }
19
+
20
+ inline dim3 optimal_block_config(int x, int y) {
21
+ const int x_threads = optimal_num_threads(x);
22
+ const int y_threads =
23
+ max(min(optimal_num_threads(y), MAXIMUM_THREADS / x_threads), 1);
24
+ dim3 block_config(x_threads, y_threads, 1);
25
+ return block_config;
26
+ }
27
+
28
+ #define CUDA_CHECK_ERRORS() \
29
+ { \
30
+ cudaError_t err = cudaGetLastError(); \
31
+ if (cudaSuccess != err) { \
32
+ fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \
33
+ cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \
34
+ __FILE__); \
35
+ exit(-1); \
36
+ } \
37
+ }
38
+
39
+ #endif
model/pvcnn/modules/functional/src/grouping/grouping.cpp ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "grouping.hpp"
2
+ #include "grouping.cuh"
3
+
4
+ #include "../utils.hpp"
5
+
6
+ at::Tensor grouping_forward(at::Tensor features, at::Tensor indices) {
7
+ CHECK_CUDA(features);
8
+ CHECK_CUDA(indices);
9
+ CHECK_CONTIGUOUS(features);
10
+ CHECK_CONTIGUOUS(indices);
11
+ CHECK_IS_FLOAT(features);
12
+ CHECK_IS_INT(indices);
13
+
14
+ int b = features.size(0);
15
+ int c = features.size(1);
16
+ int n = features.size(2);
17
+ int m = indices.size(1);
18
+ int u = indices.size(2);
19
+ at::Tensor output = torch::zeros(
20
+ {b, c, m, u}, at::device(features.device()).dtype(at::ScalarType::Float));
21
+ grouping(b, c, n, m, u, features.data_ptr<float>(), indices.data_ptr<int>(),
22
+ output.data_ptr<float>());
23
+ return output;
24
+ }
25
+
26
+ at::Tensor grouping_backward(at::Tensor grad_y, at::Tensor indices,
27
+ const int n) {
28
+ CHECK_CUDA(grad_y);
29
+ CHECK_CUDA(indices);
30
+ CHECK_CONTIGUOUS(grad_y);
31
+ CHECK_CONTIGUOUS(indices);
32
+ CHECK_IS_FLOAT(grad_y);
33
+ CHECK_IS_INT(indices);
34
+
35
+ int b = grad_y.size(0);
36
+ int c = grad_y.size(1);
37
+ int m = indices.size(1);
38
+ int u = indices.size(2);
39
+ at::Tensor grad_x = torch::zeros(
40
+ {b, c, n}, at::device(grad_y.device()).dtype(at::ScalarType::Float));
41
+ grouping_grad(b, c, n, m, u, grad_y.data_ptr<float>(),
42
+ indices.data_ptr<int>(), grad_x.data_ptr<float>());
43
+ return grad_x;
44
+ }
model/pvcnn/modules/functional/src/grouping/grouping.cu ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <stdio.h>
2
+ #include <stdlib.h>
3
+
4
+ #include "../cuda_utils.cuh"
5
+
6
+ /*
7
+ Function: grouping features of neighbors (forward)
8
+ Args:
9
+ b : batch size
10
+ c : #channles of features
11
+ n : number of points in point clouds
12
+ m : number of query centers
13
+ u : maximum number of neighbors
14
+ features: points' features, FloatTensor[b, c, n]
15
+ indices : neighbor indices in points, IntTensor[b, m, u]
16
+ out : gathered features, FloatTensor[b, c, m, u]
17
+ */
18
+ __global__ void grouping_kernel(int b, int c, int n, int m, int u,
19
+ const float *__restrict__ features,
20
+ const int *__restrict__ indices,
21
+ float *__restrict__ out) {
22
+ int batch_index = blockIdx.x;
23
+ features += batch_index * n * c;
24
+ indices += batch_index * m * u;
25
+ out += batch_index * m * u * c;
26
+
27
+ const int index = threadIdx.y * blockDim.x + threadIdx.x;
28
+ const int stride = blockDim.y * blockDim.x;
29
+ for (int i = index; i < c * m; i += stride) {
30
+ const int l = i / m;
31
+ const int j = i % m;
32
+ for (int k = 0; k < u; ++k) {
33
+ out[(l * m + j) * u + k] = features[l * n + indices[j * u + k]];
34
+ }
35
+ }
36
+ }
37
+
38
+ void grouping(int b, int c, int n, int m, int u, const float *features,
39
+ const int *indices, float *out) {
40
+ grouping_kernel<<<b, optimal_block_config(m, c), 0,
41
+ at::cuda::getCurrentCUDAStream()>>>(b, c, n, m, u, features,
42
+ indices, out);
43
+ CUDA_CHECK_ERRORS();
44
+ }
45
+
46
+ /*
47
+ Function: grouping features of neighbors (backward)
48
+ Args:
49
+ b : batch size
50
+ c : #channles of features
51
+ n : number of points in point clouds
52
+ m : number of query centers
53
+ u : maximum number of neighbors
54
+ grad_y : grad of gathered features, FloatTensor[b, c, m, u]
55
+ indices : neighbor indices in points, IntTensor[b, m, u]
56
+ grad_x: grad of points' features, FloatTensor[b, c, n]
57
+ */
58
+ __global__ void grouping_grad_kernel(int b, int c, int n, int m, int u,
59
+ const float *__restrict__ grad_y,
60
+ const int *__restrict__ indices,
61
+ float *__restrict__ grad_x) {
62
+ int batch_index = blockIdx.x;
63
+ grad_y += batch_index * m * u * c;
64
+ indices += batch_index * m * u;
65
+ grad_x += batch_index * n * c;
66
+
67
+ const int index = threadIdx.y * blockDim.x + threadIdx.x;
68
+ const int stride = blockDim.y * blockDim.x;
69
+ for (int i = index; i < c * m; i += stride) {
70
+ const int l = i / m;
71
+ const int j = i % m;
72
+ for (int k = 0; k < u; ++k) {
73
+ atomicAdd(grad_x + l * n + indices[j * u + k],
74
+ grad_y[(l * m + j) * u + k]);
75
+ }
76
+ }
77
+ }
78
+
79
+ void grouping_grad(int b, int c, int n, int m, int u, const float *grad_y,
80
+ const int *indices, float *grad_x) {
81
+ grouping_grad_kernel<<<b, optimal_block_config(m, c), 0,
82
+ at::cuda::getCurrentCUDAStream()>>>(
83
+ b, c, n, m, u, grad_y, indices, grad_x);
84
+ CUDA_CHECK_ERRORS();
85
+ }
model/pvcnn/modules/functional/src/grouping/grouping.cuh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef _GROUPING_CUH
2
+ #define _GROUPING_CUH
3
+
4
+ void grouping(int b, int c, int n, int m, int u, const float *features,
5
+ const int *indices, float *out);
6
+ void grouping_grad(int b, int c, int n, int m, int u, const float *grad_y,
7
+ const int *indices, float *grad_x);
8
+
9
+ #endif
model/pvcnn/modules/functional/src/grouping/grouping.hpp ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef _GROUPING_HPP
2
+ #define _GROUPING_HPP
3
+
4
+ #include <torch/extension.h>
5
+
6
+ at::Tensor grouping_forward(at::Tensor features, at::Tensor indices);
7
+ at::Tensor grouping_backward(at::Tensor grad_y, at::Tensor indices,
8
+ const int n);
9
+
10
+ #endif
model/pvcnn/modules/functional/src/interpolate/neighbor_interpolate.cpp ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "neighbor_interpolate.hpp"
2
+ #include "neighbor_interpolate.cuh"
3
+
4
+ #include "../utils.hpp"
5
+
6
+ std::vector<at::Tensor>
7
+ three_nearest_neighbors_interpolate_forward(at::Tensor points_coords,
8
+ at::Tensor centers_coords,
9
+ at::Tensor centers_features) {
10
+ CHECK_CUDA(points_coords);
11
+ CHECK_CUDA(centers_coords);
12
+ CHECK_CUDA(centers_features);
13
+ CHECK_CONTIGUOUS(points_coords);
14
+ CHECK_CONTIGUOUS(centers_coords);
15
+ CHECK_CONTIGUOUS(centers_features);
16
+ CHECK_IS_FLOAT(points_coords);
17
+ CHECK_IS_FLOAT(centers_coords);
18
+ CHECK_IS_FLOAT(centers_features);
19
+
20
+ int b = centers_features.size(0);
21
+ int c = centers_features.size(1);
22
+ int m = centers_features.size(2);
23
+ int n = points_coords.size(2);
24
+
25
+ at::Tensor indices = torch::zeros(
26
+ {b, 3, n}, at::device(points_coords.device()).dtype(at::ScalarType::Int));
27
+ at::Tensor weights = torch::zeros(
28
+ {b, 3, n},
29
+ at::device(points_coords.device()).dtype(at::ScalarType::Float));
30
+ at::Tensor output = torch::zeros(
31
+ {b, c, n},
32
+ at::device(centers_features.device()).dtype(at::ScalarType::Float));
33
+
34
+ three_nearest_neighbors_interpolate(
35
+ b, c, m, n, points_coords.data_ptr<float>(),
36
+ centers_coords.data_ptr<float>(), centers_features.data_ptr<float>(),
37
+ indices.data_ptr<int>(), weights.data_ptr<float>(),
38
+ output.data_ptr<float>());
39
+ return {output, indices, weights};
40
+ }
41
+
42
+ at::Tensor three_nearest_neighbors_interpolate_backward(at::Tensor grad_y,
43
+ at::Tensor indices,
44
+ at::Tensor weights,
45
+ const int m) {
46
+ CHECK_CUDA(grad_y);
47
+ CHECK_CUDA(indices);
48
+ CHECK_CUDA(weights);
49
+ CHECK_CONTIGUOUS(grad_y);
50
+ CHECK_CONTIGUOUS(indices);
51
+ CHECK_CONTIGUOUS(weights);
52
+ CHECK_IS_FLOAT(grad_y);
53
+ CHECK_IS_INT(indices);
54
+ CHECK_IS_FLOAT(weights);
55
+
56
+ int b = grad_y.size(0);
57
+ int c = grad_y.size(1);
58
+ int n = grad_y.size(2);
59
+ at::Tensor grad_x = torch::zeros(
60
+ {b, c, m}, at::device(grad_y.device()).dtype(at::ScalarType::Float));
61
+ three_nearest_neighbors_interpolate_grad(
62
+ b, c, n, m, grad_y.data_ptr<float>(), indices.data_ptr<int>(),
63
+ weights.data_ptr<float>(), grad_x.data_ptr<float>());
64
+ return grad_x;
65
+ }
model/pvcnn/modules/functional/src/interpolate/neighbor_interpolate.cu ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <math.h>
2
+ #include <stdio.h>
3
+ #include <stdlib.h>
4
+
5
+ #include "../cuda_utils.cuh"
6
+
7
+ /*
8
+ Function: three nearest neighbors
9
+ Args:
10
+ b : batch size
11
+ n : number of points in point clouds
12
+ m : number of query centers
13
+ points_coords : coordinates of points, FloatTensor[b, 3, n]
14
+ centers_coords: coordinates of centers, FloatTensor[b, 3, m]
15
+ weights : weights of nearest 3 centers to the point,
16
+ FloatTensor[b, 3, n]
17
+ indices : indices of nearest 3 centers to the point,
18
+ IntTensor[b, 3, n]
19
+ */
20
+ __global__ void three_nearest_neighbors_kernel(
21
+ int b, int n, int m, const float *__restrict__ points_coords,
22
+ const float *__restrict__ centers_coords, float *__restrict__ weights,
23
+ int *__restrict__ indices) {
24
+ int batch_index = blockIdx.x;
25
+ int index = threadIdx.x;
26
+ int stride = blockDim.x;
27
+ points_coords += batch_index * 3 * n;
28
+ weights += batch_index * 3 * n;
29
+ indices += batch_index * 3 * n;
30
+ centers_coords += batch_index * 3 * m;
31
+
32
+ for (int j = index; j < n; j += stride) {
33
+ float ux = points_coords[j];
34
+ float uy = points_coords[j + n];
35
+ float uz = points_coords[j + n + n];
36
+
37
+ double best0 = 1e40, best1 = 1e40, best2 = 1e40;
38
+ int besti0 = 0, besti1 = 0, besti2 = 0;
39
+ for (int k = 0; k < m; ++k) {
40
+ float x = centers_coords[k];
41
+ float y = centers_coords[k + m];
42
+ float z = centers_coords[k + m + m];
43
+ float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
44
+ if (d < best2) {
45
+ best2 = d;
46
+ besti2 = k;
47
+ if (d < best1) {
48
+ best2 = best1;
49
+ besti2 = besti1;
50
+ best1 = d;
51
+ besti1 = k;
52
+ if (d < best0) {
53
+ best1 = best0;
54
+ besti1 = besti0;
55
+ best0 = d;
56
+ besti0 = k;
57
+ }
58
+ }
59
+ }
60
+ }
61
+ best0 = max(min(1e10f, best0), 1e-10f);
62
+ best1 = max(min(1e10f, best1), 1e-10f);
63
+ best2 = max(min(1e10f, best2), 1e-10f);
64
+ float d0d1 = best0 * best1;
65
+ float d0d2 = best0 * best2;
66
+ float d1d2 = best1 * best2;
67
+ float d0d1d2 = 1.0f / (d0d1 + d0d2 + d1d2);
68
+ weights[j] = d1d2 * d0d1d2;
69
+ indices[j] = besti0;
70
+ weights[j + n] = d0d2 * d0d1d2;
71
+ indices[j + n] = besti1;
72
+ weights[j + n + n] = d0d1 * d0d1d2;
73
+ indices[j + n + n] = besti2;
74
+ }
75
+ }
76
+
77
+ /*
78
+ Function: interpolate three nearest neighbors (forward)
79
+ Args:
80
+ b : batch size
81
+ c : #channels of features
82
+ m : number of query centers
83
+ n : number of points in point clouds
84
+ centers_features: features of centers, FloatTensor[b, c, m]
85
+ indices : indices of nearest 3 centers to the point,
86
+ IntTensor[b, 3, n]
87
+ weights : weights for interpolation, FloatTensor[b, 3, n]
88
+ out : features of points, FloatTensor[b, c, n]
89
+ */
90
+ __global__ void three_nearest_neighbors_interpolate_kernel(
91
+ int b, int c, int m, int n, const float *__restrict__ centers_features,
92
+ const int *__restrict__ indices, const float *__restrict__ weights,
93
+ float *__restrict__ out) {
94
+ int batch_index = blockIdx.x;
95
+ centers_features += batch_index * m * c;
96
+ indices += batch_index * n * 3;
97
+ weights += batch_index * n * 3;
98
+ out += batch_index * n * c;
99
+
100
+ const int index = threadIdx.y * blockDim.x + threadIdx.x;
101
+ const int stride = blockDim.y * blockDim.x;
102
+ for (int i = index; i < c * n; i += stride) {
103
+ const int l = i / n;
104
+ const int j = i % n;
105
+ float w1 = weights[j];
106
+ float w2 = weights[j + n];
107
+ float w3 = weights[j + n + n];
108
+ int i1 = indices[j];
109
+ int i2 = indices[j + n];
110
+ int i3 = indices[j + n + n];
111
+
112
+ out[i] = centers_features[l * m + i1] * w1 +
113
+ centers_features[l * m + i2] * w2 +
114
+ centers_features[l * m + i3] * w3;
115
+ }
116
+ }
117
+
118
+ void three_nearest_neighbors_interpolate(int b, int c, int m, int n,
119
+ const float *points_coords,
120
+ const float *centers_coords,
121
+ const float *centers_features,
122
+ int *indices, float *weights,
123
+ float *out) {
124
+ three_nearest_neighbors_kernel<<<b, optimal_num_threads(n), 0,
125
+ at::cuda::getCurrentCUDAStream()>>>(
126
+ b, n, m, points_coords, centers_coords, weights, indices);
127
+ three_nearest_neighbors_interpolate_kernel<<<
128
+ b, optimal_block_config(n, c), 0, at::cuda::getCurrentCUDAStream()>>>(
129
+ b, c, m, n, centers_features, indices, weights, out);
130
+ CUDA_CHECK_ERRORS();
131
+ }
132
+
133
+ /*
134
+ Function: interpolate three nearest neighbors (backward)
135
+ Args:
136
+ b : batch size
137
+ c : #channels of features
138
+ m : number of query centers
139
+ n : number of points in point clouds
140
+ grad_y : grad of features of points, FloatTensor[b, c, n]
141
+ indices : indices of nearest 3 centers to the point, IntTensor[b, 3, n]
142
+ weights : weights for interpolation, FloatTensor[b, 3, n]
143
+ grad_x : grad of features of centers, FloatTensor[b, c, m]
144
+ */
145
+ __global__ void three_nearest_neighbors_interpolate_grad_kernel(
146
+ int b, int c, int n, int m, const float *__restrict__ grad_y,
147
+ const int *__restrict__ indices, const float *__restrict__ weights,
148
+ float *__restrict__ grad_x) {
149
+ int batch_index = blockIdx.x;
150
+ grad_y += batch_index * n * c;
151
+ indices += batch_index * n * 3;
152
+ weights += batch_index * n * 3;
153
+ grad_x += batch_index * m * c;
154
+
155
+ const int index = threadIdx.y * blockDim.x + threadIdx.x;
156
+ const int stride = blockDim.y * blockDim.x;
157
+ for (int i = index; i < c * n; i += stride) {
158
+ const int l = i / n;
159
+ const int j = i % n;
160
+ float w1 = weights[j];
161
+ float w2 = weights[j + n];
162
+ float w3 = weights[j + n + n];
163
+ int i1 = indices[j];
164
+ int i2 = indices[j + n];
165
+ int i3 = indices[j + n + n];
166
+ atomicAdd(grad_x + l * m + i1, grad_y[i] * w1);
167
+ atomicAdd(grad_x + l * m + i2, grad_y[i] * w2);
168
+ atomicAdd(grad_x + l * m + i3, grad_y[i] * w3);
169
+ }
170
+ }
171
+
172
+ void three_nearest_neighbors_interpolate_grad(int b, int c, int n, int m,
173
+ const float *grad_y,
174
+ const int *indices,
175
+ const float *weights,
176
+ float *grad_x) {
177
+ three_nearest_neighbors_interpolate_grad_kernel<<<
178
+ b, optimal_block_config(n, c), 0, at::cuda::getCurrentCUDAStream()>>>(
179
+ b, c, n, m, grad_y, indices, weights, grad_x);
180
+ CUDA_CHECK_ERRORS();
181
+ }
model/pvcnn/modules/functional/src/interpolate/neighbor_interpolate.cuh ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef _NEIGHBOR_INTERPOLATE_CUH
2
+ #define _NEIGHBOR_INTERPOLATE_CUH
3
+
4
+ void three_nearest_neighbors_interpolate(int b, int c, int m, int n,
5
+ const float *points_coords,
6
+ const float *centers_coords,
7
+ const float *centers_features,
8
+ int *indices, float *weights,
9
+ float *out);
10
+ void three_nearest_neighbors_interpolate_grad(int b, int c, int n, int m,
11
+ const float *grad_y,
12
+ const int *indices,
13
+ const float *weights,
14
+ float *grad_x);
15
+
16
+ #endif
model/pvcnn/modules/functional/src/interpolate/neighbor_interpolate.hpp ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef _NEIGHBOR_INTERPOLATE_HPP
2
+ #define _NEIGHBOR_INTERPOLATE_HPP
3
+
4
+ #include <torch/extension.h>
5
+ #include <vector>
6
+
7
+ std::vector<at::Tensor>
8
+ three_nearest_neighbors_interpolate_forward(at::Tensor points_coords,
9
+ at::Tensor centers_coords,
10
+ at::Tensor centers_features);
11
+ at::Tensor three_nearest_neighbors_interpolate_backward(at::Tensor grad_y,
12
+ at::Tensor indices,
13
+ at::Tensor weights,
14
+ const int m);
15
+
16
+ #endif