yocabon commited on
Commit
edff93e
1 Parent(s): f64a2f2

rework temp files

Browse files
Files changed (3) hide show
  1. demo.py +6 -11
  2. dust3r +1 -1
  3. mast3r/demo.py +70 -21
demo.py CHANGED
@@ -8,6 +8,7 @@
8
  import os
9
  import torch
10
  import tempfile
 
11
 
12
  from mast3r.demo import get_args_parser, main_demo
13
 
@@ -36,17 +37,11 @@ if __name__ == '__main__':
36
  model = AsymmetricMASt3R.from_pretrained(weights_path).to(args.device)
37
  chkpt_tag = hash_md5(weights_path)
38
 
39
- # mast3r will write the 3D model inside tmpdirname/chkpt_tag
40
- if args.tmp_dir is not None:
41
- tmpdirname = args.tmp_dir
 
42
  cache_path = os.path.join(tmpdirname, chkpt_tag)
43
  os.makedirs(cache_path, exist_ok=True)
44
  main_demo(cache_path, model, args.device, args.image_size, server_name, args.server_port, silent=args.silent,
45
- share=args.share)
46
- else:
47
- with tempfile.TemporaryDirectory(suffix='_mast3r_gradio_demo') as tmpdirname:
48
- cache_path = os.path.join(tmpdirname, chkpt_tag)
49
- os.makedirs(cache_path, exist_ok=True)
50
- main_demo(tmpdirname, model, args.device, args.image_size,
51
- server_name, args.server_port, silent=args.silent,
52
- share=args.share)
 
8
  import os
9
  import torch
10
  import tempfile
11
+ from contextlib import nullcontext
12
 
13
  from mast3r.demo import get_args_parser, main_demo
14
 
 
37
  model = AsymmetricMASt3R.from_pretrained(weights_path).to(args.device)
38
  chkpt_tag = hash_md5(weights_path)
39
 
40
+ def get_context(tmp_dir):
41
+ return tempfile.TemporaryDirectory(suffix='_mast3r_gradio_demo') if tmp_dir is None \
42
+ else nullcontext(tmp_dir)
43
+ with get_context(args.tmp_dir) as tmpdirname:
44
  cache_path = os.path.join(tmpdirname, chkpt_tag)
45
  os.makedirs(cache_path, exist_ok=True)
46
  main_demo(cache_path, model, args.device, args.image_size, server_name, args.server_port, silent=args.silent,
47
+ share=args.share, gradio_delete_cache=args.gradio_delete_cache)
 
 
 
 
 
 
 
dust3r CHANGED
@@ -1 +1 @@
1
- Subproject commit d99800a2d1d33f000c6f0d1c307dfb5a7a34fd53
 
1
+ Subproject commit 8cc725dd11a9b7371bfca37994f8585ca78b42e5
mast3r/demo.py CHANGED
@@ -13,6 +13,8 @@ import functools
13
  import trimesh
14
  import copy
15
  from scipy.spatial.transform import Rotation
 
 
16
 
17
  from mast3r.cloud_opt.sparse_ga import sparse_global_alignment
18
  from mast3r.cloud_opt.tsdf_optimizer import TSDFPostProcess
@@ -27,9 +29,30 @@ from dust3r.demo import get_args_parser as dust3r_get_args_parser
27
  import matplotlib.pyplot as pl
28
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def get_args_parser():
31
  parser = dust3r_get_args_parser()
32
  parser.add_argument('--share', action='store_true')
 
 
33
 
34
  actions = parser._actions
35
  for action in actions:
@@ -40,7 +63,7 @@ def get_args_parser():
40
  return parser
41
 
42
 
43
- def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
44
  cam_color=None, as_pointcloud=False,
45
  transparent_cams=False, silent=False):
46
  assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)
@@ -53,14 +76,17 @@ def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world,
53
 
54
  # full pointcloud
55
  if as_pointcloud:
56
- pts = np.concatenate([p[m.ravel()] for p, m in zip(pts3d, mask)])
57
- col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
58
- pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3))
 
59
  scene.add_geometry(pct)
60
  else:
61
  meshes = []
62
  for i in range(len(imgs)):
63
- meshes.append(pts3d_to_trimesh(imgs[i], pts3d[i].reshape(imgs[i].shape), mask[i]))
 
 
64
  mesh = trimesh.Trimesh(**cat_meshes(meshes))
65
  scene.add_geometry(mesh)
66
 
@@ -77,20 +103,22 @@ def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world,
77
  rot = np.eye(4)
78
  rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
79
  scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
80
- outfile = os.path.join(outdir, 'scene.glb')
81
  if not silent:
82
  print('(exporting 3D scene to', outfile, ')')
83
  scene.export(file_obj=outfile)
84
  return outfile
85
 
86
 
87
- def get_3D_model_from_scene(outdir, silent, scene, min_conf_thr=2, as_pointcloud=False, mask_sky=False,
88
  clean_depth=False, transparent_cams=False, cam_size=0.05, TSDF_thresh=0):
89
  """
90
  extract 3D_model (glb file) from a reconstructed scene
91
  """
92
  if scene is None:
93
  return None
 
 
 
94
 
95
  # get optimized values from scene
96
  rgbimg = scene.imgs
@@ -104,14 +132,14 @@ def get_3D_model_from_scene(outdir, silent, scene, min_conf_thr=2, as_pointcloud
104
  else:
105
  pts3d, _, confs = to_numpy(scene.get_dense_pts3d(clean_depth=clean_depth))
106
  msk = to_numpy([c > min_conf_thr for c in confs])
107
- return _convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
108
  transparent_cams=transparent_cams, cam_size=cam_size, silent=silent)
109
 
110
 
111
- def get_reconstructed_scene(outdir, model, device, silent, image_size, filelist, optim_level, lr1, niter1, lr2, niter2,
112
- min_conf_thr, matching_conf_thr, as_pointcloud, mask_sky, clean_depth, transparent_cams,
113
- cam_size, scenegraph_type, winsize, win_cyclic, refid, TSDF_thresh, shared_intrinsics,
114
- **kw):
115
  """
116
  from a list of images, run mast3r inference, sparse global aligner.
117
  then run get_3D_model_from_scene
@@ -134,11 +162,23 @@ def get_reconstructed_scene(outdir, model, device, silent, image_size, filelist,
134
  if optim_level == 'coarse':
135
  niter2 = 0
136
  # Sparse GA (forward mast3r -> matching -> 3D optim -> 2D refinement -> triangulation)
137
- scene = sparse_global_alignment(filelist, pairs, os.path.join(outdir, 'cache'),
 
 
 
 
 
 
138
  model, lr1=lr1, niter1=niter1, lr2=lr2, niter2=niter2, device=device,
139
  opt_depth='depth' in optim_level, shared_intrinsics=shared_intrinsics,
140
  matching_conf_thr=matching_conf_thr, **kw)
141
- outfile = get_3D_model_from_scene(outdir, silent, scene, min_conf_thr, as_pointcloud, mask_sky,
 
 
 
 
 
 
142
  clean_depth, transparent_cams, cam_size, TSDF_thresh)
143
  return scene, outfile
144
 
@@ -169,13 +209,24 @@ def set_scenegraph_options(inputfiles, win_cyclic, refid, scenegraph_type):
169
  return win_col, winsize, win_cyclic, refid
170
 
171
 
172
- def main_demo(tmpdirname, model, device, image_size, server_name, server_port, silent=False, share=False):
 
173
  if not silent:
174
  print('Outputing stuff in', tmpdirname)
175
 
176
- recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, model, device, silent, image_size)
177
- model_from_scene_fun = functools.partial(get_3D_model_from_scene, tmpdirname, silent)
178
- with gradio.Blocks(css=""".gradio-container {margin: 0 !important; min-width: 100%};""", title="MASt3R Demo") as demo:
 
 
 
 
 
 
 
 
 
 
179
  # scene state is save so that you can change conf_thr, cam_size... without rerunning the inference
180
  scene = gradio.State(None)
181
  gradio.HTML('<h2 style="text-align: center;">MASt3R Demo</h2>')
@@ -212,7 +263,6 @@ def main_demo(tmpdirname, model, device, image_size, server_name, server_port, s
212
  win_cyclic = gradio.Checkbox(value=False, label="Cyclic sequence")
213
  refid = gradio.Slider(label="Scene Graph: Id", value=0,
214
  minimum=0, maximum=0, step=1, visible=False)
215
-
216
  run_btn = gradio.Button("Run")
217
 
218
  with gradio.Row():
@@ -241,7 +291,7 @@ def main_demo(tmpdirname, model, device, image_size, server_name, server_port, s
241
  inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
242
  outputs=[win_col, winsize, win_cyclic, refid])
243
  run_btn.click(fn=recon_fun,
244
- inputs=[inputfiles, optim_level, lr1, niter1, lr2, niter2, min_conf_thr, matching_conf_thr,
245
  as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
246
  scenegraph_type, winsize, win_cyclic, refid, TSDF_thresh, shared_intrinsics],
247
  outputs=[scene, outmodel])
@@ -274,4 +324,3 @@ def main_demo(tmpdirname, model, device, image_size, server_name, server_port, s
274
  clean_depth, transparent_cams, cam_size, TSDF_thresh],
275
  outputs=outmodel)
276
  demo.launch(share=share, server_name=server_name, server_port=server_port)
277
-
 
13
  import trimesh
14
  import copy
15
  from scipy.spatial.transform import Rotation
16
+ import tempfile
17
+ import shutil
18
 
19
  from mast3r.cloud_opt.sparse_ga import sparse_global_alignment
20
  from mast3r.cloud_opt.tsdf_optimizer import TSDFPostProcess
 
29
  import matplotlib.pyplot as pl
30
 
31
 
32
+ class SparseGAState():
33
+ def __init__(self, sparse_ga, should_delete=False, cache_dir=None, outfile_name=None):
34
+ self.sparse_ga = sparse_ga
35
+ self.cache_dir = cache_dir
36
+ self.outfile_name = outfile_name
37
+ self.should_delete = should_delete
38
+
39
+ def __getattr__(self, name):
40
+ return getattr(self.sparse_ga, name)
41
+
42
+ def __del__(self):
43
+ if self.cache_dir is not None and os.path.isdir(self.cache_dir):
44
+ shutil.rmtree(self.cache_dir)
45
+ self.cache_dir = None
46
+ if self.outfile_name is not None and os.path.isfile(self.outfile_name):
47
+ os.remove(self.outfile_name)
48
+ self.outfile_name = None
49
+
50
+
51
  def get_args_parser():
52
  parser = dust3r_get_args_parser()
53
  parser.add_argument('--share', action='store_true')
54
+ parser.add_argument('--gradio_delete_cache', default=None, type=int,
55
+ help='age/frequency at which gradio removes the file. If >0, matching cache is purged')
56
 
57
  actions = parser._actions
58
  for action in actions:
 
63
  return parser
64
 
65
 
66
+ def _convert_scene_output_to_glb(outfile, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
67
  cam_color=None, as_pointcloud=False,
68
  transparent_cams=False, silent=False):
69
  assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)
 
76
 
77
  # full pointcloud
78
  if as_pointcloud:
79
+ pts = np.concatenate([p[m.ravel()] for p, m in zip(pts3d, mask)]).reshape(-1, 3)
80
+ col = np.concatenate([p[m] for p, m in zip(imgs, mask)]).reshape(-1, 3)
81
+ valid_msk = np.isfinite(pts.sum(axis=1))
82
+ pct = trimesh.PointCloud(pts[valid_msk], colors=col[valid_msk])
83
  scene.add_geometry(pct)
84
  else:
85
  meshes = []
86
  for i in range(len(imgs)):
87
+ pts3d_i = pts3d[i].reshape(imgs[i].shape)
88
+ msk_i = mask[i] & np.isfinite(pts3d_i.sum(axis=-1))
89
+ meshes.append(pts3d_to_trimesh(imgs[i], pts3d_i, msk_i))
90
  mesh = trimesh.Trimesh(**cat_meshes(meshes))
91
  scene.add_geometry(mesh)
92
 
 
103
  rot = np.eye(4)
104
  rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
105
  scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
 
106
  if not silent:
107
  print('(exporting 3D scene to', outfile, ')')
108
  scene.export(file_obj=outfile)
109
  return outfile
110
 
111
 
112
+ def get_3D_model_from_scene(silent, scene, min_conf_thr=2, as_pointcloud=False, mask_sky=False,
113
  clean_depth=False, transparent_cams=False, cam_size=0.05, TSDF_thresh=0):
114
  """
115
  extract 3D_model (glb file) from a reconstructed scene
116
  """
117
  if scene is None:
118
  return None
119
+ outfile = scene.outfile_name
120
+ if outfile is None:
121
+ return None
122
 
123
  # get optimized values from scene
124
  rgbimg = scene.imgs
 
132
  else:
133
  pts3d, _, confs = to_numpy(scene.get_dense_pts3d(clean_depth=clean_depth))
134
  msk = to_numpy([c > min_conf_thr for c in confs])
135
+ return _convert_scene_output_to_glb(outfile, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
136
  transparent_cams=transparent_cams, cam_size=cam_size, silent=silent)
137
 
138
 
139
+ def get_reconstructed_scene(outdir, gradio_delete_cache, model, device, silent, image_size, current_scene_state,
140
+ filelist, optim_level, lr1, niter1, lr2, niter2, min_conf_thr, matching_conf_thr,
141
+ as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size, scenegraph_type, winsize,
142
+ win_cyclic, refid, TSDF_thresh, shared_intrinsics, **kw):
143
  """
144
  from a list of images, run mast3r inference, sparse global aligner.
145
  then run get_3D_model_from_scene
 
162
  if optim_level == 'coarse':
163
  niter2 = 0
164
  # Sparse GA (forward mast3r -> matching -> 3D optim -> 2D refinement -> triangulation)
165
+ if current_scene_state is not None and current_scene_state.cache_dir is not None:
166
+ cache_dir = current_scene_state.cache_dir
167
+ elif gradio_delete_cache:
168
+ cache_dir = tempfile.mkdtemp(suffix='_cache', dir=outdir)
169
+ else:
170
+ cache_dir = os.path.join(outdir, 'cache')
171
+ scene = sparse_global_alignment(filelist, pairs, cache_dir,
172
  model, lr1=lr1, niter1=niter1, lr2=lr2, niter2=niter2, device=device,
173
  opt_depth='depth' in optim_level, shared_intrinsics=shared_intrinsics,
174
  matching_conf_thr=matching_conf_thr, **kw)
175
+ if current_scene_state is not None and current_scene_state.outfile_name is not None:
176
+ outfile_name = current_scene_state.outfile_name
177
+ else:
178
+ outfile_name = tempfile.mktemp(suffix='_scene.glb', dir=outdir)
179
+
180
+ scene = SparseGAState(scene, gradio_delete_cache, cache_dir, outfile_name)
181
+ outfile = get_3D_model_from_scene(silent, scene, min_conf_thr, as_pointcloud, mask_sky,
182
  clean_depth, transparent_cams, cam_size, TSDF_thresh)
183
  return scene, outfile
184
 
 
209
  return win_col, winsize, win_cyclic, refid
210
 
211
 
212
+ def main_demo(tmpdirname, model, device, image_size, server_name, server_port, silent=False,
213
+ share=False, gradio_delete_cache=False):
214
  if not silent:
215
  print('Outputing stuff in', tmpdirname)
216
 
217
+ recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, gradio_delete_cache, model, device,
218
+ silent, image_size)
219
+ model_from_scene_fun = functools.partial(get_3D_model_from_scene, silent)
220
+
221
+ def get_context(delete_cache):
222
+ css = """.gradio-container {margin: 0 !important; min-width: 100%};"""
223
+ title = "MASt3R Demo"
224
+ if delete_cache:
225
+ return gradio.Blocks(css=css, title=title, delete_cache=(delete_cache, delete_cache))
226
+ else:
227
+ return gradio.Blocks(css=css, title="MASt3R Demo") # for compatibility with older versions
228
+
229
+ with get_context(gradio_delete_cache) as demo:
230
  # scene state is save so that you can change conf_thr, cam_size... without rerunning the inference
231
  scene = gradio.State(None)
232
  gradio.HTML('<h2 style="text-align: center;">MASt3R Demo</h2>')
 
263
  win_cyclic = gradio.Checkbox(value=False, label="Cyclic sequence")
264
  refid = gradio.Slider(label="Scene Graph: Id", value=0,
265
  minimum=0, maximum=0, step=1, visible=False)
 
266
  run_btn = gradio.Button("Run")
267
 
268
  with gradio.Row():
 
291
  inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
292
  outputs=[win_col, winsize, win_cyclic, refid])
293
  run_btn.click(fn=recon_fun,
294
+ inputs=[scene, inputfiles, optim_level, lr1, niter1, lr2, niter2, min_conf_thr, matching_conf_thr,
295
  as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
296
  scenegraph_type, winsize, win_cyclic, refid, TSDF_thresh, shared_intrinsics],
297
  outputs=[scene, outmodel])
 
324
  clean_depth, transparent_cams, cam_size, TSDF_thresh],
325
  outputs=outmodel)
326
  demo.launch(share=share, server_name=server_name, server_port=server_port)