yocabon commited on
Commit
7a6bc02
·
1 Parent(s): 02db12f

update sparse_ga abd add new options to the demo

Browse files
Files changed (3) hide show
  1. demo.py +50 -34
  2. dust3r +1 -1
  3. mast3r/cloud_opt/sparse_ga.py +70 -35
demo.py CHANGED
@@ -118,7 +118,7 @@ def get_3D_model_from_scene(outdir, silent, scene, min_conf_thr=2, as_pointcloud
118
 
119
  def get_reconstructed_scene(outdir, model, device, silent, image_size, filelist, optim_level, lr1, niter1, lr2, niter2, min_conf_thr,
120
  as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
121
- scenegraph_type, winsize, refid, TSDF_thresh, **kw):
122
  """
123
  from a list of images, run mast3r inference, sparse global aligner.
124
  then run get_3D_model_from_scene
@@ -128,42 +128,51 @@ def get_reconstructed_scene(outdir, model, device, silent, image_size, filelist,
128
  imgs = [imgs[0], copy.deepcopy(imgs[0])]
129
  imgs[1]['idx'] = 1
130
  filelist = [filelist[0], filelist[0] + '_2']
131
- if scenegraph_type == "swin":
132
- scenegraph_type = scenegraph_type + "-" + str(winsize)
133
- elif scenegraph_type == "oneref":
134
- scenegraph_type = scenegraph_type + "-" + str(refid)
135
 
136
- pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=True)
 
 
 
 
 
 
 
 
137
  if optim_level == 'coarse':
138
  niter2 = 0
139
  # Sparse GA (forward mast3r -> matching -> 3D optim -> 2D refinement -> triangulation)
140
  scene = sparse_global_alignment(filelist, pairs, os.path.join(outdir, 'cache'),
141
  model, lr1=lr1, niter1=niter1, lr2=lr2, niter2=niter2, device=device,
142
- opt_depth='depth' in optim_level, **kw)
143
  outfile = get_3D_model_from_scene(outdir, silent, scene, min_conf_thr, as_pointcloud, mask_sky,
144
  clean_depth, transparent_cams, cam_size, TSDF_thresh)
145
  return scene, outfile
146
 
147
 
148
- def set_scenegraph_options(inputfiles, winsize, refid, scenegraph_type):
149
  num_files = len(inputfiles) if inputfiles is not None else 1
150
- max_winsize = max(1, math.ceil((num_files - 1) / 2))
 
 
 
151
  if scenegraph_type == "swin":
152
- winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
153
- minimum=1, maximum=max_winsize, step=1, visible=True)
154
- refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
155
- maximum=num_files - 1, step=1, visible=False)
156
- elif scenegraph_type == "oneref":
157
- winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
158
- minimum=1, maximum=max_winsize, step=1, visible=False)
159
- refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
160
- maximum=num_files - 1, step=1, visible=True)
161
- else:
162
- winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
163
- minimum=1, maximum=max_winsize, step=1, visible=False)
164
- refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
165
- maximum=num_files - 1, step=1, visible=False)
166
- return winsize, refid
 
 
167
 
168
 
169
  def main_demo(tmpdirname, model, device, image_size, server_name, server_port, silent=False, share=False):
@@ -180,21 +189,25 @@ def main_demo(tmpdirname, model, device, image_size, server_name, server_port, s
180
  inputfiles = gradio.File(file_count="multiple")
181
  with gradio.Row():
182
  lr1 = gradio.Slider(label="Coarse LR", value=0.07, minimum=0.01, maximum=0.2, step=0.01)
183
- niter1 = gradio.Number(value=200, precision=0, minimum=0, maximum=10_000,
184
  label="num_iterations", info="For coarse alignment!")
185
  lr2 = gradio.Slider(label="Fine LR", value=0.014, minimum=0.005, maximum=0.05, step=0.001)
186
- niter2 = gradio.Number(value=500, precision=0, minimum=0, maximum=100_000,
187
  label="num_iterations", info="For refinement!")
188
  optim_level = gradio.Dropdown(["coarse", "refine", "refine+depth"],
189
  value='refine', label="OptLevel",
190
  info="Optimization level")
 
 
191
 
192
- scenegraph_type = gradio.Dropdown(["complete", "swin", "oneref"],
193
  value='complete', label="Scenegraph",
194
  info="Define how to make pairs",
195
  interactive=True)
196
- winsize = gradio.Slider(label="Scene Graph: Window Size", value=1,
197
- minimum=1, maximum=1, step=1, visible=False)
 
 
198
  refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=0, step=1, visible=False)
199
 
200
  run_btn = gradio.Button("Run")
@@ -216,15 +229,18 @@ def main_demo(tmpdirname, model, device, image_size, server_name, server_port, s
216
 
217
  # events
218
  scenegraph_type.change(set_scenegraph_options,
219
- inputs=[inputfiles, winsize, refid, scenegraph_type],
220
- outputs=[winsize, refid])
221
  inputfiles.change(set_scenegraph_options,
222
- inputs=[inputfiles, winsize, refid, scenegraph_type],
223
- outputs=[winsize, refid])
 
 
 
224
  run_btn.click(fn=recon_fun,
225
  inputs=[inputfiles, optim_level, lr1, niter1, lr2, niter2, min_conf_thr, as_pointcloud,
226
  mask_sky, clean_depth, transparent_cams, cam_size,
227
- scenegraph_type, winsize, refid, TSDF_thresh],
228
  outputs=[scene, outmodel])
229
  min_conf_thr.release(fn=model_from_scene_fun,
230
  inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
 
118
 
119
  def get_reconstructed_scene(outdir, model, device, silent, image_size, filelist, optim_level, lr1, niter1, lr2, niter2, min_conf_thr,
120
  as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
121
+ scenegraph_type, winsize, win_cyclic, refid, TSDF_thresh, shared_intrinsics, **kw):
122
  """
123
  from a list of images, run mast3r inference, sparse global aligner.
124
  then run get_3D_model_from_scene
 
128
  imgs = [imgs[0], copy.deepcopy(imgs[0])]
129
  imgs[1]['idx'] = 1
130
  filelist = [filelist[0], filelist[0] + '_2']
 
 
 
 
131
 
132
+ scene_graph_params = [scenegraph_type]
133
+ if scenegraph_type in ["swin", "logwin"]:
134
+ scene_graph_params.append(str(winsize))
135
+ elif scenegraph_type == "oneref":
136
+ scene_graph_params.append(str(refid))
137
+ if scenegraph_type in ["swin", "logwin"] and not win_cyclic:
138
+ scene_graph_params.append('noncyclic')
139
+ scene_graph = '-'.join(scene_graph_params)
140
+ pairs = make_pairs(imgs, scene_graph=scene_graph, prefilter=None, symmetrize=True)
141
  if optim_level == 'coarse':
142
  niter2 = 0
143
  # Sparse GA (forward mast3r -> matching -> 3D optim -> 2D refinement -> triangulation)
144
  scene = sparse_global_alignment(filelist, pairs, os.path.join(outdir, 'cache'),
145
  model, lr1=lr1, niter1=niter1, lr2=lr2, niter2=niter2, device=device,
146
+ opt_depth='depth' in optim_level, shared_intrinsics=shared_intrinsics, **kw)
147
  outfile = get_3D_model_from_scene(outdir, silent, scene, min_conf_thr, as_pointcloud, mask_sky,
148
  clean_depth, transparent_cams, cam_size, TSDF_thresh)
149
  return scene, outfile
150
 
151
 
152
+ def set_scenegraph_options(inputfiles, win_cyclic, refid, scenegraph_type):
153
  num_files = len(inputfiles) if inputfiles is not None else 1
154
+ show_win_controls = scenegraph_type in ["swin", "logwin"]
155
+ show_winsize = scenegraph_type in ["swin", "logwin"]
156
+ show_cyclic = scenegraph_type in ["swin", "logwin"]
157
+ max_winsize, min_winsize = 1, 1
158
  if scenegraph_type == "swin":
159
+ if win_cyclic:
160
+ max_winsize = max(1, math.ceil((num_files - 1) / 2))
161
+ else:
162
+ max_winsize = num_files - 1
163
+ elif scenegraph_type == "logwin":
164
+ if win_cyclic:
165
+ half_size = math.ceil((num_files - 1) / 2)
166
+ max_winsize = max(1, math.ceil(math.log(half_size, 2)))
167
+ else:
168
+ max_winsize = max(1, math.ceil(math.log(num_files, 2)))
169
+ winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
170
+ minimum=min_winsize, maximum=max_winsize, step=1, visible=show_winsize)
171
+ win_cyclic = gradio.Checkbox(value=win_cyclic, label="Cyclic sequence", visible=show_cyclic)
172
+ win_col = gradio.Column(visible=show_win_controls)
173
+ refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
174
+ maximum=num_files - 1, step=1, visible=scenegraph_type == 'oneref')
175
+ return win_col, winsize, win_cyclic, refid
176
 
177
 
178
  def main_demo(tmpdirname, model, device, image_size, server_name, server_port, silent=False, share=False):
 
189
  inputfiles = gradio.File(file_count="multiple")
190
  with gradio.Row():
191
  lr1 = gradio.Slider(label="Coarse LR", value=0.07, minimum=0.01, maximum=0.2, step=0.01)
192
+ niter1 = gradio.Number(value=500, precision=0, minimum=0, maximum=10_000,
193
  label="num_iterations", info="For coarse alignment!")
194
  lr2 = gradio.Slider(label="Fine LR", value=0.014, minimum=0.005, maximum=0.05, step=0.001)
195
+ niter2 = gradio.Number(value=200, precision=0, minimum=0, maximum=100_000,
196
  label="num_iterations", info="For refinement!")
197
  optim_level = gradio.Dropdown(["coarse", "refine", "refine+depth"],
198
  value='refine', label="OptLevel",
199
  info="Optimization level")
200
+ shared_intrinsics = gradio.Checkbox(value=False, label="Shared intrinsics",
201
+ info="Only optimize one set of intrinsics for all views")
202
 
203
+ scenegraph_type = gradio.Dropdown(["complete", "swin", "logwin", "oneref"],
204
  value='complete', label="Scenegraph",
205
  info="Define how to make pairs",
206
  interactive=True)
207
+ with gradio.Column(visible=False) as win_col:
208
+ winsize = gradio.Slider(label="Scene Graph: Window Size", value=1,
209
+ minimum=1, maximum=1, step=1)
210
+ win_cyclic = gradio.Checkbox(value=False, label="Cyclic sequence")
211
  refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=0, step=1, visible=False)
212
 
213
  run_btn = gradio.Button("Run")
 
229
 
230
  # events
231
  scenegraph_type.change(set_scenegraph_options,
232
+ inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
233
+ outputs=[win_col, winsize, win_cyclic, refid])
234
  inputfiles.change(set_scenegraph_options,
235
+ inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
236
+ outputs=[win_col, winsize, win_cyclic, refid])
237
+ win_cyclic.change(set_scenegraph_options,
238
+ inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
239
+ outputs=[win_col, winsize, win_cyclic, refid])
240
  run_btn.click(fn=recon_fun,
241
  inputs=[inputfiles, optim_level, lr1, niter1, lr2, niter2, min_conf_thr, as_pointcloud,
242
  mask_sky, clean_depth, transparent_cams, cam_size,
243
+ scenegraph_type, winsize, win_cyclic, refid, TSDF_thresh, shared_intrinsics],
244
  outputs=[scene, outmodel])
245
  min_conf_thr.release(fn=model_from_scene_fun,
246
  inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
dust3r CHANGED
@@ -1 +1 @@
1
- Subproject commit c267f72ba845108be1d2bd5ea98c641f2835cbc7
 
1
+ Subproject commit 6b133a501f5760e0f22f8af3fef05dc9e227e9ac
mast3r/cloud_opt/sparse_ga.py CHANGED
@@ -115,7 +115,7 @@ def convert_dust3r_pairs_naming(imgs, pairs_in):
115
 
116
 
117
  def sparse_global_alignment(imgs, pairs_in, cache_path, model, subsample=8, desc_conf='desc_conf',
118
- device='cuda', dtype=torch.float32, **kw):
119
  """ Sparse alignment with MASt3R
120
  imgs: list of image paths
121
  cache_path: path where to dump temporary files (str)
@@ -148,8 +148,9 @@ def sparse_global_alignment(imgs, pairs_in, cache_path, model, subsample=8, desc
148
  condense_data(imgs, tmp_pairs, canonical_views, dtype)
149
 
150
  imgs, res_coarse, res_fine = sparse_scene_optimizer(
151
- imgs, subsample, imsizes, pps, base_focals, core_depth, anchors, corres, corres2d, preds_21, canonical_paths,
152
- mst, cache_path=cache_path, device=device, dtype=dtype, **kw)
 
153
  return SparseGA(imgs, pairs_in, res_fine or res_coarse, anchors, canonical_paths)
154
 
155
 
@@ -161,8 +162,9 @@ def sparse_scene_optimizer(imgs, subsample, imsizes, pps, base_focals, core_dept
161
  opt_pp=True, opt_depth=True,
162
  schedule=cosine_schedule, depth_mode='add', exp_depth=False,
163
  lora_depth=False, # dict(k=96, gamma=15, min_norm=.5),
 
164
  init={}, device='cuda', dtype=torch.float32,
165
- matching_conf_thr=4., loss_dust3r_w=0.01,
166
  verbose=True, dbg=()):
167
 
168
  # extrinsic parameters
@@ -206,11 +208,23 @@ def sparse_scene_optimizer(imgs, subsample, imsizes, pps, base_focals, core_dept
206
  assert False, 'inverse kinematic chain not yet implemented'
207
 
208
  # intrinsics parameters
209
- pps = [nn.Parameter(pp.to(dtype)) for pp in pps]
 
 
 
 
 
 
 
 
 
 
 
 
210
  diags = imsizes.float().norm(dim=1)
211
  min_focals = 0.25 * diags # diag = 1.2~1.4*max(W,H) => beta >= 1/(2*1.2*tan(fov/2)) ~= 0.26
212
  max_focals = 10 * diags
213
- log_focals = [nn.Parameter(f.view(1).log().to(dtype)) for f in base_focals]
214
  assert len(mst[1]) == len(pps) - 1
215
 
216
  def make_K_cam_depth(log_focals, pps, trans, quats, log_sizes, core_depth):
@@ -268,7 +282,11 @@ def sparse_scene_optimizer(imgs, subsample, imsizes, pps, base_focals, core_dept
268
  return K, (inv(cam2w), cam2w), depthmaps
269
 
270
  K = make_K_cam_depth(log_focals, pps, None, None, None, None)
271
- print('init focals =', to_numpy(K[:, 0, 0]))
 
 
 
 
272
 
273
  # spectral low-rank projection of depthmaps
274
  if lora_depth:
@@ -298,17 +316,39 @@ def sparse_scene_optimizer(imgs, subsample, imsizes, pps, base_focals, core_dept
298
  idxs = anchors[imgs.index(im2k)][1]
299
  subsamp_preds_21[imk][im2k] = (subpred[idxs], subconf[idxs]) # anchors subsample
300
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
  def loss_dust3r(cam2w, pts3d, pix_loss):
302
  # In the case no correspondence could be established, fallback to DUSt3R GA regression loss formulation (sparsified)
303
  loss = 0.
304
  cf_sum = 0.
305
- for s in imgs_slices:
306
- if not is_matching_ok[s.img1, s.img2]:
307
- # fallback to dust3r regression
308
- tgt_pts, tgt_confs = subsamp_preds_21[imgs[s.img2]][imgs[s.img1]]
309
- tgt_pts = geotrf(cam2w[s.img2], tgt_pts)
310
- cf_sum += tgt_confs.sum()
311
- loss += tgt_confs @ pix_loss(pts3d[s.img1], tgt_pts)
312
  return loss / cf_sum if cf_sum != 0. else 0.
313
 
314
  def loss_3d(K, w2cam, pts3d, pix_loss):
@@ -318,17 +358,16 @@ def sparse_scene_optimizer(imgs, subsample, imsizes, pps, base_focals, core_dept
318
  pts3d_1 = []
319
  pts3d_2 = []
320
  confs = []
321
- for s in imgs_slices:
322
  if init[imgs[s.img1]].get('freeze') and init[imgs[s.img2]].get('freeze'):
323
  continue
324
- if is_matching_ok[s.img1, s.img2]:
325
- pts3d_1.append(pts3d[s.img1][s.slice1])
326
- pts3d_2.append(pts3d[s.img2][s.slice2])
327
- confs.append(s.confs)
328
  else:
329
- pts3d_1 = [pts3d[s.img1][s.slice1] for s in imgs_slices if is_matching_ok[s.img1, s.img2]]
330
- pts3d_2 = [pts3d[s.img2][s.slice2] for s in imgs_slices if is_matching_ok[s.img1, s.img2]]
331
- confs = [s.confs for s in imgs_slices if is_matching_ok[s.img1, s.img2]]
332
 
333
  if pts3d_1 != []:
334
  confs = torch.cat(confs)
@@ -347,25 +386,15 @@ def sparse_scene_optimizer(imgs, subsample, imsizes, pps, base_focals, core_dept
347
  # For each 3D point, we have 2 reproj errors
348
  proj_matrix = K @ w2cam[:, :3]
349
  loss = npix = 0
350
- for img1, pix1, confs, cf_sum, imgs_slices in corres2d:
351
  if init[imgs[img1]].get('freeze', 0) >= 1:
352
  continue # no need
353
- pts3d_in_img1 = [pts3d[img2][slice2] for img2, slice2 in imgs_slices if is_matching_ok[img1, img2]]
354
- pix1_filtered = []
355
- confs_filtered = []
356
- curstep = 0
357
- for img2, slice2 in imgs_slices:
358
- if is_matching_ok[img1, img2]:
359
- tslice = slice(curstep, curstep + slice2.stop - slice2.start, slice2.step)
360
- pix1_filtered.append(pix1[tslice])
361
- confs_filtered.append(confs[tslice])
362
- curstep += slice2.stop - slice2.start
363
  if pts3d_in_img1 != []:
364
  pts3d_in_img1 = torch.cat(pts3d_in_img1)
365
- pix1_filtered = torch.cat(pix1_filtered)
366
- confs_filtered = torch.cat(confs_filtered)
367
  loss += confs_filtered @ pix_loss(pix1_filtered, reproj2d(proj_matrix[img1], pts3d_in_img1))
368
  npix += confs_filtered.sum()
 
369
  return loss / npix if npix != 0 else 0.
370
 
371
  def optimize_loop(loss_func, lr_base, niter, pix_loss, lr_end=0):
@@ -430,6 +459,12 @@ def sparse_scene_optimizer(imgs, subsample, imsizes, pps, base_focals, core_dept
430
  # refinement with 2d reproj
431
  res_fine = optimize_loop(loss_2d, lr_base=lr2, niter=niter2, pix_loss=loss2)
432
 
 
 
 
 
 
 
433
  return imgs, res_coarse, res_fine
434
 
435
 
 
115
 
116
 
117
  def sparse_global_alignment(imgs, pairs_in, cache_path, model, subsample=8, desc_conf='desc_conf',
118
+ device='cuda', dtype=torch.float32, shared_intrinsics=False, **kw):
119
  """ Sparse alignment with MASt3R
120
  imgs: list of image paths
121
  cache_path: path where to dump temporary files (str)
 
148
  condense_data(imgs, tmp_pairs, canonical_views, dtype)
149
 
150
  imgs, res_coarse, res_fine = sparse_scene_optimizer(
151
+ imgs, subsample, imsizes, pps, base_focals, core_depth, anchors, corres, corres2d, preds_21, canonical_paths, mst,
152
+ shared_intrinsics=shared_intrinsics, cache_path=cache_path, device=device, dtype=dtype, **kw)
153
+
154
  return SparseGA(imgs, pairs_in, res_fine or res_coarse, anchors, canonical_paths)
155
 
156
 
 
162
  opt_pp=True, opt_depth=True,
163
  schedule=cosine_schedule, depth_mode='add', exp_depth=False,
164
  lora_depth=False, # dict(k=96, gamma=15, min_norm=.5),
165
+ shared_intrinsics=False,
166
  init={}, device='cuda', dtype=torch.float32,
167
+ matching_conf_thr=5., loss_dust3r_w=0.01,
168
  verbose=True, dbg=()):
169
 
170
  # extrinsic parameters
 
208
  assert False, 'inverse kinematic chain not yet implemented'
209
 
210
  # intrinsics parameters
211
+ if shared_intrinsics:
212
+ # Optimize a single set of intrinsics for all cameras. Use averages as init.
213
+ confs = torch.stack([torch.load(pth)[0][2].mean() for pth in canonical_paths]).to(pps)
214
+ weighting = confs / confs.sum()
215
+ pp = nn.Parameter((weighting @ pps).to(dtype))
216
+ pps = [pp for _ in range(len(imgs))]
217
+ focal_m = weighting @ base_focals
218
+ log_focal = nn.Parameter(focal_m.view(1).log().to(dtype))
219
+ log_focals = [log_focal for _ in range(len(imgs))]
220
+ else:
221
+ pps = [nn.Parameter(pp.to(dtype)) for pp in pps]
222
+ log_focals = [nn.Parameter(f.view(1).log().to(dtype)) for f in base_focals]
223
+
224
  diags = imsizes.float().norm(dim=1)
225
  min_focals = 0.25 * diags # diag = 1.2~1.4*max(W,H) => beta >= 1/(2*1.2*tan(fov/2)) ~= 0.26
226
  max_focals = 10 * diags
227
+
228
  assert len(mst[1]) == len(pps) - 1
229
 
230
  def make_K_cam_depth(log_focals, pps, trans, quats, log_sizes, core_depth):
 
282
  return K, (inv(cam2w), cam2w), depthmaps
283
 
284
  K = make_K_cam_depth(log_focals, pps, None, None, None, None)
285
+
286
+ if shared_intrinsics:
287
+ print('init focal (shared) = ', to_numpy(K[0, 0, 0]).round(2))
288
+ else:
289
+ print('init focals =', to_numpy(K[:, 0, 0]))
290
 
291
  # spectral low-rank projection of depthmaps
292
  if lora_depth:
 
316
  idxs = anchors[imgs.index(im2k)][1]
317
  subsamp_preds_21[imk][im2k] = (subpred[idxs], subconf[idxs]) # anchors subsample
318
 
319
+ # Prepare slices and corres for losses
320
+ dust3r_slices = [s for s in imgs_slices if not is_matching_ok[s.img1, s.img2]]
321
+ loss3d_slices = [s for s in imgs_slices if is_matching_ok[s.img1, s.img2]]
322
+ cleaned_corres2d = []
323
+ for cci, (img1, pix1, confs, confsum, imgs_slices) in enumerate(corres2d):
324
+ cf_sum = 0
325
+ pix1_filtered = []
326
+ confs_filtered = []
327
+ curstep = 0
328
+ cleaned_slices = []
329
+ for img2, slice2 in imgs_slices:
330
+ if is_matching_ok[img1, img2]:
331
+ tslice = slice(curstep, curstep + slice2.stop - slice2.start, slice2.step)
332
+ pix1_filtered.append(pix1[tslice])
333
+ confs_filtered.append(confs[tslice])
334
+ cleaned_slices.append((img2, slice2))
335
+ curstep += slice2.stop - slice2.start
336
+ if pix1_filtered != []:
337
+ pix1_filtered = torch.cat(pix1_filtered)
338
+ confs_filtered = torch.cat(confs_filtered)
339
+ cf_sum = confs_filtered.sum()
340
+ cleaned_corres2d.append((img1, pix1_filtered, confs_filtered, cf_sum, cleaned_slices))
341
+
342
  def loss_dust3r(cam2w, pts3d, pix_loss):
343
  # In the case no correspondence could be established, fallback to DUSt3R GA regression loss formulation (sparsified)
344
  loss = 0.
345
  cf_sum = 0.
346
+ for s in dust3r_slices:
347
+ # fallback to dust3r regression
348
+ tgt_pts, tgt_confs = subsamp_preds_21[imgs[s.img2]][imgs[s.img1]]
349
+ tgt_pts = geotrf(cam2w[s.img2], tgt_pts)
350
+ cf_sum += tgt_confs.sum()
351
+ loss += tgt_confs @ pix_loss(pts3d[s.img1], tgt_pts)
 
352
  return loss / cf_sum if cf_sum != 0. else 0.
353
 
354
  def loss_3d(K, w2cam, pts3d, pix_loss):
 
358
  pts3d_1 = []
359
  pts3d_2 = []
360
  confs = []
361
+ for s in loss3d_slices:
362
  if init[imgs[s.img1]].get('freeze') and init[imgs[s.img2]].get('freeze'):
363
  continue
364
+ pts3d_1.append(pts3d[s.img1][s.slice1])
365
+ pts3d_2.append(pts3d[s.img2][s.slice2])
366
+ confs.append(s.confs)
 
367
  else:
368
+ pts3d_1 = [pts3d[s.img1][s.slice1] for s in loss3d_slices]
369
+ pts3d_2 = [pts3d[s.img2][s.slice2] for s in loss3d_slices]
370
+ confs = [s.confs for s in loss3d_slices]
371
 
372
  if pts3d_1 != []:
373
  confs = torch.cat(confs)
 
386
  # For each 3D point, we have 2 reproj errors
387
  proj_matrix = K @ w2cam[:, :3]
388
  loss = npix = 0
389
+ for img1, pix1_filtered, confs_filtered, cf_sum, cleaned_slices in cleaned_corres2d:
390
  if init[imgs[img1]].get('freeze', 0) >= 1:
391
  continue # no need
392
+ pts3d_in_img1 = [pts3d[img2][slice2] for img2, slice2 in cleaned_slices]
 
 
 
 
 
 
 
 
 
393
  if pts3d_in_img1 != []:
394
  pts3d_in_img1 = torch.cat(pts3d_in_img1)
 
 
395
  loss += confs_filtered @ pix_loss(pix1_filtered, reproj2d(proj_matrix[img1], pts3d_in_img1))
396
  npix += confs_filtered.sum()
397
+
398
  return loss / npix if npix != 0 else 0.
399
 
400
  def optimize_loop(loss_func, lr_base, niter, pix_loss, lr_end=0):
 
459
  # refinement with 2d reproj
460
  res_fine = optimize_loop(loss_2d, lr_base=lr2, niter=niter2, pix_loss=loss2)
461
 
462
+ K = make_K_cam_depth(log_focals, pps, None, None, None, None)
463
+ if shared_intrinsics:
464
+ print('Final focal (shared) = ', to_numpy(K[0, 0, 0]).round(2))
465
+ else:
466
+ print('Final focals =', to_numpy(K[:, 0, 0]))
467
+
468
  return imgs, res_coarse, res_fine
469
 
470