Stable-X commited on
Commit
a5130f4
1 Parent(s): a9bfc35

fix: Update demo

Browse files
Files changed (1) hide show
  1. demo.py +49 -33
demo.py CHANGED
@@ -15,6 +15,7 @@ from spann3r.datasets import *
15
  from torch.utils.data import DataLoader
16
  from spann3r.tools.eval_recon import accuracy, completion
17
  from spann3r.tools.vis import render_frames, find_render_cam, vis_pred_and_imgs
 
18
 
19
  def get_args_parser():
20
  parser = argparse.ArgumentParser('Spann3R demo', add_help=False)
@@ -27,9 +28,28 @@ def get_args_parser():
27
  parser.add_argument('--conf_thresh', type=float, default=1e-3, help='confidence threshold')
28
  parser.add_argument('--kf_every', type=int, default=10, help='map every kf_every frames')
29
  parser.add_argument('--vis', action='store_true', help='visualize')
30
-
31
  return parser
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  @torch.no_grad()
34
  def main(args):
35
 
@@ -42,6 +62,10 @@ def main(args):
42
 
43
  model.load_state_dict(torch.load(args.ckpt_path)['model'])
44
  model.eval()
 
 
 
 
45
 
46
  ##### Load dataset
47
  dataset = Demo(ROOT=args.demo_path, resolution=224, full_video=True, kf_every=args.kf_every)
@@ -96,59 +120,51 @@ def main(args):
96
  os.makedirs(save_demo_path, exist_ok=True)
97
 
98
  pts_all = []
 
99
  pts_gt_all = []
100
  images_all = []
101
  masks_all = []
102
- conf_all = []
 
103
 
 
104
  for j, view in enumerate(ordered_batch):
105
 
106
  image = view['img'].permute(0, 2, 3, 1).cpu().numpy()[0]
107
  mask = view['valid_mask'].cpu().numpy()[0]
108
 
109
  pts = preds[j]['pts3d' if j==0 else 'pts3d_in_other_view'].detach().cpu().numpy()[0]
 
110
  conf = preds[j]['conf'][0].cpu().data.numpy()
111
-
112
  pts_gt = view['pts3d'].cpu().numpy()[0]
113
 
114
  images_all.append((image[None, ...] + 1.0)/2.0)
115
  pts_all.append(pts[None, ...])
 
116
  pts_gt_all.append(pts_gt[None, ...])
117
  masks_all.append(mask[None, ...])
118
- conf_all.append(conf[None, ...])
119
-
120
  images_all = np.concatenate(images_all, axis=0)
121
  pts_all = np.concatenate(pts_all, axis=0)
 
122
  pts_gt_all = np.concatenate(pts_gt_all, axis=0)
123
  masks_all = np.concatenate(masks_all, axis=0)
124
- conf_all = np.concatenate(conf_all, axis=0)
125
-
126
- save_params = dict(
127
- images_all=images_all,
128
- pts_all=pts_all,
129
- pts_gt_all=pts_gt_all,
130
- masks_all=masks_all,
131
- conf_all=conf_all
132
- )
133
-
134
- np.save(os.path.join(save_demo_path, f"{demo_name}.npy"), save_params)
135
-
136
-
137
- # Save point cloud
138
- conf_sig_all = (conf_all-1) / conf_all
139
-
140
- pcd = o3d.geometry.PointCloud()
141
- pcd.points = o3d.utility.Vector3dVector(pts_all[conf_sig_all>args.conf_thresh].reshape(-1, 3))
142
- pcd.colors = o3d.utility.Vector3dVector(images_all[conf_sig_all>args.conf_thresh].reshape(-1, 3))
143
- o3d.io.write_point_cloud(os.path.join(save_demo_path, f"{demo_name}_conf{args.conf_thresh}.ply"), pcd)
144
-
145
-
146
- if args.vis:
147
- camera_parameters = find_render_cam(pcd)
148
-
149
- render_frames(pts_all, images_all, camera_parameters, save_demo_path, mask=conf_sig_all>args.conf_thresh)
150
- vis_pred_and_imgs(pts_all, save_demo_path, images_all=images_all, conf_all=conf_sig_all)
151
-
152
 
153
 
154
  if __name__ == '__main__':
 
15
  from torch.utils.data import DataLoader
16
  from spann3r.tools.eval_recon import accuracy, completion
17
  from spann3r.tools.vis import render_frames, find_render_cam, vis_pred_and_imgs
18
+ from backend_utils import improved_multiway_registration, pts2normal, point2mesh, combine_and_clean_point_clouds
19
 
20
  def get_args_parser():
21
  parser = argparse.ArgumentParser('Spann3R demo', add_help=False)
 
28
  parser.add_argument('--conf_thresh', type=float, default=1e-3, help='confidence threshold')
29
  parser.add_argument('--kf_every', type=int, default=10, help='map every kf_every frames')
30
  parser.add_argument('--vis', action='store_true', help='visualize')
31
+ parser.add_argument('--voxel_size', type=float, default=0.004, help='voxel size for multiway registration')
32
  return parser
33
 
34
+ import tempfile
35
+ import subprocess
36
+ def extract_frames(video_path: str, duration: float = 20.0, fps: float = 3.0) -> str:
37
+ temp_dir = tempfile.mkdtemp()
38
+ output_path = os.path.join(temp_dir, "%03d.jpg")
39
+
40
+ filter_complex = f"select='if(lt(t,{duration}),1,0)',fps={fps}"
41
+
42
+ command = [
43
+ "ffmpeg",
44
+ "-i", video_path,
45
+ "-vf", filter_complex,
46
+ "-vsync", "0",
47
+ output_path
48
+ ]
49
+
50
+ subprocess.run(command, check=True)
51
+ return temp_dir
52
+
53
  @torch.no_grad()
54
  def main(args):
55
 
 
62
 
63
  model.load_state_dict(torch.load(args.ckpt_path)['model'])
64
  model.eval()
65
+
66
+ if args.demo_path.endswith('.mp4') or args.demo_path.endswith('.avi') or args.demo_path.endswith('.MOV'):
67
+ args.demo_path = extract_frames(args.demo_path)
68
+ args.kf_every = 1
69
 
70
  ##### Load dataset
71
  dataset = Demo(ROOT=args.demo_path, resolution=224, full_video=True, kf_every=args.kf_every)
 
120
  os.makedirs(save_demo_path, exist_ok=True)
121
 
122
  pts_all = []
123
+ pts_normal_all = []
124
  pts_gt_all = []
125
  images_all = []
126
  masks_all = []
127
+ conf_sig_all = []
128
+ cameras_all = []
129
 
130
+ last_focal = None
131
  for j, view in enumerate(ordered_batch):
132
 
133
  image = view['img'].permute(0, 2, 3, 1).cpu().numpy()[0]
134
  mask = view['valid_mask'].cpu().numpy()[0]
135
 
136
  pts = preds[j]['pts3d' if j==0 else 'pts3d_in_other_view'].detach().cpu().numpy()[0]
137
+ pts_normal = pts2normal(preds[j]['pts3d' if j==0 else 'pts3d_in_other_view'][0]).cpu().numpy()
138
  conf = preds[j]['conf'][0].cpu().data.numpy()
139
+ conf_sig = (conf - 1) / conf
140
  pts_gt = view['pts3d'].cpu().numpy()[0]
141
 
142
  images_all.append((image[None, ...] + 1.0)/2.0)
143
  pts_all.append(pts[None, ...])
144
+ pts_normal_all.append(pts_normal[None, ...])
145
  pts_gt_all.append(pts_gt[None, ...])
146
  masks_all.append(mask[None, ...])
147
+ conf_sig_all.append(conf_sig[None, ...])
148
+
149
  images_all = np.concatenate(images_all, axis=0)
150
  pts_all = np.concatenate(pts_all, axis=0)
151
+ pts_normal_all = np.concatenate(pts_normal_all, axis=0)
152
  pts_gt_all = np.concatenate(pts_gt_all, axis=0)
153
  masks_all = np.concatenate(masks_all, axis=0)
154
+ conf_sig_all = np.concatenate(conf_sig_all, axis=0)
155
+
156
+ # Create point clouds for multiway registration
157
+ pcds = []
158
+ for j in range(len(pts_all)):
159
+ pcd = o3d.geometry.PointCloud()
160
+ mask = conf_sig_all[j] > args.conf_thresh
161
+ pcd.points = o3d.utility.Vector3dVector(pts_all[j][mask])
162
+ pcd.colors = o3d.utility.Vector3dVector(images_all[j][mask])
163
+ pcd.normals = o3d.utility.Vector3dVector(pts_normal_all[j][mask])
164
+ pcds.append(pcd)
165
+
166
+ pcd_combined = combine_and_clean_point_clouds(pcds, voxel_size=args.voxel_size * 0.1)
167
+ mesh_recon = point2mesh(pcd_combined)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
 
170
  if __name__ == '__main__':