Hancy commited on
Commit
851751e
β€’
1 Parent(s): 8406293
This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. README.md +3 -3
  2. app.py +76 -4
  3. app_config.py +17 -0
  4. cam_examples/conditioning_000011.png +0 -0
  5. cam_examples/conditioning_000153.png +0 -0
  6. cam_examples/conditioning_000354.png +0 -0
  7. cam_examples/conditioning_000555.png +0 -0
  8. cam_examples/conditioning_001026.png +0 -0
  9. data/config/semantic-kitti.yaml +211 -0
  10. lidm/data/__init__.py +0 -0
  11. lidm/data/annotated_dataset.py +48 -0
  12. lidm/data/base.py +121 -0
  13. lidm/data/conditional_builder/__init__.py +0 -0
  14. lidm/data/conditional_builder/objects_bbox.py +53 -0
  15. lidm/data/conditional_builder/objects_center_points.py +150 -0
  16. lidm/data/conditional_builder/utils.py +188 -0
  17. lidm/data/helper_types.py +20 -0
  18. lidm/data/kitti.py +345 -0
  19. lidm/eval/README.md +95 -0
  20. lidm/eval/__init__.py +62 -0
  21. lidm/eval/compile.sh +9 -0
  22. lidm/eval/eval_utils.py +138 -0
  23. lidm/eval/fid_score.py +191 -0
  24. lidm/eval/metric_utils.py +458 -0
  25. lidm/eval/models/__init__.py +0 -0
  26. lidm/eval/models/minkowskinet/__init__.py +0 -0
  27. lidm/eval/models/minkowskinet/model.py +141 -0
  28. lidm/eval/models/rangenet/__init__.py +0 -0
  29. lidm/eval/models/rangenet/model.py +372 -0
  30. lidm/eval/models/spvcnn/__init__.py +0 -0
  31. lidm/eval/models/spvcnn/model.py +179 -0
  32. lidm/eval/models/ts/__init__.py +0 -0
  33. lidm/eval/models/ts/basic_blocks.py +79 -0
  34. lidm/eval/models/ts/utils.py +90 -0
  35. lidm/eval/modules/__init__.py +0 -0
  36. lidm/eval/modules/chamfer2D/__init__.py +0 -0
  37. lidm/eval/modules/chamfer2D/chamfer2D.cu +182 -0
  38. lidm/eval/modules/chamfer2D/chamfer_cuda.cpp +33 -0
  39. lidm/eval/modules/chamfer2D/dist_chamfer_2D.py +84 -0
  40. lidm/eval/modules/chamfer2D/setup.py +14 -0
  41. lidm/eval/modules/chamfer3D/__init__.py +0 -0
  42. lidm/eval/modules/chamfer3D/chamfer3D.cu +196 -0
  43. lidm/eval/modules/chamfer3D/chamfer_cuda.cpp +33 -0
  44. lidm/eval/modules/chamfer3D/dist_chamfer_3D.py +76 -0
  45. lidm/eval/modules/chamfer3D/setup.py +14 -0
  46. lidm/eval/modules/emd/__init__.py +0 -0
  47. lidm/eval/modules/emd/emd.cpp +31 -0
  48. lidm/eval/modules/emd/emd_cuda.cu +316 -0
  49. lidm/eval/modules/emd/emd_module.py +112 -0
  50. lidm/eval/modules/emd/setup.py +14 -0
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: LiDAR Diffusion
3
- emoji: πŸ“š
4
- colorFrom: blue
5
- colorTo: green
6
  sdk: gradio
7
  sdk_version: 4.26.0
8
  app_file: app.py
 
1
  ---
2
  title: LiDAR Diffusion
3
+ emoji: πŸš™πŸ›žπŸš¨
4
+ colorFrom: green
5
+ colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 4.26.0
8
  app_file: app.py
app.py CHANGED
@@ -1,9 +1,81 @@
1
  import gradio as gr
 
 
 
 
 
 
2
 
 
 
3
 
4
- def greet(name):
5
- return "Hello " + name + "!!"
6
 
7
 
8
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
9
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import spaces
3
+ import tempfile
4
+ import os
5
+ import torch
6
+ import numpy as np
7
+ from matplotlib.colors import LinearSegmentedColormap
8
 
9
+ from app_config import CSS, TITLE, DESCRIPTION, DEVICE
10
+ import sample_cond
11
 
12
+ model = sample_cond.load_model()
 
13
 
14
 
15
+ def create_custom_colormap():
16
+ colors = [(0, 1, 0), (0, 1, 1), (0, 0, 1), (1, 0, 1), (1, 1, 0)]
17
+ positions = [0, 0.38, 0.6, 0.7, 1]
18
+
19
+ custom_cmap = LinearSegmentedColormap.from_list('custom_colormap', list(zip(positions, colors)), N=256)
20
+ return custom_cmap
21
+
22
+
23
+ def colorize_depth(depth, log_scale):
24
+ if log_scale:
25
+ depth = ((np.log2((depth / 255.) * 56. + 1) / 5.84) * 255.).astype(np.uint8)
26
+ mask = depth == 0
27
+ colormap = create_custom_colormap()
28
+ rgb = colormap(depth)[:, :, :3]
29
+ rgb[mask] = 0.
30
+ return rgb
31
+
32
+
33
+ @spaces.GPU
34
+ @torch.no_grad()
35
+ def generate_lidar(model, cond):
36
+ img, pcd = sample_cond.sample(model, cond)
37
+ return img, pcd
38
+
39
+
40
+ def load_camera(image):
41
+ split_per_view = 4
42
+ camera = np.array(image).astype(np.float32) / 255.
43
+ camera = camera.transpose(2, 0, 1)
44
+ camera_list = np.split(camera, split_per_view, axis=2) # split into n chunks as different views
45
+ camera_cond = torch.from_numpy(np.stack(camera_list, axis=0)).unsqueeze(0).to(DEVICE)
46
+ return camera_cond
47
+
48
+
49
+ with gr.Blocks(css=CSS) as demo:
50
+ gr.Markdown(TITLE)
51
+ gr.Markdown(DESCRIPTION)
52
+ gr.Markdown("### Camera-to-LiDAR Demo")
53
+ # gr.Markdown("You can slide the output to compare the depth prediction with input image")
54
+
55
+ with gr.Row():
56
+ input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
57
+ output_image = gr.Image(label="Range Map", elem_id='img-display-output')
58
+ raw_file = gr.File(label="Point Cloud (.txt file). Can be viewed through Meshlab")
59
+ submit = gr.Button("Submit")
60
+
61
+ def on_submit(image):
62
+ cond = load_camera(image)
63
+ img, pcd = generate_lidar(model, cond)
64
+
65
+ tmp = tempfile.NamedTemporaryFile(suffix='.txt', delete=False)
66
+ pcd.save(tmp.name)
67
+
68
+ rgb_img = colorize_depth(img, log_scale=True)
69
+
70
+ return [rgb_img, tmp.name]
71
+
72
+ submit.click(on_submit, inputs=[input_image], outputs=[output_image, raw_file])
73
+
74
+ example_files = sorted(os.listdir('cam_examples'))
75
+ example_files = [os.path.join('cam_examples', filename) for filename in example_files]
76
+ examples = gr.Examples(examples=example_files, inputs=[input_image], outputs=[output_image, raw_file],
77
+ fn=on_submit, cache_examples=True)
78
+
79
+
80
+ if __name__ == '__main__':
81
+ demo.queue().launch()
app_config.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ CSS = """
4
+ #img-display-container {
5
+ max-height: 100vh;
6
+ }
7
+ #img-display-input {
8
+ max-height: 80vh;
9
+ }
10
+ #img-display-output {
11
+ max-height: 80vh;
12
+ }
13
+ """
14
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
15
+ TITLE = "# LiDAR Diffusion"
16
+ DESCRIPTION = """Official demo for **LiDAR Diffusion: Towards Realistic Scene Generation with LiDAR Diffusion Models**.
17
+ Please refer to our [paper](https://arxiv.org/abs/2404.00815), [project page](https://lidar-diffusion.github.io/), or [github](https://github.com/hancyran/LiDAR-Diffusion) for more details."""
cam_examples/conditioning_000011.png ADDED
cam_examples/conditioning_000153.png ADDED
cam_examples/conditioning_000354.png ADDED
cam_examples/conditioning_000555.png ADDED
cam_examples/conditioning_001026.png ADDED
data/config/semantic-kitti.yaml ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is covered by the LICENSE file in the root of this project.
2
+ labels:
3
+ 0 : "unlabeled"
4
+ 1 : "outlier"
5
+ 10: "car"
6
+ 11: "bicycle"
7
+ 13: "bus"
8
+ 15: "motorcycle"
9
+ 16: "on-rails"
10
+ 18: "truck"
11
+ 20: "other-vehicle"
12
+ 30: "person"
13
+ 31: "bicyclist"
14
+ 32: "motorcyclist"
15
+ 40: "road"
16
+ 44: "parking"
17
+ 48: "sidewalk"
18
+ 49: "other-ground"
19
+ 50: "building"
20
+ 51: "fence"
21
+ 52: "other-structure"
22
+ 60: "lane-marking"
23
+ 70: "vegetation"
24
+ 71: "trunk"
25
+ 72: "terrain"
26
+ 80: "pole"
27
+ 81: "traffic-sign"
28
+ 99: "other-object"
29
+ 252: "moving-car"
30
+ 253: "moving-bicyclist"
31
+ 254: "moving-person"
32
+ 255: "moving-motorcyclist"
33
+ 256: "moving-on-rails"
34
+ 257: "moving-bus"
35
+ 258: "moving-truck"
36
+ 259: "moving-other-vehicle"
37
+ color_map: # bgr
38
+ 0 : [0, 0, 0]
39
+ 1 : [0, 0, 255]
40
+ 10: [245, 150, 100]
41
+ 11: [245, 230, 100]
42
+ 13: [250, 80, 100]
43
+ 15: [150, 60, 30]
44
+ 16: [255, 0, 0]
45
+ 18: [180, 30, 80]
46
+ 20: [255, 0, 0]
47
+ 30: [30, 30, 255]
48
+ 31: [200, 40, 255]
49
+ 32: [90, 30, 150]
50
+ 40: [255, 0, 255]
51
+ 44: [255, 150, 255]
52
+ 48: [75, 0, 75]
53
+ 49: [75, 0, 175]
54
+ 50: [0, 200, 255]
55
+ 51: [50, 120, 255]
56
+ 52: [0, 150, 255]
57
+ 60: [170, 255, 150]
58
+ 70: [0, 175, 0]
59
+ 71: [0, 60, 135]
60
+ 72: [80, 240, 150]
61
+ 80: [150, 240, 255]
62
+ 81: [0, 0, 255]
63
+ 99: [255, 255, 50]
64
+ 252: [245, 150, 100]
65
+ 256: [255, 0, 0]
66
+ 253: [200, 40, 255]
67
+ 254: [30, 30, 255]
68
+ 255: [90, 30, 150]
69
+ 257: [250, 80, 100]
70
+ 258: [180, 30, 80]
71
+ 259: [255, 0, 0]
72
+ content: # as a ratio with the total number of points
73
+ 0: 0.018889854628292943
74
+ 1: 0.0002937197336781505
75
+ 10: 0.040818519255974316
76
+ 11: 0.00016609538710764618
77
+ 13: 2.7879693665067774e-05
78
+ 15: 0.00039838616015114444
79
+ 16: 0.0
80
+ 18: 0.0020633612104619787
81
+ 20: 0.0016218197275284021
82
+ 30: 0.00017698551338515307
83
+ 31: 1.1065903904919655e-08
84
+ 32: 5.532951952459828e-09
85
+ 40: 0.1987493871255525
86
+ 44: 0.014717169549888214
87
+ 48: 0.14392298360372
88
+ 49: 0.0039048553037472045
89
+ 50: 0.1326861944777486
90
+ 51: 0.0723592229456223
91
+ 52: 0.002395131480328884
92
+ 60: 4.7084144280367186e-05
93
+ 70: 0.26681502148037506
94
+ 71: 0.006035012012626033
95
+ 72: 0.07814222006271769
96
+ 80: 0.002855498193863172
97
+ 81: 0.0006155958086189918
98
+ 99: 0.009923127583046915
99
+ 252: 0.001789309418528068
100
+ 253: 0.00012709999297008662
101
+ 254: 0.00016059776092534436
102
+ 255: 3.745553104802113e-05
103
+ 256: 0.0
104
+ 257: 0.00011351574470342043
105
+ 258: 0.00010157861367183268
106
+ 259: 4.3840131989471124e-05
107
+ # classes that are indistinguishable from single scan or inconsistent in
108
+ # ground truth are mapped to their closest equivalent
109
+ learning_map:
110
+ 0 : 0 # "unlabeled"
111
+ 1 : 0 # "outlier" mapped to "unlabeled" --------------------------mapped
112
+ 10: 1 # "car"
113
+ 11: 2 # "bicycle"
114
+ 13: 5 # "bus" mapped to "other-vehicle" --------------------------mapped
115
+ 15: 3 # "motorcycle"
116
+ 16: 5 # "on-rails" mapped to "other-vehicle" ---------------------mapped
117
+ 18: 4 # "truck"
118
+ 20: 5 # "other-vehicle"
119
+ 30: 6 # "person"
120
+ 31: 7 # "bicyclist"
121
+ 32: 8 # "motorcyclist"
122
+ 40: 9 # "road"
123
+ 44: 10 # "parking"
124
+ 48: 11 # "sidewalk"
125
+ 49: 12 # "other-ground"
126
+ 50: 13 # "building"
127
+ 51: 14 # "fence"
128
+ 52: 0 # "other-structure" mapped to "unlabeled" ------------------mapped
129
+ 60: 9 # "lane-marking" to "road" ---------------------------------mapped
130
+ 70: 15 # "vegetation"
131
+ 71: 16 # "trunk"
132
+ 72: 17 # "terrain"
133
+ 80: 18 # "pole"
134
+ 81: 19 # "traffic-sign"
135
+ 99: 0 # "other-object" to "unlabeled" ----------------------------mapped
136
+ 252: 1 # "moving-car" to "car" ------------------------------------mapped
137
+ 253: 7 # "moving-bicyclist" to "bicyclist" ------------------------mapped
138
+ 254: 6 # "moving-person" to "person" ------------------------------mapped
139
+ 255: 8 # "moving-motorcyclist" to "motorcyclist" ------------------mapped
140
+ 256: 5 # "moving-on-rails" mapped to "other-vehicle" --------------mapped
141
+ 257: 5 # "moving-bus" mapped to "other-vehicle" -------------------mapped
142
+ 258: 4 # "moving-truck" to "truck" --------------------------------mapped
143
+ 259: 5 # "moving-other"-vehicle to "other-vehicle" ----------------mapped
144
+ learning_map_inv: # inverse of previous map
145
+ 0: 0 # "unlabeled", and others ignored
146
+ 1: 10 # "car"
147
+ 2: 11 # "bicycle"
148
+ 3: 15 # "motorcycle"
149
+ 4: 18 # "truck"
150
+ 5: 20 # "other-vehicle"
151
+ 6: 30 # "person"
152
+ 7: 31 # "bicyclist"
153
+ 8: 32 # "motorcyclist"
154
+ 9: 40 # "road"
155
+ 10: 44 # "parking"
156
+ 11: 48 # "sidewalk"
157
+ 12: 49 # "other-ground"
158
+ 13: 50 # "building"
159
+ 14: 51 # "fence"
160
+ 15: 70 # "vegetation"
161
+ 16: 71 # "trunk"
162
+ 17: 72 # "terrain"
163
+ 18: 80 # "pole"
164
+ 19: 81 # "traffic-sign"
165
+ learning_ignore: # Ignore classes
166
+ 0: True # "unlabeled", and others ignored
167
+ 1: False # "car"
168
+ 2: False # "bicycle"
169
+ 3: False # "motorcycle"
170
+ 4: False # "truck"
171
+ 5: False # "other-vehicle"
172
+ 6: False # "person"
173
+ 7: False # "bicyclist"
174
+ 8: False # "motorcyclist"
175
+ 9: False # "road"
176
+ 10: False # "parking"
177
+ 11: False # "sidewalk"
178
+ 12: False # "other-ground"
179
+ 13: False # "building"
180
+ 14: False # "fence"
181
+ 15: False # "vegetation"
182
+ 16: False # "trunk"
183
+ 17: False # "terrain"
184
+ 18: False # "pole"
185
+ 19: False # "traffic-sign"
186
+ split: # sequence numbers
187
+ train:
188
+ - 0
189
+ - 1
190
+ - 2
191
+ - 3
192
+ - 4
193
+ - 5
194
+ - 6
195
+ - 7
196
+ - 9
197
+ - 10
198
+ valid:
199
+ - 8
200
+ test:
201
+ - 11
202
+ - 12
203
+ - 13
204
+ - 14
205
+ - 15
206
+ - 16
207
+ - 17
208
+ - 18
209
+ - 19
210
+ - 20
211
+ - 21
lidm/data/__init__.py ADDED
File without changes
lidm/data/annotated_dataset.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Optional, List, Dict, Union, Any
3
+ import warnings
4
+
5
+ from torch.utils.data import Dataset
6
+
7
+ from .conditional_builder.objects_bbox import ObjectsBoundingBoxConditionalBuilder
8
+ from .conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
9
+
10
+
11
+ class Annotated3DObjectsDataset(Dataset):
12
+ def __init__(self, min_objects_per_image: int,
13
+ max_objects_per_image: int, no_tokens: int, num_beams: int, cats: List[str],
14
+ cat_blacklist: Optional[List[str]] = None, **kwargs):
15
+ self.min_objects_per_image = min_objects_per_image
16
+ self.max_objects_per_image = max_objects_per_image
17
+ self.no_tokens = no_tokens
18
+ self.num_beams = num_beams
19
+
20
+ self.categories = [c for c in cats if c not in cat_blacklist] if cat_blacklist is not None else cats
21
+ self._conditional_builders = None
22
+
23
+ @property
24
+ def no_classes(self) -> int:
25
+ return len(self.categories)
26
+
27
+ @property
28
+ def conditional_builders(self) -> ObjectsCenterPointsConditionalBuilder:
29
+ # cannot set this up in init because no_classes is only known after loading data in init of superclass
30
+ if self._conditional_builders is None:
31
+ self._conditional_builders = {
32
+ 'center': ObjectsCenterPointsConditionalBuilder(
33
+ self.no_classes,
34
+ self.max_objects_per_image,
35
+ self.no_tokens,
36
+ self.num_beams
37
+ ),
38
+ 'bbox': ObjectsBoundingBoxConditionalBuilder(
39
+ self.no_classes,
40
+ self.max_objects_per_image,
41
+ self.no_tokens,
42
+ self.num_beams
43
+ )
44
+ }
45
+ return self._conditional_builders
46
+
47
+ def get_textual_label_for_category_id(self, category_id: int) -> str:
48
+ return self.categories[category_id]
lidm/data/base.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ from abc import abstractmethod
3
+ from functools import partial
4
+
5
+ import PIL
6
+ import numpy as np
7
+ from PIL import Image
8
+
9
+ import torchvision.transforms.functional as TF
10
+ from torch.utils.data import Dataset, IterableDataset
11
+
12
+ from ..utils.aug_utils import get_lidar_transform, get_camera_transform, get_anno_transform
13
+
14
+
15
+ class DatasetBase(Dataset):
16
+ def __init__(self, data_root, split, dataset_config, aug_config, return_pcd=False, condition_key=None,
17
+ scale_factors=None, degradation=None, **kwargs):
18
+ self.data_root = data_root
19
+ self.split = split
20
+ self.data = []
21
+ self.aug_config = aug_config
22
+
23
+ self.img_size = dataset_config.size
24
+ self.fov = dataset_config.fov
25
+ self.depth_range = dataset_config.depth_range
26
+ self.filtered_map_cats = dataset_config.filtered_map_cats
27
+ self.depth_scale = dataset_config.depth_scale
28
+ self.log_scale = dataset_config.log_scale
29
+
30
+ if self.log_scale:
31
+ self.depth_thresh = (np.log2(1./255. + 1) / self.depth_scale) * 2. - 1 + 1e-6
32
+ else:
33
+ self.depth_thresh = (1./255. / self.depth_scale) * 2. - 1 + 1e-6
34
+ self.return_pcd = return_pcd
35
+
36
+ if degradation is not None and scale_factors is not None:
37
+ scaled_img_size = (int(self.img_size[0] / scale_factors[0]), int(self.img_size[1] / scale_factors[1]))
38
+ degradation_fn = {
39
+ "pil_nearest": PIL.Image.NEAREST,
40
+ "pil_bilinear": PIL.Image.BILINEAR,
41
+ "pil_bicubic": PIL.Image.BICUBIC,
42
+ "pil_box": PIL.Image.BOX,
43
+ "pil_hamming": PIL.Image.HAMMING,
44
+ "pil_lanczos": PIL.Image.LANCZOS,
45
+ }[degradation]
46
+ self.degradation_transform = partial(TF.resize, size=scaled_img_size, interpolation=degradation_fn)
47
+ else:
48
+ self.degradation_transform = None
49
+ self.condition_key = condition_key
50
+
51
+ self.lidar_transform = get_lidar_transform(aug_config, split)
52
+ self.anno_transform = get_anno_transform(aug_config, split) if condition_key in ['bbox', 'center'] else None
53
+ self.view_transform = get_camera_transform(aug_config, split) if condition_key in ['camera'] else None
54
+
55
+ self.prepare_data()
56
+
57
+ def prepare_data(self):
58
+ raise NotImplementedError
59
+
60
+ def process_scan(self, range_img):
61
+ range_img = np.where(range_img < 0, 0, range_img)
62
+
63
+ if self.log_scale:
64
+ # log scale
65
+ range_img = np.log2(range_img + 0.0001 + 1)
66
+
67
+ range_img = range_img / self.depth_scale
68
+ range_img = range_img * 2. - 1.
69
+
70
+ range_img = np.clip(range_img, -1, 1)
71
+ range_img = np.expand_dims(range_img, axis=0)
72
+
73
+ # mask
74
+ range_mask = np.ones_like(range_img)
75
+ range_mask[range_img < self.depth_thresh] = -1
76
+
77
+ return range_img, range_mask
78
+
79
+ @staticmethod
80
+ def load_lidar_sweep(*args, **kwargs):
81
+ raise NotImplementedError
82
+
83
+ @staticmethod
84
+ def load_semantic_map(*args, **kwargs):
85
+ raise NotImplementedError
86
+
87
+ @staticmethod
88
+ def load_camera(*args, **kwargs):
89
+ raise NotImplementedError
90
+
91
+ @staticmethod
92
+ def load_annotation(*args, **kwargs):
93
+ raise NotImplementedError
94
+
95
+ def __len__(self):
96
+ return len(self.data)
97
+
98
+ def __getitem__(self, idx):
99
+ example = dict()
100
+ return example
101
+
102
+
103
+ class Txt2ImgIterableBaseDataset(IterableDataset):
104
+ """
105
+ Define an interface to make the IterableDatasets for text2img data chainable
106
+ """
107
+ def __init__(self, num_records=0, valid_ids=None, size=256):
108
+ super().__init__()
109
+ self.num_records = num_records
110
+ self.valid_ids = valid_ids
111
+ self.sample_ids = valid_ids
112
+ self.size = size
113
+
114
+ print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
115
+
116
+ def __len__(self):
117
+ return self.num_records
118
+
119
+ @abstractmethod
120
+ def __iter__(self):
121
+ pass
lidm/data/conditional_builder/__init__.py ADDED
File without changes
lidm/data/conditional_builder/objects_bbox.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from itertools import cycle
2
+ from typing import List, Tuple, Callable, Optional
3
+
4
+ from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont
5
+ from more_itertools.recipes import grouper
6
+ from torch import LongTensor, Tensor
7
+
8
+ from ..helper_types import BoundingBox, Annotation
9
+ from .objects_center_points import ObjectsCenterPointsConditionalBuilder, convert_pil_to_tensor
10
+ from .utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, additional_parameters_string, \
11
+ pad_list, get_plot_font_size, absolute_bbox
12
+
13
+
14
+ class ObjectsBoundingBoxConditionalBuilder(ObjectsCenterPointsConditionalBuilder):
15
+ @property
16
+ def object_descriptor_length(self) -> int:
17
+ return 3 # 3/5: object_representation (1) + corners (2/4)
18
+
19
+ def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]:
20
+ object_tuples = [
21
+ (self.object_representation(ann), *self.token_pair_from_bbox(ann.bbox))
22
+ for ann in annotations
23
+ ]
24
+ object_tuples = pad_list(object_tuples, self.empty_tuple, self.no_max_objects)
25
+ return object_tuples
26
+
27
+ def inverse_build(self, conditional: LongTensor) -> Tuple[List[Tuple[int, BoundingBox]], Optional[BoundingBox]]:
28
+ conditional_list = conditional.tolist()
29
+ object_triples = grouper(conditional_list, 3)
30
+ assert conditional.shape[0] == self.embedding_dim
31
+ return [(object_triple[0], self.bbox_from_token_pair(object_triple[1], object_triple[2])) for object_triple in object_triples if object_triple[0] != self.none], None
32
+
33
+ def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int],
34
+ line_width: int = 3, font_size: Optional[int] = None) -> Tensor:
35
+ plot = pil_image.new('RGB', figure_size, WHITE)
36
+ draw = pil_img_draw.Draw(plot)
37
+ # font = ImageFont.truetype(
38
+ # "/usr/share/fonts/truetype/lato/Lato-Regular.ttf",
39
+ # size=get_plot_font_size(font_size, figure_size)
40
+ # )
41
+ font = ImageFont.load_default()
42
+ width, height = plot.size
43
+ description, crop_coordinates = self.inverse_build(conditional)
44
+ for (representation, bbox), color in zip(description, cycle(COLOR_PALETTE)):
45
+ annotation = self.representation_to_annotation(representation)
46
+ # class_label = label_for_category_no(annotation.category_id) + ' ' + additional_parameters_string(annotation)
47
+ class_label = label_for_category_no(annotation.category_id)
48
+ bbox = absolute_bbox(bbox, width, height)
49
+ draw.rectangle(bbox, outline=color, width=line_width)
50
+ draw.text((bbox[0] + line_width, bbox[1] + line_width), class_label, anchor='la', fill=BLACK, font=font)
51
+ if crop_coordinates is not None:
52
+ draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width)
53
+ return convert_pil_to_tensor(plot) / 127.5 - 1.
lidm/data/conditional_builder/objects_center_points.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import warnings
4
+ from itertools import cycle
5
+ from typing import List, Optional, Tuple, Callable
6
+
7
+ from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont
8
+ from more_itertools.recipes import grouper
9
+ from .utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, additional_parameters_string, pad_list, get_circle_size, \
10
+ get_plot_font_size, absolute_bbox
11
+ from ..helper_types import BoundingBox, Annotation, Image
12
+ from torch import LongTensor, Tensor
13
+ from torchvision.transforms import PILToTensor
14
+
15
+
16
+ pil_to_tensor = PILToTensor()
17
+
18
+
19
+ def convert_pil_to_tensor(image: Image) -> Tensor:
20
+ with warnings.catch_warnings():
21
+ # to filter PyTorch UserWarning as described here: https://github.com/pytorch/vision/issues/2194
22
+ warnings.simplefilter("ignore")
23
+ return pil_to_tensor(image)
24
+
25
+
26
+ class ObjectsCenterPointsConditionalBuilder:
27
+ def __init__(self, no_object_classes: int, no_max_objects: int, no_tokens: int, num_beams: int):
28
+ self.no_object_classes = no_object_classes
29
+ self.no_max_objects = no_max_objects
30
+ self.no_tokens = no_tokens
31
+ # self.no_sections = int(math.sqrt(self.no_tokens))
32
+ self.no_sections = (self.no_tokens // num_beams, num_beams) # (width, height)
33
+
34
+ @property
35
+ def none(self) -> int:
36
+ return self.no_tokens - 1
37
+
38
+ @property
39
+ def object_descriptor_length(self) -> int:
40
+ return 2
41
+
42
+ @property
43
+ def empty_tuple(self) -> Tuple:
44
+ return (self.none,) * self.object_descriptor_length
45
+
46
+ @property
47
+ def embedding_dim(self) -> int:
48
+ return self.no_max_objects * self.object_descriptor_length
49
+
50
+ def tokenize_coordinates(self, x: float, y: float) -> int:
51
+ """
52
+ Express 2d coordinates with one number.
53
+ Example: assume self.no_tokens = 16, then no_sections = 4:
54
+ 0 0 0 0
55
+ 0 0 # 0
56
+ 0 0 0 0
57
+ 0 0 0 x
58
+ Then the # position corresponds to token 6, the x position to token 15.
59
+ @param x: float in [0, 1]
60
+ @param y: float in [0, 1]
61
+ @return: discrete tokenized coordinate
62
+ """
63
+ x_discrete = int(round(x * (self.no_sections[0] - 1)))
64
+ y_discrete = int(round(y * (self.no_sections[1] - 1)))
65
+ return y_discrete * self.no_sections[0] + x_discrete
66
+
67
+ def coordinates_from_token(self, token: int) -> (float, float):
68
+ x = token % self.no_sections[0]
69
+ y = token // self.no_sections[0]
70
+ return x / (self.no_sections[0] - 1), y / (self.no_sections[1] - 1)
71
+
72
+ def bbox_from_token_pair(self, token1: int, token2: int) -> BoundingBox:
73
+ x0, y0 = self.coordinates_from_token(token1)
74
+ x1, y1 = self.coordinates_from_token(token2)
75
+ # x2, y2 = self.coordinates_from_token(token3)
76
+ # x3, y3 = self.coordinates_from_token(token4)
77
+ return x0, y0, x1, y1
78
+
79
+ def token_pair_from_bbox(self, bbox: BoundingBox) -> Tuple:
80
+ # return self.tokenize_coordinates(bbox[0], bbox[1]), self.tokenize_coordinates(bbox[2], bbox[3]), self.tokenize_coordinates(bbox[4], bbox[5]), self.tokenize_coordinates(bbox[6], bbox[7])
81
+ return self.tokenize_coordinates(bbox[0], bbox[1]), self.tokenize_coordinates(bbox[4], bbox[5])
82
+
83
+ def inverse_build(self, conditional: LongTensor) \
84
+ -> Tuple[List[Tuple[int, Tuple[float, float]]], Optional[BoundingBox]]:
85
+ conditional_list = conditional.tolist()
86
+ table_of_content = grouper(conditional_list, self.object_descriptor_length)
87
+ assert conditional.shape[0] == self.embedding_dim
88
+ return [
89
+ (object_tuple[0], self.coordinates_from_token(object_tuple[1]))
90
+ for object_tuple in table_of_content if object_tuple[0] != self.none
91
+ ], None
92
+
93
+ def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int],
94
+ line_width: int = 3, font_size: Optional[int] = None) -> Tensor:
95
+ plot = pil_image.new('RGB', figure_size, WHITE)
96
+ draw = pil_img_draw.Draw(plot)
97
+ circle_size = get_circle_size(figure_size)
98
+ # font = ImageFont.truetype('/usr/share/fonts/truetype/lato/Lato-Regular.ttf',
99
+ # size=get_plot_font_size(font_size, figure_size))
100
+ font = ImageFont.load_default()
101
+ width, height = plot.size
102
+ description, crop_coordinates = self.inverse_build(conditional)
103
+ for (representation, (x, y)), color in zip(description, cycle(COLOR_PALETTE)):
104
+ x_abs, y_abs = x * width, y * height
105
+ ann = self.representation_to_annotation(representation)
106
+ label = label_for_category_no(ann.category_id) + ' ' + additional_parameters_string(ann)
107
+ ellipse_bbox = [x_abs - circle_size, y_abs - circle_size, x_abs + circle_size, y_abs + circle_size]
108
+ draw.ellipse(ellipse_bbox, fill=color, width=0)
109
+ draw.text((x_abs, y_abs), label, anchor='md', fill=BLACK, font=font)
110
+ if crop_coordinates is not None:
111
+ draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width)
112
+ return convert_pil_to_tensor(plot) / 127.5 - 1.
113
+
114
+ def object_representation(self, annotation: Annotation) -> int:
115
+ return annotation.category_id
116
+
117
+ def representation_to_annotation(self, representation: int) -> Annotation:
118
+ category_id = representation % self.no_object_classes
119
+ # noinspection PyTypeChecker
120
+ return Annotation(
121
+ bbox=None,
122
+ category_id=category_id,
123
+ )
124
+
125
+ def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]:
126
+ object_tuples = [
127
+ (self.object_representation(a),
128
+ self.tokenize_coordinates(a.center[0], a.center[1]))
129
+ for a in annotations
130
+ ]
131
+ empty_tuple = (self.none, self.none)
132
+ object_tuples = pad_list(object_tuples, empty_tuple, self.no_max_objects)
133
+ return object_tuples
134
+
135
+ def build(self, annotations: List[Annotation]) \
136
+ -> LongTensor:
137
+ if len(annotations) == 0:
138
+ warnings.warn('Did not receive any annotations.')
139
+
140
+ random.shuffle(annotations)
141
+ if len(annotations) > self.no_max_objects:
142
+ warnings.warn('Received more annotations than allowed.')
143
+ annotations = annotations[:self.no_max_objects]
144
+
145
+ object_tuples = self._make_object_descriptors(annotations)
146
+ flattened = [token for tuple_ in object_tuples for token in tuple_]
147
+ assert len(flattened) == self.embedding_dim
148
+ assert all(0 <= value < self.no_tokens for value in flattened)
149
+
150
+ return LongTensor(flattened)
lidm/data/conditional_builder/utils.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from typing import List, Any, Tuple, Optional
3
+
4
+ import numpy as np
5
+ from ..helper_types import BoundingBox, Annotation
6
+
7
+ # source: seaborn, color palette tab10
8
+ COLOR_PALETTE = [(30, 118, 179), (255, 126, 13), (43, 159, 43), (213, 38, 39), (147, 102, 188),
9
+ (139, 85, 74), (226, 118, 193), (126, 126, 126), (187, 188, 33), (22, 189, 206)]
10
+ BLACK = (0, 0, 0)
11
+ GRAY_75 = (63, 63, 63)
12
+ GRAY_50 = (127, 127, 127)
13
+ GRAY_25 = (191, 191, 191)
14
+ WHITE = (255, 255, 255)
15
+ FULL_CROP = (0., 0., 1., 1.)
16
+
17
+
18
+ def corners_3d_to_2d(corners3d):
19
+ """
20
+ Args:
21
+ corners3d: (N, 8, 2)
22
+ Returns:
23
+ corners2d: (N, 4, 2)
24
+ """
25
+ # select pairs to reorganize
26
+ mask_0_3 = corners3d[:, 0:4, 0].argmax(1) // 2 != 0
27
+ mask_4_7 = corners3d[:, 4:8, 0].argmin(1) // 2 != 0
28
+
29
+ # reorganize corners in the order of (bottom-right, bottom-left)
30
+ corners3d[mask_0_3, 0:4] = corners3d[mask_0_3][:, [2, 3, 0, 1]]
31
+ # reorganize corners in the order of (top-left, top-right)
32
+ corners3d[mask_4_7, 4:8] = corners3d[mask_4_7][:, [2, 3, 0, 1]]
33
+
34
+ # calculate corners in order
35
+ bot_r = np.stack([corners3d[:, 0:2, 0].max(1), corners3d[:, 0:2, 1].min(1)], axis=-1)
36
+ bot_l = np.stack([corners3d[:, 2:4, 0].min(1), corners3d[:, 2:4, 1].min(1)], axis=-1)
37
+ top_l = np.stack([corners3d[:, 4:6, 0].min(1), corners3d[:, 4:6, 1].max(1)], axis=-1)
38
+ top_r = np.stack([corners3d[:, 6:8, 0].max(1), corners3d[:, 6:8, 1].max(1)], axis=-1)
39
+
40
+ return np.stack([bot_r, bot_l, top_l, top_r], axis=1)
41
+
42
+
43
+ def rotate_points_along_z(points, angle):
44
+ """
45
+ Args:
46
+ points: (N, 3 + C)
47
+ angle: angle along z-axis, angle increases x ==> y
48
+ Returns:
49
+
50
+ """
51
+ cosa = np.cos(angle)
52
+ sina = np.sin(angle)
53
+ zeros = np.zeros(points.shape[0])
54
+ ones = np.ones(points.shape[0])
55
+ rot_matrix = np.stack((
56
+ cosa, sina, zeros,
57
+ -sina, cosa, zeros,
58
+ zeros, zeros, ones)).reshape((-1, 3, 3))
59
+ points_rot = np.matmul(points[:, :, 0:3], rot_matrix)
60
+ points_rot = np.concatenate((points_rot, points[:, :, 3:]), axis=-1)
61
+ return points_rot
62
+
63
+
64
+ def boxes_to_corners_3d(boxes3d):
65
+ """
66
+ 7 -------- 4
67
+ /| /|
68
+ 6 -------- 5 .
69
+ | | | |
70
+ . 3 -------- 0
71
+ |/ |/
72
+ 2 -------- 1
73
+ Args:
74
+ boxes3d: (N, 7) [x, y, z, dx, dy, dz, heading], (x, y, z) is the box center
75
+
76
+ Returns:
77
+ corners3d: (N, 8, 3)
78
+ """
79
+ template = np.array(
80
+ [[1, 1, -1], [1, -1, -1], [-1, -1, -1], [-1, 1, -1],
81
+ [1, 1, 1], [1, -1, 1], [-1, -1, 1], [-1, 1, 1]],
82
+ ) / 2
83
+
84
+ # corners3d = boxes3d[:, None, 3:6].repeat(1, 8, 1) * template[None, :, :]
85
+ corners3d = np.tile(boxes3d[:, None, 3:6], (1, 8, 1)) * template[None, :, :]
86
+ corners3d = rotate_points_along_z(corners3d.reshape((-1, 8, 3)), boxes3d[:, 6]).reshape((-1, 8, 3))
87
+ corners3d += boxes3d[:, None, 0:3]
88
+
89
+ return corners3d
90
+
91
+
92
+ def intersection_area(rectangle1: BoundingBox, rectangle2: BoundingBox) -> float:
93
+ """
94
+ Give intersection area of two rectangles.
95
+ @param rectangle1: (x0, y0, w, h) of first rectangle
96
+ @param rectangle2: (x0, y0, w, h) of second rectangle
97
+ """
98
+ rectangle1 = rectangle1[0], rectangle1[1], rectangle1[0] + rectangle1[2], rectangle1[1] + rectangle1[3]
99
+ rectangle2 = rectangle2[0], rectangle2[1], rectangle2[0] + rectangle2[2], rectangle2[1] + rectangle2[3]
100
+ x_overlap = max(0., min(rectangle1[2], rectangle2[2]) - max(rectangle1[0], rectangle2[0]))
101
+ y_overlap = max(0., min(rectangle1[3], rectangle2[3]) - max(rectangle1[1], rectangle2[1]))
102
+ return x_overlap * y_overlap
103
+
104
+
105
+ def horizontally_flip_bbox(bbox: BoundingBox) -> BoundingBox:
106
+ return 1 - (bbox[0] + bbox[2]), bbox[1], bbox[2], bbox[3]
107
+
108
+
109
+ def absolute_bbox(relative_bbox: BoundingBox, width: int, height: int) -> Tuple[int, int, int, int]:
110
+ bbox = relative_bbox
111
+ # bbox = bbox[0] * width, bbox[1] * height, (bbox[0] + bbox[2]) * width, (bbox[1] + bbox[3]) * height
112
+ bbox = bbox[0] * width, bbox[1] * height, bbox[2] * width, bbox[3] * height
113
+ # return int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
114
+ x1, x2 = min(int(bbox[2]), int(bbox[0])), max(int(bbox[2]), int(bbox[0]))
115
+ y1, y2 = min(int(bbox[3]), int(bbox[1])), max(int(bbox[3]), int(bbox[1]))
116
+ if x1 == x2:
117
+ x2 += 1
118
+ if y1 == y2:
119
+ y2 += 1
120
+ return x1, y1, x2, y2
121
+
122
+
123
+ def pad_list(list_: List, pad_element: Any, pad_to_length: int) -> List:
124
+ return list_ + [pad_element for _ in range(pad_to_length - len(list_))]
125
+
126
+
127
+ def rescale_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox, flip: bool) -> \
128
+ List[Annotation]:
129
+ def clamp(x: float):
130
+ return max(min(x, 1.), 0.)
131
+
132
+ def rescale_bbox(bbox: BoundingBox) -> BoundingBox:
133
+ x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
134
+ y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
135
+ w = min(bbox[2] / crop_coordinates[2], 1 - x0)
136
+ h = min(bbox[3] / crop_coordinates[3], 1 - y0)
137
+ if flip:
138
+ x0 = 1 - (x0 + w)
139
+ return x0, y0, w, h
140
+
141
+ return [a._replace(bbox=rescale_bbox(a.bbox)) for a in annotations]
142
+
143
+
144
+ def filter_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox) -> List:
145
+ return [a for a in annotations if intersection_area(a.bbox, crop_coordinates) > 0.0]
146
+
147
+
148
+ def additional_parameters_string(annotation: Annotation, short: bool = True) -> str:
149
+ sl = slice(1) if short else slice(None)
150
+ string = ''
151
+ if not (annotation.is_group_of or annotation.is_occluded or annotation.is_depiction or annotation.is_inside):
152
+ return string
153
+ if annotation.is_group_of:
154
+ string += 'group'[sl] + ','
155
+ if annotation.is_occluded:
156
+ string += 'occluded'[sl] + ','
157
+ if annotation.is_depiction:
158
+ string += 'depiction'[sl] + ','
159
+ if annotation.is_inside:
160
+ string += 'inside'[sl]
161
+ return '(' + string.strip(",") + ')'
162
+
163
+
164
+ def get_plot_font_size(font_size: Optional[int], figure_size: Tuple[int, int]) -> int:
165
+ if font_size is None:
166
+ font_size = 10
167
+ if max(figure_size) >= 256:
168
+ font_size = 12
169
+ if max(figure_size) >= 512:
170
+ font_size = 15
171
+ return font_size
172
+
173
+
174
+ def get_circle_size(figure_size: Tuple[int, int]) -> int:
175
+ circle_size = 2
176
+ if max(figure_size) >= 256:
177
+ circle_size = 3
178
+ if max(figure_size) >= 512:
179
+ circle_size = 4
180
+ return circle_size
181
+
182
+
183
+ def load_object_from_string(object_string: str) -> Any:
184
+ """
185
+ Source: https://stackoverflow.com/a/10773699
186
+ """
187
+ module_name, class_name = object_string.rsplit(".", 1)
188
+ return getattr(importlib.import_module(module_name), class_name)
lidm/data/helper_types.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Optional, NamedTuple, Union, List
2
+ from PIL.Image import Image as pil_image
3
+ from torch import Tensor
4
+
5
+ try:
6
+ from typing import Literal
7
+ except ImportError:
8
+ from typing_extensions import Literal
9
+
10
+ Image = Union[Tensor, pil_image]
11
+ # BoundingBox = Tuple[float, float, float, float] # x0, y0, w, h | x0, y0, x1, y1
12
+ # BoundingBox3D = Tuple[float, float, float, float, float, float] # x0, y0, z0, l, w, h
13
+ BoundingBox = Tuple[float, float, float, float] # corner coordinates (x,y) in the order of bottom-right -> bottom-left -> top-left -> top-right
14
+ Center = Tuple[float, float]
15
+
16
+
17
+ class Annotation(NamedTuple):
18
+ category_id: int
19
+ bbox: Optional[BoundingBox] = None
20
+ center: Optional[Center] = None
lidm/data/kitti.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import pickle
4
+ import numpy as np
5
+ import yaml
6
+ from PIL import Image
7
+ import xml.etree.ElementTree as ET
8
+
9
+ from lidm.data.base import DatasetBase
10
+ from .annotated_dataset import Annotated3DObjectsDataset
11
+ from .conditional_builder.utils import corners_3d_to_2d
12
+ from .helper_types import Annotation
13
+ from ..utils.lidar_utils import pcd2range, pcd2coord2d, range2pcd
14
+
15
+ # TODO add annotation categories and semantic categories
16
+ CATEGORIES = ['ignore', 'car', 'bicycle', 'motorcycle', 'truck', 'other-vehicle', 'person', 'bicyclist', 'motorcyclist',
17
+ 'road', 'parking', 'sidewalk', 'other-ground', 'building', 'fence', 'vegetation', 'trunk', 'terrain',
18
+ 'pole', 'traffic-sign']
19
+ CATE2LABEL = {k: v for v, k in enumerate(CATEGORIES)} # 0: invalid, 1~10: categories
20
+ LABEL2RGB = np.array([(0, 0, 0), (0, 0, 142), (119, 11, 32), (0, 0, 230), (0, 0, 70), (0, 0, 90), (220, 20, 60),
21
+ (255, 0, 0), (0, 0, 110), (128, 64, 128), (250, 170, 160), (244, 35, 232), (230, 150, 140),
22
+ (70, 70, 70), (190, 153, 153), (107, 142, 35), (0, 80, 100), (230, 150, 140), (153, 153, 153),
23
+ (220, 220, 0)])
24
+ CAMERAS = ['CAM_FRONT']
25
+ BBOX_CATS = ['car', 'people', 'cycle']
26
+ BBOX_CAT2LABEL = {'car': 0, 'truck': 0, 'bus': 0, 'caravan': 0, 'person': 1, 'rider': 2, 'motorcycle': 2, 'bicycle': 2}
27
+
28
+ # train + test
29
+ SEM_KITTI_TRAIN_SET = ['00', '01', '02', '03', '04', '05', '06', '07', '09', '10']
30
+ KITTI_TRAIN_SET = SEM_KITTI_TRAIN_SET + ['11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21']
31
+ KITTI360_TRAIN_SET = ['00', '02', '04', '05', '06', '07', '09', '10'] + ['08'] # partial test data at '02' sequence
32
+ CAM_KITTI360_TRAIN_SET = ['00', '04', '05', '06', '07', '08', '09', '10'] # cam mismatch lidar in '02'
33
+
34
+ # validation
35
+ SEM_KITTI_VAL_SET = KITTI_VAL_SET = ['08']
36
+ CAM_KITTI360_VAL_SET = KITTI360_VAL_SET = ['03']
37
+
38
+
39
+ class KITTIBase(DatasetBase):
40
+ def __init__(self, **kwargs):
41
+ super().__init__(**kwargs)
42
+ self.dataset_name = 'kitti'
43
+ self.num_sem_cats = kwargs['dataset_config'].num_sem_cats + 1
44
+
45
+ @staticmethod
46
+ def load_lidar_sweep(path):
47
+ scan = np.fromfile(path, dtype=np.float32)
48
+ scan = scan.reshape((-1, 4))
49
+ points = scan[:, 0:3] # get xyz
50
+ return points
51
+
52
+ def load_semantic_map(self, path, pcd):
53
+ raise NotImplementedError
54
+
55
+ def load_camera(self, path):
56
+ raise NotImplementedError
57
+
58
+ def __getitem__(self, idx):
59
+ example = dict()
60
+ data_path = self.data[idx]
61
+ # lidar point cloud
62
+ sweep = self.load_lidar_sweep(data_path)
63
+
64
+ if self.lidar_transform:
65
+ sweep, _ = self.lidar_transform(sweep, None)
66
+
67
+ if self.condition_key == 'segmentation':
68
+ # semantic maps
69
+ proj_range, sem_map = self.load_semantic_map(data_path, sweep)
70
+ example[self.condition_key] = sem_map
71
+ else:
72
+ proj_range, _ = pcd2range(sweep, self.img_size, self.fov, self.depth_range)
73
+ proj_range, proj_mask = self.process_scan(proj_range)
74
+ example['image'], example['mask'] = proj_range, proj_mask
75
+ if self.return_pcd:
76
+ reproj_sweep, _, _ = range2pcd(proj_range[0] * .5 + .5, self.fov, self.depth_range, self.depth_scale, self.log_scale)
77
+ example['raw'] = sweep
78
+ example['reproj'] = reproj_sweep.astype(np.float32)
79
+
80
+ # image degradation
81
+ if self.degradation_transform:
82
+ degraded_proj_range = self.degradation_transform(proj_range)
83
+ example['degraded_image'] = degraded_proj_range
84
+
85
+ # cameras
86
+ if self.condition_key == 'camera':
87
+ cameras = self.load_camera(data_path)
88
+ example[self.condition_key] = cameras
89
+
90
+ return example
91
+
92
+
93
+ class SemanticKITTIBase(KITTIBase):
94
+ def __init__(self, **kwargs):
95
+ super().__init__(**kwargs)
96
+ assert self.condition_key in ['segmentation'] # for segmentation input only
97
+ self.label2rgb = LABEL2RGB
98
+
99
+ def prepare_data(self):
100
+ # read data paths from KITTI
101
+ for seq_id in eval('SEM_KITTI_%s_SET' % self.split.upper()):
102
+ self.data.extend(glob.glob(os.path.join(
103
+ self.data_root, f'dataset/sequences/{seq_id}/velodyne/*.bin')))
104
+ # read label mapping
105
+ data_config = yaml.safe_load(open('./data/config/semantic-kitti.yaml', 'r'))
106
+ remap_dict = data_config["learning_map"]
107
+ max_key = max(remap_dict.keys())
108
+ self.learning_map = np.zeros((max_key + 100), dtype=np.int32)
109
+ self.learning_map[list(remap_dict.keys())] = list(remap_dict.values())
110
+
111
+ def load_semantic_map(self, path, pcd):
112
+ label_path = path.replace('velodyne', 'labels').replace('.bin', '.label')
113
+ labels = np.fromfile(label_path, dtype=np.uint32)
114
+ labels = labels.reshape((-1))
115
+ labels = labels & 0xFFFF # semantic label in lower half
116
+ labels = self.learning_map[labels]
117
+
118
+ proj_range, sem_map = pcd2range(pcd, self.img_size, self.fov, self.depth_range, labels=labels)
119
+ # sem_map = np.expand_dims(sem_map, axis=0).astype(np.int64)
120
+ sem_map = sem_map.astype(np.int64)
121
+ if self.filtered_map_cats is not None:
122
+ sem_map[np.isin(sem_map, self.filtered_map_cats)] = 0 # set filtered category as noise
123
+ onehot = np.eye(self.num_sem_cats, dtype=np.float32)[sem_map].transpose(2, 0, 1)
124
+ return proj_range, onehot
125
+
126
+
127
+ class SemanticKITTITrain(SemanticKITTIBase):
128
+ def __init__(self, **kwargs):
129
+ super().__init__(data_root='./dataset/SemanticKITTI', split='train', **kwargs)
130
+
131
+
132
+ class SemanticKITTIValidation(SemanticKITTIBase):
133
+ def __init__(self, **kwargs):
134
+ super().__init__(data_root='./dataset/SemanticKITTI', split='val', **kwargs)
135
+
136
+
137
+ class KITTI360Base(KITTIBase):
138
+ def __init__(self, split_per_view=None, **kwargs):
139
+ super().__init__(**kwargs)
140
+ self.split_per_view = split_per_view
141
+ if self.condition_key == 'camera':
142
+ assert self.split_per_view is not None, 'For camera-to-lidar, need to specify split_per_view'
143
+
144
+ def prepare_data(self):
145
+ # read data paths
146
+ self.data = []
147
+ if self.condition_key == 'camera':
148
+ seq_list = eval('CAM_KITTI360_%s_SET' % self.split.upper())
149
+ else:
150
+ seq_list = eval('KITTI360_%s_SET' % self.split.upper())
151
+ for seq_id in seq_list:
152
+ self.data.extend(glob.glob(os.path.join(
153
+ self.data_root, f'data_3d_raw/2013_05_28_drive_00{seq_id}_sync/velodyne_points/data/*.bin')))
154
+
155
+ def random_drop_camera(self, camera_list):
156
+ if np.random.rand() < self.aug_config['camera_drop'] and self.split == 'train':
157
+ camera_list = [np.zeros_like(c) if i != len(camera_list) // 2 else c for i, c in enumerate(camera_list)] # keep the middle view only
158
+ return camera_list
159
+
160
+ def load_camera(self, path):
161
+ camera_path = path.replace('data_3d_raw', 'data_2d_camera').replace('velodyne_points/data', 'image_00/data_rect').replace('.bin', '.png')
162
+ camera = np.array(Image.open(camera_path)).astype(np.float32) / 255.
163
+ camera = camera.transpose(2, 0, 1)
164
+ if self.view_transform:
165
+ camera = self.view_transform(camera)
166
+ camera_list = np.split(camera, self.split_per_view, axis=2) # split into n chunks as different views
167
+ camera_list = self.random_drop_camera(camera_list)
168
+ return camera_list
169
+
170
+
171
+ class KITTI360Train(KITTI360Base):
172
+ def __init__(self, **kwargs):
173
+ super().__init__(data_root='./dataset/KITTI-360', split='train', **kwargs)
174
+
175
+
176
+ class KITTI360Validation(KITTI360Base):
177
+ def __init__(self, **kwargs):
178
+ super().__init__(data_root='./dataset/KITTI-360', split='val', **kwargs)
179
+
180
+
181
+ class AnnotatedKITTI360Base(Annotated3DObjectsDataset, KITTI360Base):
182
+ def __init__(self, **kwargs):
183
+ self.id_bbox_dict = dict()
184
+ self.id_label_dict = dict()
185
+
186
+ Annotated3DObjectsDataset.__init__(self, **kwargs)
187
+ KITTI360Base.__init__(self, **kwargs)
188
+ assert self.condition_key in ['center', 'bbox'] # for annotated images only
189
+
190
+ @staticmethod
191
+ def parseOpencvMatrix(node):
192
+ rows = int(node.find('rows').text)
193
+ cols = int(node.find('cols').text)
194
+ data = node.find('data').text.split(' ')
195
+
196
+ mat = []
197
+ for d in data:
198
+ d = d.replace('\n', '')
199
+ if len(d) < 1:
200
+ continue
201
+ mat.append(float(d))
202
+ mat = np.reshape(mat, [rows, cols])
203
+ return mat
204
+
205
+ def parseVertices(self, child):
206
+ transform = self.parseOpencvMatrix(child.find('transform'))
207
+ R = transform[:3, :3]
208
+ T = transform[:3, 3]
209
+ vertices = self.parseOpencvMatrix(child.find('vertices'))
210
+ vertices = np.matmul(R, vertices.transpose()).transpose() + T
211
+ return vertices
212
+
213
+ def parse_bbox_xml(self, path):
214
+ tree = ET.parse(path)
215
+ root = tree.getroot()
216
+
217
+ bbox_dict = dict()
218
+ label_dict = dict()
219
+ for child in root:
220
+ if child.find('transform') is None:
221
+ continue
222
+
223
+ label_name = child.find('label').text
224
+ if label_name not in BBOX_CAT2LABEL:
225
+ continue
226
+
227
+ label = BBOX_CAT2LABEL[label_name]
228
+ timestamp = int(child.find('timestamp').text)
229
+ # verts = self.parseVertices(child)
230
+ verts = self.parseOpencvMatrix(child.find('vertices'))[:8]
231
+ if timestamp in bbox_dict:
232
+ bbox_dict[timestamp].append(verts)
233
+ label_dict[timestamp].append(label)
234
+ else:
235
+ bbox_dict[timestamp] = [verts]
236
+ label_dict[timestamp] = [label]
237
+ return bbox_dict, label_dict
238
+
239
+ def prepare_data(self):
240
+ KITTI360Base.prepare_data(self)
241
+
242
+ self.data = [p for p in self.data if '2013_05_28_drive_0008_sync' not in p] # remove unlabeled sequence 08
243
+ seq_list = eval('KITTI360_%s_SET' % self.split.upper())
244
+ for seq_id in seq_list:
245
+ if seq_id != '08':
246
+ xml_path = os.path.join(self.data_root, f'data_3d_bboxes/train/2013_05_28_drive_00{seq_id}_sync.xml')
247
+ bbox_dict, label_dict = self.parse_bbox_xml(xml_path)
248
+ self.id_bbox_dict[seq_id] = bbox_dict
249
+ self.id_label_dict[seq_id] = label_dict
250
+
251
+ def load_annotation(self, path):
252
+ seq_id = path.split('/')[-4].split('_')[-2][-2:]
253
+ timestamp = int(path.split('/')[-1].replace('.bin', ''))
254
+ verts_list = self.id_bbox_dict[seq_id][timestamp]
255
+ label_list = self.id_label_dict[seq_id][timestamp]
256
+
257
+ if self.condition_key == 'bbox':
258
+ points = np.stack(verts_list)
259
+ elif self.condition_key == 'center':
260
+ points = (verts_list[0] + verts_list[6]) / 2.
261
+ else:
262
+ raise NotImplementedError
263
+ labels = np.array([label_list])
264
+ if self.anno_transform:
265
+ points, labels = self.anno_transform(points, labels)
266
+ return points, labels
267
+
268
+ def __getitem__(self, idx):
269
+ example = dict()
270
+ data_path = self.data[idx]
271
+
272
+ # lidar point cloud
273
+ sweep = self.load_lidar_sweep(data_path)
274
+
275
+ # annotations
276
+ bbox_points, bbox_labels = self.load_annotation(data_path)
277
+
278
+ if self.lidar_transform:
279
+ sweep, bbox_points = self.lidar_transform(sweep, bbox_points)
280
+
281
+ # point cloud -> range
282
+ proj_range, _ = pcd2range(sweep, self.img_size, self.fov, self.depth_range)
283
+ proj_range, proj_mask = self.process_scan(proj_range)
284
+ example['image'], example['mask'] = proj_range, proj_mask
285
+ if self.return_pcd:
286
+ example['reproj'] = sweep
287
+
288
+ # annotation -> range
289
+ # NOTE: do not need to transform bbox points along with lidar, since their coordinates are based on range-image space instead of 3D space
290
+ proj_bbox_points, proj_bbox_labels = pcd2coord2d(bbox_points, self.fov, self.depth_range, labels=bbox_labels)
291
+ builder = self.conditional_builders[self.condition_key]
292
+ if self.condition_key == 'bbox':
293
+ proj_bbox_points = corners_3d_to_2d(proj_bbox_points)
294
+ annotations = [Annotation(bbox=bbox.flatten(), category_id=label) for bbox, label in
295
+ zip(proj_bbox_points, proj_bbox_labels)]
296
+ else:
297
+ annotations = [Annotation(center=center, category_id=label) for center, label in
298
+ zip(proj_bbox_points, proj_bbox_labels)]
299
+ example[self.condition_key] = builder.build(annotations)
300
+
301
+ return example
302
+
303
+
304
+ class AnnotatedKITTI360Train(AnnotatedKITTI360Base):
305
+ def __init__(self, **kwargs):
306
+ super().__init__(data_root='./dataset/KITTI-360', split='train', cats=BBOX_CATS, **kwargs)
307
+
308
+
309
+ class AnnotatedKITTI360Validation(AnnotatedKITTI360Base):
310
+ def __init__(self, **kwargs):
311
+ super().__init__(data_root='./dataset/KITTI-360', split='train', cats=BBOX_CATS, **kwargs)
312
+
313
+
314
+ class KITTIImageBase(KITTIBase):
315
+ """
316
+ Range ImageSet only combining KITTI-360 and SemanticKITTI
317
+
318
+ #Samples (Training): 98014, #Samples (Val): 3511
319
+
320
+ """
321
+ def __init__(self, **kwargs):
322
+ super().__init__(**kwargs)
323
+ assert self.condition_key in [None, 'image'] # for image input only
324
+
325
+ def prepare_data(self):
326
+ # read data paths from KITTI-360
327
+ self.data = []
328
+ for seq_id in eval('KITTI360_%s_SET' % self.split.upper()):
329
+ self.data.extend(glob.glob(os.path.join(
330
+ self.data_root, f'KITTI-360/data_3d_raw/2013_05_28_drive_00{seq_id}_sync/velodyne_points/data/*.bin')))
331
+
332
+ # read data paths from KITTI
333
+ for seq_id in eval('KITTI_%s_SET' % self.split.upper()):
334
+ self.data.extend(glob.glob(os.path.join(
335
+ self.data_root, f'SemanticKITTI/dataset/sequences/{seq_id}/velodyne/*.bin')))
336
+
337
+
338
+ class KITTIImageTrain(KITTIImageBase):
339
+ def __init__(self, **kwargs):
340
+ super().__init__(data_root='./dataset', split='train', **kwargs)
341
+
342
+
343
+ class KITTIImageValidation(KITTIImageBase):
344
+ def __init__(self, **kwargs):
345
+ super().__init__(data_root='./dataset', split='val', **kwargs)
lidm/eval/README.md ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Evaluation Toolbox for LiDAR Generation
2
+
3
+ This directory is a **self-contained**, **memory-friendly** and mostly **CUDA-accelerated** toolbox of multiple evaluation metrics for LiDAR generative models, including:
4
+ * Perceptual metrics (our proposed):
5
+ * FrΓ©chet Range Image Distance (**FRID**)
6
+ * FrΓ©chet Sparse Volume Distance (**FSVD**)
7
+ * FrΓ©chet Point-based Volume Distance (**FPVD**)
8
+ * Statistical metrics (proposed in [Learning Representations and Generative Models for 3D Point Clouds](https://arxiv.org/abs/1707.02392)):
9
+ * Minimum Matching Distance (**MMD**)
10
+ * Jensen-Shannon Divergence (**JSD**)
11
+ * Statistical pairwise metrics (for reconstruction only):
12
+ * Chamfer Distance (**CD**)
13
+ * Earth Mover's Distance (**EMD**)
14
+
15
+ ## Citation
16
+
17
+ If you find this project useful in your research, please consider citing:
18
+ ```
19
+ @article{ran2024towards,
20
+ title={Towards Realistic Scene Generation with LiDAR Diffusion Models},
21
+ author={Ran, Haoxi and Guizilini, Vitor and Wang, Yue},
22
+ journal={arXiv preprint arXiv:2404.00815},
23
+ year={2024}
24
+ }
25
+ ```
26
+
27
+
28
+ ## Dependencies
29
+
30
+ ### Basic (install through **pip**):
31
+ * scipy
32
+ * numpy
33
+ * torch
34
+ * pyyaml
35
+
36
+ ### Required by FSVD and FPVD:
37
+ * [Torchsparse v1.4.0](https://github.com/mit-han-lab/torchsparse/tree/v1.4.0) (pip install git+https://github.com/mit-han-lab/torchsparse.git@v1.4.0)
38
+ * [Google Sparse Hash library](https://github.com/sparsehash/sparsehash) (apt-get install libsparsehash-dev **or** compile locally and update variable CPLUS_INCLUDE_PATH with directory path)
39
+
40
+
41
+ ## Model Zoo
42
+
43
+ To evaluate with perceptual metrics on different types of LiDAR data, you can download all models through:
44
+ * this [google drive link](https://drive.google.com/file/d/1Ml4p4_nMlwLkSp7JB528GJv2_HxO8v1i/view?usp=drive_link) in the .zip file
45
+
46
+ or
47
+ * the **full directory** of one specific model:
48
+
49
+ ### 64-beam LiDAR (trained on [SemanticKITTI](http://semantic-kitti.org/dataset.html)):
50
+
51
+ | Metric | Model | Arch | Link | Code | Comments |
52
+ |:------:|:-------------------------------------------------------------------------------------------:|:-----------------------:|:-------------------------------------------------------------------------------------------------------:|:-----------------------------------------------------------------|---------------------------------------------------------------------------|
53
+ | FRID | [RangeNet++](https://www.ipb.uni-bonn.de/wp-content/papercite-data/pdf/milioto2019iros.pdf) | DarkNet21-based UNet | [Google Drive](https://drive.google.com/drive/folders/1ZS8KOoxB9hjB6kwKbH5Zfc8O5qJlKsbl?usp=drive_link) | [./models/rangenet/model.py](./models/rangenet/model.py) | range image input (our trained model without the need of remission input) |
54
+ | FSVD | [MinkowskiNet](https://arxiv.org/abs/1904.08755) | Sparse UNet | [Google Drive](https://drive.google.com/drive/folders/1zN12ZEvjIvo4PCjAsncgC22yvtRrCCMe?usp=drive_link) | [./models/minkowskinet/model.py](./models/minkowskinet/model.py) | point cloud input |
55
+ | FPVD | [SPVCNN](https://arxiv.org/abs/2007.16100) | Point-Voxel Sparse UNet | [Google Drive](https://drive.google.com/drive/folders/1oEm3qpxfGetiVAfXIvecawEiFqW79M6B?usp=drive_link) | [./models/spvcnn/model.py](./models/spvcnn/model.py) | point cloud input |
56
+
57
+
58
+ ### 32-beam LiDAR (trained on [nuScenes](https://www.nuscenes.org/nuscenes)):
59
+
60
+ | Metric | Model | Arch | Link | Code | Comments |
61
+ |:------:|:------------------------------------------------:|:-----------------------:|:-------------------------------------------------------------------------------------------------------:|:-----------------------------------------------------------------|-------------------|
62
+ | FSVD | [MinkowskiNet](https://arxiv.org/abs/1904.08755) | Sparse UNet | [Google Drive](https://drive.google.com/drive/folders/1oZIS9FlklCQ6dlh3TZ8Junir7QwgT-Me?usp=drive_link) | [./models/minkowskinet/model.py](./models/minkowskinet/model.py) | point cloud input |
63
+ | FPVD | [SPVCNN](https://arxiv.org/abs/2007.16100) | Point-Voxel Sparse UNet | [Google Drive](https://drive.google.com/drive/folders/1F69RbprAoT6MOJ7iI0KHjxuq-tbeqGiR?usp=drive_link) | [./models/spvcnn/model.py](./models/spvcnn/model.py) | point cloud input |
64
+
65
+
66
+ ## Usage
67
+
68
+ 1. Place the unzipped `pretrained_weights` folder under the root python directory **or** modify the `DEFAULT_ROOT` variable in the `__init__.py`.
69
+ 2. Prepare input data, including the synthesized samples and the reference dataset. **Note**: The reference data should be the **point clouds projected back from range images** instead of raw point clouds.
70
+ 3. Specify the data type (`32` or `64`) and the metrics to evaluate. Options: `mmd`, `jsd`, `frid`, `fsvd`, `fpvd`, `cd`, `emd`.
71
+ 4. (Optional) If you want to compute `frid`, `fsvd` or `fpvd` metric, adjust the corresponding batch size through the `MODAL2BATCHSIZE` in file `__init__.py` according to your max GPU memory (default: ~24GB).
72
+ 5. Start evaluation and all results will print out!
73
+
74
+ ### Example:
75
+
76
+ ```
77
+ from .eval_utils import evaluate
78
+
79
+ data = '64' # specify data type to evaluate
80
+ metrics = ['mmd', 'jsd', 'frid', 'fsvd', 'fpvd'] # specify metrics to evaluate
81
+
82
+ # list of np.float32 array
83
+ # shape of each array: (#points, #dim=3), #dim: xyz coordinate (NOTE: no need to input remission)
84
+ reference = ...
85
+ samples = ...
86
+
87
+ evaluate(reference, samples, metrics, data)
88
+ ```
89
+
90
+
91
+ ## Acknowledgement
92
+
93
+ - The implementation of MinkowskiNet and SPVCNN is borrowed from [2DPASS](https://github.com/yanx27/2DPASS).
94
+ - The implementation of RangeNet++ is borrowed from [the official RangeNet++ codebase](https://github.com/PRBonn/lidar-bonnetal).
95
+ - The implementation of Chamfer Distance is adapted from [CD Pytorch Implementation](https://github.com/ThibaultGROUEIX/ChamferDistancePytorch) and Earth Mover's Distance from [MSN official repo](https://github.com/Colin97/MSN-Point-Cloud-Completion).
lidm/eval/__init__.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @Author: Haoxi Ran
3
+ @Date: 01/03/2024
4
+ @Citation: Towards Realistic Scene Generation with LiDAR Diffusion Models
5
+
6
+ """
7
+
8
+ import os
9
+
10
+ import torch
11
+ import yaml
12
+
13
+ from lidm.utils.misc_utils import dict2namespace
14
+ from ..modules.rangenet.model import Model as rangenet
15
+
16
+ try:
17
+ from ..modules.spvcnn.model import Model as spvcnn
18
+ from ..modules.minkowskinet.model import Model as minkowskinet
19
+ except:
20
+ print('To install torchsparse 1.4.0, please refer to https://github.com/mit-han-lab/torchsparse/tree/74099d10a51c71c14318bce63d6421f698b24f24')
21
+
22
+ # user settings
23
+ DEFAULT_ROOT = './pretrained_weights'
24
+ MODAL2BATCHSIZE = {'range': 100, 'voxel': 50, 'point_voxel': 25}
25
+ OUTPUT_TEMPLATE = 50 * '-' + '\n|' + 16 * ' ' + '{}:{:.4E}' + 17 * ' ' + '|\n' + 50 * '-'
26
+
27
+ # eval settings (do not modify)
28
+ VOXEL_SIZE = 0.05
29
+ NUM_SECTORS = 16
30
+ AGG_TYPE = 'depth'
31
+ TYPE2DATASET = {'32': 'nuscenes', '64': 'kitti'}
32
+ DATA_CONFIG = {'64': {'x': [-50, 50], 'y': [-50, 50], 'z': [-3, 1]},
33
+ '32': {'x': [-30, 30], 'y': [-30, 30], 'z': [-3, 6]}}
34
+ MODALITY2MODEL = {'range': 'rangenet', 'voxel': 'minkowskinet', 'point_voxel': 'spvcnn'}
35
+ DATASET_CONFIG = {'kitti': {'size': [64, 1024], 'fov': [3, -25], 'depth_range': [1.0, 56.0], 'depth_scale': 6},
36
+ 'nuscenes': {'size': [32, 1024], 'fov': [10, -30], 'depth_range': [1.0, 45.0]}}
37
+
38
+
39
+ def build_model(dataset_name, model_name, device='cpu'):
40
+ # config
41
+ model_folder = os.path.join(DEFAULT_ROOT, dataset_name, model_name)
42
+
43
+ if not os.path.isdir(model_folder):
44
+ raise Exception('Not Available Pretrained Weights!')
45
+
46
+ config = yaml.safe_load(open(os.path.join(model_folder, 'config.yaml'), 'r'))
47
+ if model_name != 'rangenet':
48
+ config = dict2namespace(config)
49
+
50
+ # build model
51
+ model = eval(model_name)(config)
52
+
53
+ # load checkpoint
54
+ if model_name == 'rangenet':
55
+ model.load_pretrained_weights(model_folder)
56
+ else:
57
+ ckpt = torch.load(os.path.join(model_folder, 'model.ckpt'), map_location="cpu")
58
+ model.load_state_dict(ckpt['state_dict'], strict=False)
59
+ model.to(device)
60
+ model.eval()
61
+
62
+ return model
lidm/eval/compile.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+
3
+ cd modules/chamfer
4
+ python setup.py build_ext --inplace
5
+
6
+ cd ../emd
7
+ python setup.py build_ext --inplace
8
+
9
+ cd ..
lidm/eval/eval_utils.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @Author: Haoxi Ran
3
+ @Date: 01/03/2024
4
+ @Citation: Towards Realistic Scene Generation with LiDAR Diffusion Models
5
+
6
+ """
7
+ import multiprocessing
8
+ from functools import partial
9
+
10
+ import numpy as np
11
+ from scipy.spatial.distance import jensenshannon
12
+ from tqdm import tqdm
13
+
14
+ from . import OUTPUT_TEMPLATE
15
+ from .metric_utils import compute_logits, compute_pairwise_cd, \
16
+ compute_pairwise_emd, pcd2bev_sum, compute_pairwise_cd_batch, pcd2bev_bin
17
+ from .fid_score import calculate_frechet_distance
18
+
19
+
20
+ def evaluate(reference, samples, metrics, data):
21
+ # perceptual
22
+ if 'frid' in metrics:
23
+ compute_frid(reference, samples, data)
24
+ if 'fsvd' in metrics:
25
+ compute_fsvd(reference, samples, data)
26
+ if 'fpvd' in metrics:
27
+ compute_fpvd(reference, samples, data)
28
+
29
+ # reconstruction
30
+ if 'cd' in metrics:
31
+ compute_cd(reference, samples)
32
+ if 'emd' in metrics:
33
+ compute_emd(reference, samples)
34
+
35
+ # statistical
36
+ if 'jsd' in metrics:
37
+ compute_jsd(reference, samples, data)
38
+ if 'mmd' in metrics:
39
+ compute_mmd(reference, samples, data)
40
+
41
+
42
+ def compute_cd(reference, samples):
43
+ """
44
+ Calculate score of Chamfer Distance (CD)
45
+
46
+ """
47
+ print('Evaluating (CD) ...')
48
+ results = []
49
+ for x, y in zip(reference, samples):
50
+ d = compute_pairwise_cd(x, y)
51
+ results.append(d)
52
+ score = sum(results) / len(results)
53
+ print(OUTPUT_TEMPLATE.format('CD ', score))
54
+
55
+
56
+ def compute_emd(reference, samples):
57
+ """
58
+ Calculate score of Earth Mover's Distance (EMD)
59
+
60
+ """
61
+ print('Evaluating (EMD) ...')
62
+ results = []
63
+ for x, y in zip(reference, samples):
64
+ d = compute_pairwise_emd(x, y)
65
+ results.append(d)
66
+ score = sum(results) / len(results)
67
+ print(OUTPUT_TEMPLATE.format('EMD ', score))
68
+
69
+
70
+ def compute_mmd(reference, samples, data, dist='cd', verbose=True):
71
+ """
72
+ Calculate the score of Minimum Matching Distance (MMD)
73
+
74
+ """
75
+ print('Evaluating (MMD) ...')
76
+ assert dist in ['cd', 'emd']
77
+ reference, samples = pcd2bev_bin(data, reference, samples)
78
+ compute_dist_func = compute_pairwise_cd_batch if dist == 'cd' else compute_pairwise_emd
79
+ results = []
80
+ for r in tqdm(reference, disable=not verbose):
81
+ dists = compute_dist_func(r, samples)
82
+ results.append(min(dists))
83
+ score = sum(results) / len(results)
84
+ print(OUTPUT_TEMPLATE.format('MMD ', score))
85
+
86
+
87
+ def compute_jsd(reference, samples, data):
88
+ """
89
+ Calculate the score of Jensen-Shannon Divergence (JSD)
90
+
91
+ """
92
+ print('Evaluating (JSD) ...')
93
+ reference, samples = pcd2bev_sum(data, reference, samples)
94
+ reference = (reference / np.sum(reference)).flatten()
95
+ samples = (samples / np.sum(samples)).flatten()
96
+ score = jensenshannon(reference, samples)
97
+ print(OUTPUT_TEMPLATE.format('JSD ', score))
98
+
99
+
100
+ def compute_fd(reference, samples):
101
+ mu1, mu2 = np.mean(reference, axis=0), np.mean(samples, axis=0)
102
+ sigma1, sigma2 = np.cov(reference, rowvar=False), np.cov(samples, rowvar=False)
103
+ distance = calculate_frechet_distance(mu1, sigma1, mu2, sigma2)
104
+ return distance
105
+
106
+
107
+ def compute_frid(reference, samples, data):
108
+ """
109
+ Calculate the score of FrΓ©chet Range Image Distance (FRID)
110
+
111
+ """
112
+ print('Evaluating (FRID) ...')
113
+ gt_logits, samples_logits = compute_logits(data, 'range', reference, samples)
114
+ score = compute_fd(gt_logits, samples_logits)
115
+ print(OUTPUT_TEMPLATE.format('FRID', score))
116
+
117
+
118
+ def compute_fsvd(reference, samples, data):
119
+ """
120
+ Calculate the score of FrΓ©chet Sparse Volume Distance (FSVD)
121
+
122
+ """
123
+ print('Evaluating (FSVD) ...')
124
+ gt_logits, samples_logits = compute_logits(data, 'voxel', reference, samples)
125
+ score = compute_fd(gt_logits, samples_logits)
126
+ print(OUTPUT_TEMPLATE.format('FSVD', score))
127
+
128
+
129
+ def compute_fpvd(reference, samples, data):
130
+ """
131
+ Calculate the score of FrΓ©chet Point-based Volume Distance (FPVD)
132
+
133
+ """
134
+ print('Evaluating (FPVD) ...')
135
+ gt_logits, samples_logits = compute_logits(data, 'point_voxel', reference, samples)
136
+ score = compute_fd(gt_logits, samples_logits)
137
+ print(OUTPUT_TEMPLATE.format('FPVD', score))
138
+
lidm/eval/fid_score.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Calculates the Frechet Inception Distance (FID) to evalulate GANs
2
+ The FID metric calculates the distance between two distributions of images.
3
+ Typically, we have summary statistics (mean & covariance matrix) of one
4
+ of these distributions, while the 2nd distribution is given by a GAN.
5
+ When run as a stand-alone program, it compares the distribution of
6
+ images that are stored as PNG/JPEG at a specified location with a
7
+ distribution given by summary statistics (in pickle format).
8
+ The FID is calculated by assuming that X_1 and X_2 are the activations of
9
+ the pool_3 layer of the inception net for generated samples and real world
10
+ samples respectively.
11
+ See --help to see further details.
12
+ Code adapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
13
+ of Tensorflow
14
+ Copyright 2018 Institute of Bioinformatics, JKU Linz
15
+ Licensed under the Apache License, Version 2.0 (the "License");
16
+ you may not use this file except in compliance with the License.
17
+ You may obtain a copy of the License at
18
+ http://www.apache.org/licenses/LICENSE-2.0
19
+ Unless required by applicable law or agreed to in writing, software
20
+ distributed under the License is distributed on an "AS IS" BASIS,
21
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22
+ See the License for the specific language governing permissions and
23
+ limitations under the License.
24
+ """
25
+ import os
26
+ import pathlib
27
+ from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
28
+
29
+ import numpy as np
30
+ import torch
31
+ import torchvision.transforms as TF
32
+ from PIL import Image
33
+ from scipy import linalg
34
+ from torch.nn.functional import adaptive_avg_pool2d
35
+
36
+ try:
37
+ from tqdm import tqdm
38
+ except ImportError:
39
+ # If tqdm is not available, provide a mock version of it
40
+ def tqdm(x):
41
+ return x
42
+
43
+ class ImagePathDataset(torch.utils.data.Dataset):
44
+ def __init__(self, files, transforms=None):
45
+ self.files = files
46
+ self.transforms = transforms
47
+
48
+ def __len__(self):
49
+ return len(self.files)
50
+
51
+ def __getitem__(self, i):
52
+ path = self.files[i]
53
+ img = Image.open(path).convert('RGB')
54
+ if self.transforms is not None:
55
+ img = self.transforms(img)
56
+ return img
57
+
58
+
59
+ def get_activations(files, model, batch_size=50, dims=2048, device='cpu',
60
+ num_workers=1):
61
+ """Calculates the activations of the pool_3 layer for all images.
62
+ Params:
63
+ -- files : List of image files paths
64
+ -- model : Instance of inception model
65
+ -- batch_size : Batch size of images for the model to process at once.
66
+ Make sure that the number of samples is a multiple of
67
+ the batch size, otherwise some samples are ignored. This
68
+ behavior is retained to match the original FID score
69
+ implementation.
70
+ -- dims : Dimensionality of features returned by Inception
71
+ -- device : Device to run calculations
72
+ -- num_workers : Number of parallel dataloader workers
73
+ Returns:
74
+ -- A numpy array of dimension (num images, dims) that contains the
75
+ activations of the given tensor when feeding inception with the
76
+ query tensor.
77
+ """
78
+ model.eval()
79
+
80
+ if batch_size > len(files):
81
+ print(('Warning: batch size is bigger than the data size. '
82
+ 'Setting batch size to data size'))
83
+ batch_size = len(files)
84
+
85
+ dataset = ImagePathDataset(files, transforms=TF.ToTensor())
86
+ dataloader = torch.utils.data.DataLoader(dataset,
87
+ batch_size=batch_size,
88
+ shuffle=False,
89
+ drop_last=False,
90
+ num_workers=num_workers)
91
+
92
+ pred_arr = np.empty((len(files), dims))
93
+
94
+ start_idx = 0
95
+
96
+ for batch in tqdm(dataloader):
97
+ batch = batch.to(device)
98
+
99
+ with torch.no_grad():
100
+ pred = model(batch)[0]
101
+
102
+ # If model output is not scalar, apply global spatial average pooling.
103
+ # This happens if you choose a dimensionality not equal 2048.
104
+ if pred.size(2) != 1 or pred.size(3) != 1:
105
+ pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
106
+
107
+ pred = pred.squeeze(3).squeeze(2).cpu().numpy()
108
+
109
+ pred_arr[start_idx:start_idx + pred.shape[0]] = pred
110
+
111
+ start_idx = start_idx + pred.shape[0]
112
+
113
+ return pred_arr
114
+
115
+
116
+ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
117
+ """Numpy implementation of the Frechet Distance.
118
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
119
+ and X_2 ~ N(mu_2, C_2) is
120
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
121
+ Stable version by Dougal J. Sutherland.
122
+ Params:
123
+ -- mu1 : Numpy array containing the activations of a layer of the
124
+ inception net (like returned by the function 'get_predictions')
125
+ for generated samples.
126
+ -- mu2 : The sample mean over activations, precalculated on an
127
+ representative data set.
128
+ -- sigma1: The covariance matrix over activations for generated samples.
129
+ -- sigma2: The covariance matrix over activations, precalculated on an
130
+ representative data set.
131
+ Returns:
132
+ -- : The Frechet Distance.
133
+ """
134
+
135
+ mu1 = np.atleast_1d(mu1)
136
+ mu2 = np.atleast_1d(mu2)
137
+
138
+ sigma1 = np.atleast_2d(sigma1)
139
+ sigma2 = np.atleast_2d(sigma2)
140
+
141
+ assert mu1.shape == mu2.shape, \
142
+ 'Training and test mean vectors have different lengths'
143
+ assert sigma1.shape == sigma2.shape, \
144
+ 'Training and test covariances have different dimensions'
145
+
146
+ diff = mu1 - mu2
147
+
148
+ # Product might be almost singular
149
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
150
+ if not np.isfinite(covmean).all():
151
+ msg = ('fid calculation produces singular product; '
152
+ 'adding %s to diagonal of cov estimates') % eps
153
+ print(msg)
154
+ offset = np.eye(sigma1.shape[0]) * eps
155
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
156
+
157
+ # Numerical error might give slight imaginary component
158
+ if np.iscomplexobj(covmean):
159
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
160
+ m = np.max(np.abs(covmean.imag))
161
+ raise ValueError('Imaginary component {}'.format(m))
162
+ covmean = covmean.real
163
+
164
+ tr_covmean = np.trace(covmean)
165
+
166
+ return (diff.dot(diff) + np.trace(sigma1)
167
+ + np.trace(sigma2) - 2 * tr_covmean)
168
+
169
+
170
+ def calculate_activation_statistics(files, model, batch_size=50, dims=2048,
171
+ device='cpu', num_workers=1):
172
+ """Calculation of the statistics used by the FID.
173
+ Params:
174
+ -- files : List of image files paths
175
+ -- model : Instance of inception model
176
+ -- batch_size : The images numpy array is split into batches with
177
+ batch size batch_size. A reasonable batch size
178
+ depends on the hardware.
179
+ -- dims : Dimensionality of features returned by Inception
180
+ -- device : Device to run calculations
181
+ -- num_workers : Number of parallel dataloader workers
182
+ Returns:
183
+ -- mu : The mean over samples of the activations of the pool_3 layer of
184
+ the inception model.
185
+ -- sigma : The covariance matrix of the activations of the pool_3 layer of
186
+ the inception model.
187
+ """
188
+ act = get_activations(files, model, batch_size, dims, device, num_workers)
189
+ mu = np.mean(act, axis=0)
190
+ sigma = np.cov(act, rowvar=False)
191
+ return mu, sigma
lidm/eval/metric_utils.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @Author: Haoxi Ran
3
+ @Date: 01/03/2024
4
+ @Citation: Towards Realistic Scene Generation with LiDAR Diffusion Models
5
+
6
+ """
7
+
8
+ import math
9
+ from itertools import repeat
10
+ from typing import List, Tuple, Union
11
+ import numpy as np
12
+ import torch
13
+
14
+ from . import build_model, VOXEL_SIZE, MODALITY2MODEL, MODAL2BATCHSIZE, DATASET_CONFIG, AGG_TYPE, NUM_SECTORS, \
15
+ TYPE2DATASET, DATA_CONFIG
16
+
17
+ try:
18
+ from torchsparse import SparseTensor, PointTensor
19
+ from torchsparse.utils.collate import sparse_collate_fn
20
+ from .modules.chamfer3D.dist_chamfer_3D import chamfer_3DDist
21
+ from .modules.chamfer2D.dist_chamfer_2D import chamfer_2DDist
22
+ from .modules.emd.emd_module import emdModule
23
+ except:
24
+ print(
25
+ 'To install torchsparse 1.4.0, please refer to https://github.com/mit-han-lab/torchsparse/tree/74099d10a51c71c14318bce63d6421f698b24f24')
26
+
27
+
28
+ def ravel_hash(x: np.ndarray) -> np.ndarray:
29
+ assert x.ndim == 2, x.shape
30
+
31
+ x = x - np.min(x, axis=0)
32
+ x = x.astype(np.uint64, copy=False)
33
+ xmax = np.max(x, axis=0).astype(np.uint64) + 1
34
+
35
+ h = np.zeros(x.shape[0], dtype=np.uint64)
36
+ for k in range(x.shape[1] - 1):
37
+ h += x[:, k]
38
+ h *= xmax[k + 1]
39
+ h += x[:, -1]
40
+ return h
41
+
42
+
43
+ def sparse_quantize(coords, voxel_size: Union[float, Tuple[float, ...]] = 1, *, return_index: bool = False,
44
+ return_inverse: bool = False) -> List[np.ndarray]:
45
+ """
46
+ Modified based on https://github.com/mit-han-lab/torchsparse/blob/462dea4a701f87a7545afb3616bf2cf53dd404f3/torchsparse/utils/quantize.py
47
+
48
+ """
49
+ if isinstance(voxel_size, (float, int)):
50
+ voxel_size = tuple(repeat(voxel_size, coords.shape[1]))
51
+ assert isinstance(voxel_size, tuple) and len(voxel_size) in [2, 3] # support 2D and 3D coordinates only
52
+
53
+ voxel_size = np.array(voxel_size)
54
+ coords = np.floor(coords / voxel_size).astype(np.int32)
55
+
56
+ _, indices, inverse_indices = np.unique(
57
+ ravel_hash(coords), return_index=True, return_inverse=True
58
+ )
59
+ coords = coords[indices]
60
+
61
+ outputs = [coords]
62
+ if return_index:
63
+ outputs += [indices]
64
+ if return_inverse:
65
+ outputs += [inverse_indices]
66
+ return outputs[0] if len(outputs) == 1 else outputs
67
+
68
+
69
+ def pcd2range(pcd, size, fov, depth_range, remission=None, labels=None, **kwargs):
70
+ # laser parameters
71
+ fov_up = fov[0] / 180.0 * np.pi # field of view up in rad
72
+ fov_down = fov[1] / 180.0 * np.pi # field of view down in rad
73
+ fov_range = abs(fov_down) + abs(fov_up) # get field of view total in rad
74
+
75
+ # get depth (distance) of all points
76
+ depth = np.linalg.norm(pcd, 2, axis=1)
77
+
78
+ # mask points out of range
79
+ mask = np.logical_and(depth > depth_range[0], depth < depth_range[1])
80
+ depth, pcd = depth[mask], pcd[mask]
81
+
82
+ # get scan components
83
+ scan_x, scan_y, scan_z = pcd[:, 0], pcd[:, 1], pcd[:, 2]
84
+
85
+ # get angles of all points
86
+ yaw = -np.arctan2(scan_y, scan_x)
87
+ pitch = np.arcsin(scan_z / depth)
88
+
89
+ # get projections in image coords
90
+ proj_x = 0.5 * (yaw / np.pi + 1.0) # in [0.0, 1.0]
91
+ proj_y = 1.0 - (pitch + abs(fov_down)) / fov_range # in [0.0, 1.0]
92
+
93
+ # scale to image size using angular resolution
94
+ proj_x *= size[1] # in [0.0, W]
95
+ proj_y *= size[0] # in [0.0, H]
96
+
97
+ # round and clamp for use as index
98
+ proj_x = np.maximum(0, np.minimum(size[1] - 1, np.floor(proj_x))).astype(np.int32) # in [0,W-1]
99
+ proj_y = np.maximum(0, np.minimum(size[0] - 1, np.floor(proj_y))).astype(np.int32) # in [0,H-1]
100
+
101
+ # order in decreasing depth
102
+ order = np.argsort(depth)[::-1]
103
+ proj_x, proj_y = proj_x[order], proj_y[order]
104
+
105
+ # project depth
106
+ depth = depth[order]
107
+ proj_range = np.full(size, -1, dtype=np.float32)
108
+ proj_range[proj_y, proj_x] = depth
109
+
110
+ # project point feature
111
+ if remission is not None:
112
+ remission = remission[mask][order]
113
+ proj_feature = np.full(size, -1, dtype=np.float32)
114
+ proj_feature[proj_y, proj_x] = remission
115
+ elif labels is not None:
116
+ labels = labels[mask][order]
117
+ proj_feature = np.full(size, 0, dtype=np.float32)
118
+ proj_feature[proj_y, proj_x] = labels
119
+ else:
120
+ proj_feature = None
121
+
122
+ return proj_range, proj_feature
123
+
124
+
125
+ def range2xyz(range_img, fov, depth_range, depth_scale, log_scale=True, **kwargs):
126
+ # laser parameters
127
+ size = range_img.shape
128
+ fov_up = fov[0] / 180.0 * np.pi # field of view up in rad
129
+ fov_down = fov[1] / 180.0 * np.pi # field of view down in rad
130
+ fov_range = abs(fov_down) + abs(fov_up) # get field of view total in rad
131
+
132
+ # inverse transform from depth
133
+ if log_scale:
134
+ depth = (np.exp2(range_img * depth_scale) - 1)
135
+ else:
136
+ depth = range_img
137
+
138
+ scan_x, scan_y = np.meshgrid(np.arange(size[1]), np.arange(size[0]))
139
+ scan_x = scan_x.astype(np.float64) / size[1]
140
+ scan_y = scan_y.astype(np.float64) / size[0]
141
+
142
+ yaw = np.pi * (scan_x * 2 - 1)
143
+ pitch = (1.0 - scan_y) * fov_range - abs(fov_down)
144
+
145
+ xyz = -np.ones((3, *size))
146
+ xyz[0] = np.cos(yaw) * np.cos(pitch) * depth
147
+ xyz[1] = -np.sin(yaw) * np.cos(pitch) * depth
148
+ xyz[2] = np.sin(pitch) * depth
149
+
150
+ # mask out invalid points
151
+ mask = np.logical_and(depth > depth_range[0], depth < depth_range[1])
152
+ xyz[:, ~mask] = -1
153
+
154
+ return xyz
155
+
156
+
157
+ def pcd2voxel(pcd):
158
+ pcd_voxel = np.round(pcd / VOXEL_SIZE)
159
+ pcd_voxel = pcd_voxel - pcd_voxel.min(0, keepdims=1)
160
+ feat = np.concatenate((pcd, -np.ones((pcd.shape[0], 1))), axis=1) # -1 for remission placeholder
161
+ _, inds, inverse_map = sparse_quantize(pcd_voxel, 1, return_index=True, return_inverse=True)
162
+
163
+ feat = torch.FloatTensor(feat[inds])
164
+ pcd_voxel = torch.LongTensor(pcd_voxel[inds])
165
+ lidar = SparseTensor(feat, pcd_voxel)
166
+ output = {'lidar': lidar}
167
+ return output
168
+
169
+
170
+ def pcd2voxel_full(data_type, *args):
171
+ config = DATA_CONFIG[data_type]
172
+ x_range, y_range, z_range = config['x'], config['y'], config['z']
173
+ vol_shape = (math.ceil((x_range[1] - x_range[0]) / VOXEL_SIZE), math.ceil((y_range[1] - y_range[0]) / VOXEL_SIZE),
174
+ math.ceil((z_range[1] - z_range[0]) / VOXEL_SIZE))
175
+ min_bound = (math.ceil((x_range[0]) / VOXEL_SIZE), math.ceil((y_range[0]) / VOXEL_SIZE),
176
+ math.ceil((z_range[0]) / VOXEL_SIZE))
177
+
178
+ output = tuple()
179
+ for data in args:
180
+ volume_list = []
181
+ for pcd in data:
182
+ # mask out invalid points
183
+ mask_x = np.logical_and(pcd[:, 0] > x_range[0], pcd[:, 0] < x_range[1])
184
+ mask_y = np.logical_and(pcd[:, 1] > y_range[0], pcd[:, 1] < y_range[1])
185
+ mask_z = np.logical_and(pcd[:, 2] > z_range[0], pcd[:, 2] < z_range[1])
186
+ mask = mask_x & mask_y & mask_z
187
+ pcd = pcd[mask]
188
+
189
+ # voxelize
190
+ pcd_voxel = np.floor(pcd / VOXEL_SIZE)
191
+ _, indices, inverse_map = sparse_quantize(pcd_voxel, 1, return_index=True, return_inverse=True)
192
+ pcd_voxel = pcd_voxel[indices]
193
+ pcd_voxel = (pcd_voxel - min_bound).astype(np.int32)
194
+
195
+ # 2D bev grid
196
+ vol = np.zeros(vol_shape, dtype=np.float32)
197
+ vol[pcd_voxel[:, 0], pcd_voxel[:, 1], pcd_voxel[:, 2]] = 1
198
+ volume_list.append(vol)
199
+ output += (volume_list,)
200
+ return output
201
+
202
+
203
+ # def pcd2bev_full(data_type, *args, voxel_size=VOXEL_SIZE):
204
+ # config = DATA_CONFIG[data_type]
205
+ # x_range, y_range = config['x'], config['y']
206
+ # vol_shape = (math.ceil((x_range[1] - x_range[0]) / voxel_size), math.ceil((y_range[1] - y_range[0]) / voxel_size))
207
+ # min_bound = (math.ceil((x_range[0]) / voxel_size), math.ceil((y_range[0]) / voxel_size))
208
+ #
209
+ # output = tuple()
210
+ # for data in args:
211
+ # volume_list = []
212
+ # for pcd in data:
213
+ # # mask out invalid points
214
+ # mask_x = np.logical_and(pcd[:, 0] > x_range[0], pcd[:, 0] < x_range[1])
215
+ # mask_y = np.logical_and(pcd[:, 1] > y_range[0], pcd[:, 1] < y_range[1])
216
+ # mask = mask_x & mask_y
217
+ # pcd = pcd[mask][:, :2] # keep x,y coord
218
+ #
219
+ # # voxelize
220
+ # pcd_voxel = np.floor(pcd / voxel_size)
221
+ # _, indices, inverse_map = sparse_quantize(pcd_voxel, 1, return_index=True, return_inverse=True)
222
+ # pcd_voxel = pcd_voxel[indices]
223
+ # pcd_voxel = (pcd_voxel - min_bound).astype(np.int32)
224
+ #
225
+ # # 2D bev grid
226
+ # vol = np.zeros(vol_shape, dtype=np.float32)
227
+ # vol[pcd_voxel[:, 0], pcd_voxel[:, 1]] = 1
228
+ # volume_list.append(vol)
229
+ # output += (volume_list,)
230
+ # return output
231
+
232
+
233
+ def pcd2bev_sum(data_type, *args, voxel_size=VOXEL_SIZE):
234
+ config = DATA_CONFIG[data_type]
235
+ x_range, y_range = config['x'], config['y']
236
+ vol_shape = (math.ceil((x_range[1] - x_range[0]) / voxel_size), math.ceil((y_range[1] - y_range[0]) / voxel_size))
237
+ min_bound = (math.ceil((x_range[0]) / voxel_size), math.ceil((y_range[0]) / voxel_size))
238
+
239
+ output = tuple()
240
+ for data in args:
241
+ volume_sum = np.zeros(vol_shape, np.float32)
242
+ for pcd in data:
243
+ # mask out invalid points
244
+ mask_x = np.logical_and(pcd[:, 0] > x_range[0], pcd[:, 0] < x_range[1])
245
+ mask_y = np.logical_and(pcd[:, 1] > y_range[0], pcd[:, 1] < y_range[1])
246
+ mask = mask_x & mask_y
247
+ pcd = pcd[mask][:, :2] # keep x,y coord
248
+
249
+ # voxelize
250
+ pcd_voxel = np.floor(pcd / voxel_size)
251
+ _, indices, inverse_map = sparse_quantize(pcd_voxel, 1, return_index=True, return_inverse=True)
252
+ pcd_voxel = pcd_voxel[indices]
253
+ pcd_voxel = (pcd_voxel - min_bound).astype(np.int32)
254
+
255
+ # summation
256
+ volume_sum[pcd_voxel[:, 0], pcd_voxel[:, 1]] += 1.
257
+ output += (volume_sum,)
258
+ return output
259
+
260
+
261
+ def pcd2bev_bin(data_type, *args, voxel_size=0.5):
262
+ config = DATA_CONFIG[data_type]
263
+ x_range, y_range = config['x'], config['y']
264
+ vol_shape = (math.ceil((x_range[1] - x_range[0]) / voxel_size), math.ceil((y_range[1] - y_range[0]) / voxel_size))
265
+ min_bound = (math.ceil((x_range[0]) / voxel_size), math.ceil((y_range[0]) / voxel_size))
266
+
267
+ output = tuple()
268
+ for data in args:
269
+ pcd_list = []
270
+ for pcd in data:
271
+ # mask out invalid points
272
+ mask_x = np.logical_and(pcd[:, 0] > x_range[0], pcd[:, 0] < x_range[1])
273
+ mask_y = np.logical_and(pcd[:, 1] > y_range[0], pcd[:, 1] < y_range[1])
274
+ mask = mask_x & mask_y
275
+ pcd = pcd[mask][:, :2] # keep x,y coord
276
+
277
+ # voxelize
278
+ pcd_voxel = np.floor(pcd / voxel_size)
279
+ _, indices, inverse_map = sparse_quantize(pcd_voxel, 1, return_index=True, return_inverse=True)
280
+ pcd_voxel = pcd_voxel[indices]
281
+ pcd_voxel = ((pcd_voxel - min_bound) / vol_shape).astype(np.float32)
282
+ pcd_list.append(pcd_voxel)
283
+ output += (pcd_list,)
284
+ return output
285
+
286
+
287
+ def bev_sample(data_type, *args, voxel_size=0.5):
288
+ config = DATA_CONFIG[data_type]
289
+ x_range, y_range = config['x'], config['y']
290
+
291
+ output = tuple()
292
+ for data in args:
293
+ pcd_list = []
294
+ for pcd in data:
295
+ # mask out invalid points
296
+ mask_x = np.logical_and(pcd[:, 0] > x_range[0], pcd[:, 0] < x_range[1])
297
+ mask_y = np.logical_and(pcd[:, 1] > y_range[0], pcd[:, 1] < y_range[1])
298
+ mask = mask_x & mask_y
299
+ pcd = pcd[mask][:, :2] # keep x,y coord
300
+
301
+ # voxelize
302
+ pcd_voxel = np.floor(pcd / voxel_size)
303
+ _, indices, inverse_map = sparse_quantize(pcd_voxel, 1, return_index=True, return_inverse=True)
304
+ pcd = pcd[indices]
305
+ pcd_list.append(pcd)
306
+ output += (pcd_list,)
307
+ return output
308
+
309
+
310
+ def preprocess_pcd(pcd, **kwargs):
311
+ depth = np.linalg.norm(pcd, 2, axis=1)
312
+ mask = np.logical_and(depth > kwargs['depth_range'][0], depth < kwargs['depth_range'][1])
313
+ pcd = pcd[mask]
314
+ return pcd
315
+
316
+
317
+ def preprocess_range(pcd, **kwargs):
318
+ depth_img = pcd2range(pcd, **kwargs)[0]
319
+ xyz_img = range2xyz(depth_img, log_scale=False, **kwargs)
320
+ depth_img = depth_img[None]
321
+ img = np.vstack([depth_img, xyz_img])
322
+ return img
323
+
324
+
325
+ def batch2list(batch_dict, agg_type='depth', **kwargs):
326
+ """
327
+ Aggregation Type: Default 'depth', ['all', 'sector', 'depth']
328
+ """
329
+ output_list = []
330
+ batch_indices = batch_dict['batch_indices']
331
+ for b_idx in range(batch_indices.max() + 1):
332
+ # avg all
333
+ if agg_type == 'all':
334
+ logits = batch_dict['logits'][batch_indices == b_idx].mean(0)
335
+
336
+ # avg on sectors
337
+ elif agg_type == 'sector':
338
+ logits = batch_dict['logits'][batch_indices == b_idx]
339
+ coords = batch_dict['coords'][batch_indices == b_idx].float()
340
+ coords = coords - coords.mean(0)
341
+ angle = torch.atan2(coords[:, 1], coords[:, 0]) # [-pi, pi]
342
+ sector_range = torch.linspace(-np.pi - 1e-4, np.pi + 1e-4, NUM_SECTORS + 1)
343
+ logits_list = []
344
+ for i in range(NUM_SECTORS):
345
+ sector_indices = torch.where((angle >= sector_range[i]) & (angle < sector_range[i + 1]))[0]
346
+ sector_logits = logits[sector_indices].mean(0)
347
+ sector_logits = torch.nan_to_num(sector_logits, 0.)
348
+ logits_list.append(sector_logits)
349
+ logits = torch.cat(logits_list) # dim: 768
350
+
351
+ # avg by depth
352
+ elif agg_type == 'depth':
353
+ logits = batch_dict['logits'][batch_indices == b_idx]
354
+ coords = batch_dict['coords'][batch_indices == b_idx].float()
355
+ coords = coords - coords.mean(0)
356
+ bev_depth = torch.norm(coords, dim=-1) * VOXEL_SIZE
357
+ sector_range = torch.linspace(kwargs['depth_range'][0] + 3, kwargs['depth_range'][1], NUM_SECTORS + 1)
358
+ sector_range[0] = 0.
359
+ logits_list = []
360
+ for i in range(NUM_SECTORS):
361
+ sector_indices = torch.where((bev_depth >= sector_range[i]) & (bev_depth < sector_range[i + 1]))[0]
362
+ sector_logits = logits[sector_indices].mean(0)
363
+ sector_logits = torch.nan_to_num(sector_logits, 0.)
364
+ logits_list.append(sector_logits)
365
+ logits = torch.cat(logits_list) # dim: 768
366
+
367
+ else:
368
+ raise NotImplementedError
369
+
370
+ output_list.append(logits.detach().cpu().numpy())
371
+ return output_list
372
+
373
+
374
+ def compute_logits(data_type, modality, *args):
375
+ assert data_type in ['32', '64']
376
+ assert modality in ['range', 'voxel', 'point_voxel']
377
+ is_voxel = 'voxel' in modality
378
+ dataset_name = TYPE2DATASET[data_type]
379
+ dataset_config = DATASET_CONFIG[dataset_name]
380
+ bs = MODAL2BATCHSIZE[modality]
381
+
382
+ model = build_model(dataset_name, MODALITY2MODEL[modality], device='cuda')
383
+
384
+ output = tuple()
385
+ for data in args:
386
+ all_logits_list = []
387
+ for i in range(math.ceil(len(data) / bs)):
388
+ batch = data[i * bs:(i + 1) * bs]
389
+ if is_voxel:
390
+ batch = [pcd2voxel(preprocess_pcd(pcd, **dataset_config)) for pcd in batch]
391
+ batch = sparse_collate_fn(batch)
392
+ batch = {k: v.cuda() if isinstance(v, (torch.Tensor, SparseTensor, PointTensor)) else v for k, v in
393
+ batch.items()}
394
+ with torch.no_grad():
395
+ batch_out = model(batch, return_final_logits=True)
396
+ batch_out = batch2list(batch_out, AGG_TYPE, **dataset_config)
397
+ all_logits_list.extend(batch_out)
398
+ else:
399
+ batch = [preprocess_range(pcd, **dataset_config) for pcd in batch]
400
+ batch = torch.from_numpy(np.stack(batch)).float().cuda()
401
+ with torch.no_grad():
402
+ batch_out = model(batch, return_final_logits=True, agg_type=AGG_TYPE)
403
+ all_logits_list.append(batch_out)
404
+ if is_voxel:
405
+ all_logits = np.stack(all_logits_list)
406
+ else:
407
+ all_logits = np.vstack(all_logits_list)
408
+ output += (all_logits,)
409
+
410
+ del model, batch, batch_out
411
+ torch.cuda.empty_cache()
412
+ return output
413
+
414
+
415
+ def compute_pairwise_cd(x, y, module=None):
416
+ if module is None:
417
+ module = chamfer_3DDist()
418
+ if x.ndim == 2 and y.ndim == 2:
419
+ x, y = x[None], y[None]
420
+ x, y = torch.from_numpy(x).cuda(), torch.from_numpy(y).cuda()
421
+ dist1, dist2, _, _ = module(x, y)
422
+ dist = (dist1.mean() + dist2.mean()) / 2
423
+ return dist.item()
424
+
425
+
426
+ def compute_pairwise_cd_batch(reference, samples):
427
+ ndim = reference.ndim
428
+ assert ndim in [2, 3]
429
+ module = chamfer_3DDist() if ndim == 3 else chamfer_2DDist()
430
+ len_r, len_s = reference.shape[0], [s.shape[0] for s in samples]
431
+ max_len = max([len_r] + len_s)
432
+ reference = torch.from_numpy(
433
+ np.vstack([reference, np.ones((max_len - reference.shape[0], ndim), dtype=np.float32) * 1e6])).cuda()
434
+ samples = [np.vstack([s, np.ones((max_len - s.shape[0], ndim), dtype=np.float32) * 1e6]) for s in samples]
435
+ samples = torch.from_numpy(np.stack(samples)).cuda()
436
+ reference = reference.expand_as(samples)
437
+ dist_r, dist_s, _, _ = module(reference, samples)
438
+
439
+ results = []
440
+ for i in range(samples.shape[0]):
441
+ dist1, dist2, len1, len2 = dist_r[i], dist_s[i], len_r, len_s[i]
442
+ dist = (dist1[:len1].mean() + dist2[:len2].mean()) / 2.
443
+ results.append(dist.item())
444
+ return results
445
+
446
+
447
+ def compute_pairwise_emd(x, y, module=None):
448
+ if module is None:
449
+ module = emdModule()
450
+ n_points = min(x.shape[0], y.shape[0])
451
+ n_points = n_points - n_points % 1024
452
+ x, y = x[:n_points], y[:n_points]
453
+ if x.ndim == 2 and y.ndim == 2:
454
+ x, y = x[None], y[None]
455
+ x, y = torch.from_numpy(x).cuda(), torch.from_numpy(y).cuda()
456
+ dist, _ = module(x, y, 0.005, 50)
457
+ dist = torch.sqrt(dist).mean()
458
+ return dist.item()
lidm/eval/models/__init__.py ADDED
File without changes
lidm/eval/models/minkowskinet/__init__.py ADDED
File without changes
lidm/eval/models/minkowskinet/model.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ try:
5
+ import torchsparse
6
+ import torchsparse.nn as spnn
7
+ from ..ts import basic_blocks
8
+ except ImportError:
9
+ raise Exception('Required ts lib. Reference: https://github.com/mit-han-lab/torchsparse/tree/v1.4.0')
10
+
11
+
12
+ class Model(nn.Module):
13
+ def __init__(self, config):
14
+ super().__init__()
15
+
16
+ cr = config.model_params.cr
17
+ cs = config.model_params.layer_num
18
+ cs = [int(cr * x) for x in cs]
19
+
20
+ self.pres = self.vres = config.model_params.voxel_size
21
+ self.num_classes = config.model_params.num_class
22
+
23
+ self.stem = nn.Sequential(
24
+ spnn.Conv3d(config.model_params.input_dims, cs[0], kernel_size=3, stride=1),
25
+ spnn.BatchNorm(cs[0]), spnn.ReLU(True),
26
+ spnn.Conv3d(cs[0], cs[0], kernel_size=3, stride=1),
27
+ spnn.BatchNorm(cs[0]), spnn.ReLU(True))
28
+
29
+ self.stage1 = nn.Sequential(
30
+ basic_blocks.BasicConvolutionBlock(cs[0], cs[0], ks=2, stride=2, dilation=1),
31
+ basic_blocks.ResidualBlock(cs[0], cs[1], ks=3, stride=1, dilation=1),
32
+ basic_blocks.ResidualBlock(cs[1], cs[1], ks=3, stride=1, dilation=1),
33
+ )
34
+
35
+ self.stage2 = nn.Sequential(
36
+ basic_blocks.BasicConvolutionBlock(cs[1], cs[1], ks=2, stride=2, dilation=1),
37
+ basic_blocks.ResidualBlock(cs[1], cs[2], ks=3, stride=1, dilation=1),
38
+ basic_blocks.ResidualBlock(cs[2], cs[2], ks=3, stride=1, dilation=1),
39
+ )
40
+
41
+ self.stage3 = nn.Sequential(
42
+ basic_blocks.BasicConvolutionBlock(cs[2], cs[2], ks=2, stride=2, dilation=1),
43
+ basic_blocks.ResidualBlock(cs[2], cs[3], ks=3, stride=1, dilation=1),
44
+ basic_blocks.ResidualBlock(cs[3], cs[3], ks=3, stride=1, dilation=1),
45
+ )
46
+
47
+ self.stage4 = nn.Sequential(
48
+ basic_blocks.BasicConvolutionBlock(cs[3], cs[3], ks=2, stride=2, dilation=1),
49
+ basic_blocks.ResidualBlock(cs[3], cs[4], ks=3, stride=1, dilation=1),
50
+ basic_blocks.ResidualBlock(cs[4], cs[4], ks=3, stride=1, dilation=1),
51
+ )
52
+
53
+ self.up1 = nn.ModuleList([
54
+ basic_blocks.BasicDeconvolutionBlock(cs[4], cs[5], ks=2, stride=2),
55
+ nn.Sequential(
56
+ basic_blocks.ResidualBlock(cs[5] + cs[3], cs[5], ks=3, stride=1,
57
+ dilation=1),
58
+ basic_blocks.ResidualBlock(cs[5], cs[5], ks=3, stride=1, dilation=1),
59
+ )
60
+ ])
61
+
62
+ self.up2 = nn.ModuleList([
63
+ basic_blocks.BasicDeconvolutionBlock(cs[5], cs[6], ks=2, stride=2),
64
+ nn.Sequential(
65
+ basic_blocks.ResidualBlock(cs[6] + cs[2], cs[6], ks=3, stride=1,
66
+ dilation=1),
67
+ basic_blocks.ResidualBlock(cs[6], cs[6], ks=3, stride=1, dilation=1),
68
+ )
69
+ ])
70
+
71
+ self.up3 = nn.ModuleList([
72
+ basic_blocks.BasicDeconvolutionBlock(cs[6], cs[7], ks=2, stride=2),
73
+ nn.Sequential(
74
+ basic_blocks.ResidualBlock(cs[7] + cs[1], cs[7], ks=3, stride=1,
75
+ dilation=1),
76
+ basic_blocks.ResidualBlock(cs[7], cs[7], ks=3, stride=1, dilation=1),
77
+ )
78
+ ])
79
+
80
+ self.up4 = nn.ModuleList([
81
+ basic_blocks.BasicDeconvolutionBlock(cs[7], cs[8], ks=2, stride=2),
82
+ nn.Sequential(
83
+ basic_blocks.ResidualBlock(cs[8] + cs[0], cs[8], ks=3, stride=1,
84
+ dilation=1),
85
+ basic_blocks.ResidualBlock(cs[8], cs[8], ks=3, stride=1, dilation=1),
86
+ )
87
+ ])
88
+
89
+ self.classifier = nn.Sequential(nn.Linear(cs[8], self.num_classes))
90
+
91
+ self.weight_initialization()
92
+ self.dropout = nn.Dropout(0.3, True)
93
+
94
+ def weight_initialization(self):
95
+ for m in self.modules():
96
+ if isinstance(m, nn.BatchNorm1d):
97
+ nn.init.constant_(m.weight, 1)
98
+ nn.init.constant_(m.bias, 0)
99
+
100
+ def forward(self, data_dict, return_logits=False, return_final_logits=False):
101
+ x = data_dict['lidar']
102
+ x.C = x.C.int()
103
+
104
+ x0 = self.stem(x)
105
+ x1 = self.stage1(x0)
106
+ x2 = self.stage2(x1)
107
+ x3 = self.stage3(x2)
108
+ x4 = self.stage4(x3)
109
+
110
+ if return_logits:
111
+ output_dict = dict()
112
+ output_dict['logits'] = x4.F
113
+ output_dict['batch_indices'] = x4.C[:, -1]
114
+ return output_dict
115
+
116
+ y1 = self.up1[0](x4)
117
+ y1 = torchsparse.cat([y1, x3])
118
+ y1 = self.up1[1](y1)
119
+
120
+ y2 = self.up2[0](y1)
121
+ y2 = torchsparse.cat([y2, x2])
122
+ y2 = self.up2[1](y2)
123
+
124
+ y3 = self.up3[0](y2)
125
+ y3 = torchsparse.cat([y3, x1])
126
+ y3 = self.up3[1](y3)
127
+
128
+ y4 = self.up4[0](y3)
129
+ y4 = torchsparse.cat([y4, x0])
130
+ y4 = self.up4[1](y4)
131
+ if return_final_logits:
132
+ output_dict = dict()
133
+ output_dict['logits'] = y4.F
134
+ output_dict['coords'] = y4.C[:, :3]
135
+ output_dict['batch_indices'] = y4.C[:, -1]
136
+ return output_dict
137
+
138
+ output = self.classifier(y4.F)
139
+ data_dict['output'] = output.F
140
+
141
+ return data_dict
lidm/eval/models/rangenet/__init__.py ADDED
File without changes
lidm/eval/models/rangenet/model.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # This file is covered by the LICENSE file in the root of this project.
3
+ from collections import OrderedDict
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+
10
+ class BasicBlock(nn.Module):
11
+ def __init__(self, inplanes, planes, bn_d=0.1):
12
+ super(BasicBlock, self).__init__()
13
+ self.conv1 = nn.Conv2d(inplanes, planes[0], kernel_size=1,
14
+ stride=1, padding=0, bias=False)
15
+ self.bn1 = nn.BatchNorm2d(planes[0], momentum=bn_d)
16
+ self.relu1 = nn.LeakyReLU(0.1)
17
+ self.conv2 = nn.Conv2d(planes[0], planes[1], kernel_size=3,
18
+ stride=1, padding=1, bias=False)
19
+ self.bn2 = nn.BatchNorm2d(planes[1], momentum=bn_d)
20
+ self.relu2 = nn.LeakyReLU(0.1)
21
+
22
+ def forward(self, x):
23
+ residual = x
24
+
25
+ out = self.conv1(x)
26
+ out = self.bn1(out)
27
+ out = self.relu1(out)
28
+
29
+ out = self.conv2(out)
30
+ out = self.bn2(out)
31
+ out = self.relu2(out)
32
+
33
+ out += residual
34
+ return out
35
+
36
+
37
+ # ******************************************************************************
38
+
39
+ # number of layers per model
40
+ model_blocks = {
41
+ 21: [1, 1, 2, 2, 1],
42
+ 53: [1, 2, 8, 8, 4],
43
+ }
44
+
45
+
46
+ class Backbone(nn.Module):
47
+ """
48
+ Class for DarknetSeg. Subclasses PyTorch's own "nn" module
49
+ """
50
+
51
+ def __init__(self, params):
52
+ super(Backbone, self).__init__()
53
+ self.use_range = params["input_depth"]["range"]
54
+ self.use_xyz = params["input_depth"]["xyz"]
55
+ self.use_remission = params["input_depth"]["remission"]
56
+ self.drop_prob = params["dropout"]
57
+ self.bn_d = params["bn_d"]
58
+ self.OS = params["OS"]
59
+ self.layers = params["extra"]["layers"]
60
+
61
+ # input depth calc
62
+ self.input_depth = 0
63
+ self.input_idxs = []
64
+ if self.use_range:
65
+ self.input_depth += 1
66
+ self.input_idxs.append(0)
67
+ if self.use_xyz:
68
+ self.input_depth += 3
69
+ self.input_idxs.extend([1, 2, 3])
70
+ if self.use_remission:
71
+ self.input_depth += 1
72
+ self.input_idxs.append(4)
73
+
74
+ # stride play
75
+ self.strides = [2, 2, 2, 2, 2]
76
+ # check current stride
77
+ current_os = 1
78
+ for s in self.strides:
79
+ current_os *= s
80
+
81
+ # make the new stride
82
+ if self.OS > current_os:
83
+ print("Can't do OS, ", self.OS,
84
+ " because it is bigger than original ", current_os)
85
+ else:
86
+ # redo strides according to needed stride
87
+ for i, stride in enumerate(reversed(self.strides), 0):
88
+ if int(current_os) != self.OS:
89
+ if stride == 2:
90
+ current_os /= 2
91
+ self.strides[-1 - i] = 1
92
+ if int(current_os) == self.OS:
93
+ break
94
+
95
+ # check that darknet exists
96
+ assert self.layers in model_blocks.keys()
97
+
98
+ # generate layers depending on darknet type
99
+ self.blocks = model_blocks[self.layers]
100
+
101
+ # input layer
102
+ self.conv1 = nn.Conv2d(self.input_depth, 32, kernel_size=3,
103
+ stride=1, padding=1, bias=False)
104
+ self.bn1 = nn.BatchNorm2d(32, momentum=self.bn_d)
105
+ self.relu1 = nn.LeakyReLU(0.1)
106
+
107
+ # encoder
108
+ self.enc1 = self._make_enc_layer(BasicBlock, [32, 64], self.blocks[0],
109
+ stride=self.strides[0], bn_d=self.bn_d)
110
+ self.enc2 = self._make_enc_layer(BasicBlock, [64, 128], self.blocks[1],
111
+ stride=self.strides[1], bn_d=self.bn_d)
112
+ self.enc3 = self._make_enc_layer(BasicBlock, [128, 256], self.blocks[2],
113
+ stride=self.strides[2], bn_d=self.bn_d)
114
+ self.enc4 = self._make_enc_layer(BasicBlock, [256, 512], self.blocks[3],
115
+ stride=self.strides[3], bn_d=self.bn_d)
116
+ self.enc5 = self._make_enc_layer(BasicBlock, [512, 1024], self.blocks[4],
117
+ stride=self.strides[4], bn_d=self.bn_d)
118
+
119
+ # for a bit of fun
120
+ self.dropout = nn.Dropout2d(self.drop_prob)
121
+
122
+ # last channels
123
+ self.last_channels = 1024
124
+
125
+ # make layer useful function
126
+ def _make_enc_layer(self, block, planes, blocks, stride, bn_d=0.1):
127
+ layers = []
128
+
129
+ # downsample
130
+ layers.append(("conv", nn.Conv2d(planes[0], planes[1],
131
+ kernel_size=3,
132
+ stride=[1, stride], dilation=1,
133
+ padding=1, bias=False)))
134
+ layers.append(("bn", nn.BatchNorm2d(planes[1], momentum=bn_d)))
135
+ layers.append(("relu", nn.LeakyReLU(0.1)))
136
+
137
+ # blocks
138
+ inplanes = planes[1]
139
+ for i in range(0, blocks):
140
+ layers.append(("residual_{}".format(i),
141
+ block(inplanes, planes, bn_d)))
142
+
143
+ return nn.Sequential(OrderedDict(layers))
144
+
145
+ def run_layer(self, x, layer, skips, os):
146
+ y = layer(x)
147
+ if y.shape[2] < x.shape[2] or y.shape[3] < x.shape[3]:
148
+ skips[os] = x.detach()
149
+ os *= 2
150
+ x = y
151
+ return x, skips, os
152
+
153
+ def forward(self, x, return_logits=False, return_list=None):
154
+ # filter input
155
+ x = x[:, self.input_idxs]
156
+
157
+ # run cnn
158
+ # store for skip connections
159
+ skips = {}
160
+ out_dict = {}
161
+ os = 1
162
+
163
+ # first layer
164
+ x, skips, os = self.run_layer(x, self.conv1, skips, os)
165
+ x, skips, os = self.run_layer(x, self.bn1, skips, os)
166
+ x, skips, os = self.run_layer(x, self.relu1, skips, os)
167
+ if return_list and 'enc_0' in return_list:
168
+ out_dict['enc_0'] = x.detach().cpu() # 32, 64, 1024
169
+
170
+ # all encoder blocks with intermediate dropouts
171
+ x, skips, os = self.run_layer(x, self.enc1, skips, os)
172
+ if return_list and 'enc_1' in return_list:
173
+ out_dict['enc_1'] = x.detach().cpu() # 64, 64, 512
174
+ x, skips, os = self.run_layer(x, self.dropout, skips, os)
175
+
176
+ x, skips, os = self.run_layer(x, self.enc2, skips, os)
177
+ if return_list and 'enc_2' in return_list:
178
+ out_dict['enc_2'] = x.detach().cpu() # 128, 64, 256
179
+ x, skips, os = self.run_layer(x, self.dropout, skips, os)
180
+
181
+ x, skips, os = self.run_layer(x, self.enc3, skips, os)
182
+ if return_list and 'enc_3' in return_list:
183
+ out_dict['enc_3'] = x.detach().cpu() # 256, 64, 128
184
+ x, skips, os = self.run_layer(x, self.dropout, skips, os)
185
+
186
+ x, skips, os = self.run_layer(x, self.enc4, skips, os)
187
+ if return_list and 'enc_4' in return_list:
188
+ out_dict['enc_4'] = x.detach().cpu() # 512, 64, 64
189
+ x, skips, os = self.run_layer(x, self.dropout, skips, os)
190
+
191
+ x, skips, os = self.run_layer(x, self.enc5, skips, os)
192
+ if return_list and 'enc_5' in return_list:
193
+ out_dict['enc_5'] = x.detach().cpu() # 1024, 64, 32
194
+ if return_logits:
195
+ return x
196
+
197
+ x, skips, os = self.run_layer(x, self.dropout, skips, os)
198
+
199
+ if return_list is not None:
200
+ return x, skips, out_dict
201
+ return x, skips
202
+
203
+ def get_last_depth(self):
204
+ return self.last_channels
205
+
206
+ def get_input_depth(self):
207
+ return self.input_depth
208
+
209
+
210
+ class Decoder(nn.Module):
211
+ """
212
+ Class for DarknetSeg. Subclasses PyTorch's own "nn" module
213
+ """
214
+
215
+ def __init__(self, params, OS=32, feature_depth=1024):
216
+ super(Decoder, self).__init__()
217
+ self.backbone_OS = OS
218
+ self.backbone_feature_depth = feature_depth
219
+ self.drop_prob = params["dropout"]
220
+ self.bn_d = params["bn_d"]
221
+ self.index = 0
222
+
223
+ # stride play
224
+ self.strides = [2, 2, 2, 2, 2]
225
+ # check current stride
226
+ current_os = 1
227
+ for s in self.strides:
228
+ current_os *= s
229
+ # redo strides according to needed stride
230
+ for i, stride in enumerate(self.strides):
231
+ if int(current_os) != self.backbone_OS:
232
+ if stride == 2:
233
+ current_os /= 2
234
+ self.strides[i] = 1
235
+ if int(current_os) == self.backbone_OS:
236
+ break
237
+
238
+ # decoder
239
+ self.dec5 = self._make_dec_layer(BasicBlock,
240
+ [self.backbone_feature_depth, 512],
241
+ bn_d=self.bn_d,
242
+ stride=self.strides[0])
243
+ self.dec4 = self._make_dec_layer(BasicBlock, [512, 256], bn_d=self.bn_d,
244
+ stride=self.strides[1])
245
+ self.dec3 = self._make_dec_layer(BasicBlock, [256, 128], bn_d=self.bn_d,
246
+ stride=self.strides[2])
247
+ self.dec2 = self._make_dec_layer(BasicBlock, [128, 64], bn_d=self.bn_d,
248
+ stride=self.strides[3])
249
+ self.dec1 = self._make_dec_layer(BasicBlock, [64, 32], bn_d=self.bn_d,
250
+ stride=self.strides[4])
251
+
252
+ # layer list to execute with skips
253
+ self.layers = [self.dec5, self.dec4, self.dec3, self.dec2, self.dec1]
254
+
255
+ # for a bit of fun
256
+ self.dropout = nn.Dropout2d(self.drop_prob)
257
+
258
+ # last channels
259
+ self.last_channels = 32
260
+
261
+ def _make_dec_layer(self, block, planes, bn_d=0.1, stride=2):
262
+ layers = []
263
+
264
+ # downsample
265
+ if stride == 2:
266
+ layers.append(("upconv", nn.ConvTranspose2d(planes[0], planes[1],
267
+ kernel_size=[1, 4], stride=[1, 2],
268
+ padding=[0, 1])))
269
+ else:
270
+ layers.append(("conv", nn.Conv2d(planes[0], planes[1],
271
+ kernel_size=3, padding=1)))
272
+ layers.append(("bn", nn.BatchNorm2d(planes[1], momentum=bn_d)))
273
+ layers.append(("relu", nn.LeakyReLU(0.1)))
274
+
275
+ # blocks
276
+ layers.append(("residual", block(planes[1], planes, bn_d)))
277
+
278
+ return nn.Sequential(OrderedDict(layers))
279
+
280
+ def run_layer(self, x, layer, skips, os):
281
+ feats = layer(x) # up
282
+ if feats.shape[-1] > x.shape[-1]:
283
+ os //= 2 # match skip
284
+ feats = feats + skips[os].detach() # add skip
285
+ x = feats
286
+ return x, skips, os
287
+
288
+ def forward(self, x, skips, return_logits=False, return_list=None):
289
+ os = self.backbone_OS
290
+ out_dict = {}
291
+
292
+ # run layers
293
+ x, skips, os = self.run_layer(x, self.dec5, skips, os)
294
+ if return_list and 'dec_4' in return_list:
295
+ out_dict['dec_4'] = x.detach().cpu() # 512, 64, 64
296
+ x, skips, os = self.run_layer(x, self.dec4, skips, os)
297
+ if return_list and 'dec_3' in return_list:
298
+ out_dict['dec_3'] = x.detach().cpu() # 256, 64, 128
299
+ x, skips, os = self.run_layer(x, self.dec3, skips, os)
300
+ if return_list and 'dec_2' in return_list:
301
+ out_dict['dec_2'] = x.detach().cpu() # 128, 64, 256
302
+ x, skips, os = self.run_layer(x, self.dec2, skips, os)
303
+ if return_list and 'dec_1' in return_list:
304
+ out_dict['dec_1'] = x.detach().cpu() # 64, 64, 512
305
+ x, skips, os = self.run_layer(x, self.dec1, skips, os)
306
+ if return_list and 'dec_0' in return_list:
307
+ out_dict['dec_0'] = x.detach().cpu() # 32, 64, 1024
308
+
309
+ logits = torch.clone(x).detach()
310
+ x = self.dropout(x)
311
+
312
+ if return_logits:
313
+ return x, logits
314
+ if return_list is not None:
315
+ return out_dict
316
+ return x
317
+
318
+ def get_last_depth(self):
319
+ return self.last_channels
320
+
321
+
322
+ class Model(nn.Module):
323
+ def __init__(self, config):
324
+ super().__init__()
325
+ self.config = config
326
+ self.backbone = Backbone(params=self.config["backbone"])
327
+ self.decoder = Decoder(params=self.config["decoder"], OS=self.config["backbone"]["OS"],
328
+ feature_depth=self.backbone.get_last_depth())
329
+
330
+ def load_pretrained_weights(self, path):
331
+ w_dict = torch.load(path + "/backbone",
332
+ map_location=lambda storage, loc: storage)
333
+ self.backbone.load_state_dict(w_dict, strict=True)
334
+ w_dict = torch.load(path + "/segmentation_decoder",
335
+ map_location=lambda storage, loc: storage)
336
+ self.decoder.load_state_dict(w_dict, strict=True)
337
+
338
+ def forward(self, x, return_logits=False, return_final_logits=False, return_list=None, agg_type='depth'):
339
+ if return_logits:
340
+ logits = self.backbone(x, return_logits)
341
+ logits = F.adaptive_avg_pool2d(logits, (1, 1)).squeeze()
342
+ logits = torch.clone(logits).detach().cpu().numpy()
343
+ return logits
344
+ elif return_list is not None:
345
+ x, skips, enc_dict = self.backbone(x, return_list=return_list)
346
+ dec_dict = self.decoder(x, skips, return_list=return_list)
347
+ out_dict = {**enc_dict, **dec_dict}
348
+ return out_dict
349
+ elif return_final_logits:
350
+ assert agg_type in ['all', 'sector', 'depth']
351
+ y, skips = self.backbone(x)
352
+ y, logits = self.decoder(y, skips, True)
353
+
354
+ B, C, H, W = logits.shape
355
+ N = 16
356
+
357
+ # avg all
358
+ if agg_type == 'all':
359
+ logits = logits.mean([2, 3])
360
+ # avg in patch
361
+ elif agg_type == 'sector':
362
+ logits = logits.view(B, C, H, N, W // N).mean([2, 4]).reshape(B, -1)
363
+ # avg in row
364
+ elif agg_type == 'depth':
365
+ logits = logits.view(B, C, N, H // N, W).mean([3, 4]).reshape(B, -1)
366
+
367
+ logits = torch.clone(logits).detach().cpu().numpy()
368
+ return logits
369
+ else:
370
+ y, skips = self.backbone(x)
371
+ y = self.decoder(y, skips, False)
372
+ return y
lidm/eval/models/spvcnn/__init__.py ADDED
File without changes
lidm/eval/models/spvcnn/model.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ try:
4
+ import torchsparse
5
+ import torchsparse.nn as spnn
6
+ from torchsparse import PointTensor
7
+ from ..ts.utils import initial_voxelize, point_to_voxel, voxel_to_point
8
+ from ..ts import basic_blocks
9
+ except ImportError:
10
+ raise Exception('Required torchsparse lib. Reference: https://github.com/mit-han-lab/torchsparse/tree/v1.4.0')
11
+
12
+
13
+ class Model(nn.Module):
14
+ def __init__(self, config):
15
+ super().__init__()
16
+ cr = config.model_params.cr
17
+ cs = config.model_params.layer_num
18
+ cs = [int(cr * x) for x in cs]
19
+
20
+ self.pres = self.vres = config.model_params.voxel_size
21
+ self.num_classes = config.model_params.num_class
22
+
23
+ self.stem = nn.Sequential(
24
+ spnn.Conv3d(config.model_params.input_dims, cs[0], kernel_size=3, stride=1),
25
+ spnn.BatchNorm(cs[0]), spnn.ReLU(True),
26
+ spnn.Conv3d(cs[0], cs[0], kernel_size=3, stride=1),
27
+ spnn.BatchNorm(cs[0]), spnn.ReLU(True))
28
+
29
+ self.stage1 = nn.Sequential(
30
+ basic_blocks.BasicConvolutionBlock(cs[0], cs[0], ks=2, stride=2, dilation=1),
31
+ basic_blocks.ResidualBlock(cs[0], cs[1], ks=3, stride=1, dilation=1),
32
+ basic_blocks.ResidualBlock(cs[1], cs[1], ks=3, stride=1, dilation=1),
33
+ )
34
+
35
+ self.stage2 = nn.Sequential(
36
+ basic_blocks.BasicConvolutionBlock(cs[1], cs[1], ks=2, stride=2, dilation=1),
37
+ basic_blocks.ResidualBlock(cs[1], cs[2], ks=3, stride=1, dilation=1),
38
+ basic_blocks.ResidualBlock(cs[2], cs[2], ks=3, stride=1, dilation=1),
39
+ )
40
+
41
+ self.stage3 = nn.Sequential(
42
+ basic_blocks.BasicConvolutionBlock(cs[2], cs[2], ks=2, stride=2, dilation=1),
43
+ basic_blocks.ResidualBlock(cs[2], cs[3], ks=3, stride=1, dilation=1),
44
+ basic_blocks.ResidualBlock(cs[3], cs[3], ks=3, stride=1, dilation=1),
45
+ )
46
+
47
+ self.stage4 = nn.Sequential(
48
+ basic_blocks.BasicConvolutionBlock(cs[3], cs[3], ks=2, stride=2, dilation=1),
49
+ basic_blocks.ResidualBlock(cs[3], cs[4], ks=3, stride=1, dilation=1),
50
+ basic_blocks.ResidualBlock(cs[4], cs[4], ks=3, stride=1, dilation=1),
51
+ )
52
+
53
+ self.up1 = nn.ModuleList([
54
+ basic_blocks.BasicDeconvolutionBlock(cs[4], cs[5], ks=2, stride=2),
55
+ nn.Sequential(
56
+ basic_blocks.ResidualBlock(cs[5] + cs[3], cs[5], ks=3, stride=1,
57
+ dilation=1),
58
+ basic_blocks.ResidualBlock(cs[5], cs[5], ks=3, stride=1, dilation=1),
59
+ )
60
+ ])
61
+
62
+ self.up2 = nn.ModuleList([
63
+ basic_blocks.BasicDeconvolutionBlock(cs[5], cs[6], ks=2, stride=2),
64
+ nn.Sequential(
65
+ basic_blocks.ResidualBlock(cs[6] + cs[2], cs[6], ks=3, stride=1,
66
+ dilation=1),
67
+ basic_blocks.ResidualBlock(cs[6], cs[6], ks=3, stride=1, dilation=1),
68
+ )
69
+ ])
70
+
71
+ self.up3 = nn.ModuleList([
72
+ basic_blocks.BasicDeconvolutionBlock(cs[6], cs[7], ks=2, stride=2),
73
+ nn.Sequential(
74
+ basic_blocks.ResidualBlock(cs[7] + cs[1], cs[7], ks=3, stride=1,
75
+ dilation=1),
76
+ basic_blocks.ResidualBlock(cs[7], cs[7], ks=3, stride=1, dilation=1),
77
+ )
78
+ ])
79
+
80
+ self.up4 = nn.ModuleList([
81
+ basic_blocks.BasicDeconvolutionBlock(cs[7], cs[8], ks=2, stride=2),
82
+ nn.Sequential(
83
+ basic_blocks.ResidualBlock(cs[8] + cs[0], cs[8], ks=3, stride=1,
84
+ dilation=1),
85
+ basic_blocks.ResidualBlock(cs[8], cs[8], ks=3, stride=1, dilation=1),
86
+ )
87
+ ])
88
+
89
+ self.classifier = nn.Sequential(nn.Linear(cs[8], self.num_classes))
90
+
91
+ self.point_transforms = nn.ModuleList([
92
+ nn.Sequential(
93
+ nn.Linear(cs[0], cs[4]),
94
+ nn.BatchNorm1d(cs[4]),
95
+ nn.ReLU(True),
96
+ ),
97
+ nn.Sequential(
98
+ nn.Linear(cs[4], cs[6]),
99
+ nn.BatchNorm1d(cs[6]),
100
+ nn.ReLU(True),
101
+ ),
102
+ nn.Sequential(
103
+ nn.Linear(cs[6], cs[8]),
104
+ nn.BatchNorm1d(cs[8]),
105
+ nn.ReLU(True),
106
+ )
107
+ ])
108
+
109
+ self.weight_initialization()
110
+ self.dropout = nn.Dropout(0.3, True)
111
+
112
+ def weight_initialization(self):
113
+ for m in self.modules():
114
+ if isinstance(m, nn.BatchNorm1d):
115
+ nn.init.constant_(m.weight, 1)
116
+ nn.init.constant_(m.bias, 0)
117
+
118
+ def forward(self, data_dict, return_logits=False, return_final_logits=False):
119
+ x = data_dict['lidar']
120
+
121
+ # x: SparseTensor z: PointTensor
122
+ z = PointTensor(x.F, x.C.float())
123
+
124
+ x0 = initial_voxelize(z, self.pres, self.vres)
125
+
126
+ x0 = self.stem(x0)
127
+ z0 = voxel_to_point(x0, z, nearest=False)
128
+ z0.F = z0.F
129
+
130
+ x1 = point_to_voxel(x0, z0)
131
+ x1 = self.stage1(x1)
132
+ x2 = self.stage2(x1)
133
+ x3 = self.stage3(x2)
134
+ x4 = self.stage4(x3)
135
+ z1 = voxel_to_point(x4, z0)
136
+ z1.F = z1.F + self.point_transforms[0](z0.F)
137
+
138
+ y1 = point_to_voxel(x4, z1)
139
+
140
+ if return_logits:
141
+ output_dict = dict()
142
+ output_dict['logits'] = y1.F
143
+ output_dict['batch_indices'] = y1.C[:, -1]
144
+ return output_dict
145
+
146
+ y1.F = self.dropout(y1.F)
147
+ y1 = self.up1[0](y1)
148
+ y1 = torchsparse.cat([y1, x3])
149
+ y1 = self.up1[1](y1)
150
+
151
+ y2 = self.up2[0](y1)
152
+ y2 = torchsparse.cat([y2, x2])
153
+ y2 = self.up2[1](y2)
154
+ z2 = voxel_to_point(y2, z1)
155
+ z2.F = z2.F + self.point_transforms[1](z1.F)
156
+
157
+ y3 = point_to_voxel(y2, z2)
158
+ y3.F = self.dropout(y3.F)
159
+ y3 = self.up3[0](y3)
160
+ y3 = torchsparse.cat([y3, x1])
161
+ y3 = self.up3[1](y3)
162
+
163
+ y4 = self.up4[0](y3)
164
+ y4 = torchsparse.cat([y4, x0])
165
+ y4 = self.up4[1](y4)
166
+ z3 = voxel_to_point(y4, z2)
167
+ z3.F = z3.F + self.point_transforms[2](z2.F)
168
+
169
+ if return_final_logits:
170
+ output_dict = dict()
171
+ output_dict['logits'] = z3.F
172
+ output_dict['coords'] = z3.C[:, :3]
173
+ output_dict['batch_indices'] = z3.C[:, -1].long()
174
+ return output_dict
175
+
176
+ # output = self.classifier(z3.F)
177
+ data_dict['logits'] = z3.F
178
+
179
+ return data_dict
lidm/eval/models/ts/__init__.py ADDED
File without changes
lidm/eval/models/ts/basic_blocks.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+ '''
4
+ @author: Xu Yan
5
+ @file: basic_blocks.py
6
+ @time: 2021/4/14 22:53
7
+ '''
8
+ import torch.nn as nn
9
+
10
+ try:
11
+ import torchsparse.nn as spnn
12
+ except:
13
+ print('To install torchsparse 1.4.0, please refer to https://github.com/mit-han-lab/torchsparse/tree/74099d10a51c71c14318bce63d6421f698b24f24')
14
+
15
+
16
+ class BasicConvolutionBlock(nn.Module):
17
+ def __init__(self, inc, outc, ks=3, stride=1, dilation=1):
18
+ super().__init__()
19
+ self.net = nn.Sequential(
20
+ spnn.Conv3d(
21
+ inc,
22
+ outc,
23
+ kernel_size=ks,
24
+ dilation=dilation,
25
+ stride=stride), spnn.BatchNorm(outc),
26
+ spnn.ReLU(True))
27
+
28
+ def forward(self, x):
29
+ out = self.net(x)
30
+ return out
31
+
32
+
33
+ class BasicDeconvolutionBlock(nn.Module):
34
+ def __init__(self, inc, outc, ks=3, stride=1):
35
+ super().__init__()
36
+ self.net = nn.Sequential(
37
+ spnn.Conv3d(
38
+ inc,
39
+ outc,
40
+ kernel_size=ks,
41
+ stride=stride,
42
+ transposed=True),
43
+ spnn.BatchNorm(outc),
44
+ spnn.ReLU(True))
45
+
46
+ def forward(self, x):
47
+ return self.net(x)
48
+
49
+
50
+ class ResidualBlock(nn.Module):
51
+ def __init__(self, inc, outc, ks=3, stride=1, dilation=1):
52
+ super().__init__()
53
+ self.net = nn.Sequential(
54
+ spnn.Conv3d(
55
+ inc,
56
+ outc,
57
+ kernel_size=ks,
58
+ dilation=dilation,
59
+ stride=stride), spnn.BatchNorm(outc),
60
+ spnn.ReLU(True),
61
+ spnn.Conv3d(
62
+ outc,
63
+ outc,
64
+ kernel_size=ks,
65
+ dilation=dilation,
66
+ stride=1),
67
+ spnn.BatchNorm(outc))
68
+
69
+ self.downsample = nn.Sequential() if (inc == outc and stride == 1) else \
70
+ nn.Sequential(
71
+ spnn.Conv3d(inc, outc, kernel_size=1, dilation=1, stride=stride),
72
+ spnn.BatchNorm(outc)
73
+ )
74
+
75
+ self.ReLU = spnn.ReLU(True)
76
+
77
+ def forward(self, x):
78
+ out = self.ReLU(self.net(x) + self.downsample(x))
79
+ return out
lidm/eval/models/ts/utils.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ try:
4
+ import torchsparse.nn.functional as F
5
+ from torchsparse import PointTensor, SparseTensor
6
+ from torchsparse.nn.utils import get_kernel_offsets
7
+ except:
8
+ print('To install torchsparse 1.4.0, please refer to https://github.com/mit-han-lab/torchsparse/tree/74099d10a51c71c14318bce63d6421f698b24f24')
9
+
10
+ __all__ = ['initial_voxelize', 'point_to_voxel', 'voxel_to_point']
11
+
12
+
13
+ # z: PointTensor
14
+ # return: SparseTensor
15
+ def initial_voxelize(z, init_res, after_res):
16
+ new_float_coord = torch.cat([(z.C[:, :3] * init_res) / after_res, z.C[:, -1].view(-1, 1)], 1)
17
+
18
+ pc_hash = F.sphash(torch.floor(new_float_coord).int())
19
+ sparse_hash = torch.unique(pc_hash)
20
+ idx_query = F.sphashquery(pc_hash, sparse_hash)
21
+ counts = F.spcount(idx_query.int(), len(sparse_hash))
22
+
23
+ inserted_coords = F.spvoxelize(torch.floor(new_float_coord), idx_query, counts)
24
+ inserted_coords = torch.round(inserted_coords).int()
25
+ inserted_feat = F.spvoxelize(z.F, idx_query, counts)
26
+
27
+ new_tensor = SparseTensor(inserted_feat, inserted_coords, 1)
28
+ new_tensor.cmaps.setdefault(new_tensor.stride, new_tensor.coords)
29
+ z.additional_features['idx_query'][1] = idx_query
30
+ z.additional_features['counts'][1] = counts
31
+ z.C = new_float_coord
32
+
33
+ return new_tensor
34
+
35
+
36
+ # x: SparseTensor, z: PointTensor
37
+ # return: SparseTensor
38
+ def point_to_voxel(x, z):
39
+ if z.additional_features is None or \
40
+ z.additional_features.get('idx_query') is None or \
41
+ z.additional_features['idx_query'].get(x.s) is None:
42
+ pc_hash = F.sphash(
43
+ torch.cat([torch.floor(z.C[:, :3] / x.s[0]).int() * x.s[0], z.C[:, -1].int().view(-1, 1)], 1))
44
+ sparse_hash = F.sphash(x.C)
45
+ idx_query = F.sphashquery(pc_hash, sparse_hash)
46
+ counts = F.spcount(idx_query.int(), x.C.shape[0])
47
+ z.additional_features['idx_query'][x.s] = idx_query
48
+ z.additional_features['counts'][x.s] = counts
49
+ else:
50
+ idx_query = z.additional_features['idx_query'][x.s]
51
+ counts = z.additional_features['counts'][x.s]
52
+
53
+ inserted_feat = F.spvoxelize(z.F, idx_query, counts)
54
+ new_tensor = SparseTensor(inserted_feat, x.C, x.s)
55
+ new_tensor.cmaps = x.cmaps
56
+ new_tensor.kmaps = x.kmaps
57
+
58
+ return new_tensor
59
+
60
+
61
+ # x: SparseTensor, z: PointTensor
62
+ # return: PointTensor
63
+ def voxel_to_point(x, z, nearest=False):
64
+ if z.idx_query is None or z.weights is None or z.idx_query.get(x.s) is None or z.weights.get(x.s) is None:
65
+ off = get_kernel_offsets(2, x.s, 1, device=z.F.device)
66
+ old_hash = F.sphash(
67
+ torch.cat([
68
+ torch.floor(z.C[:, :3] / x.s[0]).int() * x.s[0],
69
+ z.C[:, -1].int().view(-1, 1)], 1), off)
70
+ pc_hash = F.sphash(x.C.to(z.F.device))
71
+ idx_query = F.sphashquery(old_hash, pc_hash)
72
+ weights = F.calc_ti_weights(z.C, idx_query, scale=x.s[0]).transpose(0, 1).contiguous()
73
+ idx_query = idx_query.transpose(0, 1).contiguous()
74
+ if nearest:
75
+ weights[:, 1:] = 0.
76
+ idx_query[:, 1:] = -1
77
+ new_feat = F.spdevoxelize(x.F, idx_query, weights)
78
+ new_tensor = PointTensor(new_feat, z.C, idx_query=z.idx_query, weights=z.weights)
79
+ new_tensor.additional_features = z.additional_features
80
+ new_tensor.idx_query[x.s] = idx_query
81
+ new_tensor.weights[x.s] = weights
82
+ z.idx_query[x.s] = idx_query
83
+ z.weights[x.s] = weights
84
+
85
+ else:
86
+ new_feat = F.spdevoxelize(x.F, z.idx_query.get(x.s), z.weights.get(x.s))
87
+ new_tensor = PointTensor(new_feat, z.C, idx_query=z.idx_query, weights=z.weights)
88
+ new_tensor.additional_features = z.additional_features
89
+
90
+ return new_tensor
lidm/eval/modules/__init__.py ADDED
File without changes
lidm/eval/modules/chamfer2D/__init__.py ADDED
File without changes
lidm/eval/modules/chamfer2D/chamfer2D.cu ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #include <stdio.h>
3
+ #include <ATen/ATen.h>
4
+
5
+ #include <cuda.h>
6
+ #include <cuda_runtime.h>
7
+
8
+ #include <vector>
9
+
10
+
11
+
12
+ __global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){
13
+ const int batch=512;
14
+ __shared__ float buf[batch*2];
15
+ for (int i=blockIdx.x;i<b;i+=gridDim.x){
16
+ for (int k2=0;k2<m;k2+=batch){
17
+ int end_k=min(m,k2+batch)-k2;
18
+ for (int j=threadIdx.x;j<end_k*2;j+=blockDim.x){
19
+ buf[j]=xyz2[(i*m+k2)*2+j];
20
+ }
21
+ __syncthreads();
22
+ for (int j=threadIdx.x+blockIdx.y*blockDim.x;j<n;j+=blockDim.x*gridDim.y){
23
+ float x1=xyz[(i*n+j)*2+0];
24
+ float y1=xyz[(i*n+j)*2+1];
25
+ int best_i=0;
26
+ float best=0;
27
+ int end_ka=end_k-(end_k&2);
28
+ if (end_ka==batch){
29
+ for (int k=0;k<batch;k+=4){
30
+ {
31
+ float x2=buf[k*2+0]-x1;
32
+ float y2=buf[k*2+1]-y1;
33
+ float d=x2*x2+y2*y2;
34
+ if (k==0 || d<best){
35
+ best=d;
36
+ best_i=k+k2;
37
+ }
38
+ }
39
+ {
40
+ float x2=buf[k*2+2]-x1;
41
+ float y2=buf[k*2+3]-y1;
42
+ float d=x2*x2+y2*y2;
43
+ if (d<best){
44
+ best=d;
45
+ best_i=k+k2+1;
46
+ }
47
+ }
48
+ {
49
+ float x2=buf[k*2+4]-x1;
50
+ float y2=buf[k*2+5]-y1;
51
+ float d=x2*x2+y2*y2;
52
+ if (d<best){
53
+ best=d;
54
+ best_i=k+k2+2;
55
+ }
56
+ }
57
+ {
58
+ float x2=buf[k*2+6]-x1;
59
+ float y2=buf[k*2+7]-y1;
60
+ float d=x2*x2+y2*y2;
61
+ if (d<best){
62
+ best=d;
63
+ best_i=k+k2+3;
64
+ }
65
+ }
66
+ }
67
+ }else{
68
+ for (int k=0;k<end_ka;k+=4){
69
+ {
70
+ float x2=buf[k*2+0]-x1;
71
+ float y2=buf[k*2+1]-y1;
72
+ float d=x2*x2+y2*y2;
73
+ if (k==0 || d<best){
74
+ best=d;
75
+ best_i=k+k2;
76
+ }
77
+ }
78
+ {
79
+ float x2=buf[k*2+2]-x1;
80
+ float y2=buf[k*2+3]-y1;
81
+ float d=x2*x2+y2*y2;
82
+ if (d<best){
83
+ best=d;
84
+ best_i=k+k2+1;
85
+ }
86
+ }
87
+ {
88
+ float x2=buf[k*2+4]-x1;
89
+ float y2=buf[k*2+5]-y1;
90
+ float d=x2*x2+y2*y2;
91
+ if (d<best){
92
+ best=d;
93
+ best_i=k+k2+2;
94
+ }
95
+ }
96
+ {
97
+ float x2=buf[k*2+6]-x1;
98
+ float y2=buf[k*2+7]-y1;
99
+ float d=x2*x2+y2*y2;
100
+ if (d<best){
101
+ best=d;
102
+ best_i=k+k2+3;
103
+ }
104
+ }
105
+ }
106
+ }
107
+ for (int k=end_ka;k<end_k;k++){
108
+ float x2=buf[k*2+0]-x1;
109
+ float y2=buf[k*2+1]-y1;
110
+ float d=x2*x2+y2*y2;
111
+ if (k==0 || d<best){
112
+ best=d;
113
+ best_i=k+k2;
114
+ }
115
+ }
116
+ if (k2==0 || result[(i*n+j)]>best){
117
+ result[(i*n+j)]=best;
118
+ result_i[(i*n+j)]=best_i;
119
+ }
120
+ }
121
+ __syncthreads();
122
+ }
123
+ }
124
+ }
125
+ // int chamfer_cuda_forward(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){
126
+ int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2){
127
+
128
+ const auto batch_size = xyz1.size(0);
129
+ const auto n = xyz1.size(1); //num_points point cloud A
130
+ const auto m = xyz2.size(1); //num_points point cloud B
131
+
132
+ NmDistanceKernel<<<dim3(32,16,1),512>>>(batch_size, n, xyz1.data<float>(), m, xyz2.data<float>(), dist1.data<float>(), idx1.data<int>());
133
+ NmDistanceKernel<<<dim3(32,16,1),512>>>(batch_size, m, xyz2.data<float>(), n, xyz1.data<float>(), dist2.data<float>(), idx2.data<int>());
134
+
135
+ cudaError_t err = cudaGetLastError();
136
+ if (err != cudaSuccess) {
137
+ printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err));
138
+ //THError("aborting");
139
+ return 0;
140
+ }
141
+ return 1;
142
+
143
+
144
+ }
145
+ __global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){
146
+ for (int i=blockIdx.x;i<b;i+=gridDim.x){
147
+ for (int j=threadIdx.x+blockIdx.y*blockDim.x;j<n;j+=blockDim.x*gridDim.y){
148
+ float x1=xyz1[(i*n+j)*2+0];
149
+ float y1=xyz1[(i*n+j)*2+1];
150
+ int j2=idx1[i*n+j];
151
+ float x2=xyz2[(i*m+j2)*2+0];
152
+ float y2=xyz2[(i*m+j2)*2+1];
153
+ float g=grad_dist1[i*n+j]*2;
154
+ atomicAdd(&(grad_xyz1[(i*n+j)*2+0]),g*(x1-x2));
155
+ atomicAdd(&(grad_xyz1[(i*n+j)*2+1]),g*(y1-y2));
156
+ atomicAdd(&(grad_xyz2[(i*m+j2)*2+0]),-(g*(x1-x2)));
157
+ atomicAdd(&(grad_xyz2[(i*m+j2)*2+1]),-(g*(y1-y2)));
158
+ }
159
+ }
160
+ }
161
+ // int chamfer_cuda_backward(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream){
162
+ int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2){
163
+ // cudaMemset(grad_xyz1,0,b*n*3*4);
164
+ // cudaMemset(grad_xyz2,0,b*m*3*4);
165
+
166
+ const auto batch_size = xyz1.size(0);
167
+ const auto n = xyz1.size(1); //num_points point cloud A
168
+ const auto m = xyz2.size(1); //num_points point cloud B
169
+
170
+ NmDistanceGradKernel<<<dim3(1,16,1),256>>>(batch_size,n,xyz1.data<float>(),m,xyz2.data<float>(),graddist1.data<float>(),idx1.data<int>(),gradxyz1.data<float>(),gradxyz2.data<float>());
171
+ NmDistanceGradKernel<<<dim3(1,16,1),256>>>(batch_size,m,xyz2.data<float>(),n,xyz1.data<float>(),graddist2.data<float>(),idx2.data<int>(),gradxyz2.data<float>(),gradxyz1.data<float>());
172
+
173
+ cudaError_t err = cudaGetLastError();
174
+ if (err != cudaSuccess) {
175
+ printf("error in nnd get grad: %s\n", cudaGetErrorString(err));
176
+ //THError("aborting");
177
+ return 0;
178
+ }
179
+ return 1;
180
+
181
+ }
182
+
lidm/eval/modules/chamfer2D/chamfer_cuda.cpp ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/torch.h>
2
+ #include <vector>
3
+
4
+ ///TMP
5
+ //#include "common.h"
6
+ /// NOT TMP
7
+
8
+
9
+ int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2);
10
+
11
+
12
+ int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2);
13
+
14
+
15
+
16
+
17
+ int chamfer_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2) {
18
+ return chamfer_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2);
19
+ }
20
+
21
+
22
+ int chamfer_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1,
23
+ at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) {
24
+
25
+ return chamfer_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2);
26
+ }
27
+
28
+
29
+
30
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
31
+ m.def("forward", &chamfer_forward, "chamfer forward (CUDA)");
32
+ m.def("backward", &chamfer_backward, "chamfer backward (CUDA)");
33
+ }
lidm/eval/modules/chamfer2D/dist_chamfer_2D.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from torch.autograd import Function
3
+ import torch
4
+ import importlib
5
+ import os
6
+
7
+ chamfer_found = importlib.find_loader("chamfer_2D") is not None
8
+ if not chamfer_found:
9
+ ## Cool trick from https://github.com/chrdiller
10
+ print("Jitting Chamfer 2D")
11
+ cur_path = os.path.dirname(os.path.abspath(__file__))
12
+ build_path = cur_path.replace('chamfer2D', 'tmp')
13
+ os.makedirs(build_path, exist_ok=True)
14
+
15
+ from torch.utils.cpp_extension import load
16
+
17
+ chamfer_2D = load(name="chamfer_2D",
18
+ sources=[
19
+ "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]),
20
+ "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer2D.cu"]),
21
+ ], build_directory=build_path)
22
+ print("Loaded JIT 2D CUDA chamfer distance")
23
+
24
+ else:
25
+ import chamfer_2D
26
+
27
+ print("Loaded compiled 2D CUDA chamfer distance")
28
+
29
+
30
+ # Chamfer's distance module @thibaultgroueix
31
+ # GPU tensors only
32
+ class chamfer_2DFunction(Function):
33
+ @staticmethod
34
+ def forward(ctx, xyz1, xyz2):
35
+ batchsize, n, dim = xyz1.size()
36
+ assert dim == 2, "Wrong last dimension for the chamfer distance 's input! Check with .size()"
37
+ _, m, dim = xyz2.size()
38
+ assert dim == 2, "Wrong last dimension for the chamfer distance 's input! Check with .size()"
39
+ device = xyz1.device
40
+
41
+ device = xyz1.device
42
+
43
+ dist1 = torch.zeros(batchsize, n)
44
+ dist2 = torch.zeros(batchsize, m)
45
+
46
+ idx1 = torch.zeros(batchsize, n).type(torch.IntTensor)
47
+ idx2 = torch.zeros(batchsize, m).type(torch.IntTensor)
48
+
49
+ dist1 = dist1.to(device)
50
+ dist2 = dist2.to(device)
51
+ idx1 = idx1.to(device)
52
+ idx2 = idx2.to(device)
53
+ torch.cuda.set_device(device)
54
+
55
+ chamfer_2D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2)
56
+ ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
57
+ return dist1, dist2, idx1, idx2
58
+
59
+ @staticmethod
60
+ def backward(ctx, graddist1, graddist2, gradidx1, gradidx2):
61
+ xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
62
+ graddist1 = graddist1.contiguous()
63
+ graddist2 = graddist2.contiguous()
64
+ device = graddist1.device
65
+
66
+ gradxyz1 = torch.zeros(xyz1.size())
67
+ gradxyz2 = torch.zeros(xyz2.size())
68
+
69
+ gradxyz1 = gradxyz1.to(device)
70
+ gradxyz2 = gradxyz2.to(device)
71
+ chamfer_2D.backward(
72
+ xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2
73
+ )
74
+ return gradxyz1, gradxyz2
75
+
76
+
77
+ class chamfer_2DDist(nn.Module):
78
+ def __init__(self):
79
+ super(chamfer_2DDist, self).__init__()
80
+
81
+ def forward(self, input1, input2):
82
+ input1 = input1.contiguous()
83
+ input2 = input2.contiguous()
84
+ return chamfer_2DFunction.apply(input1, input2)
lidm/eval/modules/chamfer2D/setup.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup
2
+ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
3
+
4
+ setup(
5
+ name='chamfer_2D',
6
+ ext_modules=[
7
+ CUDAExtension('chamfer_2D', [
8
+ "/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']),
9
+ "/".join(__file__.split('/')[:-1] + ['chamfer2D.cu']),
10
+ ]),
11
+ ],
12
+ cmdclass={
13
+ 'build_ext': BuildExtension
14
+ })
lidm/eval/modules/chamfer3D/__init__.py ADDED
File without changes
lidm/eval/modules/chamfer3D/chamfer3D.cu ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #include <stdio.h>
3
+ #include <ATen/ATen.h>
4
+
5
+ #include <cuda.h>
6
+ #include <cuda_runtime.h>
7
+
8
+ #include <vector>
9
+
10
+
11
+
12
+ __global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){
13
+ const int batch=512;
14
+ __shared__ float buf[batch*3];
15
+ for (int i=blockIdx.x;i<b;i+=gridDim.x){
16
+ for (int k2=0;k2<m;k2+=batch){
17
+ int end_k=min(m,k2+batch)-k2;
18
+ for (int j=threadIdx.x;j<end_k*3;j+=blockDim.x){
19
+ buf[j]=xyz2[(i*m+k2)*3+j];
20
+ }
21
+ __syncthreads();
22
+ for (int j=threadIdx.x+blockIdx.y*blockDim.x;j<n;j+=blockDim.x*gridDim.y){
23
+ float x1=xyz[(i*n+j)*3+0];
24
+ float y1=xyz[(i*n+j)*3+1];
25
+ float z1=xyz[(i*n+j)*3+2];
26
+ int best_i=0;
27
+ float best=0;
28
+ int end_ka=end_k-(end_k&3);
29
+ if (end_ka==batch){
30
+ for (int k=0;k<batch;k+=4){
31
+ {
32
+ float x2=buf[k*3+0]-x1;
33
+ float y2=buf[k*3+1]-y1;
34
+ float z2=buf[k*3+2]-z1;
35
+ float d=x2*x2+y2*y2+z2*z2;
36
+ if (k==0 || d<best){
37
+ best=d;
38
+ best_i=k+k2;
39
+ }
40
+ }
41
+ {
42
+ float x2=buf[k*3+3]-x1;
43
+ float y2=buf[k*3+4]-y1;
44
+ float z2=buf[k*3+5]-z1;
45
+ float d=x2*x2+y2*y2+z2*z2;
46
+ if (d<best){
47
+ best=d;
48
+ best_i=k+k2+1;
49
+ }
50
+ }
51
+ {
52
+ float x2=buf[k*3+6]-x1;
53
+ float y2=buf[k*3+7]-y1;
54
+ float z2=buf[k*3+8]-z1;
55
+ float d=x2*x2+y2*y2+z2*z2;
56
+ if (d<best){
57
+ best=d;
58
+ best_i=k+k2+2;
59
+ }
60
+ }
61
+ {
62
+ float x2=buf[k*3+9]-x1;
63
+ float y2=buf[k*3+10]-y1;
64
+ float z2=buf[k*3+11]-z1;
65
+ float d=x2*x2+y2*y2+z2*z2;
66
+ if (d<best){
67
+ best=d;
68
+ best_i=k+k2+3;
69
+ }
70
+ }
71
+ }
72
+ }else{
73
+ for (int k=0;k<end_ka;k+=4){
74
+ {
75
+ float x2=buf[k*3+0]-x1;
76
+ float y2=buf[k*3+1]-y1;
77
+ float z2=buf[k*3+2]-z1;
78
+ float d=x2*x2+y2*y2+z2*z2;
79
+ if (k==0 || d<best){
80
+ best=d;
81
+ best_i=k+k2;
82
+ }
83
+ }
84
+ {
85
+ float x2=buf[k*3+3]-x1;
86
+ float y2=buf[k*3+4]-y1;
87
+ float z2=buf[k*3+5]-z1;
88
+ float d=x2*x2+y2*y2+z2*z2;
89
+ if (d<best){
90
+ best=d;
91
+ best_i=k+k2+1;
92
+ }
93
+ }
94
+ {
95
+ float x2=buf[k*3+6]-x1;
96
+ float y2=buf[k*3+7]-y1;
97
+ float z2=buf[k*3+8]-z1;
98
+ float d=x2*x2+y2*y2+z2*z2;
99
+ if (d<best){
100
+ best=d;
101
+ best_i=k+k2+2;
102
+ }
103
+ }
104
+ {
105
+ float x2=buf[k*3+9]-x1;
106
+ float y2=buf[k*3+10]-y1;
107
+ float z2=buf[k*3+11]-z1;
108
+ float d=x2*x2+y2*y2+z2*z2;
109
+ if (d<best){
110
+ best=d;
111
+ best_i=k+k2+3;
112
+ }
113
+ }
114
+ }
115
+ }
116
+ for (int k=end_ka;k<end_k;k++){
117
+ float x2=buf[k*3+0]-x1;
118
+ float y2=buf[k*3+1]-y1;
119
+ float z2=buf[k*3+2]-z1;
120
+ float d=x2*x2+y2*y2+z2*z2;
121
+ if (k==0 || d<best){
122
+ best=d;
123
+ best_i=k+k2;
124
+ }
125
+ }
126
+ if (k2==0 || result[(i*n+j)]>best){
127
+ result[(i*n+j)]=best;
128
+ result_i[(i*n+j)]=best_i;
129
+ }
130
+ }
131
+ __syncthreads();
132
+ }
133
+ }
134
+ }
135
+ // int chamfer_cuda_forward(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){
136
+ int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2){
137
+
138
+ const auto batch_size = xyz1.size(0);
139
+ const auto n = xyz1.size(1); //num_points point cloud A
140
+ const auto m = xyz2.size(1); //num_points point cloud B
141
+
142
+ NmDistanceKernel<<<dim3(32,16,1),512>>>(batch_size, n, xyz1.data<float>(), m, xyz2.data<float>(), dist1.data<float>(), idx1.data<int>());
143
+ NmDistanceKernel<<<dim3(32,16,1),512>>>(batch_size, m, xyz2.data<float>(), n, xyz1.data<float>(), dist2.data<float>(), idx2.data<int>());
144
+
145
+ cudaError_t err = cudaGetLastError();
146
+ if (err != cudaSuccess) {
147
+ printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err));
148
+ //THError("aborting");
149
+ return 0;
150
+ }
151
+ return 1;
152
+
153
+
154
+ }
155
+ __global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){
156
+ for (int i=blockIdx.x;i<b;i+=gridDim.x){
157
+ for (int j=threadIdx.x+blockIdx.y*blockDim.x;j<n;j+=blockDim.x*gridDim.y){
158
+ float x1=xyz1[(i*n+j)*3+0];
159
+ float y1=xyz1[(i*n+j)*3+1];
160
+ float z1=xyz1[(i*n+j)*3+2];
161
+ int j2=idx1[i*n+j];
162
+ float x2=xyz2[(i*m+j2)*3+0];
163
+ float y2=xyz2[(i*m+j2)*3+1];
164
+ float z2=xyz2[(i*m+j2)*3+2];
165
+ float g=grad_dist1[i*n+j]*2;
166
+ atomicAdd(&(grad_xyz1[(i*n+j)*3+0]),g*(x1-x2));
167
+ atomicAdd(&(grad_xyz1[(i*n+j)*3+1]),g*(y1-y2));
168
+ atomicAdd(&(grad_xyz1[(i*n+j)*3+2]),g*(z1-z2));
169
+ atomicAdd(&(grad_xyz2[(i*m+j2)*3+0]),-(g*(x1-x2)));
170
+ atomicAdd(&(grad_xyz2[(i*m+j2)*3+1]),-(g*(y1-y2)));
171
+ atomicAdd(&(grad_xyz2[(i*m+j2)*3+2]),-(g*(z1-z2)));
172
+ }
173
+ }
174
+ }
175
+ // int chamfer_cuda_backward(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream){
176
+ int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2){
177
+ // cudaMemset(grad_xyz1,0,b*n*3*4);
178
+ // cudaMemset(grad_xyz2,0,b*m*3*4);
179
+
180
+ const auto batch_size = xyz1.size(0);
181
+ const auto n = xyz1.size(1); //num_points point cloud A
182
+ const auto m = xyz2.size(1); //num_points point cloud B
183
+
184
+ NmDistanceGradKernel<<<dim3(1,16,1),256>>>(batch_size,n,xyz1.data<float>(),m,xyz2.data<float>(),graddist1.data<float>(),idx1.data<int>(),gradxyz1.data<float>(),gradxyz2.data<float>());
185
+ NmDistanceGradKernel<<<dim3(1,16,1),256>>>(batch_size,m,xyz2.data<float>(),n,xyz1.data<float>(),graddist2.data<float>(),idx2.data<int>(),gradxyz2.data<float>(),gradxyz1.data<float>());
186
+
187
+ cudaError_t err = cudaGetLastError();
188
+ if (err != cudaSuccess) {
189
+ printf("error in nnd get grad: %s\n", cudaGetErrorString(err));
190
+ //THError("aborting");
191
+ return 0;
192
+ }
193
+ return 1;
194
+
195
+ }
196
+
lidm/eval/modules/chamfer3D/chamfer_cuda.cpp ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/torch.h>
2
+ #include <vector>
3
+
4
+ ///TMP
5
+ //#include "common.h"
6
+ /// NOT TMP
7
+
8
+
9
+ int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2);
10
+
11
+
12
+ int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2);
13
+
14
+
15
+
16
+
17
+ int chamfer_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2) {
18
+ return chamfer_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2);
19
+ }
20
+
21
+
22
+ int chamfer_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1,
23
+ at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) {
24
+
25
+ return chamfer_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2);
26
+ }
27
+
28
+
29
+
30
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
31
+ m.def("forward", &chamfer_forward, "chamfer forward (CUDA)");
32
+ m.def("backward", &chamfer_backward, "chamfer backward (CUDA)");
33
+ }
lidm/eval/modules/chamfer3D/dist_chamfer_3D.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from torch.autograd import Function
3
+ import torch
4
+ import importlib
5
+ import os
6
+
7
+ chamfer_found = importlib.find_loader("chamfer_3D") is not None
8
+ if not chamfer_found:
9
+ ## Cool trick from https://github.com/chrdiller
10
+ print("Jitting Chamfer 3D")
11
+
12
+ from torch.utils.cpp_extension import load
13
+
14
+ chamfer_3D = load(name="chamfer_3D",
15
+ sources=[
16
+ "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]),
17
+ "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer3D.cu"]),
18
+ ])
19
+ print("Loaded JIT 3D CUDA chamfer distance")
20
+
21
+ else:
22
+ import chamfer_3D
23
+ print("Loaded compiled 3D CUDA chamfer distance")
24
+
25
+
26
+ # Chamfer's distance module @thibaultgroueix
27
+ # GPU tensors only
28
+ class chamfer_3DFunction(Function):
29
+ @staticmethod
30
+ def forward(ctx, xyz1, xyz2):
31
+ batchsize, n, _ = xyz1.size()
32
+ _, m, _ = xyz2.size()
33
+ device = xyz1.device
34
+
35
+ dist1 = torch.zeros(batchsize, n)
36
+ dist2 = torch.zeros(batchsize, m)
37
+
38
+ idx1 = torch.zeros(batchsize, n).type(torch.IntTensor)
39
+ idx2 = torch.zeros(batchsize, m).type(torch.IntTensor)
40
+
41
+ dist1 = dist1.to(device)
42
+ dist2 = dist2.to(device)
43
+ idx1 = idx1.to(device)
44
+ idx2 = idx2.to(device)
45
+ torch.cuda.set_device(device)
46
+
47
+ chamfer_3D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2)
48
+ ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
49
+ return dist1, dist2, idx1, idx2
50
+
51
+ @staticmethod
52
+ def backward(ctx, graddist1, graddist2, gradidx1, gradidx2):
53
+ xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
54
+ graddist1 = graddist1.contiguous()
55
+ graddist2 = graddist2.contiguous()
56
+ device = graddist1.device
57
+
58
+ gradxyz1 = torch.zeros(xyz1.size())
59
+ gradxyz2 = torch.zeros(xyz2.size())
60
+
61
+ gradxyz1 = gradxyz1.to(device)
62
+ gradxyz2 = gradxyz2.to(device)
63
+ chamfer_3D.backward(
64
+ xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2
65
+ )
66
+ return gradxyz1, gradxyz2
67
+
68
+
69
+ class chamfer_3DDist(nn.Module):
70
+ def __init__(self):
71
+ super(chamfer_3DDist, self).__init__()
72
+
73
+ def forward(self, input1, input2):
74
+ input1 = input1.contiguous()
75
+ input2 = input2.contiguous()
76
+ return chamfer_3DFunction.apply(input1, input2)
lidm/eval/modules/chamfer3D/setup.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup
2
+ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
3
+
4
+ setup(
5
+ name='chamfer_3D',
6
+ ext_modules=[
7
+ CUDAExtension('chamfer_3D', [
8
+ "/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']),
9
+ "/".join(__file__.split('/')[:-1] + ['chamfer3D.cu']),
10
+ ]),
11
+ ],
12
+ cmdclass={
13
+ 'build_ext': BuildExtension
14
+ })
lidm/eval/modules/emd/__init__.py ADDED
File without changes
lidm/eval/modules/emd/emd.cpp ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // EMD approximation module (based on auction algorithm)
2
+ // author: Minghua Liu
3
+ #include <torch/extension.h>
4
+ #include <vector>
5
+
6
+ int emd_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist, at::Tensor assignment, at::Tensor price,
7
+ at::Tensor assignment_inv, at::Tensor bid, at::Tensor bid_increments, at::Tensor max_increments,
8
+ at::Tensor unass_idx, at::Tensor unass_cnt, at::Tensor unass_cnt_sum, at::Tensor cnt_tmp, at::Tensor max_idx, float eps, int iters);
9
+
10
+ int emd_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz, at::Tensor graddist, at::Tensor idx);
11
+
12
+
13
+
14
+ int emd_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist, at::Tensor assignment, at::Tensor price,
15
+ at::Tensor assignment_inv, at::Tensor bid, at::Tensor bid_increments, at::Tensor max_increments,
16
+ at::Tensor unass_idx, at::Tensor unass_cnt, at::Tensor unass_cnt_sum, at::Tensor cnt_tmp, at::Tensor max_idx, float eps, int iters) {
17
+ return emd_cuda_forward(xyz1, xyz2, dist, assignment, price, assignment_inv, bid, bid_increments, max_increments, unass_idx, unass_cnt, unass_cnt_sum, cnt_tmp, max_idx, eps, iters);
18
+ }
19
+
20
+ int emd_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz, at::Tensor graddist, at::Tensor idx) {
21
+
22
+ return emd_cuda_backward(xyz1, xyz2, gradxyz, graddist, idx);
23
+ }
24
+
25
+
26
+
27
+
28
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
29
+ m.def("forward", &emd_forward, "emd forward (CUDA)");
30
+ m.def("backward", &emd_backward, "emd backward (CUDA)");
31
+ }
lidm/eval/modules/emd/emd_cuda.cu ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // EMD approximation module (based on auction algorithm)
2
+ // author: Minghua Liu
3
+ #include <stdio.h>
4
+ #include <ATen/ATen.h>
5
+
6
+ #include <cuda.h>
7
+ #include <iostream>
8
+ #include <cuda_runtime.h>
9
+
10
+ __device__ __forceinline__ float atomicMax(float *address, float val)
11
+ {
12
+ int ret = __float_as_int(*address);
13
+ while(val > __int_as_float(ret))
14
+ {
15
+ int old = ret;
16
+ if((ret = atomicCAS((int *)address, old, __float_as_int(val))) == old)
17
+ break;
18
+ }
19
+ return __int_as_float(ret);
20
+ }
21
+
22
+
23
+ __global__ void clear(int b, int * cnt_tmp, int * unass_cnt) {
24
+ for (int i = threadIdx.x; i < b; i += blockDim.x) {
25
+ cnt_tmp[i] = 0;
26
+ unass_cnt[i] = 0;
27
+ }
28
+ }
29
+
30
+ __global__ void calc_unass_cnt(int b, int n, int * assignment, int * unass_cnt) {
31
+ // count the number of unassigned points in each batch
32
+ const int BLOCK_SIZE = 1024;
33
+ __shared__ int scan_array[BLOCK_SIZE];
34
+ for (int i = blockIdx.x; i < b; i += gridDim.x) {
35
+ scan_array[threadIdx.x] = assignment[i * n + blockIdx.y * BLOCK_SIZE + threadIdx.x] == -1 ? 1 : 0;
36
+ __syncthreads();
37
+
38
+ int stride = 1;
39
+ while(stride <= BLOCK_SIZE / 2) {
40
+ int index = (threadIdx.x + 1) * stride * 2 - 1;
41
+ if(index < BLOCK_SIZE)
42
+ scan_array[index] += scan_array[index - stride];
43
+ stride = stride * 2;
44
+ __syncthreads();
45
+ }
46
+ __syncthreads();
47
+
48
+ if (threadIdx.x == BLOCK_SIZE - 1) {
49
+ atomicAdd(&unass_cnt[i], scan_array[threadIdx.x]);
50
+ }
51
+ __syncthreads();
52
+ }
53
+ }
54
+
55
+ __global__ void calc_unass_cnt_sum(int b, int * unass_cnt, int * unass_cnt_sum) {
56
+ // count the cumulative sum over over unass_cnt
57
+ const int BLOCK_SIZE = 512; // batch_size <= 512
58
+ __shared__ int scan_array[BLOCK_SIZE];
59
+ scan_array[threadIdx.x] = unass_cnt[threadIdx.x];
60
+ __syncthreads();
61
+
62
+ int stride = 1;
63
+ while(stride <= BLOCK_SIZE / 2) {
64
+ int index = (threadIdx.x + 1) * stride * 2 - 1;
65
+ if(index < BLOCK_SIZE)
66
+ scan_array[index] += scan_array[index - stride];
67
+ stride = stride * 2;
68
+ __syncthreads();
69
+ }
70
+ __syncthreads();
71
+ stride = BLOCK_SIZE / 4;
72
+ while(stride > 0) {
73
+ int index = (threadIdx.x + 1) * stride * 2 - 1;
74
+ if((index + stride) < BLOCK_SIZE)
75
+ scan_array[index + stride] += scan_array[index];
76
+ stride = stride / 2;
77
+ __syncthreads();
78
+ }
79
+ __syncthreads();
80
+
81
+ //printf("%d\n", unass_cnt_sum[b - 1]);
82
+ unass_cnt_sum[threadIdx.x] = scan_array[threadIdx.x];
83
+ }
84
+
85
+ __global__ void calc_unass_idx(int b, int n, int * assignment, int * unass_idx, int * unass_cnt, int * unass_cnt_sum, int * cnt_tmp) {
86
+ // list all the unassigned points
87
+ for (int i = blockIdx.x; i < b; i += gridDim.x) {
88
+ if (assignment[i * n + blockIdx.y * 1024 + threadIdx.x] == -1) {
89
+ int idx = atomicAdd(&cnt_tmp[i], 1);
90
+ unass_idx[unass_cnt_sum[i] - unass_cnt[i] + idx] = blockIdx.y * 1024 + threadIdx.x;
91
+ }
92
+ }
93
+ }
94
+
95
+ __global__ void Bid(int b, int n, const float * xyz1, const float * xyz2, float eps, int * assignment, int * assignment_inv, float * price,
96
+ int * bid, float * bid_increments, float * max_increments, int * unass_cnt, int * unass_cnt_sum, int * unass_idx) {
97
+ const int batch = 2048, block_size = 1024, block_cnt = n / 1024;
98
+ __shared__ float xyz2_buf[batch * 3];
99
+ __shared__ float price_buf[batch];
100
+ __shared__ float best_buf[block_size];
101
+ __shared__ float better_buf[block_size];
102
+ __shared__ int best_i_buf[block_size];
103
+ for (int i = blockIdx.x; i < b; i += gridDim.x) {
104
+ int _unass_cnt = unass_cnt[i];
105
+ if (_unass_cnt == 0)
106
+ continue;
107
+ int _unass_cnt_sum = unass_cnt_sum[i];
108
+ int unass_per_block = (_unass_cnt + block_cnt - 1) / block_cnt;
109
+ int thread_per_unass = block_size / unass_per_block;
110
+ int unass_this_block = max(min(_unass_cnt - (int) blockIdx.y * unass_per_block, unass_per_block), 0);
111
+
112
+ float x1, y1, z1, best = -1e9, better = -1e9;
113
+ int best_i = -1, _unass_id = -1, thread_in_unass;
114
+
115
+ if (threadIdx.x < thread_per_unass * unass_this_block) {
116
+ _unass_id = unass_per_block * blockIdx.y + threadIdx.x / thread_per_unass + _unass_cnt_sum - _unass_cnt;
117
+ _unass_id = unass_idx[_unass_id];
118
+ thread_in_unass = threadIdx.x % thread_per_unass;
119
+
120
+ x1 = xyz1[(i * n + _unass_id) * 3 + 0];
121
+ y1 = xyz1[(i * n + _unass_id) * 3 + 1];
122
+ z1 = xyz1[(i * n + _unass_id) * 3 + 2];
123
+ }
124
+
125
+ for (int k2 = 0; k2 < n; k2 += batch) {
126
+ int end_k = min(n, k2 + batch) - k2;
127
+ for (int j = threadIdx.x; j < end_k * 3; j += blockDim.x) {
128
+ xyz2_buf[j] = xyz2[(i * n + k2) * 3 + j];
129
+ }
130
+ for (int j = threadIdx.x; j < end_k; j += blockDim.x) {
131
+ price_buf[j] = price[i * n + k2 + j];
132
+ }
133
+ __syncthreads();
134
+
135
+ if (_unass_id != -1) {
136
+ int delta = (end_k + thread_per_unass - 1) / thread_per_unass;
137
+ int l = thread_in_unass * delta;
138
+ int r = min((thread_in_unass + 1) * delta, end_k);
139
+ for (int k = l; k < r; k++)
140
+ //if (!last || assignment_inv[i * n + k + k2] == -1)
141
+ {
142
+ float x2 = xyz2_buf[k * 3 + 0] - x1;
143
+ float y2 = xyz2_buf[k * 3 + 1] - y1;
144
+ float z2 = xyz2_buf[k * 3 + 2] - z1;
145
+ // the coordinates of points should be normalized to [0, 1]
146
+ float d = 3.0 - sqrtf(x2 * x2 + y2 * y2 + z2 * z2) - price_buf[k];
147
+ if (d > best) {
148
+ better = best;
149
+ best = d;
150
+ best_i = k + k2;
151
+ }
152
+ else if (d > better) {
153
+ better = d;
154
+ }
155
+ }
156
+ }
157
+ __syncthreads();
158
+ }
159
+
160
+ best_buf[threadIdx.x] = best;
161
+ better_buf[threadIdx.x] = better;
162
+ best_i_buf[threadIdx.x] = best_i;
163
+ __syncthreads();
164
+
165
+ if (_unass_id != -1 && thread_in_unass == 0) {
166
+ for (int j = threadIdx.x + 1; j < threadIdx.x + thread_per_unass; j++) {
167
+ if (best_buf[j] > best) {
168
+ better = max(best, better_buf[j]);
169
+ best = best_buf[j];
170
+ best_i = best_i_buf[j];
171
+ }
172
+ else better = max(better, best_buf[j]);
173
+ }
174
+ bid[i * n + _unass_id] = best_i;
175
+ bid_increments[i * n + _unass_id] = best - better + eps;
176
+ atomicMax(&max_increments[i * n + best_i], best - better + eps);
177
+ }
178
+ }
179
+ }
180
+
181
+ __global__ void GetMax(int b, int n, int * assignment, int * bid, float * bid_increments, float * max_increments, int * max_idx) {
182
+ for (int i = blockIdx.x; i < b; i += gridDim.x) {
183
+ int j = threadIdx.x + blockIdx.y * blockDim.x;
184
+ if (assignment[i * n + j] == -1) {
185
+ int bid_id = bid[i * n + j];
186
+ float bid_inc = bid_increments[i * n + j];
187
+ float max_inc = max_increments[i * n + bid_id];
188
+ if (bid_inc - 1e-6 <= max_inc && max_inc <= bid_inc + 1e-6)
189
+ {
190
+ max_idx[i * n + bid_id] = j;
191
+ }
192
+ }
193
+ }
194
+ }
195
+
196
+ __global__ void Assign(int b, int n, int * assignment, int * assignment_inv, float * price, int * bid, float * bid_increments, float * max_increments, int * max_idx, bool last) {
197
+ for (int i = blockIdx.x; i < b; i += gridDim.x) {
198
+ int j = threadIdx.x + blockIdx.y * blockDim.x;
199
+ if (assignment[i * n + j] == -1) {
200
+ int bid_id = bid[i * n + j];
201
+ if (last || max_idx[i * n + bid_id] == j)
202
+ {
203
+ float bid_inc = bid_increments[i * n + j];
204
+ int ass_inv = assignment_inv[i * n + bid_id];
205
+ if (!last && ass_inv != -1) {
206
+ assignment[i * n + ass_inv] = -1;
207
+ }
208
+ assignment_inv[i * n + bid_id] = j;
209
+ assignment[i * n + j] = bid_id;
210
+ price[i * n + bid_id] += bid_inc;
211
+ max_increments[i * n + bid_id] = -1e9;
212
+ }
213
+ }
214
+ }
215
+ }
216
+
217
+ __global__ void CalcDist(int b, int n, float * xyz1, float * xyz2, float * dist, int * assignment) {
218
+ for (int i = blockIdx.x; i < b; i += gridDim.x) {
219
+ int j = threadIdx.x + blockIdx.y * blockDim.x;
220
+ int k = assignment[i * n + j];
221
+ float deltax = xyz1[(i * n + j) * 3 + 0] - xyz2[(i * n + k) * 3 + 0];
222
+ float deltay = xyz1[(i * n + j) * 3 + 1] - xyz2[(i * n + k) * 3 + 1];
223
+ float deltaz = xyz1[(i * n + j) * 3 + 2] - xyz2[(i * n + k) * 3 + 2];
224
+ dist[i * n + j] = deltax * deltax + deltay * deltay + deltaz * deltaz;
225
+ }
226
+ }
227
+
228
+ int emd_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist, at::Tensor assignment, at::Tensor price,
229
+ at::Tensor assignment_inv, at::Tensor bid, at::Tensor bid_increments, at::Tensor max_increments,
230
+ at::Tensor unass_idx, at::Tensor unass_cnt, at::Tensor unass_cnt_sum, at::Tensor cnt_tmp, at::Tensor max_idx, float eps, int iters) {
231
+
232
+ const auto batch_size = xyz1.size(0);
233
+ const auto n = xyz1.size(1); //num_points point cloud A
234
+ const auto m = xyz2.size(1); //num_points point cloud B
235
+
236
+ if (n != m) {
237
+ printf("Input Error! The two point clouds should have the same size.\n");
238
+ return -1;
239
+ }
240
+
241
+ if (batch_size > 512) {
242
+ printf("Input Error! The batch size should be less than 512.\n");
243
+ return -1;
244
+ }
245
+
246
+ if (n % 1024 != 0) {
247
+ printf("Input Error! The size of the point clouds should be a multiple of 1024.\n");
248
+ return -1;
249
+ }
250
+
251
+ //cudaEvent_t start,stop;
252
+ //cudaEventCreate(&start);
253
+ //cudaEventCreate(&stop);
254
+ //cudaEventRecord(start);
255
+ //int iters = 50;
256
+ for (int i = 0; i < iters; i++) {
257
+ clear<<<1, batch_size>>>(batch_size, cnt_tmp.data<int>(), unass_cnt.data<int>());
258
+ calc_unass_cnt<<<dim3(batch_size, n / 1024, 1), 1024>>>(batch_size, n, assignment.data<int>(), unass_cnt.data<int>());
259
+ calc_unass_cnt_sum<<<1, batch_size>>>(batch_size, unass_cnt.data<int>(), unass_cnt_sum.data<int>());
260
+ calc_unass_idx<<<dim3(batch_size, n / 1024, 1), 1024>>>(batch_size, n, assignment.data<int>(), unass_idx.data<int>(), unass_cnt.data<int>(),
261
+ unass_cnt_sum.data<int>(), cnt_tmp.data<int>());
262
+ Bid<<<dim3(batch_size, n / 1024, 1), 1024>>>(batch_size, n, xyz1.data<float>(), xyz2.data<float>(), eps, assignment.data<int>(), assignment_inv.data<int>(),
263
+ price.data<float>(), bid.data<int>(), bid_increments.data<float>(), max_increments.data<float>(),
264
+ unass_cnt.data<int>(), unass_cnt_sum.data<int>(), unass_idx.data<int>());
265
+ GetMax<<<dim3(batch_size, n / 1024, 1), 1024>>>(batch_size, n, assignment.data<int>(), bid.data<int>(), bid_increments.data<float>(), max_increments.data<float>(), max_idx.data<int>());
266
+ Assign<<<dim3(batch_size, n / 1024, 1), 1024>>>(batch_size, n, assignment.data<int>(), assignment_inv.data<int>(), price.data<float>(), bid.data<int>(),
267
+ bid_increments.data<float>(), max_increments.data<float>(), max_idx.data<int>(), i == iters - 1);
268
+ }
269
+ CalcDist<<<dim3(batch_size, n / 1024, 1), 1024>>>(batch_size, n, xyz1.data<float>(), xyz2.data<float>(), dist.data<float>(), assignment.data<int>());
270
+ //cudaEventRecord(stop);
271
+ //cudaEventSynchronize(stop);
272
+ //float elapsedTime;
273
+ //cudaEventElapsedTime(&elapsedTime,start,stop);
274
+ //printf("%lf\n", elapsedTime);
275
+
276
+ cudaError_t err = cudaGetLastError();
277
+ if (err != cudaSuccess) {
278
+ printf("error in nnd Output: %s\n", cudaGetErrorString(err));
279
+ return 0;
280
+ }
281
+ return 1;
282
+ }
283
+
284
+ __global__ void NmDistanceGradKernel(int b, int n, const float * xyz1, const float * xyz2, const float * grad_dist, const int * idx, float * grad_xyz){
285
+ for (int i = blockIdx.x; i < b; i += gridDim.x) {
286
+ for (int j = threadIdx.x + blockIdx.y * blockDim.x; j < n; j += blockDim.x * gridDim.y) {
287
+ float x1 = xyz1[(i * n + j) * 3 + 0];
288
+ float y1 = xyz1[(i * n + j) * 3 + 1];
289
+ float z1 = xyz1[(i * n + j) * 3 + 2];
290
+ int j2 = idx[i * n + j];
291
+ float x2 = xyz2[(i * n + j2) * 3 + 0];
292
+ float y2 = xyz2[(i * n + j2) * 3 + 1];
293
+ float z2 = xyz2[(i * n + j2) * 3 + 2];
294
+ float g = grad_dist[i * n + j] * 2;
295
+ atomicAdd(&(grad_xyz[(i * n + j) * 3 + 0]), g * (x1 - x2));
296
+ atomicAdd(&(grad_xyz[(i * n + j) * 3 + 1]), g * (y1 - y2));
297
+ atomicAdd(&(grad_xyz[(i * n + j) * 3 + 2]), g * (z1 - z2));
298
+ }
299
+ }
300
+ }
301
+
302
+ int emd_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz, at::Tensor graddist, at::Tensor idx){
303
+ const auto batch_size = xyz1.size(0);
304
+ const auto n = xyz1.size(1);
305
+ const auto m = xyz2.size(1);
306
+
307
+ NmDistanceGradKernel<<<dim3(batch_size, n / 1024, 1), 1024>>>(batch_size, n, xyz1.data<float>(), xyz2.data<float>(), graddist.data<float>(), idx.data<int>(), gradxyz.data<float>());
308
+
309
+ cudaError_t err = cudaGetLastError();
310
+ if (err != cudaSuccess) {
311
+ printf("error in nnd get grad: %s\n", cudaGetErrorString(err));
312
+ return 0;
313
+ }
314
+ return 1;
315
+
316
+ }
lidm/eval/modules/emd/emd_module.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EMD approximation module (based on auction algorithm)
2
+ # memory complexity: O(n)
3
+ # time complexity: O(n^2 * iter)
4
+ # author: Minghua Liu
5
+
6
+ # Input:
7
+ # xyz1, xyz2: [#batch, #points, 3]
8
+ # where xyz1 is the predicted point cloud and xyz2 is the ground truth point cloud
9
+ # two point clouds should have same size and be normalized to [0, 1]
10
+ # #points should be a multiple of 1024
11
+ # #batch should be no greater than 512
12
+ # eps is a parameter which balances the error rate and the speed of convergence
13
+ # iters is the number of iteration
14
+ # we only calculate gradient for xyz1
15
+
16
+ # Output:
17
+ # dist: [#batch, #points], sqrt(dist) -> L2 distance
18
+ # assignment: [#batch, #points], index of the matched point in the ground truth point cloud
19
+ # the result is an approximation and the assignment is not guranteed to be a bijection
20
+ import importlib
21
+ import os
22
+ import time
23
+ import numpy as np
24
+ import torch
25
+ from torch import nn
26
+ from torch.autograd import Function
27
+
28
+ emd_found = importlib.find_loader("emd") is not None
29
+ if not emd_found:
30
+ ## Cool trick from https://github.com/chrdiller
31
+ print("Jitting EMD 3D")
32
+
33
+ from torch.utils.cpp_extension import load
34
+
35
+ emd = load(name="emd",
36
+ sources=[
37
+ "/".join(os.path.abspath(__file__).split('/')[:-1] + ["emd.cpp"]),
38
+ "/".join(os.path.abspath(__file__).split('/')[:-1] + ["emd_cuda.cu"]),
39
+ ])
40
+ print("Loaded JIT 3D CUDA emd")
41
+ else:
42
+ import emd
43
+ print("Loaded compiled 3D CUDA emd")
44
+
45
+
46
+ class emdFunction(Function):
47
+ @staticmethod
48
+ def forward(ctx, xyz1, xyz2, eps, iters):
49
+ batchsize, n, _ = xyz1.size()
50
+ _, m, _ = xyz2.size()
51
+
52
+ assert (n == m)
53
+ assert (xyz1.size()[0] == xyz2.size()[0])
54
+ # assert(n % 1024 == 0)
55
+ assert (batchsize <= 512)
56
+
57
+ xyz1 = xyz1.contiguous().float().cuda()
58
+ xyz2 = xyz2.contiguous().float().cuda()
59
+ dist = torch.zeros(batchsize, n, device='cuda').contiguous()
60
+ assignment = torch.zeros(batchsize, n, device='cuda', dtype=torch.int32).contiguous() - 1
61
+ assignment_inv = torch.zeros(batchsize, m, device='cuda', dtype=torch.int32).contiguous() - 1
62
+ price = torch.zeros(batchsize, m, device='cuda').contiguous()
63
+ bid = torch.zeros(batchsize, n, device='cuda', dtype=torch.int32).contiguous()
64
+ bid_increments = torch.zeros(batchsize, n, device='cuda').contiguous()
65
+ max_increments = torch.zeros(batchsize, m, device='cuda').contiguous()
66
+ unass_idx = torch.zeros(batchsize * n, device='cuda', dtype=torch.int32).contiguous()
67
+ max_idx = torch.zeros(batchsize * m, device='cuda', dtype=torch.int32).contiguous()
68
+ unass_cnt = torch.zeros(512, dtype=torch.int32, device='cuda').contiguous()
69
+ unass_cnt_sum = torch.zeros(512, dtype=torch.int32, device='cuda').contiguous()
70
+ cnt_tmp = torch.zeros(512, dtype=torch.int32, device='cuda').contiguous()
71
+
72
+ emd.forward(xyz1, xyz2, dist, assignment, price, assignment_inv, bid, bid_increments, max_increments, unass_idx,
73
+ unass_cnt, unass_cnt_sum, cnt_tmp, max_idx, eps, iters)
74
+
75
+ ctx.save_for_backward(xyz1, xyz2, assignment)
76
+ return dist, assignment
77
+
78
+ @staticmethod
79
+ def backward(ctx, graddist, gradidx):
80
+ xyz1, xyz2, assignment = ctx.saved_tensors
81
+ graddist = graddist.contiguous()
82
+
83
+ gradxyz1 = torch.zeros(xyz1.size(), device='cuda').contiguous()
84
+ gradxyz2 = torch.zeros(xyz2.size(), device='cuda').contiguous()
85
+
86
+ emd.backward(xyz1, xyz2, gradxyz1, graddist, assignment)
87
+ return gradxyz1, gradxyz2, None, None
88
+
89
+
90
+ class emdModule(nn.Module):
91
+ def __init__(self):
92
+ super(emdModule, self).__init__()
93
+
94
+ def forward(self, input1, input2, eps, iters):
95
+ return emdFunction.apply(input1, input2, eps, iters)
96
+
97
+
98
+ def test_emd():
99
+ x1 = torch.rand(20, 8192, 3).cuda()
100
+ x2 = torch.rand(20, 8192, 3).cuda()
101
+ emd = emdModule()
102
+ start_time = time.perf_counter()
103
+ dis, assigment = emd(x1, x2, 0.05, 3000)
104
+ print("Input_size: ", x1.shape)
105
+ print("Runtime: %lfs" % (time.perf_counter() - start_time))
106
+ print("EMD: %lf" % np.sqrt(dis.cpu()).mean())
107
+ print("|set(assignment)|: %d" % assigment.unique().numel())
108
+ assigment = assigment.cpu().numpy()
109
+ assigment = np.expand_dims(assigment, -1)
110
+ x2 = np.take_along_axis(x2, assigment, axis=1)
111
+ d = (x1 - x2) * (x1 - x2)
112
+ print("Verified EMD: %lf" % np.sqrt(d.cpu().sum(-1)).mean())
lidm/eval/modules/emd/setup.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup
2
+ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
3
+
4
+ setup(
5
+ name='emd',
6
+ ext_modules=[
7
+ CUDAExtension('emd', [
8
+ 'emd.cpp',
9
+ 'emd_cuda.cu',
10
+ ]),
11
+ ],
12
+ cmdclass={
13
+ 'build_ext': BuildExtension
14
+ })