Spaces:
Sleeping
Sleeping
fix: Update demo
Browse files
demo.py
CHANGED
@@ -15,6 +15,7 @@ from spann3r.datasets import *
|
|
15 |
from torch.utils.data import DataLoader
|
16 |
from spann3r.tools.eval_recon import accuracy, completion
|
17 |
from spann3r.tools.vis import render_frames, find_render_cam, vis_pred_and_imgs
|
|
|
18 |
|
19 |
def get_args_parser():
|
20 |
parser = argparse.ArgumentParser('Spann3R demo', add_help=False)
|
@@ -27,9 +28,28 @@ def get_args_parser():
|
|
27 |
parser.add_argument('--conf_thresh', type=float, default=1e-3, help='confidence threshold')
|
28 |
parser.add_argument('--kf_every', type=int, default=10, help='map every kf_every frames')
|
29 |
parser.add_argument('--vis', action='store_true', help='visualize')
|
30 |
-
|
31 |
return parser
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
@torch.no_grad()
|
34 |
def main(args):
|
35 |
|
@@ -42,6 +62,10 @@ def main(args):
|
|
42 |
|
43 |
model.load_state_dict(torch.load(args.ckpt_path)['model'])
|
44 |
model.eval()
|
|
|
|
|
|
|
|
|
45 |
|
46 |
##### Load dataset
|
47 |
dataset = Demo(ROOT=args.demo_path, resolution=224, full_video=True, kf_every=args.kf_every)
|
@@ -96,59 +120,51 @@ def main(args):
|
|
96 |
os.makedirs(save_demo_path, exist_ok=True)
|
97 |
|
98 |
pts_all = []
|
|
|
99 |
pts_gt_all = []
|
100 |
images_all = []
|
101 |
masks_all = []
|
102 |
-
|
|
|
103 |
|
|
|
104 |
for j, view in enumerate(ordered_batch):
|
105 |
|
106 |
image = view['img'].permute(0, 2, 3, 1).cpu().numpy()[0]
|
107 |
mask = view['valid_mask'].cpu().numpy()[0]
|
108 |
|
109 |
pts = preds[j]['pts3d' if j==0 else 'pts3d_in_other_view'].detach().cpu().numpy()[0]
|
|
|
110 |
conf = preds[j]['conf'][0].cpu().data.numpy()
|
111 |
-
|
112 |
pts_gt = view['pts3d'].cpu().numpy()[0]
|
113 |
|
114 |
images_all.append((image[None, ...] + 1.0)/2.0)
|
115 |
pts_all.append(pts[None, ...])
|
|
|
116 |
pts_gt_all.append(pts_gt[None, ...])
|
117 |
masks_all.append(mask[None, ...])
|
118 |
-
|
119 |
-
|
120 |
images_all = np.concatenate(images_all, axis=0)
|
121 |
pts_all = np.concatenate(pts_all, axis=0)
|
|
|
122 |
pts_gt_all = np.concatenate(pts_gt_all, axis=0)
|
123 |
masks_all = np.concatenate(masks_all, axis=0)
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
)
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
conf_sig_all = (conf_all-1) / conf_all
|
139 |
-
|
140 |
-
pcd = o3d.geometry.PointCloud()
|
141 |
-
pcd.points = o3d.utility.Vector3dVector(pts_all[conf_sig_all>args.conf_thresh].reshape(-1, 3))
|
142 |
-
pcd.colors = o3d.utility.Vector3dVector(images_all[conf_sig_all>args.conf_thresh].reshape(-1, 3))
|
143 |
-
o3d.io.write_point_cloud(os.path.join(save_demo_path, f"{demo_name}_conf{args.conf_thresh}.ply"), pcd)
|
144 |
-
|
145 |
-
|
146 |
-
if args.vis:
|
147 |
-
camera_parameters = find_render_cam(pcd)
|
148 |
-
|
149 |
-
render_frames(pts_all, images_all, camera_parameters, save_demo_path, mask=conf_sig_all>args.conf_thresh)
|
150 |
-
vis_pred_and_imgs(pts_all, save_demo_path, images_all=images_all, conf_all=conf_sig_all)
|
151 |
-
|
152 |
|
153 |
|
154 |
if __name__ == '__main__':
|
|
|
15 |
from torch.utils.data import DataLoader
|
16 |
from spann3r.tools.eval_recon import accuracy, completion
|
17 |
from spann3r.tools.vis import render_frames, find_render_cam, vis_pred_and_imgs
|
18 |
+
from backend_utils import improved_multiway_registration, pts2normal, point2mesh, combine_and_clean_point_clouds
|
19 |
|
20 |
def get_args_parser():
|
21 |
parser = argparse.ArgumentParser('Spann3R demo', add_help=False)
|
|
|
28 |
parser.add_argument('--conf_thresh', type=float, default=1e-3, help='confidence threshold')
|
29 |
parser.add_argument('--kf_every', type=int, default=10, help='map every kf_every frames')
|
30 |
parser.add_argument('--vis', action='store_true', help='visualize')
|
31 |
+
parser.add_argument('--voxel_size', type=float, default=0.004, help='voxel size for multiway registration')
|
32 |
return parser
|
33 |
|
34 |
+
import tempfile
|
35 |
+
import subprocess
|
36 |
+
def extract_frames(video_path: str, duration: float = 20.0, fps: float = 3.0) -> str:
|
37 |
+
temp_dir = tempfile.mkdtemp()
|
38 |
+
output_path = os.path.join(temp_dir, "%03d.jpg")
|
39 |
+
|
40 |
+
filter_complex = f"select='if(lt(t,{duration}),1,0)',fps={fps}"
|
41 |
+
|
42 |
+
command = [
|
43 |
+
"ffmpeg",
|
44 |
+
"-i", video_path,
|
45 |
+
"-vf", filter_complex,
|
46 |
+
"-vsync", "0",
|
47 |
+
output_path
|
48 |
+
]
|
49 |
+
|
50 |
+
subprocess.run(command, check=True)
|
51 |
+
return temp_dir
|
52 |
+
|
53 |
@torch.no_grad()
|
54 |
def main(args):
|
55 |
|
|
|
62 |
|
63 |
model.load_state_dict(torch.load(args.ckpt_path)['model'])
|
64 |
model.eval()
|
65 |
+
|
66 |
+
if args.demo_path.endswith('.mp4') or args.demo_path.endswith('.avi') or args.demo_path.endswith('.MOV'):
|
67 |
+
args.demo_path = extract_frames(args.demo_path)
|
68 |
+
args.kf_every = 1
|
69 |
|
70 |
##### Load dataset
|
71 |
dataset = Demo(ROOT=args.demo_path, resolution=224, full_video=True, kf_every=args.kf_every)
|
|
|
120 |
os.makedirs(save_demo_path, exist_ok=True)
|
121 |
|
122 |
pts_all = []
|
123 |
+
pts_normal_all = []
|
124 |
pts_gt_all = []
|
125 |
images_all = []
|
126 |
masks_all = []
|
127 |
+
conf_sig_all = []
|
128 |
+
cameras_all = []
|
129 |
|
130 |
+
last_focal = None
|
131 |
for j, view in enumerate(ordered_batch):
|
132 |
|
133 |
image = view['img'].permute(0, 2, 3, 1).cpu().numpy()[0]
|
134 |
mask = view['valid_mask'].cpu().numpy()[0]
|
135 |
|
136 |
pts = preds[j]['pts3d' if j==0 else 'pts3d_in_other_view'].detach().cpu().numpy()[0]
|
137 |
+
pts_normal = pts2normal(preds[j]['pts3d' if j==0 else 'pts3d_in_other_view'][0]).cpu().numpy()
|
138 |
conf = preds[j]['conf'][0].cpu().data.numpy()
|
139 |
+
conf_sig = (conf - 1) / conf
|
140 |
pts_gt = view['pts3d'].cpu().numpy()[0]
|
141 |
|
142 |
images_all.append((image[None, ...] + 1.0)/2.0)
|
143 |
pts_all.append(pts[None, ...])
|
144 |
+
pts_normal_all.append(pts_normal[None, ...])
|
145 |
pts_gt_all.append(pts_gt[None, ...])
|
146 |
masks_all.append(mask[None, ...])
|
147 |
+
conf_sig_all.append(conf_sig[None, ...])
|
148 |
+
|
149 |
images_all = np.concatenate(images_all, axis=0)
|
150 |
pts_all = np.concatenate(pts_all, axis=0)
|
151 |
+
pts_normal_all = np.concatenate(pts_normal_all, axis=0)
|
152 |
pts_gt_all = np.concatenate(pts_gt_all, axis=0)
|
153 |
masks_all = np.concatenate(masks_all, axis=0)
|
154 |
+
conf_sig_all = np.concatenate(conf_sig_all, axis=0)
|
155 |
+
|
156 |
+
# Create point clouds for multiway registration
|
157 |
+
pcds = []
|
158 |
+
for j in range(len(pts_all)):
|
159 |
+
pcd = o3d.geometry.PointCloud()
|
160 |
+
mask = conf_sig_all[j] > args.conf_thresh
|
161 |
+
pcd.points = o3d.utility.Vector3dVector(pts_all[j][mask])
|
162 |
+
pcd.colors = o3d.utility.Vector3dVector(images_all[j][mask])
|
163 |
+
pcd.normals = o3d.utility.Vector3dVector(pts_normal_all[j][mask])
|
164 |
+
pcds.append(pcd)
|
165 |
+
|
166 |
+
pcd_combined = combine_and_clean_point_clouds(pcds, voxel_size=args.voxel_size * 0.1)
|
167 |
+
mesh_recon = point2mesh(pcd_combined)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
|
169 |
|
170 |
if __name__ == '__main__':
|