Spaces:
Running
on
L40S
Running
on
L40S
Upload third_party/dust3r/dust3r/inference.py with huggingface_hub
Browse files
third_party/dust3r/dust3r/inference.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# utilities needed for the inference
|
6 |
+
# --------------------------------------------------------
|
7 |
+
import tqdm
|
8 |
+
import torch
|
9 |
+
from dust3r.utils.device import to_cpu, collate_with_cat
|
10 |
+
from dust3r.utils.misc import invalid_to_nans
|
11 |
+
from dust3r.utils.geometry import depthmap_to_pts3d, geotrf
|
12 |
+
|
13 |
+
|
14 |
+
def _interleave_imgs(img1, img2):
|
15 |
+
res = {}
|
16 |
+
for key, value1 in img1.items():
|
17 |
+
value2 = img2[key]
|
18 |
+
if isinstance(value1, torch.Tensor):
|
19 |
+
value = torch.stack((value1, value2), dim=1).flatten(0, 1)
|
20 |
+
else:
|
21 |
+
value = [x for pair in zip(value1, value2) for x in pair]
|
22 |
+
res[key] = value
|
23 |
+
return res
|
24 |
+
|
25 |
+
|
26 |
+
def make_batch_symmetric(batch):
|
27 |
+
view1, view2 = batch
|
28 |
+
view1, view2 = (_interleave_imgs(view1, view2), _interleave_imgs(view2, view1))
|
29 |
+
return view1, view2
|
30 |
+
|
31 |
+
|
32 |
+
def loss_of_one_batch(batch, model, criterion, device, symmetrize_batch=False, use_amp=False, ret=None):
|
33 |
+
view1, view2 = batch
|
34 |
+
ignore_keys = set(['depthmap', 'dataset', 'label', 'instance', 'idx', 'true_shape', 'rng'])
|
35 |
+
for view in batch:
|
36 |
+
for name in view.keys(): # pseudo_focal
|
37 |
+
if name in ignore_keys:
|
38 |
+
continue
|
39 |
+
view[name] = view[name].to(device, non_blocking=True)
|
40 |
+
|
41 |
+
if symmetrize_batch:
|
42 |
+
view1, view2 = make_batch_symmetric(batch)
|
43 |
+
|
44 |
+
with torch.cuda.amp.autocast(enabled=bool(use_amp)):
|
45 |
+
pred1, pred2 = model(view1, view2)
|
46 |
+
|
47 |
+
# loss is supposed to be symmetric
|
48 |
+
with torch.cuda.amp.autocast(enabled=False):
|
49 |
+
loss = criterion(view1, view2, pred1, pred2) if criterion is not None else None
|
50 |
+
|
51 |
+
result = dict(view1=view1, view2=view2, pred1=pred1, pred2=pred2, loss=loss)
|
52 |
+
return result[ret] if ret else result
|
53 |
+
|
54 |
+
|
55 |
+
@torch.no_grad()
|
56 |
+
def inference(pairs, model, device, batch_size=8, verbose=True):
|
57 |
+
if verbose:
|
58 |
+
print(f'>> Inference with model on {len(pairs)} image pairs')
|
59 |
+
result = []
|
60 |
+
|
61 |
+
# first, check if all images have the same size
|
62 |
+
multiple_shapes = not (check_if_same_size(pairs))
|
63 |
+
if multiple_shapes: # force bs=1
|
64 |
+
batch_size = 1
|
65 |
+
|
66 |
+
for i in tqdm.trange(0, len(pairs), batch_size, disable=not verbose):
|
67 |
+
res = loss_of_one_batch(collate_with_cat(pairs[i:i + batch_size]), model, None, device)
|
68 |
+
result.append(to_cpu(res))
|
69 |
+
|
70 |
+
result = collate_with_cat(result, lists=multiple_shapes)
|
71 |
+
|
72 |
+
return result
|
73 |
+
|
74 |
+
|
75 |
+
def check_if_same_size(pairs):
|
76 |
+
shapes1 = [img1['img'].shape[-2:] for img1, img2 in pairs]
|
77 |
+
shapes2 = [img2['img'].shape[-2:] for img1, img2 in pairs]
|
78 |
+
return all(shapes1[0] == s for s in shapes1) and all(shapes2[0] == s for s in shapes2)
|
79 |
+
|
80 |
+
|
81 |
+
def get_pred_pts3d(gt, pred, use_pose=False):
|
82 |
+
if 'depth' in pred and 'pseudo_focal' in pred:
|
83 |
+
try:
|
84 |
+
pp = gt['camera_intrinsics'][..., :2, 2]
|
85 |
+
except KeyError:
|
86 |
+
pp = None
|
87 |
+
pts3d = depthmap_to_pts3d(**pred, pp=pp)
|
88 |
+
|
89 |
+
elif 'pts3d' in pred:
|
90 |
+
# pts3d from my camera
|
91 |
+
pts3d = pred['pts3d']
|
92 |
+
|
93 |
+
elif 'pts3d_in_other_view' in pred:
|
94 |
+
# pts3d from the other camera, already transformed
|
95 |
+
assert use_pose is True
|
96 |
+
return pred['pts3d_in_other_view'] # return!
|
97 |
+
|
98 |
+
if use_pose:
|
99 |
+
camera_pose = pred.get('camera_pose')
|
100 |
+
assert camera_pose is not None
|
101 |
+
pts3d = geotrf(camera_pose, pts3d)
|
102 |
+
|
103 |
+
return pts3d
|
104 |
+
|
105 |
+
|
106 |
+
def find_opt_scaling(gt_pts1, gt_pts2, pr_pts1, pr_pts2=None, fit_mode='weiszfeld_stop_grad', valid1=None, valid2=None):
|
107 |
+
assert gt_pts1.ndim == pr_pts1.ndim == 4
|
108 |
+
assert gt_pts1.shape == pr_pts1.shape
|
109 |
+
if gt_pts2 is not None:
|
110 |
+
assert gt_pts2.ndim == pr_pts2.ndim == 4
|
111 |
+
assert gt_pts2.shape == pr_pts2.shape
|
112 |
+
|
113 |
+
# concat the pointcloud
|
114 |
+
nan_gt_pts1 = invalid_to_nans(gt_pts1, valid1).flatten(1, 2)
|
115 |
+
nan_gt_pts2 = invalid_to_nans(gt_pts2, valid2).flatten(1, 2) if gt_pts2 is not None else None
|
116 |
+
|
117 |
+
pr_pts1 = invalid_to_nans(pr_pts1, valid1).flatten(1, 2)
|
118 |
+
pr_pts2 = invalid_to_nans(pr_pts2, valid2).flatten(1, 2) if pr_pts2 is not None else None
|
119 |
+
|
120 |
+
all_gt = torch.cat((nan_gt_pts1, nan_gt_pts2), dim=1) if gt_pts2 is not None else nan_gt_pts1
|
121 |
+
all_pr = torch.cat((pr_pts1, pr_pts2), dim=1) if pr_pts2 is not None else pr_pts1
|
122 |
+
|
123 |
+
dot_gt_pr = (all_pr * all_gt).sum(dim=-1)
|
124 |
+
dot_gt_gt = all_gt.square().sum(dim=-1)
|
125 |
+
|
126 |
+
if fit_mode.startswith('avg'):
|
127 |
+
# scaling = (all_pr / all_gt).view(B, -1).mean(dim=1)
|
128 |
+
scaling = dot_gt_pr.nanmean(dim=1) / dot_gt_gt.nanmean(dim=1)
|
129 |
+
elif fit_mode.startswith('median'):
|
130 |
+
scaling = (dot_gt_pr / dot_gt_gt).nanmedian(dim=1).values
|
131 |
+
elif fit_mode.startswith('weiszfeld'):
|
132 |
+
# init scaling with l2 closed form
|
133 |
+
scaling = dot_gt_pr.nanmean(dim=1) / dot_gt_gt.nanmean(dim=1)
|
134 |
+
# iterative re-weighted least-squares
|
135 |
+
for iter in range(10):
|
136 |
+
# re-weighting by inverse of distance
|
137 |
+
dis = (all_pr - scaling.view(-1, 1, 1) * all_gt).norm(dim=-1)
|
138 |
+
# print(dis.nanmean(-1))
|
139 |
+
w = dis.clip_(min=1e-8).reciprocal()
|
140 |
+
# update the scaling with the new weights
|
141 |
+
scaling = (w * dot_gt_pr).nanmean(dim=1) / (w * dot_gt_gt).nanmean(dim=1)
|
142 |
+
else:
|
143 |
+
raise ValueError(f'bad {fit_mode=}')
|
144 |
+
|
145 |
+
if fit_mode.endswith('stop_grad'):
|
146 |
+
scaling = scaling.detach()
|
147 |
+
|
148 |
+
scaling = scaling.clip(min=1e-3)
|
149 |
+
# assert scaling.isfinite().all(), bb()
|
150 |
+
return scaling
|