Spaces:
Sleeping
Sleeping
update demo
Browse files- app.py +152 -102
- backend_utils.py +239 -107
app.py
CHANGED
@@ -16,7 +16,7 @@ from transformers import AutoModelForImageSegmentation
|
|
16 |
from torchvision import transforms
|
17 |
from PIL import Image
|
18 |
import open3d as o3d
|
19 |
-
from backend_utils import improved_multiway_registration
|
20 |
|
21 |
|
22 |
# Default values
|
@@ -29,15 +29,45 @@ OPENGL = np.array([[1, 0, 0, 0],
|
|
29 |
[0, 0, -1, 0],
|
30 |
[0, 0, 0, 1]])
|
31 |
|
32 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
temp_dir = tempfile.mkdtemp()
|
34 |
output_path = os.path.join(temp_dir, "%03d.jpg")
|
|
|
|
|
|
|
35 |
command = [
|
36 |
"ffmpeg",
|
37 |
"-i", video_path,
|
38 |
-
"-vf",
|
|
|
39 |
output_path
|
40 |
]
|
|
|
41 |
subprocess.run(command, check=True)
|
42 |
return temp_dir
|
43 |
|
@@ -144,9 +174,9 @@ def generate_mask(image: np.ndarray):
|
|
144 |
# Convert mask to numpy array
|
145 |
mask_np = np.array(mask) / 255.0
|
146 |
return mask_np
|
147 |
-
|
148 |
@torch.no_grad()
|
149 |
-
def reconstruct(video_path, conf_thresh, kf_every,
|
|
|
150 |
# Extract frames from video
|
151 |
demo_path = extract_frames(video_path)
|
152 |
|
@@ -168,123 +198,143 @@ def reconstruct(video_path, conf_thresh, kf_every, as_pointcloud=False, remove_b
|
|
168 |
print(f'Finished reconstruction for {demo_name}, FPS: {fps:.2f}')
|
169 |
|
170 |
# Process results
|
171 |
-
|
172 |
for j, view in enumerate(batch):
|
173 |
image = view['img'].permute(0, 2, 3, 1).cpu().numpy()[0]
|
|
|
174 |
pts = preds[j]['pts3d' if j==0 else 'pts3d_in_other_view'].detach().cpu().numpy()[0]
|
|
|
175 |
conf = preds[j]['conf'][0].cpu().data.numpy()
|
176 |
-
|
177 |
if remove_background:
|
178 |
mask = generate_mask(image)
|
179 |
else:
|
180 |
mask = np.ones_like(conf)
|
|
|
|
|
181 |
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
images_all = np.concatenate(images_all, axis=0)
|
188 |
-
pts_all = np.concatenate(pts_all, axis=0) * 10
|
189 |
-
conf_all = np.concatenate(conf_all, axis=0)
|
190 |
-
mask_all = np.concatenate(mask_all, axis=0)
|
191 |
|
192 |
-
|
193 |
-
conf_sig_all = (conf_all-1) / conf_all
|
194 |
-
combined_mask = (conf_sig_all > conf_thresh) & (mask_all > 0.5)
|
195 |
|
|
|
|
|
|
|
|
|
|
|
196 |
# Create coarse result
|
197 |
-
|
198 |
-
coarse_output_path = save_scene(coarse_scene, as_pointcloud)
|
199 |
-
|
200 |
-
yield coarse_output_path, None, f"Reconstruction completed. FPS: {fps:.2f}"
|
201 |
|
202 |
-
|
203 |
-
pcds = []
|
204 |
-
for j in range(len(pts_all)):
|
205 |
-
pcd = o3d.geometry.PointCloud()
|
206 |
-
mask = combined_mask[j]
|
207 |
-
pcd.points = o3d.utility.Vector3dVector(pts_all[j][mask])
|
208 |
-
pcd.colors = o3d.utility.Vector3dVector(images_all[j][mask])
|
209 |
-
pcds.append(pcd)
|
210 |
-
|
211 |
-
# Perform global optimization
|
212 |
-
print("Performing global registration...")
|
213 |
-
transformed_pcds, pose_graph = improved_multiway_registration(pcds, voxel_size=0.01)
|
214 |
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
transformation = pose_graph.nodes[j].pose
|
220 |
-
|
221 |
-
# Reshape pts_all[j] to (H*W, 3)
|
222 |
-
H, W, _ = pts_all[j].shape
|
223 |
-
pts_reshaped = pts_all[j].reshape(-1, 3)
|
224 |
|
225 |
-
|
226 |
-
|
227 |
-
|
|
|
|
|
|
|
|
|
228 |
|
229 |
-
|
230 |
-
transformed_pts_all[j] = transformed_pts.reshape(H, W, 3)
|
231 |
-
|
232 |
-
print(f"Original shape: {pts_all.shape}, Transformed shape: {transformed_pts_all.shape}")
|
233 |
-
|
234 |
-
# Create refined result
|
235 |
-
refined_scene = create_scene(transformed_pts_all, images_all, combined_mask, as_pointcloud)
|
236 |
-
refined_output_path = save_scene(refined_scene, as_pointcloud)
|
237 |
|
238 |
# Clean up temporary directory
|
239 |
os.system(f"rm -rf {demo_path}")
|
240 |
-
|
241 |
-
yield coarse_output_path, refined_output_path, f"Refinement completed. FPS: {fps:.2f}"
|
242 |
|
243 |
-
|
244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
scene.add_geometry(pcd)
|
252 |
-
else:
|
253 |
-
meshes = []
|
254 |
-
for i in range(len(images_all)):
|
255 |
-
meshes.append(pts3d_to_trimesh(images_all[i], pts_all[i], combined_mask[i]))
|
256 |
-
mesh = trimesh.Trimesh(**cat_meshes(meshes))
|
257 |
-
scene.add_geometry(mesh)
|
258 |
-
|
259 |
-
rot = np.eye(4)
|
260 |
-
rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
|
261 |
-
scene.apply_transform(np.linalg.inv(OPENGL @ rot))
|
262 |
-
return scene
|
263 |
-
def save_scene(scene, as_pointcloud):
|
264 |
-
if as_pointcloud:
|
265 |
-
output_path = tempfile.mktemp(suffix='.ply')
|
266 |
-
else:
|
267 |
-
output_path = tempfile.mktemp(suffix='.obj')
|
268 |
-
scene.export(output_path)
|
269 |
-
return output_path
|
270 |
-
|
271 |
-
# Update the Gradio interface
|
272 |
-
iface = gr.Interface(
|
273 |
-
fn=reconstruct,
|
274 |
-
inputs=[
|
275 |
-
gr.Video(label="Input Video"),
|
276 |
-
gr.Slider(0, 1, value=1e-6, label="Confidence Threshold"),
|
277 |
-
gr.Slider(1, 30, step=1, value=5, label="Keyframe Interval"),
|
278 |
-
gr.Checkbox(label="As Pointcloud", value=False),
|
279 |
-
gr.Checkbox(label="Remove Background", value=False)
|
280 |
-
],
|
281 |
-
outputs=[
|
282 |
-
gr.Model3D(label="Coarse 3D Model", display_mode="solid"),
|
283 |
-
gr.Model3D(label="Refined 3D Model", display_mode="solid"),
|
284 |
-
gr.Textbox(label="Status")
|
285 |
-
],
|
286 |
-
title="3D Reconstruction with Spatial Memory, Background Removal, and Global Optimization",
|
287 |
-
)
|
288 |
|
289 |
if __name__ == "__main__":
|
290 |
-
iface.launch(server_name="0.0.0.0"
|
|
|
16 |
from torchvision import transforms
|
17 |
from PIL import Image
|
18 |
import open3d as o3d
|
19 |
+
from backend_utils import improved_multiway_registration, pts2normal, point2mesh, combine_and_clean_point_clouds
|
20 |
|
21 |
|
22 |
# Default values
|
|
|
29 |
[0, 0, -1, 0],
|
30 |
[0, 0, 0, 1]])
|
31 |
|
32 |
+
def export_geometry(geometry, as_pointcloud=False):
|
33 |
+
if as_pointcloud:
|
34 |
+
if not isinstance(geometry, o3d.geometry.PointCloud):
|
35 |
+
raise ValueError("Expected an Open3D PointCloud object when as_pointcloud is True")
|
36 |
+
output_path = tempfile.mktemp(suffix='.ply')
|
37 |
+
else:
|
38 |
+
if not isinstance(geometry, o3d.geometry.TriangleMesh):
|
39 |
+
raise ValueError("Expected an Open3D TriangleMesh object when as_pointcloud is False")
|
40 |
+
output_path = tempfile.mktemp(suffix='.obj')
|
41 |
+
|
42 |
+
# Apply rotation
|
43 |
+
rot = np.eye(4)
|
44 |
+
rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
|
45 |
+
transform = np.linalg.inv(OPENGL @ rot)
|
46 |
+
geometry.transform(transform)
|
47 |
+
|
48 |
+
# Export the geometry
|
49 |
+
if as_pointcloud:
|
50 |
+
o3d.io.write_point_cloud(output_path, geometry, write_ascii=False, compressed=True)
|
51 |
+
else:
|
52 |
+
o3d.io.write_triangle_mesh(output_path, geometry, write_ascii=False, compressed=True)
|
53 |
+
|
54 |
+
return output_path
|
55 |
+
|
56 |
+
|
57 |
+
def extract_frames(video_path: str, duration: float = 20.0, fps: float = 3.0) -> str:
|
58 |
temp_dir = tempfile.mkdtemp()
|
59 |
output_path = os.path.join(temp_dir, "%03d.jpg")
|
60 |
+
|
61 |
+
filter_complex = f"select='if(lt(t,{duration}),1,0)',fps={fps}"
|
62 |
+
|
63 |
command = [
|
64 |
"ffmpeg",
|
65 |
"-i", video_path,
|
66 |
+
"-vf", filter_complex,
|
67 |
+
"-vsync", "0",
|
68 |
output_path
|
69 |
]
|
70 |
+
|
71 |
subprocess.run(command, check=True)
|
72 |
return temp_dir
|
73 |
|
|
|
174 |
# Convert mask to numpy array
|
175 |
mask_np = np.array(mask) / 255.0
|
176 |
return mask_np
|
|
|
177 |
@torch.no_grad()
|
178 |
+
def reconstruct(video_path, conf_thresh, kf_every,
|
179 |
+
as_pointcloud=False, remove_background=False, refine=False):
|
180 |
# Extract frames from video
|
181 |
demo_path = extract_frames(video_path)
|
182 |
|
|
|
198 |
print(f'Finished reconstruction for {demo_name}, FPS: {fps:.2f}')
|
199 |
|
200 |
# Process results
|
201 |
+
pcds = []
|
202 |
for j, view in enumerate(batch):
|
203 |
image = view['img'].permute(0, 2, 3, 1).cpu().numpy()[0]
|
204 |
+
image = (image + 1) / 2
|
205 |
pts = preds[j]['pts3d' if j==0 else 'pts3d_in_other_view'].detach().cpu().numpy()[0]
|
206 |
+
pts_normal = pts2normal(preds[j]['pts3d' if j==0 else 'pts3d_in_other_view'][0]).cpu().numpy()
|
207 |
conf = preds[j]['conf'][0].cpu().data.numpy()
|
208 |
+
conf_sig = (conf - 1) / conf
|
209 |
if remove_background:
|
210 |
mask = generate_mask(image)
|
211 |
else:
|
212 |
mask = np.ones_like(conf)
|
213 |
+
|
214 |
+
combined_mask = (conf_sig > conf_thresh) & (mask > 0.5)
|
215 |
|
216 |
+
pcd = o3d.geometry.PointCloud()
|
217 |
+
pcd.points = o3d.utility.Vector3dVector(pts[combined_mask])
|
218 |
+
pcd.colors = o3d.utility.Vector3dVector(image[combined_mask])
|
219 |
+
pcd.normals = o3d.utility.Vector3dVector(pts_normal[combined_mask])
|
220 |
+
pcds.append(pcd)
|
|
|
|
|
|
|
|
|
221 |
|
222 |
+
pcd_combined = combine_and_clean_point_clouds(pcds, voxel_size=0.001)
|
|
|
|
|
223 |
|
224 |
+
if as_pointcloud:
|
225 |
+
o3d_geometry = pcd_combined
|
226 |
+
else:
|
227 |
+
o3d_geometry = point2mesh(pcd_combined)
|
228 |
+
|
229 |
# Create coarse result
|
230 |
+
coarse_output_path = export_geometry(o3d_geometry, as_pointcloud)
|
|
|
|
|
|
|
231 |
|
232 |
+
yield coarse_output_path, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
|
234 |
+
if refine:
|
235 |
+
# Perform global optimization
|
236 |
+
print("Performing global registration...")
|
237 |
+
transformed_pcds, _, _ = improved_multiway_registration(pcds, voxel_size=0.001)
|
|
|
|
|
|
|
|
|
|
|
238 |
|
239 |
+
if as_pointcloud:
|
240 |
+
o3d_geometry = transformed_pcds
|
241 |
+
else:
|
242 |
+
o3d_geometry = point2mesh(transformed_pcds)
|
243 |
+
|
244 |
+
# Create coarse result
|
245 |
+
refined_output_path = export_geometry(o3d_geometry, as_pointcloud)
|
246 |
|
247 |
+
yield coarse_output_path, refined_output_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
248 |
|
249 |
# Clean up temporary directory
|
250 |
os.system(f"rm -rf {demo_path}")
|
|
|
|
|
251 |
|
252 |
+
# Update the Gradio interface with improved layout
|
253 |
+
with gr.Blocks(
|
254 |
+
title="StableSpann3r: Making Spann3r stable with Odometry Backend",
|
255 |
+
css="""
|
256 |
+
#download {
|
257 |
+
height: 118px;
|
258 |
+
}
|
259 |
+
.slider .inner {
|
260 |
+
width: 5px;
|
261 |
+
background: #FFF;
|
262 |
+
}
|
263 |
+
.viewport {
|
264 |
+
aspect-ratio: 4/3;
|
265 |
+
}
|
266 |
+
.tabs button.selected {
|
267 |
+
font-size: 20px !important;
|
268 |
+
color: crimson !important;
|
269 |
+
}
|
270 |
+
h1 {
|
271 |
+
text-align: center;
|
272 |
+
display: block;
|
273 |
+
}
|
274 |
+
h2 {
|
275 |
+
text-align: center;
|
276 |
+
display: block;
|
277 |
+
}
|
278 |
+
h3 {
|
279 |
+
text-align: center;
|
280 |
+
display: block;
|
281 |
+
}
|
282 |
+
.md_feedback li {
|
283 |
+
margin-bottom: 0px !important;
|
284 |
+
}
|
285 |
+
""",
|
286 |
+
head="""
|
287 |
+
<script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script>
|
288 |
+
<script>
|
289 |
+
window.dataLayer = window.dataLayer || [];
|
290 |
+
function gtag() {dataLayer.push(arguments);}
|
291 |
+
gtag('js', new Date());
|
292 |
+
gtag('config', 'G-1FWSVCGZTG');
|
293 |
+
</script>
|
294 |
+
""",
|
295 |
+
) as iface:
|
296 |
+
gr.Markdown(
|
297 |
+
"""
|
298 |
+
# StableSpann3r: Making Spann3r stable with Odometry Backend
|
299 |
+
<p align="center">
|
300 |
+
<a title="Website" href="https://stable-x.github.io/StableSpann3r/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
|
301 |
+
<img src="https://www.obukhov.ai/img/badges/badge-website.svg">
|
302 |
+
</a>
|
303 |
+
<a title="arXiv" href="https://arxiv.org/abs/XXXX.XXXXX" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
|
304 |
+
<img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
|
305 |
+
</a>
|
306 |
+
<a title="Github" href="https://github.com/Stable-X/StableSpann3r" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
|
307 |
+
<img src="https://img.shields.io/github/stars/Stable-X/StableSpann3r?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
|
308 |
+
</a>
|
309 |
+
<a title="Social" href="https://x.com/ychngji6" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
|
310 |
+
<img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
|
311 |
+
</a>
|
312 |
+
</p>
|
313 |
+
"""
|
314 |
+
)
|
315 |
+
with gr.Row():
|
316 |
+
with gr.Column(scale=1):
|
317 |
+
video_input = gr.Video(label="Input Video")
|
318 |
+
with gr.Row():
|
319 |
+
conf_thresh = gr.Slider(0, 1, value=1e-3, label="Confidence Threshold")
|
320 |
+
kf_every = gr.Slider(1, 30, step=1, value=1, label="Keyframe Interval")
|
321 |
+
with gr.Row():
|
322 |
+
remove_background = gr.Checkbox(label="Remove Background", value=False)
|
323 |
+
refine = gr.Checkbox(label="Enable Backend", value=False)
|
324 |
+
as_pointcloud = gr.Checkbox(label="As Pointcloud", value=False)
|
325 |
+
reconstruct_btn = gr.Button("Reconstruct")
|
326 |
+
|
327 |
+
with gr.Column(scale=2):
|
328 |
+
with gr.Tab("Coarse Model"):
|
329 |
+
coarse_model = gr.Model3D(label="Coarse 3D Model", display_mode="solid", clear_color=[0.0, 0.0, 0.0, 0.0])
|
330 |
+
with gr.Tab("Refined Model"):
|
331 |
+
refined_model = gr.Model3D(label="Refined 3D Model", display_mode="solid", clear_color=[0.0, 0.0, 0.0, 0.0])
|
332 |
|
333 |
+
reconstruct_btn.click(
|
334 |
+
fn=reconstruct,
|
335 |
+
inputs=[video_input, conf_thresh, kf_every, as_pointcloud, remove_background, refine],
|
336 |
+
outputs=[coarse_model, refined_model]
|
337 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
338 |
|
339 |
if __name__ == "__main__":
|
340 |
+
iface.launch(server_name="0.0.0.0")
|
backend_utils.py
CHANGED
@@ -1,90 +1,152 @@
|
|
1 |
import numpy as np
|
2 |
import open3d as o3d
|
|
|
|
|
|
|
3 |
|
4 |
-
def
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
-
|
32 |
-
|
33 |
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
o3d.geometry.KDTreeSearchParamHybrid(radius=radius * 2, max_nn=30))
|
38 |
|
39 |
-
|
40 |
-
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
o3d.pipelines.registration.TransformationEstimationForColoredICP(),
|
43 |
o3d.pipelines.registration.ICPConvergenceCriteria(relative_fitness=1e-6,
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
except RuntimeError as e:
|
48 |
-
print(f"Colored ICP failed at scale {scale}: {str(e)}")
|
49 |
-
print("Keeping the previous transformation")
|
50 |
-
# We keep the previous transformation, no need to reassign
|
51 |
-
|
52 |
-
transformation_icp = current_transformation
|
53 |
-
else:
|
54 |
-
print("Apply point-to-plane ICP")
|
55 |
-
try:
|
56 |
-
icp_coarse = o3d.pipelines.registration.registration_icp(
|
57 |
-
source, target, max_correspondence_distance_coarse, current_transformation,
|
58 |
-
o3d.pipelines.registration.TransformationEstimationPointToPlane())
|
59 |
-
current_transformation = icp_coarse.transformation
|
60 |
-
|
61 |
icp_fine = o3d.pipelines.registration.registration_icp(
|
62 |
source, target, max_correspondence_distance_fine,
|
63 |
current_transformation,
|
64 |
o3d.pipelines.registration.TransformationEstimationPointToPlane())
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
transformation_icp = current_transformation
|
70 |
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
except RuntimeError as e:
|
76 |
-
print(f"
|
77 |
-
|
78 |
-
information_icp = np.identity(6)
|
79 |
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
|
|
|
|
87 |
|
|
|
|
|
|
|
88 |
pairs = []
|
89 |
for i in range(n_pcds - 1):
|
90 |
for j in range(i + 1, min(i + overlap + 1, n_pcds)):
|
@@ -93,52 +155,122 @@ def improved_multiway_registration(pcds, voxel_size=0.05, max_correspondence_dis
|
|
93 |
q = 2**(j-i)
|
94 |
if q > overlap and i + q < n_pcds:
|
95 |
pairs.append((i, i + q))
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
-
for
|
98 |
-
|
|
|
|
|
|
|
|
|
99 |
pcds_down[source_id], pcds_down[target_id], use_colored_icp,
|
100 |
-
|
101 |
-
print(f"Build PoseGraph: {source_id} -> {target_id}")
|
102 |
-
|
103 |
-
if target_id == source_id + 1:
|
104 |
-
odometry = np.dot(transformation_icp, odometry)
|
105 |
-
pose_graph.nodes.append(
|
106 |
-
o3d.pipelines.registration.PoseGraphNode(
|
107 |
-
np.linalg.inv(odometry)))
|
108 |
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
return pose_graph
|
116 |
|
117 |
-
|
118 |
-
print("
|
119 |
-
|
|
|
|
|
|
|
|
|
120 |
|
121 |
-
print("
|
122 |
-
|
123 |
-
|
124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
option = o3d.pipelines.registration.GlobalOptimizationOption(
|
126 |
max_correspondence_distance=max_correspondence_distance_fine,
|
127 |
edge_prune_threshold=0.25,
|
128 |
reference_node=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
|
130 |
-
|
131 |
-
o3d.pipelines.registration.global_optimization(
|
132 |
-
pose_graph,
|
133 |
-
o3d.pipelines.registration.GlobalOptimizationLevenbergMarquardt(),
|
134 |
-
o3d.pipelines.registration.GlobalOptimizationConvergenceCriteria(),
|
135 |
-
option)
|
136 |
-
|
137 |
-
print("Transform points and combine")
|
138 |
pcd_combined = o3d.geometry.PointCloud()
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
|
144 |
-
return
|
|
|
1 |
import numpy as np
|
2 |
import open3d as o3d
|
3 |
+
import torch
|
4 |
+
from tqdm import tqdm
|
5 |
+
import torch.nn.functional as F
|
6 |
|
7 |
+
def pts2normal(pts):
|
8 |
+
h, w, _ = pts.shape
|
9 |
+
|
10 |
+
# Compute differences in x and y directions
|
11 |
+
dx = torch.cat([pts[2:, 1:-1] - pts[:-2, 1:-1]], dim=0)
|
12 |
+
dy = torch.cat([pts[1:-1, 2:] - pts[1:-1, :-2]], dim=1)
|
13 |
+
|
14 |
+
# Compute normal vectors using cross product
|
15 |
+
normal_map = F.normalize(torch.cross(dx, dy, dim=-1), dim=-1)
|
16 |
+
|
17 |
+
# Create padded normal map
|
18 |
+
padded_normal_map = torch.zeros_like(pts)
|
19 |
+
padded_normal_map[1:-1, 1:-1, :] = normal_map
|
20 |
+
|
21 |
+
# Pad the borders
|
22 |
+
padded_normal_map[0, 1:-1, :] = normal_map[0, :, :] # Top edge
|
23 |
+
padded_normal_map[-1, 1:-1, :] = normal_map[-1, :, :] # Bottom edge
|
24 |
+
padded_normal_map[1:-1, 0, :] = normal_map[:, 0, :] # Left edge
|
25 |
+
padded_normal_map[1:-1, -1, :] = normal_map[:, -1, :] # Right edge
|
26 |
+
|
27 |
+
# Pad the corners
|
28 |
+
padded_normal_map[0, 0, :] = normal_map[0, 0, :] # Top-left corner
|
29 |
+
padded_normal_map[0, -1, :] = normal_map[0, -1, :] # Top-right corner
|
30 |
+
padded_normal_map[-1, 0, :] = normal_map[-1, 0, :] # Bottom-left corner
|
31 |
+
padded_normal_map[-1, -1, :] = normal_map[-1, -1, :] # Bottom-right corner
|
32 |
+
|
33 |
+
return padded_normal_map
|
34 |
+
|
35 |
+
def point2mesh(pcd, depth=8, density_threshold=0.1, clean_mesh=True):
|
36 |
+
print("\nPerforming Poisson surface reconstruction...")
|
37 |
+
mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(
|
38 |
+
pcd, depth=depth, width=0, scale=1.1, linear_fit=False)
|
39 |
+
|
40 |
+
print(f"Reconstructed mesh has {len(mesh.vertices)} vertices and {len(mesh.triangles)} triangles")
|
41 |
+
|
42 |
+
# Normalize densities
|
43 |
+
densities = np.asarray(densities)
|
44 |
+
densities = (densities - densities.min()) / (densities.max() - densities.min())
|
45 |
+
|
46 |
+
# Remove low density vertices
|
47 |
+
print("\nPruning low-density vertices...")
|
48 |
+
vertices_to_remove = densities < np.quantile(densities, density_threshold)
|
49 |
+
mesh.remove_vertices_by_mask(vertices_to_remove)
|
50 |
+
|
51 |
+
print(f"Pruned mesh has {len(mesh.vertices)} vertices and {len(mesh.triangles)} triangles")
|
52 |
+
|
53 |
+
if clean_mesh:
|
54 |
+
print("\nCleaning the mesh...")
|
55 |
+
mesh.remove_degenerate_triangles()
|
56 |
+
mesh.remove_duplicated_triangles()
|
57 |
+
mesh.remove_duplicated_vertices()
|
58 |
+
mesh.remove_non_manifold_edges()
|
59 |
+
|
60 |
+
print(f"Final cleaned mesh has {len(mesh.vertices)} vertices and {len(mesh.triangles)} triangles")
|
61 |
|
62 |
+
mesh.compute_triangle_normals()
|
63 |
+
return mesh
|
64 |
|
65 |
+
def combine_and_clean_point_clouds(pcds, voxel_size):
|
66 |
+
"""
|
67 |
+
Combine, downsample, and clean a list of point clouds.
|
|
|
68 |
|
69 |
+
Parameters:
|
70 |
+
pcds (list): List of open3d.geometry.PointCloud objects to be processed.
|
71 |
+
voxel_size (float): The size of the voxel for downsampling.
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
o3d.geometry.PointCloud: The cleaned and combined point cloud.
|
75 |
+
"""
|
76 |
+
print("\nCombining point clouds...")
|
77 |
+
pcd_combined = o3d.geometry.PointCloud()
|
78 |
+
for p3d in pcds:
|
79 |
+
pcd_combined += p3d
|
80 |
+
|
81 |
+
print("\nDownsampling the combined point cloud...")
|
82 |
+
pcd_combined = pcd_combined.voxel_down_sample(voxel_size)
|
83 |
+
print(f"Downsampled from {len(pcd_combined.points)} to {len(pcd_combined.points)} points")
|
84 |
+
|
85 |
+
print("\nCleaning the combined point cloud...")
|
86 |
+
cl, ind = pcd_combined.remove_statistical_outlier(nb_neighbors=20, std_ratio=2.0)
|
87 |
+
pcd_cleaned = pcd_combined.select_by_index(ind)
|
88 |
+
|
89 |
+
print(f"Cleaned point cloud contains {len(pcd_cleaned.points)} points.")
|
90 |
+
print(f"Removed {len(pcd_combined.points) - len(pcd_cleaned.points)} outlier points.")
|
91 |
+
|
92 |
+
return pcd_cleaned
|
93 |
+
|
94 |
+
def improved_multiway_registration(pcds, descriptors=None, voxel_size=0.05,
|
95 |
+
max_correspondence_distance_coarse=None, max_correspondence_distance_fine=None,
|
96 |
+
overlap=5, quadratic_overlap=False, use_colored_icp=False):
|
97 |
+
if max_correspondence_distance_coarse is None:
|
98 |
+
max_correspondence_distance_coarse = voxel_size * 1.5
|
99 |
+
if max_correspondence_distance_fine is None:
|
100 |
+
max_correspondence_distance_fine = voxel_size * 0.15
|
101 |
+
|
102 |
+
def pairwise_registration(source, target, use_colored_icp, max_correspondence_distance_coarse, max_correspondence_distance_fine):
|
103 |
+
current_transformation = np.identity(4)
|
104 |
+
try:
|
105 |
+
if use_colored_icp:
|
106 |
+
icp_fine = o3d.pipelines.registration.registration_colored_icp(
|
107 |
+
source, target, max_correspondence_distance_fine, current_transformation,
|
108 |
o3d.pipelines.registration.TransformationEstimationForColoredICP(),
|
109 |
o3d.pipelines.registration.ICPConvergenceCriteria(relative_fitness=1e-6,
|
110 |
+
relative_rmse=1e-6,
|
111 |
+
max_iteration=100))
|
112 |
+
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
icp_fine = o3d.pipelines.registration.registration_icp(
|
114 |
source, target, max_correspondence_distance_fine,
|
115 |
current_transformation,
|
116 |
o3d.pipelines.registration.TransformationEstimationPointToPlane())
|
117 |
+
|
118 |
+
|
119 |
+
fitness = icp_fine.fitness
|
120 |
+
FITNESS_THRESHOLD = 0.01
|
|
|
121 |
|
122 |
+
if fitness >= FITNESS_THRESHOLD:
|
123 |
+
current_transformation = icp_fine.transformation
|
124 |
+
|
125 |
+
information_icp = o3d.pipelines.registration.get_information_matrix_from_point_clouds(
|
126 |
+
source, target, max_correspondence_distance_fine,
|
127 |
+
current_transformation)
|
128 |
+
return current_transformation, information_icp, True
|
129 |
+
else:
|
130 |
+
print(f"Registration failed. Fitness {fitness} is below threshold {FITNESS_THRESHOLD}")
|
131 |
+
return None, None, False
|
132 |
+
|
133 |
except RuntimeError as e:
|
134 |
+
print(f" ICP registration failed: {str(e)}")
|
135 |
+
return None, None, False
|
|
|
136 |
|
137 |
+
def detect_loop_closure(descriptors, min_interval=3, similarity_threshold=0.9):
|
138 |
+
n_pcds = len(descriptors)
|
139 |
+
loop_edges = []
|
140 |
+
|
141 |
+
for i in range(n_pcds):
|
142 |
+
for j in range(i + min_interval, n_pcds):
|
143 |
+
similarity = torch.dot(descriptors[i], descriptors[j])
|
144 |
+
if similarity > similarity_threshold:
|
145 |
+
loop_edges.append((i, j))
|
146 |
|
147 |
+
return loop_edges
|
148 |
+
|
149 |
+
def generate_pairs(n_pcds, overlap, quadratic_overlap):
|
150 |
pairs = []
|
151 |
for i in range(n_pcds - 1):
|
152 |
for j in range(i + 1, min(i + overlap + 1, n_pcds)):
|
|
|
155 |
q = 2**(j-i)
|
156 |
if q > overlap and i + q < n_pcds:
|
157 |
pairs.append((i, i + q))
|
158 |
+
return pairs
|
159 |
+
|
160 |
+
def full_registration(pcds_down, pairs, loop_edges):
|
161 |
+
pose_graph = o3d.pipelines.registration.PoseGraph()
|
162 |
+
n_pcds = len(pcds_down)
|
163 |
|
164 |
+
for i in range(n_pcds):
|
165 |
+
pose_graph.nodes.append(o3d.pipelines.registration.PoseGraphNode(np.identity(4)))
|
166 |
+
|
167 |
+
print("\nPerforming pairwise registration:")
|
168 |
+
for source_id, target_id in tqdm(pairs):
|
169 |
+
transformation_icp, information_icp, success = pairwise_registration(
|
170 |
pcds_down[source_id], pcds_down[target_id], use_colored_icp,
|
171 |
+
max_correspondence_distance_coarse, max_correspondence_distance_fine)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
|
173 |
+
if success:
|
174 |
+
uncertain = abs(target_id - source_id) == 1
|
175 |
+
pose_graph.edges.append(
|
176 |
+
o3d.pipelines.registration.PoseGraphEdge(source_id,
|
177 |
+
target_id,
|
178 |
+
transformation_icp,
|
179 |
+
information_icp,
|
180 |
+
uncertain=uncertain))
|
181 |
+
else:
|
182 |
+
print(f" Skipping edge between {source_id} and {target_id} due to ICP failure")
|
183 |
+
|
184 |
+
# Add loop closure edges
|
185 |
+
print("\nAdding loop closure edges:")
|
186 |
+
for source_id, target_id in tqdm(loop_edges):
|
187 |
+
transformation_icp, information_icp, success = pairwise_registration(
|
188 |
+
pcds_down[source_id], pcds_down[target_id], use_colored_icp,
|
189 |
+
max_correspondence_distance_coarse, max_correspondence_distance_fine)
|
190 |
+
|
191 |
+
if success:
|
192 |
+
pose_graph.edges.append(
|
193 |
+
o3d.pipelines.registration.PoseGraphEdge(source_id,
|
194 |
+
target_id,
|
195 |
+
transformation_icp,
|
196 |
+
information_icp,
|
197 |
+
uncertain=True))
|
198 |
+
else:
|
199 |
+
print(f" Skipping loop closure edge between {source_id} and {target_id} due to ICP failure")
|
200 |
+
|
201 |
return pose_graph
|
202 |
|
203 |
+
print("\n--- Improved Multiway Registration Process ---")
|
204 |
+
print(f"Number of point clouds: {len(pcds)}")
|
205 |
+
print(f"Voxel size: {voxel_size}")
|
206 |
+
print(f"Max correspondence distance (coarse): {max_correspondence_distance_coarse}")
|
207 |
+
print(f"Max correspondence distance (fine): {max_correspondence_distance_fine}")
|
208 |
+
print(f"Overlap: {overlap}")
|
209 |
+
print(f"Quadratic overlap: {quadratic_overlap}")
|
210 |
|
211 |
+
print("\nPreprocessing point clouds...")
|
212 |
+
pcds_down = pcds
|
213 |
+
print(f"Preprocessing complete. {len(pcds_down)} point clouds processed.")
|
214 |
+
|
215 |
+
print("\nGenerating initial graph pairs...")
|
216 |
+
pairs = generate_pairs(len(pcds), overlap, quadratic_overlap)
|
217 |
+
print(f"Generated {len(pairs)} pairs for initial graph.")
|
218 |
+
|
219 |
+
if descriptors is None:
|
220 |
+
print("\nNo descriptors provided. Skipping loop closure detection.")
|
221 |
+
loop_edges = []
|
222 |
+
else:
|
223 |
+
print(descriptors[0].shape)
|
224 |
+
print("\nDetecting loop closures...")
|
225 |
+
loop_edges = detect_loop_closure(descriptors)
|
226 |
+
print(f"Detected {len(loop_edges)} loop closures.")
|
227 |
+
|
228 |
+
print("\nPerforming full registration...")
|
229 |
+
pose_graph = full_registration(pcds_down, pairs, loop_edges)
|
230 |
+
|
231 |
+
print("\nOptimizing PoseGraph...")
|
232 |
option = o3d.pipelines.registration.GlobalOptimizationOption(
|
233 |
max_correspondence_distance=max_correspondence_distance_fine,
|
234 |
edge_prune_threshold=0.25,
|
235 |
reference_node=0)
|
236 |
+
o3d.pipelines.registration.global_optimization(
|
237 |
+
pose_graph,
|
238 |
+
o3d.pipelines.registration.GlobalOptimizationLevenbergMarquardt(),
|
239 |
+
o3d.pipelines.registration.GlobalOptimizationConvergenceCriteria(),
|
240 |
+
option)
|
241 |
+
|
242 |
+
# Count edges for each node
|
243 |
+
edge_count = {i: 0 for i in range(len(pcds))}
|
244 |
+
for edge in pose_graph.edges:
|
245 |
+
edge_count[edge.source_node_id] += 1
|
246 |
+
edge_count[edge.target_node_id] += 1
|
247 |
+
|
248 |
+
# Filter nodes with more than 3 edges
|
249 |
+
valid_nodes = [count > 3 for count in edge_count.values()]
|
250 |
|
251 |
+
print("\nTransforming and combining point clouds...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
252 |
pcd_combined = o3d.geometry.PointCloud()
|
253 |
+
|
254 |
+
for point_id, is_valid in enumerate(valid_nodes):
|
255 |
+
if is_valid:
|
256 |
+
pcds[point_id].transform(pose_graph.nodes[point_id].pose)
|
257 |
+
pcd_combined += pcds[point_id]
|
258 |
+
else:
|
259 |
+
print(f"Skipping point cloud {point_id} as it has {edge_count[point_id]} edges (<=3)")
|
260 |
+
|
261 |
+
print("\nDownsampling the combined point cloud...")
|
262 |
+
# pcd_combined.orient_normals_consistent_tangent_plane(k=30)
|
263 |
+
pcd_combined = pcd_combined.voxel_down_sample(voxel_size * 0.1)
|
264 |
+
print(f"Downsampled from {len(pcd_combined.points)} to {len(pcd_combined.points)} points")
|
265 |
+
|
266 |
+
print("\nCleaning the combined point cloud...")
|
267 |
+
cl, ind = pcd_combined.remove_statistical_outlier(nb_neighbors=20, std_ratio=2.0)
|
268 |
+
pcd_cleaned = pcd_combined.select_by_index(ind)
|
269 |
+
|
270 |
+
print(f"Cleaned point cloud contains {len(pcd_cleaned.points)} points.")
|
271 |
+
print(f"Removed {len(pcd_combined.points) - len(pcd_cleaned.points)} outlier points.")
|
272 |
+
|
273 |
+
print("\nMultiway registration complete.")
|
274 |
+
print(f"Included {len(valid_nodes)} out of {len(pcds)} point clouds (with >3 edges).")
|
275 |
|
276 |
+
return pcd_cleaned, pose_graph, valid_nodes
|