Huiwenshi commited on
Commit
5899685
1 Parent(s): b9551f9

Upload third_party/dust3r/dust3r/inference.py with huggingface_hub

Browse files
third_party/dust3r/dust3r/inference.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # utilities needed for the inference
6
+ # --------------------------------------------------------
7
+ import tqdm
8
+ import torch
9
+ from dust3r.utils.device import to_cpu, collate_with_cat
10
+ from dust3r.utils.misc import invalid_to_nans
11
+ from dust3r.utils.geometry import depthmap_to_pts3d, geotrf
12
+
13
+
14
+ def _interleave_imgs(img1, img2):
15
+ res = {}
16
+ for key, value1 in img1.items():
17
+ value2 = img2[key]
18
+ if isinstance(value1, torch.Tensor):
19
+ value = torch.stack((value1, value2), dim=1).flatten(0, 1)
20
+ else:
21
+ value = [x for pair in zip(value1, value2) for x in pair]
22
+ res[key] = value
23
+ return res
24
+
25
+
26
+ def make_batch_symmetric(batch):
27
+ view1, view2 = batch
28
+ view1, view2 = (_interleave_imgs(view1, view2), _interleave_imgs(view2, view1))
29
+ return view1, view2
30
+
31
+
32
+ def loss_of_one_batch(batch, model, criterion, device, symmetrize_batch=False, use_amp=False, ret=None):
33
+ view1, view2 = batch
34
+ ignore_keys = set(['depthmap', 'dataset', 'label', 'instance', 'idx', 'true_shape', 'rng'])
35
+ for view in batch:
36
+ for name in view.keys(): # pseudo_focal
37
+ if name in ignore_keys:
38
+ continue
39
+ view[name] = view[name].to(device, non_blocking=True)
40
+
41
+ if symmetrize_batch:
42
+ view1, view2 = make_batch_symmetric(batch)
43
+
44
+ with torch.cuda.amp.autocast(enabled=bool(use_amp)):
45
+ pred1, pred2 = model(view1, view2)
46
+
47
+ # loss is supposed to be symmetric
48
+ with torch.cuda.amp.autocast(enabled=False):
49
+ loss = criterion(view1, view2, pred1, pred2) if criterion is not None else None
50
+
51
+ result = dict(view1=view1, view2=view2, pred1=pred1, pred2=pred2, loss=loss)
52
+ return result[ret] if ret else result
53
+
54
+
55
+ @torch.no_grad()
56
+ def inference(pairs, model, device, batch_size=8, verbose=True):
57
+ if verbose:
58
+ print(f'>> Inference with model on {len(pairs)} image pairs')
59
+ result = []
60
+
61
+ # first, check if all images have the same size
62
+ multiple_shapes = not (check_if_same_size(pairs))
63
+ if multiple_shapes: # force bs=1
64
+ batch_size = 1
65
+
66
+ for i in tqdm.trange(0, len(pairs), batch_size, disable=not verbose):
67
+ res = loss_of_one_batch(collate_with_cat(pairs[i:i + batch_size]), model, None, device)
68
+ result.append(to_cpu(res))
69
+
70
+ result = collate_with_cat(result, lists=multiple_shapes)
71
+
72
+ return result
73
+
74
+
75
+ def check_if_same_size(pairs):
76
+ shapes1 = [img1['img'].shape[-2:] for img1, img2 in pairs]
77
+ shapes2 = [img2['img'].shape[-2:] for img1, img2 in pairs]
78
+ return all(shapes1[0] == s for s in shapes1) and all(shapes2[0] == s for s in shapes2)
79
+
80
+
81
+ def get_pred_pts3d(gt, pred, use_pose=False):
82
+ if 'depth' in pred and 'pseudo_focal' in pred:
83
+ try:
84
+ pp = gt['camera_intrinsics'][..., :2, 2]
85
+ except KeyError:
86
+ pp = None
87
+ pts3d = depthmap_to_pts3d(**pred, pp=pp)
88
+
89
+ elif 'pts3d' in pred:
90
+ # pts3d from my camera
91
+ pts3d = pred['pts3d']
92
+
93
+ elif 'pts3d_in_other_view' in pred:
94
+ # pts3d from the other camera, already transformed
95
+ assert use_pose is True
96
+ return pred['pts3d_in_other_view'] # return!
97
+
98
+ if use_pose:
99
+ camera_pose = pred.get('camera_pose')
100
+ assert camera_pose is not None
101
+ pts3d = geotrf(camera_pose, pts3d)
102
+
103
+ return pts3d
104
+
105
+
106
+ def find_opt_scaling(gt_pts1, gt_pts2, pr_pts1, pr_pts2=None, fit_mode='weiszfeld_stop_grad', valid1=None, valid2=None):
107
+ assert gt_pts1.ndim == pr_pts1.ndim == 4
108
+ assert gt_pts1.shape == pr_pts1.shape
109
+ if gt_pts2 is not None:
110
+ assert gt_pts2.ndim == pr_pts2.ndim == 4
111
+ assert gt_pts2.shape == pr_pts2.shape
112
+
113
+ # concat the pointcloud
114
+ nan_gt_pts1 = invalid_to_nans(gt_pts1, valid1).flatten(1, 2)
115
+ nan_gt_pts2 = invalid_to_nans(gt_pts2, valid2).flatten(1, 2) if gt_pts2 is not None else None
116
+
117
+ pr_pts1 = invalid_to_nans(pr_pts1, valid1).flatten(1, 2)
118
+ pr_pts2 = invalid_to_nans(pr_pts2, valid2).flatten(1, 2) if pr_pts2 is not None else None
119
+
120
+ all_gt = torch.cat((nan_gt_pts1, nan_gt_pts2), dim=1) if gt_pts2 is not None else nan_gt_pts1
121
+ all_pr = torch.cat((pr_pts1, pr_pts2), dim=1) if pr_pts2 is not None else pr_pts1
122
+
123
+ dot_gt_pr = (all_pr * all_gt).sum(dim=-1)
124
+ dot_gt_gt = all_gt.square().sum(dim=-1)
125
+
126
+ if fit_mode.startswith('avg'):
127
+ # scaling = (all_pr / all_gt).view(B, -1).mean(dim=1)
128
+ scaling = dot_gt_pr.nanmean(dim=1) / dot_gt_gt.nanmean(dim=1)
129
+ elif fit_mode.startswith('median'):
130
+ scaling = (dot_gt_pr / dot_gt_gt).nanmedian(dim=1).values
131
+ elif fit_mode.startswith('weiszfeld'):
132
+ # init scaling with l2 closed form
133
+ scaling = dot_gt_pr.nanmean(dim=1) / dot_gt_gt.nanmean(dim=1)
134
+ # iterative re-weighted least-squares
135
+ for iter in range(10):
136
+ # re-weighting by inverse of distance
137
+ dis = (all_pr - scaling.view(-1, 1, 1) * all_gt).norm(dim=-1)
138
+ # print(dis.nanmean(-1))
139
+ w = dis.clip_(min=1e-8).reciprocal()
140
+ # update the scaling with the new weights
141
+ scaling = (w * dot_gt_pr).nanmean(dim=1) / (w * dot_gt_gt).nanmean(dim=1)
142
+ else:
143
+ raise ValueError(f'bad {fit_mode=}')
144
+
145
+ if fit_mode.endswith('stop_grad'):
146
+ scaling = scaling.detach()
147
+
148
+ scaling = scaling.clip(min=1e-3)
149
+ # assert scaling.isfinite().all(), bb()
150
+ return scaling