Stable-X commited on
Commit
82b898c
1 Parent(s): d7542a1

feat: Add pose_utils to solve camera and depth

Browse files
Files changed (2) hide show
  1. demo.py +10 -2
  2. pose_utils.py +126 -0
demo.py CHANGED
@@ -15,6 +15,7 @@ from spann3r.datasets import *
15
  from torch.utils.data import DataLoader
16
  from spann3r.tools.eval_recon import accuracy, completion
17
  from spann3r.tools.vis import render_frames, find_render_cam, vis_pred_and_imgs
 
18
  from backend_utils import improved_multiway_registration, pts2normal, point2mesh, combine_and_clean_point_clouds
19
 
20
  def get_args_parser():
@@ -63,7 +64,7 @@ def main(args):
63
  model.load_state_dict(torch.load(args.ckpt_path)['model'])
64
  model.eval()
65
 
66
- if args.demo_path.endswith('.mp4') or args.demo_path.endswith('.avi') or args.demo_path.endswith('.MOV'):
67
  args.demo_path = extract_frames(args.demo_path)
68
  args.kf_every = 1
69
 
@@ -139,10 +140,15 @@ def main(args):
139
  conf_sig = (conf - 1) / conf
140
  pts_gt = view['pts3d'].cpu().numpy()[0]
141
 
 
 
 
 
142
  images_all.append((image[None, ...] + 1.0)/2.0)
143
  pts_all.append(pts[None, ...])
144
  pts_normal_all.append(pts_normal[None, ...])
145
  pts_gt_all.append(pts_gt[None, ...])
 
146
  masks_all.append(mask[None, ...])
147
  conf_sig_all.append(conf_sig[None, ...])
148
 
@@ -163,7 +169,9 @@ def main(args):
163
  pcd.normals = o3d.utility.Vector3dVector(pts_normal_all[j][mask])
164
  pcds.append(pcd)
165
 
166
- pcd_combined = combine_and_clean_point_clouds(pcds, voxel_size=args.voxel_size * 0.1)
 
 
167
  mesh_recon = point2mesh(pcd_combined)
168
 
169
 
 
15
  from torch.utils.data import DataLoader
16
  from spann3r.tools.eval_recon import accuracy, completion
17
  from spann3r.tools.vis import render_frames, find_render_cam, vis_pred_and_imgs
18
+ from pose_utils import solve_cemara
19
  from backend_utils import improved_multiway_registration, pts2normal, point2mesh, combine_and_clean_point_clouds
20
 
21
  def get_args_parser():
 
64
  model.load_state_dict(torch.load(args.ckpt_path)['model'])
65
  model.eval()
66
 
67
+ if args.demo_path.endswith('.mp4') or args.demo_path.endswith('.avi') or args.demo_path.endswith('.webm'):
68
  args.demo_path = extract_frames(args.demo_path)
69
  args.kf_every = 1
70
 
 
140
  conf_sig = (conf - 1) / conf
141
  pts_gt = view['pts3d'].cpu().numpy()[0]
142
 
143
+ camera, last_focal, depth_map = solve_cemara(torch.tensor(pts), torch.tensor(conf_sig) > args.conf_thresh,
144
+ args.device, focal=last_focal)
145
+ pts_scale = depth_map / last_focal
146
+
147
  images_all.append((image[None, ...] + 1.0)/2.0)
148
  pts_all.append(pts[None, ...])
149
  pts_normal_all.append(pts_normal[None, ...])
150
  pts_gt_all.append(pts_gt[None, ...])
151
+ pts_scale_all.append(pts_scale[None, ...])
152
  masks_all.append(mask[None, ...])
153
  conf_sig_all.append(conf_sig[None, ...])
154
 
 
169
  pcd.normals = o3d.utility.Vector3dVector(pts_normal_all[j][mask])
170
  pcds.append(pcd)
171
 
172
+ print("Performing global registration...")
173
+ pcd_combined, _, _ = improved_multiway_registration(pcds, voxel_size=0.001)
174
+ # pcd_combined = combine_and_clean_point_clouds(transformed_pcds, voxel_size=args.voxel_size * 0.1)
175
  mesh_recon = point2mesh(pcd_combined)
176
 
177
 
pose_utils.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import cv2
4
+ import open3d as o3d
5
+ from dust3r.post_process import estimate_focal_knowing_depth
6
+ from dust3r.utils.geometry import inv
7
+
8
+ def estimate_focal(pts3d_i, pp=None):
9
+ if pp is None:
10
+ H, W, THREE = pts3d_i.shape
11
+ assert THREE == 3
12
+ pp = torch.tensor((W/2, H/2), device=pts3d_i.device)
13
+ focal = estimate_focal_knowing_depth(pts3d_i.unsqueeze(0), pp.unsqueeze(0), focal_mode='weiszfeld').ravel()
14
+ return float(focal)
15
+
16
+ def pixel_grid(H, W):
17
+ return np.mgrid[:W, :H].T.astype(np.float32)
18
+
19
+ def sRT_to_4x4(scale, R, T, device):
20
+ trf = torch.eye(4, device=device)
21
+ trf[:3, :3] = R * scale
22
+ trf[:3, 3] = T.ravel() # doesn't need scaling
23
+ return trf
24
+
25
+ def to_numpy(tensor):
26
+ return tensor.cpu().numpy() if isinstance(tensor, torch.Tensor) else tensor
27
+
28
+ def calculate_depth_map(pts3d, R, T):
29
+ """
30
+ Calculate ray depths directly using camera center and 3D points.
31
+
32
+ Args:
33
+ pts3d (np.array): 3D points in world coordinates, shape (H, W, 3)
34
+ R (np.array): Rotation matrix, shape (3, 3)
35
+ T (np.array): Translation vector, shape (3, 1)
36
+
37
+ Returns:
38
+ np.array: Depth map of shape (H, W)
39
+ """
40
+ # Camera center in world coordinates is simply -T
41
+ C = -T.ravel()
42
+
43
+ # Calculate ray vectors
44
+ ray_vectors = pts3d - C
45
+
46
+ # Calculate ray depths
47
+ depth_map = np.linalg.norm(ray_vectors, axis=2)
48
+
49
+ return depth_map
50
+
51
+ def fast_pnp(pts3d, focal, msk, device, pp=None, niter_PnP=10):
52
+ # extract camera poses and focals with RANSAC-PnP
53
+ if msk.sum() < 4:
54
+ return None # we need at least 4 points for PnP
55
+ pts3d, msk = map(to_numpy, (pts3d, msk))
56
+
57
+ H, W, THREE = pts3d.shape
58
+ assert THREE == 3
59
+ pixels = pixel_grid(H, W)
60
+
61
+ if focal is None:
62
+ S = max(W, H)
63
+ tentative_focals = np.geomspace(S/2, S*3, 21)
64
+ else:
65
+ tentative_focals = [focal]
66
+
67
+ if pp is None:
68
+ pp = (W/2, H/2)
69
+ else:
70
+ pp = to_numpy(pp)
71
+
72
+ best = 0, None, None, None, None
73
+ for focal in tentative_focals:
74
+ K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)])
75
+
76
+ success, R, T, inliers = cv2.solvePnPRansac(pts3d[msk], pixels[msk], K, None,
77
+ iterationsCount=niter_PnP, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP)
78
+
79
+ if not success:
80
+ continue
81
+
82
+ score = len(inliers)
83
+ if success and score > best[0]:
84
+ depth_map = calculate_depth_map(pts3d, R, T)
85
+ best = score, R, T, focal, depth_map
86
+
87
+ if not best[0]:
88
+ return None
89
+
90
+ _, R, T, best_focal, depth_map = best
91
+ R = cv2.Rodrigues(R)[0] # world to cam
92
+ R, T = map(torch.from_numpy, (R, T))
93
+ depth_map = torch.from_numpy(depth_map).to(device)
94
+
95
+ cam_to_world = inv(sRT_to_4x4(1, R, T, device)) # cam to world
96
+
97
+ return best_focal, cam_to_world, depth_map
98
+
99
+ def solve_cemara(pts3d, msk, device, focal=None, pp=None):
100
+ # Estimate focal length
101
+ if focal is None:
102
+ focal = estimate_focal(pts3d, pp)
103
+
104
+ # Compute camera pose using PnP
105
+ result = fast_pnp(pts3d, focal, msk, device, pp)
106
+
107
+ if result is None:
108
+ return None, focal, None
109
+
110
+ best_focal, camera_to_world, depth_map = result
111
+
112
+ # Construct K matrix
113
+ H, W, _ = pts3d.shape
114
+ if pp is None:
115
+ pp = (W/2, H/2)
116
+
117
+ camera_parameters = o3d.camera.PinholeCameraParameters()
118
+ intrinsic = o3d.camera.PinholeCameraIntrinsic()
119
+ intrinsic.set_intrinsics(W, H,
120
+ best_focal, best_focal,
121
+ pp[0], pp[1])
122
+
123
+ camera_parameters.intrinsic = intrinsic
124
+ camera_parameters.extrinsic = torch.inverse(camera_to_world).cpu().numpy()
125
+
126
+ return camera_parameters, best_focal, depth_map