Stable-X commited on
Commit
e66346c
·
1 Parent(s): d1dbe71

update demo

Browse files
Files changed (2) hide show
  1. app.py +152 -102
  2. backend_utils.py +239 -107
app.py CHANGED
@@ -16,7 +16,7 @@ from transformers import AutoModelForImageSegmentation
16
  from torchvision import transforms
17
  from PIL import Image
18
  import open3d as o3d
19
- from backend_utils import improved_multiway_registration
20
 
21
 
22
  # Default values
@@ -29,15 +29,45 @@ OPENGL = np.array([[1, 0, 0, 0],
29
  [0, 0, -1, 0],
30
  [0, 0, 0, 1]])
31
 
32
- def extract_frames(video_path: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  temp_dir = tempfile.mkdtemp()
34
  output_path = os.path.join(temp_dir, "%03d.jpg")
 
 
 
35
  command = [
36
  "ffmpeg",
37
  "-i", video_path,
38
- "-vf", "fps=1",
 
39
  output_path
40
  ]
 
41
  subprocess.run(command, check=True)
42
  return temp_dir
43
 
@@ -144,9 +174,9 @@ def generate_mask(image: np.ndarray):
144
  # Convert mask to numpy array
145
  mask_np = np.array(mask) / 255.0
146
  return mask_np
147
-
148
  @torch.no_grad()
149
- def reconstruct(video_path, conf_thresh, kf_every, as_pointcloud=False, remove_background=False):
 
150
  # Extract frames from video
151
  demo_path = extract_frames(video_path)
152
 
@@ -168,123 +198,143 @@ def reconstruct(video_path, conf_thresh, kf_every, as_pointcloud=False, remove_b
168
  print(f'Finished reconstruction for {demo_name}, FPS: {fps:.2f}')
169
 
170
  # Process results
171
- pts_all, images_all, conf_all, mask_all = [], [], [], []
172
  for j, view in enumerate(batch):
173
  image = view['img'].permute(0, 2, 3, 1).cpu().numpy()[0]
 
174
  pts = preds[j]['pts3d' if j==0 else 'pts3d_in_other_view'].detach().cpu().numpy()[0]
 
175
  conf = preds[j]['conf'][0].cpu().data.numpy()
176
-
177
  if remove_background:
178
  mask = generate_mask(image)
179
  else:
180
  mask = np.ones_like(conf)
 
 
181
 
182
- images_all.append((image[None, ...] + 1.0)/2.0)
183
- pts_all.append(pts[None, ...])
184
- conf_all.append(conf[None, ...])
185
- mask_all.append(mask[None, ...])
186
-
187
- images_all = np.concatenate(images_all, axis=0)
188
- pts_all = np.concatenate(pts_all, axis=0) * 10
189
- conf_all = np.concatenate(conf_all, axis=0)
190
- mask_all = np.concatenate(mask_all, axis=0)
191
 
192
- # Create point cloud or mesh
193
- conf_sig_all = (conf_all-1) / conf_all
194
- combined_mask = (conf_sig_all > conf_thresh) & (mask_all > 0.5)
195
 
 
 
 
 
 
196
  # Create coarse result
197
- coarse_scene = create_scene(pts_all, images_all, combined_mask, as_pointcloud)
198
- coarse_output_path = save_scene(coarse_scene, as_pointcloud)
199
-
200
- yield coarse_output_path, None, f"Reconstruction completed. FPS: {fps:.2f}"
201
 
202
- # Create point clouds for multiway registration
203
- pcds = []
204
- for j in range(len(pts_all)):
205
- pcd = o3d.geometry.PointCloud()
206
- mask = combined_mask[j]
207
- pcd.points = o3d.utility.Vector3dVector(pts_all[j][mask])
208
- pcd.colors = o3d.utility.Vector3dVector(images_all[j][mask])
209
- pcds.append(pcd)
210
-
211
- # Perform global optimization
212
- print("Performing global registration...")
213
- transformed_pcds, pose_graph = improved_multiway_registration(pcds, voxel_size=0.01)
214
 
215
- # Apply transformations from pose_graph to original pts_all
216
- transformed_pts_all = np.zeros_like(pts_all)
217
- for j in range(len(pts_all)):
218
- # Get the transformation matrix from the pose graph
219
- transformation = pose_graph.nodes[j].pose
220
-
221
- # Reshape pts_all[j] to (H*W, 3)
222
- H, W, _ = pts_all[j].shape
223
- pts_reshaped = pts_all[j].reshape(-1, 3)
224
 
225
- # Apply transformation to all points
226
- homogeneous_pts = np.hstack((pts_reshaped, np.ones((pts_reshaped.shape[0], 1))))
227
- transformed_pts = (transformation @ homogeneous_pts.T).T[:, :3]
 
 
 
 
228
 
229
- # Reshape back to (H, W, 3) and store
230
- transformed_pts_all[j] = transformed_pts.reshape(H, W, 3)
231
-
232
- print(f"Original shape: {pts_all.shape}, Transformed shape: {transformed_pts_all.shape}")
233
-
234
- # Create refined result
235
- refined_scene = create_scene(transformed_pts_all, images_all, combined_mask, as_pointcloud)
236
- refined_output_path = save_scene(refined_scene, as_pointcloud)
237
 
238
  # Clean up temporary directory
239
  os.system(f"rm -rf {demo_path}")
240
-
241
- yield coarse_output_path, refined_output_path, f"Refinement completed. FPS: {fps:.2f}"
242
 
243
- def create_scene(pts_all, images_all, combined_mask, as_pointcloud):
244
- scene = trimesh.Scene()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
- if as_pointcloud:
247
- pcd = trimesh.PointCloud(
248
- vertices=pts_all[combined_mask].reshape(-1, 3),
249
- colors=images_all[combined_mask].reshape(-1, 3)
250
- )
251
- scene.add_geometry(pcd)
252
- else:
253
- meshes = []
254
- for i in range(len(images_all)):
255
- meshes.append(pts3d_to_trimesh(images_all[i], pts_all[i], combined_mask[i]))
256
- mesh = trimesh.Trimesh(**cat_meshes(meshes))
257
- scene.add_geometry(mesh)
258
-
259
- rot = np.eye(4)
260
- rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
261
- scene.apply_transform(np.linalg.inv(OPENGL @ rot))
262
- return scene
263
- def save_scene(scene, as_pointcloud):
264
- if as_pointcloud:
265
- output_path = tempfile.mktemp(suffix='.ply')
266
- else:
267
- output_path = tempfile.mktemp(suffix='.obj')
268
- scene.export(output_path)
269
- return output_path
270
-
271
- # Update the Gradio interface
272
- iface = gr.Interface(
273
- fn=reconstruct,
274
- inputs=[
275
- gr.Video(label="Input Video"),
276
- gr.Slider(0, 1, value=1e-6, label="Confidence Threshold"),
277
- gr.Slider(1, 30, step=1, value=5, label="Keyframe Interval"),
278
- gr.Checkbox(label="As Pointcloud", value=False),
279
- gr.Checkbox(label="Remove Background", value=False)
280
- ],
281
- outputs=[
282
- gr.Model3D(label="Coarse 3D Model", display_mode="solid"),
283
- gr.Model3D(label="Refined 3D Model", display_mode="solid"),
284
- gr.Textbox(label="Status")
285
- ],
286
- title="3D Reconstruction with Spatial Memory, Background Removal, and Global Optimization",
287
- )
288
 
289
  if __name__ == "__main__":
290
- iface.launch(server_name="0.0.0.0",)
 
16
  from torchvision import transforms
17
  from PIL import Image
18
  import open3d as o3d
19
+ from backend_utils import improved_multiway_registration, pts2normal, point2mesh, combine_and_clean_point_clouds
20
 
21
 
22
  # Default values
 
29
  [0, 0, -1, 0],
30
  [0, 0, 0, 1]])
31
 
32
+ def export_geometry(geometry, as_pointcloud=False):
33
+ if as_pointcloud:
34
+ if not isinstance(geometry, o3d.geometry.PointCloud):
35
+ raise ValueError("Expected an Open3D PointCloud object when as_pointcloud is True")
36
+ output_path = tempfile.mktemp(suffix='.ply')
37
+ else:
38
+ if not isinstance(geometry, o3d.geometry.TriangleMesh):
39
+ raise ValueError("Expected an Open3D TriangleMesh object when as_pointcloud is False")
40
+ output_path = tempfile.mktemp(suffix='.obj')
41
+
42
+ # Apply rotation
43
+ rot = np.eye(4)
44
+ rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
45
+ transform = np.linalg.inv(OPENGL @ rot)
46
+ geometry.transform(transform)
47
+
48
+ # Export the geometry
49
+ if as_pointcloud:
50
+ o3d.io.write_point_cloud(output_path, geometry, write_ascii=False, compressed=True)
51
+ else:
52
+ o3d.io.write_triangle_mesh(output_path, geometry, write_ascii=False, compressed=True)
53
+
54
+ return output_path
55
+
56
+
57
+ def extract_frames(video_path: str, duration: float = 20.0, fps: float = 3.0) -> str:
58
  temp_dir = tempfile.mkdtemp()
59
  output_path = os.path.join(temp_dir, "%03d.jpg")
60
+
61
+ filter_complex = f"select='if(lt(t,{duration}),1,0)',fps={fps}"
62
+
63
  command = [
64
  "ffmpeg",
65
  "-i", video_path,
66
+ "-vf", filter_complex,
67
+ "-vsync", "0",
68
  output_path
69
  ]
70
+
71
  subprocess.run(command, check=True)
72
  return temp_dir
73
 
 
174
  # Convert mask to numpy array
175
  mask_np = np.array(mask) / 255.0
176
  return mask_np
 
177
  @torch.no_grad()
178
+ def reconstruct(video_path, conf_thresh, kf_every,
179
+ as_pointcloud=False, remove_background=False, refine=False):
180
  # Extract frames from video
181
  demo_path = extract_frames(video_path)
182
 
 
198
  print(f'Finished reconstruction for {demo_name}, FPS: {fps:.2f}')
199
 
200
  # Process results
201
+ pcds = []
202
  for j, view in enumerate(batch):
203
  image = view['img'].permute(0, 2, 3, 1).cpu().numpy()[0]
204
+ image = (image + 1) / 2
205
  pts = preds[j]['pts3d' if j==0 else 'pts3d_in_other_view'].detach().cpu().numpy()[0]
206
+ pts_normal = pts2normal(preds[j]['pts3d' if j==0 else 'pts3d_in_other_view'][0]).cpu().numpy()
207
  conf = preds[j]['conf'][0].cpu().data.numpy()
208
+ conf_sig = (conf - 1) / conf
209
  if remove_background:
210
  mask = generate_mask(image)
211
  else:
212
  mask = np.ones_like(conf)
213
+
214
+ combined_mask = (conf_sig > conf_thresh) & (mask > 0.5)
215
 
216
+ pcd = o3d.geometry.PointCloud()
217
+ pcd.points = o3d.utility.Vector3dVector(pts[combined_mask])
218
+ pcd.colors = o3d.utility.Vector3dVector(image[combined_mask])
219
+ pcd.normals = o3d.utility.Vector3dVector(pts_normal[combined_mask])
220
+ pcds.append(pcd)
 
 
 
 
221
 
222
+ pcd_combined = combine_and_clean_point_clouds(pcds, voxel_size=0.001)
 
 
223
 
224
+ if as_pointcloud:
225
+ o3d_geometry = pcd_combined
226
+ else:
227
+ o3d_geometry = point2mesh(pcd_combined)
228
+
229
  # Create coarse result
230
+ coarse_output_path = export_geometry(o3d_geometry, as_pointcloud)
 
 
 
231
 
232
+ yield coarse_output_path, None
 
 
 
 
 
 
 
 
 
 
 
233
 
234
+ if refine:
235
+ # Perform global optimization
236
+ print("Performing global registration...")
237
+ transformed_pcds, _, _ = improved_multiway_registration(pcds, voxel_size=0.001)
 
 
 
 
 
238
 
239
+ if as_pointcloud:
240
+ o3d_geometry = transformed_pcds
241
+ else:
242
+ o3d_geometry = point2mesh(transformed_pcds)
243
+
244
+ # Create coarse result
245
+ refined_output_path = export_geometry(o3d_geometry, as_pointcloud)
246
 
247
+ yield coarse_output_path, refined_output_path
 
 
 
 
 
 
 
248
 
249
  # Clean up temporary directory
250
  os.system(f"rm -rf {demo_path}")
 
 
251
 
252
+ # Update the Gradio interface with improved layout
253
+ with gr.Blocks(
254
+ title="StableSpann3r: Making Spann3r stable with Odometry Backend",
255
+ css="""
256
+ #download {
257
+ height: 118px;
258
+ }
259
+ .slider .inner {
260
+ width: 5px;
261
+ background: #FFF;
262
+ }
263
+ .viewport {
264
+ aspect-ratio: 4/3;
265
+ }
266
+ .tabs button.selected {
267
+ font-size: 20px !important;
268
+ color: crimson !important;
269
+ }
270
+ h1 {
271
+ text-align: center;
272
+ display: block;
273
+ }
274
+ h2 {
275
+ text-align: center;
276
+ display: block;
277
+ }
278
+ h3 {
279
+ text-align: center;
280
+ display: block;
281
+ }
282
+ .md_feedback li {
283
+ margin-bottom: 0px !important;
284
+ }
285
+ """,
286
+ head="""
287
+ <script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script>
288
+ <script>
289
+ window.dataLayer = window.dataLayer || [];
290
+ function gtag() {dataLayer.push(arguments);}
291
+ gtag('js', new Date());
292
+ gtag('config', 'G-1FWSVCGZTG');
293
+ </script>
294
+ """,
295
+ ) as iface:
296
+ gr.Markdown(
297
+ """
298
+ # StableSpann3r: Making Spann3r stable with Odometry Backend
299
+ <p align="center">
300
+ <a title="Website" href="https://stable-x.github.io/StableSpann3r/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
301
+ <img src="https://www.obukhov.ai/img/badges/badge-website.svg">
302
+ </a>
303
+ <a title="arXiv" href="https://arxiv.org/abs/XXXX.XXXXX" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
304
+ <img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
305
+ </a>
306
+ <a title="Github" href="https://github.com/Stable-X/StableSpann3r" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
307
+ <img src="https://img.shields.io/github/stars/Stable-X/StableSpann3r?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
308
+ </a>
309
+ <a title="Social" href="https://x.com/ychngji6" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
310
+ <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
311
+ </a>
312
+ </p>
313
+ """
314
+ )
315
+ with gr.Row():
316
+ with gr.Column(scale=1):
317
+ video_input = gr.Video(label="Input Video")
318
+ with gr.Row():
319
+ conf_thresh = gr.Slider(0, 1, value=1e-3, label="Confidence Threshold")
320
+ kf_every = gr.Slider(1, 30, step=1, value=1, label="Keyframe Interval")
321
+ with gr.Row():
322
+ remove_background = gr.Checkbox(label="Remove Background", value=False)
323
+ refine = gr.Checkbox(label="Enable Backend", value=False)
324
+ as_pointcloud = gr.Checkbox(label="As Pointcloud", value=False)
325
+ reconstruct_btn = gr.Button("Reconstruct")
326
+
327
+ with gr.Column(scale=2):
328
+ with gr.Tab("Coarse Model"):
329
+ coarse_model = gr.Model3D(label="Coarse 3D Model", display_mode="solid", clear_color=[0.0, 0.0, 0.0, 0.0])
330
+ with gr.Tab("Refined Model"):
331
+ refined_model = gr.Model3D(label="Refined 3D Model", display_mode="solid", clear_color=[0.0, 0.0, 0.0, 0.0])
332
 
333
+ reconstruct_btn.click(
334
+ fn=reconstruct,
335
+ inputs=[video_input, conf_thresh, kf_every, as_pointcloud, remove_background, refine],
336
+ outputs=[coarse_model, refined_model]
337
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
 
339
  if __name__ == "__main__":
340
+ iface.launch(server_name="0.0.0.0")
backend_utils.py CHANGED
@@ -1,90 +1,152 @@
1
  import numpy as np
2
  import open3d as o3d
 
 
 
3
 
4
- def improved_multiway_registration(pcds, voxel_size=0.05, max_correspondence_distance_coarse=None, max_correspondence_distance_fine=None, overlap=3, quadratic_overlap=True, use_colored_icp=True):
5
- if max_correspondence_distance_coarse is None:
6
- max_correspondence_distance_coarse = voxel_size * 15
7
- if max_correspondence_distance_fine is None:
8
- max_correspondence_distance_fine = voxel_size * 1.5
9
-
10
- def preprocess_point_cloud(pcd, voxel_size):
11
- pcd_down = pcd.voxel_down_sample(voxel_size)
12
- pcd_down.estimate_normals(
13
- o3d.geometry.KDTreeSearchParamHybrid(radius=voxel_size * 2, max_nn=30))
14
- # Apply statistical outlier removal
15
- cl, ind = pcd_down.remove_statistical_outlier(nb_neighbors=20, std_ratio=2.0)
16
- pcd_down = pcd_down.select_by_index(ind)
17
- return pcd_down
18
-
19
- def pairwise_registration(source, target, use_colored_icp, voxel_size, max_correspondence_distance_coarse, max_correspondence_distance_fine):
20
- current_transformation = np.identity(4) # Start with identity matrix
21
-
22
- if use_colored_icp:
23
- print("Apply colored point cloud registration")
24
- voxel_radius = [5*voxel_size, 3*voxel_size, voxel_size]
25
- max_iter = [60, 35, 20]
26
-
27
- for scale in range(3):
28
- iter = max_iter[scale]
29
- radius = voxel_radius[scale]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- source_down = source.voxel_down_sample(radius)
32
- target_down = target.voxel_down_sample(radius)
33
 
34
- source_down.estimate_normals(
35
- o3d.geometry.KDTreeSearchParamHybrid(radius=radius * 2, max_nn=30))
36
- target_down.estimate_normals(
37
- o3d.geometry.KDTreeSearchParamHybrid(radius=radius * 2, max_nn=30))
38
 
39
- try:
40
- result_icp = o3d.pipelines.registration.registration_colored_icp(
41
- source_down, target_down, radius, current_transformation,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  o3d.pipelines.registration.TransformationEstimationForColoredICP(),
43
  o3d.pipelines.registration.ICPConvergenceCriteria(relative_fitness=1e-6,
44
- relative_rmse=1e-6,
45
- max_iteration=iter))
46
- current_transformation = result_icp.transformation
47
- except RuntimeError as e:
48
- print(f"Colored ICP failed at scale {scale}: {str(e)}")
49
- print("Keeping the previous transformation")
50
- # We keep the previous transformation, no need to reassign
51
-
52
- transformation_icp = current_transformation
53
- else:
54
- print("Apply point-to-plane ICP")
55
- try:
56
- icp_coarse = o3d.pipelines.registration.registration_icp(
57
- source, target, max_correspondence_distance_coarse, current_transformation,
58
- o3d.pipelines.registration.TransformationEstimationPointToPlane())
59
- current_transformation = icp_coarse.transformation
60
-
61
  icp_fine = o3d.pipelines.registration.registration_icp(
62
  source, target, max_correspondence_distance_fine,
63
  current_transformation,
64
  o3d.pipelines.registration.TransformationEstimationPointToPlane())
65
- transformation_icp = icp_fine.transformation
66
- except RuntimeError as e:
67
- print(f"Point-to-plane ICP failed: {str(e)}")
68
- print("Keeping the best available transformation")
69
- transformation_icp = current_transformation
70
 
71
- try:
72
- information_icp = o3d.pipelines.registration.get_information_matrix_from_point_clouds(
73
- source, target, max_correspondence_distance_fine,
74
- transformation_icp)
 
 
 
 
 
 
 
75
  except RuntimeError as e:
76
- print(f"Failed to compute information matrix: {str(e)}")
77
- print("Using identity information matrix")
78
- information_icp = np.identity(6)
79
 
80
- return transformation_icp, information_icp
81
-
82
- def full_registration(pcds_down):
83
- pose_graph = o3d.pipelines.registration.PoseGraph()
84
- odometry = np.identity(4)
85
- pose_graph.nodes.append(o3d.pipelines.registration.PoseGraphNode(odometry))
86
- n_pcds = len(pcds_down)
 
 
87
 
 
 
 
88
  pairs = []
89
  for i in range(n_pcds - 1):
90
  for j in range(i + 1, min(i + overlap + 1, n_pcds)):
@@ -93,52 +155,122 @@ def improved_multiway_registration(pcds, voxel_size=0.05, max_correspondence_dis
93
  q = 2**(j-i)
94
  if q > overlap and i + q < n_pcds:
95
  pairs.append((i, i + q))
 
 
 
 
 
96
 
97
- for source_id, target_id in pairs:
98
- transformation_icp, information_icp = pairwise_registration(
 
 
 
 
99
  pcds_down[source_id], pcds_down[target_id], use_colored_icp,
100
- voxel_size, max_correspondence_distance_coarse, max_correspondence_distance_fine)
101
- print(f"Build PoseGraph: {source_id} -> {target_id}")
102
-
103
- if target_id == source_id + 1:
104
- odometry = np.dot(transformation_icp, odometry)
105
- pose_graph.nodes.append(
106
- o3d.pipelines.registration.PoseGraphNode(
107
- np.linalg.inv(odometry)))
108
 
109
- pose_graph.edges.append(
110
- o3d.pipelines.registration.PoseGraphEdge(source_id,
111
- target_id,
112
- transformation_icp,
113
- information_icp,
114
- uncertain=False))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  return pose_graph
116
 
117
- # Preprocess point clouds
118
- print("Preprocessing point clouds...")
119
- pcds_down = [preprocess_point_cloud(pcd, voxel_size) for pcd in pcds]
 
 
 
 
120
 
121
- print("Full registration ...")
122
- pose_graph = full_registration(pcds_down)
123
-
124
- print("Optimizing PoseGraph ...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  option = o3d.pipelines.registration.GlobalOptimizationOption(
126
  max_correspondence_distance=max_correspondence_distance_fine,
127
  edge_prune_threshold=0.25,
128
  reference_node=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
- with o3d.utility.VerbosityContextManager(o3d.utility.VerbosityLevel.Debug) as cm:
131
- o3d.pipelines.registration.global_optimization(
132
- pose_graph,
133
- o3d.pipelines.registration.GlobalOptimizationLevenbergMarquardt(),
134
- o3d.pipelines.registration.GlobalOptimizationConvergenceCriteria(),
135
- option)
136
-
137
- print("Transform points and combine")
138
  pcd_combined = o3d.geometry.PointCloud()
139
- for point_id in range(len(pcds)):
140
- print(pose_graph.nodes[point_id].pose)
141
- pcds[point_id].transform(pose_graph.nodes[point_id].pose)
142
- pcd_combined += pcds[point_id]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
- return pcd_combined, pose_graph
 
1
  import numpy as np
2
  import open3d as o3d
3
+ import torch
4
+ from tqdm import tqdm
5
+ import torch.nn.functional as F
6
 
7
+ def pts2normal(pts):
8
+ h, w, _ = pts.shape
9
+
10
+ # Compute differences in x and y directions
11
+ dx = torch.cat([pts[2:, 1:-1] - pts[:-2, 1:-1]], dim=0)
12
+ dy = torch.cat([pts[1:-1, 2:] - pts[1:-1, :-2]], dim=1)
13
+
14
+ # Compute normal vectors using cross product
15
+ normal_map = F.normalize(torch.cross(dx, dy, dim=-1), dim=-1)
16
+
17
+ # Create padded normal map
18
+ padded_normal_map = torch.zeros_like(pts)
19
+ padded_normal_map[1:-1, 1:-1, :] = normal_map
20
+
21
+ # Pad the borders
22
+ padded_normal_map[0, 1:-1, :] = normal_map[0, :, :] # Top edge
23
+ padded_normal_map[-1, 1:-1, :] = normal_map[-1, :, :] # Bottom edge
24
+ padded_normal_map[1:-1, 0, :] = normal_map[:, 0, :] # Left edge
25
+ padded_normal_map[1:-1, -1, :] = normal_map[:, -1, :] # Right edge
26
+
27
+ # Pad the corners
28
+ padded_normal_map[0, 0, :] = normal_map[0, 0, :] # Top-left corner
29
+ padded_normal_map[0, -1, :] = normal_map[0, -1, :] # Top-right corner
30
+ padded_normal_map[-1, 0, :] = normal_map[-1, 0, :] # Bottom-left corner
31
+ padded_normal_map[-1, -1, :] = normal_map[-1, -1, :] # Bottom-right corner
32
+
33
+ return padded_normal_map
34
+
35
+ def point2mesh(pcd, depth=8, density_threshold=0.1, clean_mesh=True):
36
+ print("\nPerforming Poisson surface reconstruction...")
37
+ mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(
38
+ pcd, depth=depth, width=0, scale=1.1, linear_fit=False)
39
+
40
+ print(f"Reconstructed mesh has {len(mesh.vertices)} vertices and {len(mesh.triangles)} triangles")
41
+
42
+ # Normalize densities
43
+ densities = np.asarray(densities)
44
+ densities = (densities - densities.min()) / (densities.max() - densities.min())
45
+
46
+ # Remove low density vertices
47
+ print("\nPruning low-density vertices...")
48
+ vertices_to_remove = densities < np.quantile(densities, density_threshold)
49
+ mesh.remove_vertices_by_mask(vertices_to_remove)
50
+
51
+ print(f"Pruned mesh has {len(mesh.vertices)} vertices and {len(mesh.triangles)} triangles")
52
+
53
+ if clean_mesh:
54
+ print("\nCleaning the mesh...")
55
+ mesh.remove_degenerate_triangles()
56
+ mesh.remove_duplicated_triangles()
57
+ mesh.remove_duplicated_vertices()
58
+ mesh.remove_non_manifold_edges()
59
+
60
+ print(f"Final cleaned mesh has {len(mesh.vertices)} vertices and {len(mesh.triangles)} triangles")
61
 
62
+ mesh.compute_triangle_normals()
63
+ return mesh
64
 
65
+ def combine_and_clean_point_clouds(pcds, voxel_size):
66
+ """
67
+ Combine, downsample, and clean a list of point clouds.
 
68
 
69
+ Parameters:
70
+ pcds (list): List of open3d.geometry.PointCloud objects to be processed.
71
+ voxel_size (float): The size of the voxel for downsampling.
72
+
73
+ Returns:
74
+ o3d.geometry.PointCloud: The cleaned and combined point cloud.
75
+ """
76
+ print("\nCombining point clouds...")
77
+ pcd_combined = o3d.geometry.PointCloud()
78
+ for p3d in pcds:
79
+ pcd_combined += p3d
80
+
81
+ print("\nDownsampling the combined point cloud...")
82
+ pcd_combined = pcd_combined.voxel_down_sample(voxel_size)
83
+ print(f"Downsampled from {len(pcd_combined.points)} to {len(pcd_combined.points)} points")
84
+
85
+ print("\nCleaning the combined point cloud...")
86
+ cl, ind = pcd_combined.remove_statistical_outlier(nb_neighbors=20, std_ratio=2.0)
87
+ pcd_cleaned = pcd_combined.select_by_index(ind)
88
+
89
+ print(f"Cleaned point cloud contains {len(pcd_cleaned.points)} points.")
90
+ print(f"Removed {len(pcd_combined.points) - len(pcd_cleaned.points)} outlier points.")
91
+
92
+ return pcd_cleaned
93
+
94
+ def improved_multiway_registration(pcds, descriptors=None, voxel_size=0.05,
95
+ max_correspondence_distance_coarse=None, max_correspondence_distance_fine=None,
96
+ overlap=5, quadratic_overlap=False, use_colored_icp=False):
97
+ if max_correspondence_distance_coarse is None:
98
+ max_correspondence_distance_coarse = voxel_size * 1.5
99
+ if max_correspondence_distance_fine is None:
100
+ max_correspondence_distance_fine = voxel_size * 0.15
101
+
102
+ def pairwise_registration(source, target, use_colored_icp, max_correspondence_distance_coarse, max_correspondence_distance_fine):
103
+ current_transformation = np.identity(4)
104
+ try:
105
+ if use_colored_icp:
106
+ icp_fine = o3d.pipelines.registration.registration_colored_icp(
107
+ source, target, max_correspondence_distance_fine, current_transformation,
108
  o3d.pipelines.registration.TransformationEstimationForColoredICP(),
109
  o3d.pipelines.registration.ICPConvergenceCriteria(relative_fitness=1e-6,
110
+ relative_rmse=1e-6,
111
+ max_iteration=100))
112
+ else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  icp_fine = o3d.pipelines.registration.registration_icp(
114
  source, target, max_correspondence_distance_fine,
115
  current_transformation,
116
  o3d.pipelines.registration.TransformationEstimationPointToPlane())
117
+
118
+
119
+ fitness = icp_fine.fitness
120
+ FITNESS_THRESHOLD = 0.01
 
121
 
122
+ if fitness >= FITNESS_THRESHOLD:
123
+ current_transformation = icp_fine.transformation
124
+
125
+ information_icp = o3d.pipelines.registration.get_information_matrix_from_point_clouds(
126
+ source, target, max_correspondence_distance_fine,
127
+ current_transformation)
128
+ return current_transformation, information_icp, True
129
+ else:
130
+ print(f"Registration failed. Fitness {fitness} is below threshold {FITNESS_THRESHOLD}")
131
+ return None, None, False
132
+
133
  except RuntimeError as e:
134
+ print(f" ICP registration failed: {str(e)}")
135
+ return None, None, False
 
136
 
137
+ def detect_loop_closure(descriptors, min_interval=3, similarity_threshold=0.9):
138
+ n_pcds = len(descriptors)
139
+ loop_edges = []
140
+
141
+ for i in range(n_pcds):
142
+ for j in range(i + min_interval, n_pcds):
143
+ similarity = torch.dot(descriptors[i], descriptors[j])
144
+ if similarity > similarity_threshold:
145
+ loop_edges.append((i, j))
146
 
147
+ return loop_edges
148
+
149
+ def generate_pairs(n_pcds, overlap, quadratic_overlap):
150
  pairs = []
151
  for i in range(n_pcds - 1):
152
  for j in range(i + 1, min(i + overlap + 1, n_pcds)):
 
155
  q = 2**(j-i)
156
  if q > overlap and i + q < n_pcds:
157
  pairs.append((i, i + q))
158
+ return pairs
159
+
160
+ def full_registration(pcds_down, pairs, loop_edges):
161
+ pose_graph = o3d.pipelines.registration.PoseGraph()
162
+ n_pcds = len(pcds_down)
163
 
164
+ for i in range(n_pcds):
165
+ pose_graph.nodes.append(o3d.pipelines.registration.PoseGraphNode(np.identity(4)))
166
+
167
+ print("\nPerforming pairwise registration:")
168
+ for source_id, target_id in tqdm(pairs):
169
+ transformation_icp, information_icp, success = pairwise_registration(
170
  pcds_down[source_id], pcds_down[target_id], use_colored_icp,
171
+ max_correspondence_distance_coarse, max_correspondence_distance_fine)
 
 
 
 
 
 
 
172
 
173
+ if success:
174
+ uncertain = abs(target_id - source_id) == 1
175
+ pose_graph.edges.append(
176
+ o3d.pipelines.registration.PoseGraphEdge(source_id,
177
+ target_id,
178
+ transformation_icp,
179
+ information_icp,
180
+ uncertain=uncertain))
181
+ else:
182
+ print(f" Skipping edge between {source_id} and {target_id} due to ICP failure")
183
+
184
+ # Add loop closure edges
185
+ print("\nAdding loop closure edges:")
186
+ for source_id, target_id in tqdm(loop_edges):
187
+ transformation_icp, information_icp, success = pairwise_registration(
188
+ pcds_down[source_id], pcds_down[target_id], use_colored_icp,
189
+ max_correspondence_distance_coarse, max_correspondence_distance_fine)
190
+
191
+ if success:
192
+ pose_graph.edges.append(
193
+ o3d.pipelines.registration.PoseGraphEdge(source_id,
194
+ target_id,
195
+ transformation_icp,
196
+ information_icp,
197
+ uncertain=True))
198
+ else:
199
+ print(f" Skipping loop closure edge between {source_id} and {target_id} due to ICP failure")
200
+
201
  return pose_graph
202
 
203
+ print("\n--- Improved Multiway Registration Process ---")
204
+ print(f"Number of point clouds: {len(pcds)}")
205
+ print(f"Voxel size: {voxel_size}")
206
+ print(f"Max correspondence distance (coarse): {max_correspondence_distance_coarse}")
207
+ print(f"Max correspondence distance (fine): {max_correspondence_distance_fine}")
208
+ print(f"Overlap: {overlap}")
209
+ print(f"Quadratic overlap: {quadratic_overlap}")
210
 
211
+ print("\nPreprocessing point clouds...")
212
+ pcds_down = pcds
213
+ print(f"Preprocessing complete. {len(pcds_down)} point clouds processed.")
214
+
215
+ print("\nGenerating initial graph pairs...")
216
+ pairs = generate_pairs(len(pcds), overlap, quadratic_overlap)
217
+ print(f"Generated {len(pairs)} pairs for initial graph.")
218
+
219
+ if descriptors is None:
220
+ print("\nNo descriptors provided. Skipping loop closure detection.")
221
+ loop_edges = []
222
+ else:
223
+ print(descriptors[0].shape)
224
+ print("\nDetecting loop closures...")
225
+ loop_edges = detect_loop_closure(descriptors)
226
+ print(f"Detected {len(loop_edges)} loop closures.")
227
+
228
+ print("\nPerforming full registration...")
229
+ pose_graph = full_registration(pcds_down, pairs, loop_edges)
230
+
231
+ print("\nOptimizing PoseGraph...")
232
  option = o3d.pipelines.registration.GlobalOptimizationOption(
233
  max_correspondence_distance=max_correspondence_distance_fine,
234
  edge_prune_threshold=0.25,
235
  reference_node=0)
236
+ o3d.pipelines.registration.global_optimization(
237
+ pose_graph,
238
+ o3d.pipelines.registration.GlobalOptimizationLevenbergMarquardt(),
239
+ o3d.pipelines.registration.GlobalOptimizationConvergenceCriteria(),
240
+ option)
241
+
242
+ # Count edges for each node
243
+ edge_count = {i: 0 for i in range(len(pcds))}
244
+ for edge in pose_graph.edges:
245
+ edge_count[edge.source_node_id] += 1
246
+ edge_count[edge.target_node_id] += 1
247
+
248
+ # Filter nodes with more than 3 edges
249
+ valid_nodes = [count > 3 for count in edge_count.values()]
250
 
251
+ print("\nTransforming and combining point clouds...")
 
 
 
 
 
 
 
252
  pcd_combined = o3d.geometry.PointCloud()
253
+
254
+ for point_id, is_valid in enumerate(valid_nodes):
255
+ if is_valid:
256
+ pcds[point_id].transform(pose_graph.nodes[point_id].pose)
257
+ pcd_combined += pcds[point_id]
258
+ else:
259
+ print(f"Skipping point cloud {point_id} as it has {edge_count[point_id]} edges (<=3)")
260
+
261
+ print("\nDownsampling the combined point cloud...")
262
+ # pcd_combined.orient_normals_consistent_tangent_plane(k=30)
263
+ pcd_combined = pcd_combined.voxel_down_sample(voxel_size * 0.1)
264
+ print(f"Downsampled from {len(pcd_combined.points)} to {len(pcd_combined.points)} points")
265
+
266
+ print("\nCleaning the combined point cloud...")
267
+ cl, ind = pcd_combined.remove_statistical_outlier(nb_neighbors=20, std_ratio=2.0)
268
+ pcd_cleaned = pcd_combined.select_by_index(ind)
269
+
270
+ print(f"Cleaned point cloud contains {len(pcd_cleaned.points)} points.")
271
+ print(f"Removed {len(pcd_combined.points) - len(pcd_cleaned.points)} outlier points.")
272
+
273
+ print("\nMultiway registration complete.")
274
+ print(f"Included {len(valid_nodes)} out of {len(pcds)} point clouds (with >3 edges).")
275
 
276
+ return pcd_cleaned, pose_graph, valid_nodes