yocabon commited on
Commit
f64a2f2
1 Parent(s): 706762b

move part of the demo to the lib

Browse files
Files changed (3) hide show
  1. demo.py +2 -267
  2. dust3r +1 -1
  3. mast3r/demo.py +277 -0
demo.py CHANGED
@@ -3,286 +3,21 @@
3
  # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
4
  #
5
  # --------------------------------------------------------
6
- # gradio demo
7
  # --------------------------------------------------------
8
- import math
9
- import gradio
10
  import os
11
  import torch
12
- import numpy as np
13
  import tempfile
14
- import functools
15
- import trimesh
16
- import copy
17
- from scipy.spatial.transform import Rotation
18
 
19
- from mast3r.cloud_opt.sparse_ga import sparse_global_alignment
20
- from mast3r.cloud_opt.tsdf_optimizer import TSDFPostProcess
21
 
22
  from mast3r.model import AsymmetricMASt3R
23
  from mast3r.utils.misc import hash_md5
24
- import mast3r.utils.path_to_dust3r # noqa
25
- from dust3r.image_pairs import make_pairs
26
- from dust3r.utils.image import load_images
27
- from dust3r.utils.device import to_numpy
28
- from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
29
- from dust3r.demo import get_args_parser as dust3r_get_args_parser
30
 
31
  import matplotlib.pyplot as pl
32
  pl.ion()
33
 
34
  torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
35
- batch_size = 1
36
-
37
-
38
- def get_args_parser():
39
- parser = dust3r_get_args_parser()
40
- parser.add_argument('--share', action='store_true')
41
-
42
- actions = parser._actions
43
- for action in actions:
44
- if action.dest == 'model_name':
45
- action.choices = ["MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"]
46
- # change defaults
47
- parser.prog = 'mast3r demo'
48
- return parser
49
-
50
-
51
- def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
52
- cam_color=None, as_pointcloud=False,
53
- transparent_cams=False, silent=False):
54
- assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)
55
- pts3d = to_numpy(pts3d)
56
- imgs = to_numpy(imgs)
57
- focals = to_numpy(focals)
58
- cams2world = to_numpy(cams2world)
59
-
60
- scene = trimesh.Scene()
61
-
62
- # full pointcloud
63
- if as_pointcloud:
64
- pts = np.concatenate([p[m.ravel()] for p, m in zip(pts3d, mask)])
65
- col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
66
- pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3))
67
- scene.add_geometry(pct)
68
- else:
69
- meshes = []
70
- for i in range(len(imgs)):
71
- meshes.append(pts3d_to_trimesh(imgs[i], pts3d[i].reshape(imgs[i].shape), mask[i]))
72
- mesh = trimesh.Trimesh(**cat_meshes(meshes))
73
- scene.add_geometry(mesh)
74
-
75
- # add each camera
76
- for i, pose_c2w in enumerate(cams2world):
77
- if isinstance(cam_color, list):
78
- camera_edge_color = cam_color[i]
79
- else:
80
- camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
81
- add_scene_cam(scene, pose_c2w, camera_edge_color,
82
- None if transparent_cams else imgs[i], focals[i],
83
- imsize=imgs[i].shape[1::-1], screen_width=cam_size)
84
-
85
- rot = np.eye(4)
86
- rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
87
- scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
88
- outfile = os.path.join(outdir, 'scene.glb')
89
- if not silent:
90
- print('(exporting 3D scene to', outfile, ')')
91
- scene.export(file_obj=outfile)
92
- return outfile
93
-
94
-
95
- def get_3D_model_from_scene(outdir, silent, scene, min_conf_thr=2, as_pointcloud=False, mask_sky=False,
96
- clean_depth=False, transparent_cams=False, cam_size=0.05, TSDF_thresh=0):
97
- """
98
- extract 3D_model (glb file) from a reconstructed scene
99
- """
100
- if scene is None:
101
- return None
102
-
103
- # get optimized values from scene
104
- rgbimg = scene.imgs
105
- focals = scene.get_focals().cpu()
106
- cams2world = scene.get_im_poses().cpu()
107
-
108
- # 3D pointcloud from depthmap, poses and intrinsics
109
- if TSDF_thresh > 0:
110
- tsdf = TSDFPostProcess(scene, TSDF_thresh=TSDF_thresh)
111
- pts3d, _, confs = to_numpy(tsdf.get_dense_pts3d(clean_depth=clean_depth))
112
- else:
113
- pts3d, _, confs = to_numpy(scene.get_dense_pts3d(clean_depth=clean_depth))
114
- msk = to_numpy([c > min_conf_thr for c in confs])
115
- return _convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
116
- transparent_cams=transparent_cams, cam_size=cam_size, silent=silent)
117
-
118
-
119
- def get_reconstructed_scene(outdir, model, device, silent, image_size, filelist, optim_level, lr1, niter1, lr2, niter2,
120
- min_conf_thr, matching_conf_thr, as_pointcloud, mask_sky, clean_depth, transparent_cams,
121
- cam_size, scenegraph_type, winsize, win_cyclic, refid, TSDF_thresh, shared_intrinsics,
122
- **kw):
123
- """
124
- from a list of images, run mast3r inference, sparse global aligner.
125
- then run get_3D_model_from_scene
126
- """
127
- imgs = load_images(filelist, size=image_size, verbose=not silent)
128
- if len(imgs) == 1:
129
- imgs = [imgs[0], copy.deepcopy(imgs[0])]
130
- imgs[1]['idx'] = 1
131
- filelist = [filelist[0], filelist[0] + '_2']
132
-
133
- scene_graph_params = [scenegraph_type]
134
- if scenegraph_type in ["swin", "logwin"]:
135
- scene_graph_params.append(str(winsize))
136
- elif scenegraph_type == "oneref":
137
- scene_graph_params.append(str(refid))
138
- if scenegraph_type in ["swin", "logwin"] and not win_cyclic:
139
- scene_graph_params.append('noncyclic')
140
- scene_graph = '-'.join(scene_graph_params)
141
- pairs = make_pairs(imgs, scene_graph=scene_graph, prefilter=None, symmetrize=True)
142
- if optim_level == 'coarse':
143
- niter2 = 0
144
- # Sparse GA (forward mast3r -> matching -> 3D optim -> 2D refinement -> triangulation)
145
- scene = sparse_global_alignment(filelist, pairs, os.path.join(outdir, 'cache'),
146
- model, lr1=lr1, niter1=niter1, lr2=lr2, niter2=niter2, device=device,
147
- opt_depth='depth' in optim_level, shared_intrinsics=shared_intrinsics,
148
- matching_conf_thr=matching_conf_thr, **kw)
149
- outfile = get_3D_model_from_scene(outdir, silent, scene, min_conf_thr, as_pointcloud, mask_sky,
150
- clean_depth, transparent_cams, cam_size, TSDF_thresh)
151
- return scene, outfile
152
-
153
-
154
- def set_scenegraph_options(inputfiles, win_cyclic, refid, scenegraph_type):
155
- num_files = len(inputfiles) if inputfiles is not None else 1
156
- show_win_controls = scenegraph_type in ["swin", "logwin"]
157
- show_winsize = scenegraph_type in ["swin", "logwin"]
158
- show_cyclic = scenegraph_type in ["swin", "logwin"]
159
- max_winsize, min_winsize = 1, 1
160
- if scenegraph_type == "swin":
161
- if win_cyclic:
162
- max_winsize = max(1, math.ceil((num_files - 1) / 2))
163
- else:
164
- max_winsize = num_files - 1
165
- elif scenegraph_type == "logwin":
166
- if win_cyclic:
167
- half_size = math.ceil((num_files - 1) / 2)
168
- max_winsize = max(1, math.ceil(math.log(half_size, 2)))
169
- else:
170
- max_winsize = max(1, math.ceil(math.log(num_files, 2)))
171
- winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
172
- minimum=min_winsize, maximum=max_winsize, step=1, visible=show_winsize)
173
- win_cyclic = gradio.Checkbox(value=win_cyclic, label="Cyclic sequence", visible=show_cyclic)
174
- win_col = gradio.Column(visible=show_win_controls)
175
- refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
176
- maximum=num_files - 1, step=1, visible=scenegraph_type == 'oneref')
177
- return win_col, winsize, win_cyclic, refid
178
-
179
-
180
- def main_demo(tmpdirname, model, device, image_size, server_name, server_port, silent=False, share=False):
181
- if not silent:
182
- print('Outputing stuff in', tmpdirname)
183
-
184
- recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, model, device, silent, image_size)
185
- model_from_scene_fun = functools.partial(get_3D_model_from_scene, tmpdirname, silent)
186
- with gradio.Blocks(css=""".gradio-container {margin: 0 !important; min-width: 100%};""", title="MASt3R Demo") as demo:
187
- # scene state is save so that you can change conf_thr, cam_size... without rerunning the inference
188
- scene = gradio.State(None)
189
- gradio.HTML('<h2 style="text-align: center;">MASt3R Demo</h2>')
190
- with gradio.Column():
191
- inputfiles = gradio.File(file_count="multiple")
192
- with gradio.Row():
193
- with gradio.Column():
194
- with gradio.Row():
195
- lr1 = gradio.Slider(label="Coarse LR", value=0.07, minimum=0.01, maximum=0.2, step=0.01)
196
- niter1 = gradio.Number(value=500, precision=0, minimum=0, maximum=10_000,
197
- label="num_iterations", info="For coarse alignment!")
198
- lr2 = gradio.Slider(label="Fine LR", value=0.014, minimum=0.005, maximum=0.05, step=0.001)
199
- niter2 = gradio.Number(value=200, precision=0, minimum=0, maximum=100_000,
200
- label="num_iterations", info="For refinement!")
201
- optim_level = gradio.Dropdown(["coarse", "refine", "refine+depth"],
202
- value='refine', label="OptLevel",
203
- info="Optimization level")
204
- with gradio.Row():
205
- matching_conf_thr = gradio.Slider(label="Matching Confidence Thr", value=5.,
206
- minimum=0., maximum=30., step=0.1,
207
- info="Before Fallback to Regr3D!")
208
- shared_intrinsics = gradio.Checkbox(value=False, label="Shared intrinsics",
209
- info="Only optimize one set of intrinsics for all views")
210
- scenegraph_type = gradio.Dropdown([("complete: all possible image pairs", "complete"),
211
- ("swin: sliding window", "swin"),
212
- ("logwin: sliding window with long range", "logwin"),
213
- ("oneref: match one image with all", "oneref")],
214
- value='complete', label="Scenegraph",
215
- info="Define how to make pairs",
216
- interactive=True)
217
- with gradio.Column(visible=False) as win_col:
218
- winsize = gradio.Slider(label="Scene Graph: Window Size", value=1,
219
- minimum=1, maximum=1, step=1)
220
- win_cyclic = gradio.Checkbox(value=False, label="Cyclic sequence")
221
- refid = gradio.Slider(label="Scene Graph: Id", value=0,
222
- minimum=0, maximum=0, step=1, visible=False)
223
-
224
- run_btn = gradio.Button("Run")
225
-
226
- with gradio.Row():
227
- # adjust the confidence threshold
228
- min_conf_thr = gradio.Slider(label="min_conf_thr", value=1.5, minimum=0.0, maximum=10, step=0.1)
229
- # adjust the camera size in the output pointcloud
230
- cam_size = gradio.Slider(label="cam_size", value=0.2, minimum=0.001, maximum=1.0, step=0.001)
231
- TSDF_thresh = gradio.Slider(label="TSDF Threshold", value=0., minimum=0., maximum=1., step=0.01)
232
- with gradio.Row():
233
- as_pointcloud = gradio.Checkbox(value=True, label="As pointcloud")
234
- # two post process implemented
235
- mask_sky = gradio.Checkbox(value=False, label="Mask sky")
236
- clean_depth = gradio.Checkbox(value=True, label="Clean-up depthmaps")
237
- transparent_cams = gradio.Checkbox(value=False, label="Transparent cameras")
238
-
239
- outmodel = gradio.Model3D()
240
-
241
- # events
242
- scenegraph_type.change(set_scenegraph_options,
243
- inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
244
- outputs=[win_col, winsize, win_cyclic, refid])
245
- inputfiles.change(set_scenegraph_options,
246
- inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
247
- outputs=[win_col, winsize, win_cyclic, refid])
248
- win_cyclic.change(set_scenegraph_options,
249
- inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
250
- outputs=[win_col, winsize, win_cyclic, refid])
251
- run_btn.click(fn=recon_fun,
252
- inputs=[inputfiles, optim_level, lr1, niter1, lr2, niter2, min_conf_thr, matching_conf_thr,
253
- as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
254
- scenegraph_type, winsize, win_cyclic, refid, TSDF_thresh, shared_intrinsics],
255
- outputs=[scene, outmodel])
256
- min_conf_thr.release(fn=model_from_scene_fun,
257
- inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
258
- clean_depth, transparent_cams, cam_size, TSDF_thresh],
259
- outputs=outmodel)
260
- cam_size.change(fn=model_from_scene_fun,
261
- inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
262
- clean_depth, transparent_cams, cam_size, TSDF_thresh],
263
- outputs=outmodel)
264
- TSDF_thresh.change(fn=model_from_scene_fun,
265
- inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
266
- clean_depth, transparent_cams, cam_size, TSDF_thresh],
267
- outputs=outmodel)
268
- as_pointcloud.change(fn=model_from_scene_fun,
269
- inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
270
- clean_depth, transparent_cams, cam_size, TSDF_thresh],
271
- outputs=outmodel)
272
- mask_sky.change(fn=model_from_scene_fun,
273
- inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
274
- clean_depth, transparent_cams, cam_size, TSDF_thresh],
275
- outputs=outmodel)
276
- clean_depth.change(fn=model_from_scene_fun,
277
- inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
278
- clean_depth, transparent_cams, cam_size, TSDF_thresh],
279
- outputs=outmodel)
280
- transparent_cams.change(model_from_scene_fun,
281
- inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
282
- clean_depth, transparent_cams, cam_size, TSDF_thresh],
283
- outputs=outmodel)
284
- demo.launch(share=False, server_name=server_name, server_port=server_port)
285
-
286
 
287
  if __name__ == '__main__':
288
  parser = get_args_parser()
 
3
  # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
4
  #
5
  # --------------------------------------------------------
6
+ # gradio demo executable
7
  # --------------------------------------------------------
 
 
8
  import os
9
  import torch
 
10
  import tempfile
 
 
 
 
11
 
12
+ from mast3r.demo import get_args_parser, main_demo
 
13
 
14
  from mast3r.model import AsymmetricMASt3R
15
  from mast3r.utils.misc import hash_md5
 
 
 
 
 
 
16
 
17
  import matplotlib.pyplot as pl
18
  pl.ion()
19
 
20
  torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  if __name__ == '__main__':
23
  parser = get_args_parser()
dust3r CHANGED
@@ -1 +1 @@
1
- Subproject commit 8cc725dd11a9b7371bfca37994f8585ca78b42e5
 
1
+ Subproject commit d99800a2d1d33f000c6f0d1c307dfb5a7a34fd53
mast3r/demo.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
3
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
4
+ #
5
+ # --------------------------------------------------------
6
+ # sparse gradio demo functions
7
+ # --------------------------------------------------------
8
+ import math
9
+ import gradio
10
+ import os
11
+ import numpy as np
12
+ import functools
13
+ import trimesh
14
+ import copy
15
+ from scipy.spatial.transform import Rotation
16
+
17
+ from mast3r.cloud_opt.sparse_ga import sparse_global_alignment
18
+ from mast3r.cloud_opt.tsdf_optimizer import TSDFPostProcess
19
+
20
+ import mast3r.utils.path_to_dust3r # noqa
21
+ from dust3r.image_pairs import make_pairs
22
+ from dust3r.utils.image import load_images
23
+ from dust3r.utils.device import to_numpy
24
+ from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
25
+ from dust3r.demo import get_args_parser as dust3r_get_args_parser
26
+
27
+ import matplotlib.pyplot as pl
28
+
29
+
30
+ def get_args_parser():
31
+ parser = dust3r_get_args_parser()
32
+ parser.add_argument('--share', action='store_true')
33
+
34
+ actions = parser._actions
35
+ for action in actions:
36
+ if action.dest == 'model_name':
37
+ action.choices = ["MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"]
38
+ # change defaults
39
+ parser.prog = 'mast3r demo'
40
+ return parser
41
+
42
+
43
+ def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
44
+ cam_color=None, as_pointcloud=False,
45
+ transparent_cams=False, silent=False):
46
+ assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)
47
+ pts3d = to_numpy(pts3d)
48
+ imgs = to_numpy(imgs)
49
+ focals = to_numpy(focals)
50
+ cams2world = to_numpy(cams2world)
51
+
52
+ scene = trimesh.Scene()
53
+
54
+ # full pointcloud
55
+ if as_pointcloud:
56
+ pts = np.concatenate([p[m.ravel()] for p, m in zip(pts3d, mask)])
57
+ col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
58
+ pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3))
59
+ scene.add_geometry(pct)
60
+ else:
61
+ meshes = []
62
+ for i in range(len(imgs)):
63
+ meshes.append(pts3d_to_trimesh(imgs[i], pts3d[i].reshape(imgs[i].shape), mask[i]))
64
+ mesh = trimesh.Trimesh(**cat_meshes(meshes))
65
+ scene.add_geometry(mesh)
66
+
67
+ # add each camera
68
+ for i, pose_c2w in enumerate(cams2world):
69
+ if isinstance(cam_color, list):
70
+ camera_edge_color = cam_color[i]
71
+ else:
72
+ camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
73
+ add_scene_cam(scene, pose_c2w, camera_edge_color,
74
+ None if transparent_cams else imgs[i], focals[i],
75
+ imsize=imgs[i].shape[1::-1], screen_width=cam_size)
76
+
77
+ rot = np.eye(4)
78
+ rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
79
+ scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
80
+ outfile = os.path.join(outdir, 'scene.glb')
81
+ if not silent:
82
+ print('(exporting 3D scene to', outfile, ')')
83
+ scene.export(file_obj=outfile)
84
+ return outfile
85
+
86
+
87
+ def get_3D_model_from_scene(outdir, silent, scene, min_conf_thr=2, as_pointcloud=False, mask_sky=False,
88
+ clean_depth=False, transparent_cams=False, cam_size=0.05, TSDF_thresh=0):
89
+ """
90
+ extract 3D_model (glb file) from a reconstructed scene
91
+ """
92
+ if scene is None:
93
+ return None
94
+
95
+ # get optimized values from scene
96
+ rgbimg = scene.imgs
97
+ focals = scene.get_focals().cpu()
98
+ cams2world = scene.get_im_poses().cpu()
99
+
100
+ # 3D pointcloud from depthmap, poses and intrinsics
101
+ if TSDF_thresh > 0:
102
+ tsdf = TSDFPostProcess(scene, TSDF_thresh=TSDF_thresh)
103
+ pts3d, _, confs = to_numpy(tsdf.get_dense_pts3d(clean_depth=clean_depth))
104
+ else:
105
+ pts3d, _, confs = to_numpy(scene.get_dense_pts3d(clean_depth=clean_depth))
106
+ msk = to_numpy([c > min_conf_thr for c in confs])
107
+ return _convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
108
+ transparent_cams=transparent_cams, cam_size=cam_size, silent=silent)
109
+
110
+
111
+ def get_reconstructed_scene(outdir, model, device, silent, image_size, filelist, optim_level, lr1, niter1, lr2, niter2,
112
+ min_conf_thr, matching_conf_thr, as_pointcloud, mask_sky, clean_depth, transparent_cams,
113
+ cam_size, scenegraph_type, winsize, win_cyclic, refid, TSDF_thresh, shared_intrinsics,
114
+ **kw):
115
+ """
116
+ from a list of images, run mast3r inference, sparse global aligner.
117
+ then run get_3D_model_from_scene
118
+ """
119
+ imgs = load_images(filelist, size=image_size, verbose=not silent)
120
+ if len(imgs) == 1:
121
+ imgs = [imgs[0], copy.deepcopy(imgs[0])]
122
+ imgs[1]['idx'] = 1
123
+ filelist = [filelist[0], filelist[0] + '_2']
124
+
125
+ scene_graph_params = [scenegraph_type]
126
+ if scenegraph_type in ["swin", "logwin"]:
127
+ scene_graph_params.append(str(winsize))
128
+ elif scenegraph_type == "oneref":
129
+ scene_graph_params.append(str(refid))
130
+ if scenegraph_type in ["swin", "logwin"] and not win_cyclic:
131
+ scene_graph_params.append('noncyclic')
132
+ scene_graph = '-'.join(scene_graph_params)
133
+ pairs = make_pairs(imgs, scene_graph=scene_graph, prefilter=None, symmetrize=True)
134
+ if optim_level == 'coarse':
135
+ niter2 = 0
136
+ # Sparse GA (forward mast3r -> matching -> 3D optim -> 2D refinement -> triangulation)
137
+ scene = sparse_global_alignment(filelist, pairs, os.path.join(outdir, 'cache'),
138
+ model, lr1=lr1, niter1=niter1, lr2=lr2, niter2=niter2, device=device,
139
+ opt_depth='depth' in optim_level, shared_intrinsics=shared_intrinsics,
140
+ matching_conf_thr=matching_conf_thr, **kw)
141
+ outfile = get_3D_model_from_scene(outdir, silent, scene, min_conf_thr, as_pointcloud, mask_sky,
142
+ clean_depth, transparent_cams, cam_size, TSDF_thresh)
143
+ return scene, outfile
144
+
145
+
146
+ def set_scenegraph_options(inputfiles, win_cyclic, refid, scenegraph_type):
147
+ num_files = len(inputfiles) if inputfiles is not None else 1
148
+ show_win_controls = scenegraph_type in ["swin", "logwin"]
149
+ show_winsize = scenegraph_type in ["swin", "logwin"]
150
+ show_cyclic = scenegraph_type in ["swin", "logwin"]
151
+ max_winsize, min_winsize = 1, 1
152
+ if scenegraph_type == "swin":
153
+ if win_cyclic:
154
+ max_winsize = max(1, math.ceil((num_files - 1) / 2))
155
+ else:
156
+ max_winsize = num_files - 1
157
+ elif scenegraph_type == "logwin":
158
+ if win_cyclic:
159
+ half_size = math.ceil((num_files - 1) / 2)
160
+ max_winsize = max(1, math.ceil(math.log(half_size, 2)))
161
+ else:
162
+ max_winsize = max(1, math.ceil(math.log(num_files, 2)))
163
+ winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
164
+ minimum=min_winsize, maximum=max_winsize, step=1, visible=show_winsize)
165
+ win_cyclic = gradio.Checkbox(value=win_cyclic, label="Cyclic sequence", visible=show_cyclic)
166
+ win_col = gradio.Column(visible=show_win_controls)
167
+ refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
168
+ maximum=num_files - 1, step=1, visible=scenegraph_type == 'oneref')
169
+ return win_col, winsize, win_cyclic, refid
170
+
171
+
172
+ def main_demo(tmpdirname, model, device, image_size, server_name, server_port, silent=False, share=False):
173
+ if not silent:
174
+ print('Outputing stuff in', tmpdirname)
175
+
176
+ recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, model, device, silent, image_size)
177
+ model_from_scene_fun = functools.partial(get_3D_model_from_scene, tmpdirname, silent)
178
+ with gradio.Blocks(css=""".gradio-container {margin: 0 !important; min-width: 100%};""", title="MASt3R Demo") as demo:
179
+ # scene state is save so that you can change conf_thr, cam_size... without rerunning the inference
180
+ scene = gradio.State(None)
181
+ gradio.HTML('<h2 style="text-align: center;">MASt3R Demo</h2>')
182
+ with gradio.Column():
183
+ inputfiles = gradio.File(file_count="multiple")
184
+ with gradio.Row():
185
+ with gradio.Column():
186
+ with gradio.Row():
187
+ lr1 = gradio.Slider(label="Coarse LR", value=0.07, minimum=0.01, maximum=0.2, step=0.01)
188
+ niter1 = gradio.Number(value=500, precision=0, minimum=0, maximum=10_000,
189
+ label="num_iterations", info="For coarse alignment!")
190
+ lr2 = gradio.Slider(label="Fine LR", value=0.014, minimum=0.005, maximum=0.05, step=0.001)
191
+ niter2 = gradio.Number(value=200, precision=0, minimum=0, maximum=100_000,
192
+ label="num_iterations", info="For refinement!")
193
+ optim_level = gradio.Dropdown(["coarse", "refine", "refine+depth"],
194
+ value='refine', label="OptLevel",
195
+ info="Optimization level")
196
+ with gradio.Row():
197
+ matching_conf_thr = gradio.Slider(label="Matching Confidence Thr", value=5.,
198
+ minimum=0., maximum=30., step=0.1,
199
+ info="Before Fallback to Regr3D!")
200
+ shared_intrinsics = gradio.Checkbox(value=False, label="Shared intrinsics",
201
+ info="Only optimize one set of intrinsics for all views")
202
+ scenegraph_type = gradio.Dropdown([("complete: all possible image pairs", "complete"),
203
+ ("swin: sliding window", "swin"),
204
+ ("logwin: sliding window with long range", "logwin"),
205
+ ("oneref: match one image with all", "oneref")],
206
+ value='complete', label="Scenegraph",
207
+ info="Define how to make pairs",
208
+ interactive=True)
209
+ with gradio.Column(visible=False) as win_col:
210
+ winsize = gradio.Slider(label="Scene Graph: Window Size", value=1,
211
+ minimum=1, maximum=1, step=1)
212
+ win_cyclic = gradio.Checkbox(value=False, label="Cyclic sequence")
213
+ refid = gradio.Slider(label="Scene Graph: Id", value=0,
214
+ minimum=0, maximum=0, step=1, visible=False)
215
+
216
+ run_btn = gradio.Button("Run")
217
+
218
+ with gradio.Row():
219
+ # adjust the confidence threshold
220
+ min_conf_thr = gradio.Slider(label="min_conf_thr", value=1.5, minimum=0.0, maximum=10, step=0.1)
221
+ # adjust the camera size in the output pointcloud
222
+ cam_size = gradio.Slider(label="cam_size", value=0.2, minimum=0.001, maximum=1.0, step=0.001)
223
+ TSDF_thresh = gradio.Slider(label="TSDF Threshold", value=0., minimum=0., maximum=1., step=0.01)
224
+ with gradio.Row():
225
+ as_pointcloud = gradio.Checkbox(value=True, label="As pointcloud")
226
+ # two post process implemented
227
+ mask_sky = gradio.Checkbox(value=False, label="Mask sky")
228
+ clean_depth = gradio.Checkbox(value=True, label="Clean-up depthmaps")
229
+ transparent_cams = gradio.Checkbox(value=False, label="Transparent cameras")
230
+
231
+ outmodel = gradio.Model3D()
232
+
233
+ # events
234
+ scenegraph_type.change(set_scenegraph_options,
235
+ inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
236
+ outputs=[win_col, winsize, win_cyclic, refid])
237
+ inputfiles.change(set_scenegraph_options,
238
+ inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
239
+ outputs=[win_col, winsize, win_cyclic, refid])
240
+ win_cyclic.change(set_scenegraph_options,
241
+ inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
242
+ outputs=[win_col, winsize, win_cyclic, refid])
243
+ run_btn.click(fn=recon_fun,
244
+ inputs=[inputfiles, optim_level, lr1, niter1, lr2, niter2, min_conf_thr, matching_conf_thr,
245
+ as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
246
+ scenegraph_type, winsize, win_cyclic, refid, TSDF_thresh, shared_intrinsics],
247
+ outputs=[scene, outmodel])
248
+ min_conf_thr.release(fn=model_from_scene_fun,
249
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
250
+ clean_depth, transparent_cams, cam_size, TSDF_thresh],
251
+ outputs=outmodel)
252
+ cam_size.change(fn=model_from_scene_fun,
253
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
254
+ clean_depth, transparent_cams, cam_size, TSDF_thresh],
255
+ outputs=outmodel)
256
+ TSDF_thresh.change(fn=model_from_scene_fun,
257
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
258
+ clean_depth, transparent_cams, cam_size, TSDF_thresh],
259
+ outputs=outmodel)
260
+ as_pointcloud.change(fn=model_from_scene_fun,
261
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
262
+ clean_depth, transparent_cams, cam_size, TSDF_thresh],
263
+ outputs=outmodel)
264
+ mask_sky.change(fn=model_from_scene_fun,
265
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
266
+ clean_depth, transparent_cams, cam_size, TSDF_thresh],
267
+ outputs=outmodel)
268
+ clean_depth.change(fn=model_from_scene_fun,
269
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
270
+ clean_depth, transparent_cams, cam_size, TSDF_thresh],
271
+ outputs=outmodel)
272
+ transparent_cams.change(model_from_scene_fun,
273
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
274
+ clean_depth, transparent_cams, cam_size, TSDF_thresh],
275
+ outputs=outmodel)
276
+ demo.launch(share=share, server_name=server_name, server_port=server_port)
277
+