Spaces:
Sleeping
Sleeping
feat: Add pose_utils to solve camera and depth
Browse files- demo.py +10 -2
- pose_utils.py +126 -0
demo.py
CHANGED
@@ -15,6 +15,7 @@ from spann3r.datasets import *
|
|
15 |
from torch.utils.data import DataLoader
|
16 |
from spann3r.tools.eval_recon import accuracy, completion
|
17 |
from spann3r.tools.vis import render_frames, find_render_cam, vis_pred_and_imgs
|
|
|
18 |
from backend_utils import improved_multiway_registration, pts2normal, point2mesh, combine_and_clean_point_clouds
|
19 |
|
20 |
def get_args_parser():
|
@@ -63,7 +64,7 @@ def main(args):
|
|
63 |
model.load_state_dict(torch.load(args.ckpt_path)['model'])
|
64 |
model.eval()
|
65 |
|
66 |
-
if args.demo_path.endswith('.mp4') or args.demo_path.endswith('.avi') or args.demo_path.endswith('.
|
67 |
args.demo_path = extract_frames(args.demo_path)
|
68 |
args.kf_every = 1
|
69 |
|
@@ -139,10 +140,15 @@ def main(args):
|
|
139 |
conf_sig = (conf - 1) / conf
|
140 |
pts_gt = view['pts3d'].cpu().numpy()[0]
|
141 |
|
|
|
|
|
|
|
|
|
142 |
images_all.append((image[None, ...] + 1.0)/2.0)
|
143 |
pts_all.append(pts[None, ...])
|
144 |
pts_normal_all.append(pts_normal[None, ...])
|
145 |
pts_gt_all.append(pts_gt[None, ...])
|
|
|
146 |
masks_all.append(mask[None, ...])
|
147 |
conf_sig_all.append(conf_sig[None, ...])
|
148 |
|
@@ -163,7 +169,9 @@ def main(args):
|
|
163 |
pcd.normals = o3d.utility.Vector3dVector(pts_normal_all[j][mask])
|
164 |
pcds.append(pcd)
|
165 |
|
166 |
-
|
|
|
|
|
167 |
mesh_recon = point2mesh(pcd_combined)
|
168 |
|
169 |
|
|
|
15 |
from torch.utils.data import DataLoader
|
16 |
from spann3r.tools.eval_recon import accuracy, completion
|
17 |
from spann3r.tools.vis import render_frames, find_render_cam, vis_pred_and_imgs
|
18 |
+
from pose_utils import solve_cemara
|
19 |
from backend_utils import improved_multiway_registration, pts2normal, point2mesh, combine_and_clean_point_clouds
|
20 |
|
21 |
def get_args_parser():
|
|
|
64 |
model.load_state_dict(torch.load(args.ckpt_path)['model'])
|
65 |
model.eval()
|
66 |
|
67 |
+
if args.demo_path.endswith('.mp4') or args.demo_path.endswith('.avi') or args.demo_path.endswith('.webm'):
|
68 |
args.demo_path = extract_frames(args.demo_path)
|
69 |
args.kf_every = 1
|
70 |
|
|
|
140 |
conf_sig = (conf - 1) / conf
|
141 |
pts_gt = view['pts3d'].cpu().numpy()[0]
|
142 |
|
143 |
+
camera, last_focal, depth_map = solve_cemara(torch.tensor(pts), torch.tensor(conf_sig) > args.conf_thresh,
|
144 |
+
args.device, focal=last_focal)
|
145 |
+
pts_scale = depth_map / last_focal
|
146 |
+
|
147 |
images_all.append((image[None, ...] + 1.0)/2.0)
|
148 |
pts_all.append(pts[None, ...])
|
149 |
pts_normal_all.append(pts_normal[None, ...])
|
150 |
pts_gt_all.append(pts_gt[None, ...])
|
151 |
+
pts_scale_all.append(pts_scale[None, ...])
|
152 |
masks_all.append(mask[None, ...])
|
153 |
conf_sig_all.append(conf_sig[None, ...])
|
154 |
|
|
|
169 |
pcd.normals = o3d.utility.Vector3dVector(pts_normal_all[j][mask])
|
170 |
pcds.append(pcd)
|
171 |
|
172 |
+
print("Performing global registration...")
|
173 |
+
pcd_combined, _, _ = improved_multiway_registration(pcds, voxel_size=0.001)
|
174 |
+
# pcd_combined = combine_and_clean_point_clouds(transformed_pcds, voxel_size=args.voxel_size * 0.1)
|
175 |
mesh_recon = point2mesh(pcd_combined)
|
176 |
|
177 |
|
pose_utils.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import cv2
|
4 |
+
import open3d as o3d
|
5 |
+
from dust3r.post_process import estimate_focal_knowing_depth
|
6 |
+
from dust3r.utils.geometry import inv
|
7 |
+
|
8 |
+
def estimate_focal(pts3d_i, pp=None):
|
9 |
+
if pp is None:
|
10 |
+
H, W, THREE = pts3d_i.shape
|
11 |
+
assert THREE == 3
|
12 |
+
pp = torch.tensor((W/2, H/2), device=pts3d_i.device)
|
13 |
+
focal = estimate_focal_knowing_depth(pts3d_i.unsqueeze(0), pp.unsqueeze(0), focal_mode='weiszfeld').ravel()
|
14 |
+
return float(focal)
|
15 |
+
|
16 |
+
def pixel_grid(H, W):
|
17 |
+
return np.mgrid[:W, :H].T.astype(np.float32)
|
18 |
+
|
19 |
+
def sRT_to_4x4(scale, R, T, device):
|
20 |
+
trf = torch.eye(4, device=device)
|
21 |
+
trf[:3, :3] = R * scale
|
22 |
+
trf[:3, 3] = T.ravel() # doesn't need scaling
|
23 |
+
return trf
|
24 |
+
|
25 |
+
def to_numpy(tensor):
|
26 |
+
return tensor.cpu().numpy() if isinstance(tensor, torch.Tensor) else tensor
|
27 |
+
|
28 |
+
def calculate_depth_map(pts3d, R, T):
|
29 |
+
"""
|
30 |
+
Calculate ray depths directly using camera center and 3D points.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
pts3d (np.array): 3D points in world coordinates, shape (H, W, 3)
|
34 |
+
R (np.array): Rotation matrix, shape (3, 3)
|
35 |
+
T (np.array): Translation vector, shape (3, 1)
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
np.array: Depth map of shape (H, W)
|
39 |
+
"""
|
40 |
+
# Camera center in world coordinates is simply -T
|
41 |
+
C = -T.ravel()
|
42 |
+
|
43 |
+
# Calculate ray vectors
|
44 |
+
ray_vectors = pts3d - C
|
45 |
+
|
46 |
+
# Calculate ray depths
|
47 |
+
depth_map = np.linalg.norm(ray_vectors, axis=2)
|
48 |
+
|
49 |
+
return depth_map
|
50 |
+
|
51 |
+
def fast_pnp(pts3d, focal, msk, device, pp=None, niter_PnP=10):
|
52 |
+
# extract camera poses and focals with RANSAC-PnP
|
53 |
+
if msk.sum() < 4:
|
54 |
+
return None # we need at least 4 points for PnP
|
55 |
+
pts3d, msk = map(to_numpy, (pts3d, msk))
|
56 |
+
|
57 |
+
H, W, THREE = pts3d.shape
|
58 |
+
assert THREE == 3
|
59 |
+
pixels = pixel_grid(H, W)
|
60 |
+
|
61 |
+
if focal is None:
|
62 |
+
S = max(W, H)
|
63 |
+
tentative_focals = np.geomspace(S/2, S*3, 21)
|
64 |
+
else:
|
65 |
+
tentative_focals = [focal]
|
66 |
+
|
67 |
+
if pp is None:
|
68 |
+
pp = (W/2, H/2)
|
69 |
+
else:
|
70 |
+
pp = to_numpy(pp)
|
71 |
+
|
72 |
+
best = 0, None, None, None, None
|
73 |
+
for focal in tentative_focals:
|
74 |
+
K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)])
|
75 |
+
|
76 |
+
success, R, T, inliers = cv2.solvePnPRansac(pts3d[msk], pixels[msk], K, None,
|
77 |
+
iterationsCount=niter_PnP, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP)
|
78 |
+
|
79 |
+
if not success:
|
80 |
+
continue
|
81 |
+
|
82 |
+
score = len(inliers)
|
83 |
+
if success and score > best[0]:
|
84 |
+
depth_map = calculate_depth_map(pts3d, R, T)
|
85 |
+
best = score, R, T, focal, depth_map
|
86 |
+
|
87 |
+
if not best[0]:
|
88 |
+
return None
|
89 |
+
|
90 |
+
_, R, T, best_focal, depth_map = best
|
91 |
+
R = cv2.Rodrigues(R)[0] # world to cam
|
92 |
+
R, T = map(torch.from_numpy, (R, T))
|
93 |
+
depth_map = torch.from_numpy(depth_map).to(device)
|
94 |
+
|
95 |
+
cam_to_world = inv(sRT_to_4x4(1, R, T, device)) # cam to world
|
96 |
+
|
97 |
+
return best_focal, cam_to_world, depth_map
|
98 |
+
|
99 |
+
def solve_cemara(pts3d, msk, device, focal=None, pp=None):
|
100 |
+
# Estimate focal length
|
101 |
+
if focal is None:
|
102 |
+
focal = estimate_focal(pts3d, pp)
|
103 |
+
|
104 |
+
# Compute camera pose using PnP
|
105 |
+
result = fast_pnp(pts3d, focal, msk, device, pp)
|
106 |
+
|
107 |
+
if result is None:
|
108 |
+
return None, focal, None
|
109 |
+
|
110 |
+
best_focal, camera_to_world, depth_map = result
|
111 |
+
|
112 |
+
# Construct K matrix
|
113 |
+
H, W, _ = pts3d.shape
|
114 |
+
if pp is None:
|
115 |
+
pp = (W/2, H/2)
|
116 |
+
|
117 |
+
camera_parameters = o3d.camera.PinholeCameraParameters()
|
118 |
+
intrinsic = o3d.camera.PinholeCameraIntrinsic()
|
119 |
+
intrinsic.set_intrinsics(W, H,
|
120 |
+
best_focal, best_focal,
|
121 |
+
pp[0], pp[1])
|
122 |
+
|
123 |
+
camera_parameters.intrinsic = intrinsic
|
124 |
+
camera_parameters.extrinsic = torch.inverse(camera_to_world).cpu().numpy()
|
125 |
+
|
126 |
+
return camera_parameters, best_focal, depth_map
|