Spaces:
Running
Running
update: d2net lib
Browse files- third_party/d2net/lib/dataset.py +239 -0
- third_party/d2net/lib/exceptions.py +6 -0
- third_party/d2net/lib/loss.py +340 -0
- third_party/d2net/lib/model.py +121 -0
- third_party/d2net/lib/model_test.py +187 -0
- third_party/d2net/lib/pyramid.py +129 -0
- third_party/d2net/lib/utils.py +167 -0
third_party/d2net/lib/dataset.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import h5py
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
import os
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from torch.utils.data import Dataset
|
11 |
+
|
12 |
+
import time
|
13 |
+
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
from lib.utils import preprocess_image
|
17 |
+
|
18 |
+
|
19 |
+
class MegaDepthDataset(Dataset):
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
scene_list_path='megadepth_utils/train_scenes.txt',
|
23 |
+
scene_info_path='/local/dataset/megadepth/scene_info',
|
24 |
+
base_path='/local/dataset/megadepth',
|
25 |
+
train=True,
|
26 |
+
preprocessing=None,
|
27 |
+
min_overlap_ratio=.5,
|
28 |
+
max_overlap_ratio=1,
|
29 |
+
max_scale_ratio=np.inf,
|
30 |
+
pairs_per_scene=100,
|
31 |
+
image_size=256
|
32 |
+
):
|
33 |
+
self.scenes = []
|
34 |
+
with open(scene_list_path, 'r') as f:
|
35 |
+
lines = f.readlines()
|
36 |
+
for line in lines:
|
37 |
+
self.scenes.append(line.strip('\n'))
|
38 |
+
|
39 |
+
self.scene_info_path = scene_info_path
|
40 |
+
self.base_path = base_path
|
41 |
+
|
42 |
+
self.train = train
|
43 |
+
|
44 |
+
self.preprocessing = preprocessing
|
45 |
+
|
46 |
+
self.min_overlap_ratio = min_overlap_ratio
|
47 |
+
self.max_overlap_ratio = max_overlap_ratio
|
48 |
+
self.max_scale_ratio = max_scale_ratio
|
49 |
+
|
50 |
+
self.pairs_per_scene = pairs_per_scene
|
51 |
+
|
52 |
+
self.image_size = image_size
|
53 |
+
|
54 |
+
self.dataset = []
|
55 |
+
|
56 |
+
def build_dataset(self):
|
57 |
+
self.dataset = []
|
58 |
+
if not self.train:
|
59 |
+
np_random_state = np.random.get_state()
|
60 |
+
np.random.seed(42)
|
61 |
+
print('Building the validation dataset...')
|
62 |
+
else:
|
63 |
+
print('Building a new training dataset...')
|
64 |
+
for scene in tqdm(self.scenes, total=len(self.scenes)):
|
65 |
+
scene_info_path = os.path.join(
|
66 |
+
self.scene_info_path, '%s.npz' % scene
|
67 |
+
)
|
68 |
+
if not os.path.exists(scene_info_path):
|
69 |
+
continue
|
70 |
+
scene_info = np.load(scene_info_path, allow_pickle=True)
|
71 |
+
overlap_matrix = scene_info['overlap_matrix']
|
72 |
+
scale_ratio_matrix = scene_info['scale_ratio_matrix']
|
73 |
+
|
74 |
+
valid = np.logical_and(
|
75 |
+
np.logical_and(
|
76 |
+
overlap_matrix >= self.min_overlap_ratio,
|
77 |
+
overlap_matrix <= self.max_overlap_ratio
|
78 |
+
),
|
79 |
+
scale_ratio_matrix <= self.max_scale_ratio
|
80 |
+
)
|
81 |
+
|
82 |
+
pairs = np.vstack(np.where(valid))
|
83 |
+
try:
|
84 |
+
selected_ids = np.random.choice(
|
85 |
+
pairs.shape[1], self.pairs_per_scene
|
86 |
+
)
|
87 |
+
except:
|
88 |
+
continue
|
89 |
+
|
90 |
+
image_paths = scene_info['image_paths']
|
91 |
+
depth_paths = scene_info['depth_paths']
|
92 |
+
points3D_id_to_2D = scene_info['points3D_id_to_2D']
|
93 |
+
points3D_id_to_ndepth = scene_info['points3D_id_to_ndepth']
|
94 |
+
intrinsics = scene_info['intrinsics']
|
95 |
+
poses = scene_info['poses']
|
96 |
+
|
97 |
+
for pair_idx in selected_ids:
|
98 |
+
idx1 = pairs[0, pair_idx]
|
99 |
+
idx2 = pairs[1, pair_idx]
|
100 |
+
matches = np.array(list(
|
101 |
+
points3D_id_to_2D[idx1].keys() &
|
102 |
+
points3D_id_to_2D[idx2].keys()
|
103 |
+
))
|
104 |
+
|
105 |
+
# Scale filtering
|
106 |
+
matches_nd1 = np.array([points3D_id_to_ndepth[idx1][match] for match in matches])
|
107 |
+
matches_nd2 = np.array([points3D_id_to_ndepth[idx2][match] for match in matches])
|
108 |
+
scale_ratio = np.maximum(matches_nd1 / matches_nd2, matches_nd2 / matches_nd1)
|
109 |
+
matches = matches[np.where(scale_ratio <= self.max_scale_ratio)[0]]
|
110 |
+
|
111 |
+
point3D_id = np.random.choice(matches)
|
112 |
+
point2D1 = points3D_id_to_2D[idx1][point3D_id]
|
113 |
+
point2D2 = points3D_id_to_2D[idx2][point3D_id]
|
114 |
+
nd1 = points3D_id_to_ndepth[idx1][point3D_id]
|
115 |
+
nd2 = points3D_id_to_ndepth[idx2][point3D_id]
|
116 |
+
central_match = np.array([
|
117 |
+
point2D1[1], point2D1[0],
|
118 |
+
point2D2[1], point2D2[0]
|
119 |
+
])
|
120 |
+
self.dataset.append({
|
121 |
+
'image_path1': image_paths[idx1],
|
122 |
+
'depth_path1': depth_paths[idx1],
|
123 |
+
'intrinsics1': intrinsics[idx1],
|
124 |
+
'pose1': poses[idx1],
|
125 |
+
'image_path2': image_paths[idx2],
|
126 |
+
'depth_path2': depth_paths[idx2],
|
127 |
+
'intrinsics2': intrinsics[idx2],
|
128 |
+
'pose2': poses[idx2],
|
129 |
+
'central_match': central_match,
|
130 |
+
'scale_ratio': max(nd1 / nd2, nd2 / nd1)
|
131 |
+
})
|
132 |
+
np.random.shuffle(self.dataset)
|
133 |
+
if not self.train:
|
134 |
+
np.random.set_state(np_random_state)
|
135 |
+
|
136 |
+
def __len__(self):
|
137 |
+
return len(self.dataset)
|
138 |
+
|
139 |
+
def recover_pair(self, pair_metadata):
|
140 |
+
depth_path1 = os.path.join(
|
141 |
+
self.base_path, pair_metadata['depth_path1']
|
142 |
+
)
|
143 |
+
with h5py.File(depth_path1, 'r') as hdf5_file:
|
144 |
+
depth1 = np.array(hdf5_file['/depth'])
|
145 |
+
assert(np.min(depth1) >= 0)
|
146 |
+
image_path1 = os.path.join(
|
147 |
+
self.base_path, pair_metadata['image_path1']
|
148 |
+
)
|
149 |
+
image1 = Image.open(image_path1)
|
150 |
+
if image1.mode != 'RGB':
|
151 |
+
image1 = image1.convert('RGB')
|
152 |
+
image1 = np.array(image1)
|
153 |
+
assert(image1.shape[0] == depth1.shape[0] and image1.shape[1] == depth1.shape[1])
|
154 |
+
intrinsics1 = pair_metadata['intrinsics1']
|
155 |
+
pose1 = pair_metadata['pose1']
|
156 |
+
|
157 |
+
depth_path2 = os.path.join(
|
158 |
+
self.base_path, pair_metadata['depth_path2']
|
159 |
+
)
|
160 |
+
with h5py.File(depth_path2, 'r') as hdf5_file:
|
161 |
+
depth2 = np.array(hdf5_file['/depth'])
|
162 |
+
assert(np.min(depth2) >= 0)
|
163 |
+
image_path2 = os.path.join(
|
164 |
+
self.base_path, pair_metadata['image_path2']
|
165 |
+
)
|
166 |
+
image2 = Image.open(image_path2)
|
167 |
+
if image2.mode != 'RGB':
|
168 |
+
image2 = image2.convert('RGB')
|
169 |
+
image2 = np.array(image2)
|
170 |
+
assert(image2.shape[0] == depth2.shape[0] and image2.shape[1] == depth2.shape[1])
|
171 |
+
intrinsics2 = pair_metadata['intrinsics2']
|
172 |
+
pose2 = pair_metadata['pose2']
|
173 |
+
|
174 |
+
central_match = pair_metadata['central_match']
|
175 |
+
image1, bbox1, image2, bbox2 = self.crop(image1, image2, central_match)
|
176 |
+
|
177 |
+
depth1 = depth1[
|
178 |
+
bbox1[0] : bbox1[0] + self.image_size,
|
179 |
+
bbox1[1] : bbox1[1] + self.image_size
|
180 |
+
]
|
181 |
+
depth2 = depth2[
|
182 |
+
bbox2[0] : bbox2[0] + self.image_size,
|
183 |
+
bbox2[1] : bbox2[1] + self.image_size
|
184 |
+
]
|
185 |
+
|
186 |
+
return (
|
187 |
+
image1, depth1, intrinsics1, pose1, bbox1,
|
188 |
+
image2, depth2, intrinsics2, pose2, bbox2
|
189 |
+
)
|
190 |
+
|
191 |
+
def crop(self, image1, image2, central_match):
|
192 |
+
bbox1_i = max(int(central_match[0]) - self.image_size // 2, 0)
|
193 |
+
if bbox1_i + self.image_size >= image1.shape[0]:
|
194 |
+
bbox1_i = image1.shape[0] - self.image_size
|
195 |
+
bbox1_j = max(int(central_match[1]) - self.image_size // 2, 0)
|
196 |
+
if bbox1_j + self.image_size >= image1.shape[1]:
|
197 |
+
bbox1_j = image1.shape[1] - self.image_size
|
198 |
+
|
199 |
+
bbox2_i = max(int(central_match[2]) - self.image_size // 2, 0)
|
200 |
+
if bbox2_i + self.image_size >= image2.shape[0]:
|
201 |
+
bbox2_i = image2.shape[0] - self.image_size
|
202 |
+
bbox2_j = max(int(central_match[3]) - self.image_size // 2, 0)
|
203 |
+
if bbox2_j + self.image_size >= image2.shape[1]:
|
204 |
+
bbox2_j = image2.shape[1] - self.image_size
|
205 |
+
|
206 |
+
return (
|
207 |
+
image1[
|
208 |
+
bbox1_i : bbox1_i + self.image_size,
|
209 |
+
bbox1_j : bbox1_j + self.image_size
|
210 |
+
],
|
211 |
+
np.array([bbox1_i, bbox1_j]),
|
212 |
+
image2[
|
213 |
+
bbox2_i : bbox2_i + self.image_size,
|
214 |
+
bbox2_j : bbox2_j + self.image_size
|
215 |
+
],
|
216 |
+
np.array([bbox2_i, bbox2_j])
|
217 |
+
)
|
218 |
+
|
219 |
+
def __getitem__(self, idx):
|
220 |
+
(
|
221 |
+
image1, depth1, intrinsics1, pose1, bbox1,
|
222 |
+
image2, depth2, intrinsics2, pose2, bbox2
|
223 |
+
) = self.recover_pair(self.dataset[idx])
|
224 |
+
|
225 |
+
image1 = preprocess_image(image1, preprocessing=self.preprocessing)
|
226 |
+
image2 = preprocess_image(image2, preprocessing=self.preprocessing)
|
227 |
+
|
228 |
+
return {
|
229 |
+
'image1': torch.from_numpy(image1.astype(np.float32)),
|
230 |
+
'depth1': torch.from_numpy(depth1.astype(np.float32)),
|
231 |
+
'intrinsics1': torch.from_numpy(intrinsics1.astype(np.float32)),
|
232 |
+
'pose1': torch.from_numpy(pose1.astype(np.float32)),
|
233 |
+
'bbox1': torch.from_numpy(bbox1.astype(np.float32)),
|
234 |
+
'image2': torch.from_numpy(image2.astype(np.float32)),
|
235 |
+
'depth2': torch.from_numpy(depth2.astype(np.float32)),
|
236 |
+
'intrinsics2': torch.from_numpy(intrinsics2.astype(np.float32)),
|
237 |
+
'pose2': torch.from_numpy(pose2.astype(np.float32)),
|
238 |
+
'bbox2': torch.from_numpy(bbox2.astype(np.float32))
|
239 |
+
}
|
third_party/d2net/lib/exceptions.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class EmptyTensorError(Exception):
|
2 |
+
pass
|
3 |
+
|
4 |
+
|
5 |
+
class NoGradientError(Exception):
|
6 |
+
pass
|
third_party/d2net/lib/loss.py
ADDED
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from lib.utils import (
|
10 |
+
grid_positions,
|
11 |
+
upscale_positions,
|
12 |
+
downscale_positions,
|
13 |
+
savefig,
|
14 |
+
imshow_image
|
15 |
+
)
|
16 |
+
from lib.exceptions import NoGradientError, EmptyTensorError
|
17 |
+
|
18 |
+
matplotlib.use('Agg')
|
19 |
+
|
20 |
+
|
21 |
+
def loss_function(
|
22 |
+
model, batch, device, margin=1, safe_radius=4, scaling_steps=3, plot=False
|
23 |
+
):
|
24 |
+
output = model({
|
25 |
+
'image1': batch['image1'].to(device),
|
26 |
+
'image2': batch['image2'].to(device)
|
27 |
+
})
|
28 |
+
|
29 |
+
loss = torch.tensor(np.array([0], dtype=np.float32), device=device)
|
30 |
+
has_grad = False
|
31 |
+
|
32 |
+
n_valid_samples = 0
|
33 |
+
for idx_in_batch in range(batch['image1'].size(0)):
|
34 |
+
# Annotations
|
35 |
+
depth1 = batch['depth1'][idx_in_batch].to(device) # [h1, w1]
|
36 |
+
intrinsics1 = batch['intrinsics1'][idx_in_batch].to(device) # [3, 3]
|
37 |
+
pose1 = batch['pose1'][idx_in_batch].view(4, 4).to(device) # [4, 4]
|
38 |
+
bbox1 = batch['bbox1'][idx_in_batch].to(device) # [2]
|
39 |
+
|
40 |
+
depth2 = batch['depth2'][idx_in_batch].to(device)
|
41 |
+
intrinsics2 = batch['intrinsics2'][idx_in_batch].to(device)
|
42 |
+
pose2 = batch['pose2'][idx_in_batch].view(4, 4).to(device)
|
43 |
+
bbox2 = batch['bbox2'][idx_in_batch].to(device)
|
44 |
+
|
45 |
+
# Network output
|
46 |
+
dense_features1 = output['dense_features1'][idx_in_batch]
|
47 |
+
c, h1, w1 = dense_features1.size()
|
48 |
+
scores1 = output['scores1'][idx_in_batch].view(-1)
|
49 |
+
|
50 |
+
dense_features2 = output['dense_features2'][idx_in_batch]
|
51 |
+
_, h2, w2 = dense_features2.size()
|
52 |
+
scores2 = output['scores2'][idx_in_batch]
|
53 |
+
|
54 |
+
all_descriptors1 = F.normalize(dense_features1.view(c, -1), dim=0)
|
55 |
+
descriptors1 = all_descriptors1
|
56 |
+
|
57 |
+
all_descriptors2 = F.normalize(dense_features2.view(c, -1), dim=0)
|
58 |
+
|
59 |
+
# Warp the positions from image 1 to image 2
|
60 |
+
fmap_pos1 = grid_positions(h1, w1, device)
|
61 |
+
pos1 = upscale_positions(fmap_pos1, scaling_steps=scaling_steps)
|
62 |
+
try:
|
63 |
+
pos1, pos2, ids = warp(
|
64 |
+
pos1,
|
65 |
+
depth1, intrinsics1, pose1, bbox1,
|
66 |
+
depth2, intrinsics2, pose2, bbox2
|
67 |
+
)
|
68 |
+
except EmptyTensorError:
|
69 |
+
continue
|
70 |
+
fmap_pos1 = fmap_pos1[:, ids]
|
71 |
+
descriptors1 = descriptors1[:, ids]
|
72 |
+
scores1 = scores1[ids]
|
73 |
+
|
74 |
+
# Skip the pair if not enough GT correspondences are available
|
75 |
+
if ids.size(0) < 128:
|
76 |
+
continue
|
77 |
+
|
78 |
+
# Descriptors at the corresponding positions
|
79 |
+
fmap_pos2 = torch.round(
|
80 |
+
downscale_positions(pos2, scaling_steps=scaling_steps)
|
81 |
+
).long()
|
82 |
+
descriptors2 = F.normalize(
|
83 |
+
dense_features2[:, fmap_pos2[0, :], fmap_pos2[1, :]],
|
84 |
+
dim=0
|
85 |
+
)
|
86 |
+
positive_distance = 2 - 2 * (
|
87 |
+
descriptors1.t().unsqueeze(1) @ descriptors2.t().unsqueeze(2)
|
88 |
+
).squeeze()
|
89 |
+
|
90 |
+
all_fmap_pos2 = grid_positions(h2, w2, device)
|
91 |
+
position_distance = torch.max(
|
92 |
+
torch.abs(
|
93 |
+
fmap_pos2.unsqueeze(2).float() -
|
94 |
+
all_fmap_pos2.unsqueeze(1)
|
95 |
+
),
|
96 |
+
dim=0
|
97 |
+
)[0]
|
98 |
+
is_out_of_safe_radius = position_distance > safe_radius
|
99 |
+
distance_matrix = 2 - 2 * (descriptors1.t() @ all_descriptors2)
|
100 |
+
negative_distance2 = torch.min(
|
101 |
+
distance_matrix + (1 - is_out_of_safe_radius.float()) * 10.,
|
102 |
+
dim=1
|
103 |
+
)[0]
|
104 |
+
|
105 |
+
all_fmap_pos1 = grid_positions(h1, w1, device)
|
106 |
+
position_distance = torch.max(
|
107 |
+
torch.abs(
|
108 |
+
fmap_pos1.unsqueeze(2).float() -
|
109 |
+
all_fmap_pos1.unsqueeze(1)
|
110 |
+
),
|
111 |
+
dim=0
|
112 |
+
)[0]
|
113 |
+
is_out_of_safe_radius = position_distance > safe_radius
|
114 |
+
distance_matrix = 2 - 2 * (descriptors2.t() @ all_descriptors1)
|
115 |
+
negative_distance1 = torch.min(
|
116 |
+
distance_matrix + (1 - is_out_of_safe_radius.float()) * 10.,
|
117 |
+
dim=1
|
118 |
+
)[0]
|
119 |
+
|
120 |
+
diff = positive_distance - torch.min(
|
121 |
+
negative_distance1, negative_distance2
|
122 |
+
)
|
123 |
+
|
124 |
+
scores2 = scores2[fmap_pos2[0, :], fmap_pos2[1, :]]
|
125 |
+
|
126 |
+
loss = loss + (
|
127 |
+
torch.sum(scores1 * scores2 * F.relu(margin + diff)) /
|
128 |
+
torch.sum(scores1 * scores2)
|
129 |
+
)
|
130 |
+
|
131 |
+
has_grad = True
|
132 |
+
n_valid_samples += 1
|
133 |
+
|
134 |
+
if plot and batch['batch_idx'] % batch['log_interval'] == 0:
|
135 |
+
pos1_aux = pos1.cpu().numpy()
|
136 |
+
pos2_aux = pos2.cpu().numpy()
|
137 |
+
k = pos1_aux.shape[1]
|
138 |
+
col = np.random.rand(k, 3)
|
139 |
+
n_sp = 4
|
140 |
+
plt.figure()
|
141 |
+
plt.subplot(1, n_sp, 1)
|
142 |
+
im1 = imshow_image(
|
143 |
+
batch['image1'][idx_in_batch].cpu().numpy(),
|
144 |
+
preprocessing=batch['preprocessing']
|
145 |
+
)
|
146 |
+
plt.imshow(im1)
|
147 |
+
plt.scatter(
|
148 |
+
pos1_aux[1, :], pos1_aux[0, :],
|
149 |
+
s=0.25**2, c=col, marker=',', alpha=0.5
|
150 |
+
)
|
151 |
+
plt.axis('off')
|
152 |
+
plt.subplot(1, n_sp, 2)
|
153 |
+
plt.imshow(
|
154 |
+
output['scores1'][idx_in_batch].data.cpu().numpy(),
|
155 |
+
cmap='Reds'
|
156 |
+
)
|
157 |
+
plt.axis('off')
|
158 |
+
plt.subplot(1, n_sp, 3)
|
159 |
+
im2 = imshow_image(
|
160 |
+
batch['image2'][idx_in_batch].cpu().numpy(),
|
161 |
+
preprocessing=batch['preprocessing']
|
162 |
+
)
|
163 |
+
plt.imshow(im2)
|
164 |
+
plt.scatter(
|
165 |
+
pos2_aux[1, :], pos2_aux[0, :],
|
166 |
+
s=0.25**2, c=col, marker=',', alpha=0.5
|
167 |
+
)
|
168 |
+
plt.axis('off')
|
169 |
+
plt.subplot(1, n_sp, 4)
|
170 |
+
plt.imshow(
|
171 |
+
output['scores2'][idx_in_batch].data.cpu().numpy(),
|
172 |
+
cmap='Reds'
|
173 |
+
)
|
174 |
+
plt.axis('off')
|
175 |
+
savefig('train_vis/%s.%02d.%02d.%d.png' % (
|
176 |
+
'train' if batch['train'] else 'valid',
|
177 |
+
batch['epoch_idx'],
|
178 |
+
batch['batch_idx'] // batch['log_interval'],
|
179 |
+
idx_in_batch
|
180 |
+
), dpi=300)
|
181 |
+
plt.close()
|
182 |
+
|
183 |
+
if not has_grad:
|
184 |
+
raise NoGradientError
|
185 |
+
|
186 |
+
loss = loss / n_valid_samples
|
187 |
+
|
188 |
+
return loss
|
189 |
+
|
190 |
+
|
191 |
+
def interpolate_depth(pos, depth):
|
192 |
+
device = pos.device
|
193 |
+
|
194 |
+
ids = torch.arange(0, pos.size(1), device=device)
|
195 |
+
|
196 |
+
h, w = depth.size()
|
197 |
+
|
198 |
+
i = pos[0, :]
|
199 |
+
j = pos[1, :]
|
200 |
+
|
201 |
+
# Valid corners
|
202 |
+
i_top_left = torch.floor(i).long()
|
203 |
+
j_top_left = torch.floor(j).long()
|
204 |
+
valid_top_left = torch.min(i_top_left >= 0, j_top_left >= 0)
|
205 |
+
|
206 |
+
i_top_right = torch.floor(i).long()
|
207 |
+
j_top_right = torch.ceil(j).long()
|
208 |
+
valid_top_right = torch.min(i_top_right >= 0, j_top_right < w)
|
209 |
+
|
210 |
+
i_bottom_left = torch.ceil(i).long()
|
211 |
+
j_bottom_left = torch.floor(j).long()
|
212 |
+
valid_bottom_left = torch.min(i_bottom_left < h, j_bottom_left >= 0)
|
213 |
+
|
214 |
+
i_bottom_right = torch.ceil(i).long()
|
215 |
+
j_bottom_right = torch.ceil(j).long()
|
216 |
+
valid_bottom_right = torch.min(i_bottom_right < h, j_bottom_right < w)
|
217 |
+
|
218 |
+
valid_corners = torch.min(
|
219 |
+
torch.min(valid_top_left, valid_top_right),
|
220 |
+
torch.min(valid_bottom_left, valid_bottom_right)
|
221 |
+
)
|
222 |
+
|
223 |
+
i_top_left = i_top_left[valid_corners]
|
224 |
+
j_top_left = j_top_left[valid_corners]
|
225 |
+
|
226 |
+
i_top_right = i_top_right[valid_corners]
|
227 |
+
j_top_right = j_top_right[valid_corners]
|
228 |
+
|
229 |
+
i_bottom_left = i_bottom_left[valid_corners]
|
230 |
+
j_bottom_left = j_bottom_left[valid_corners]
|
231 |
+
|
232 |
+
i_bottom_right = i_bottom_right[valid_corners]
|
233 |
+
j_bottom_right = j_bottom_right[valid_corners]
|
234 |
+
|
235 |
+
ids = ids[valid_corners]
|
236 |
+
if ids.size(0) == 0:
|
237 |
+
raise EmptyTensorError
|
238 |
+
|
239 |
+
# Valid depth
|
240 |
+
valid_depth = torch.min(
|
241 |
+
torch.min(
|
242 |
+
depth[i_top_left, j_top_left] > 0,
|
243 |
+
depth[i_top_right, j_top_right] > 0
|
244 |
+
),
|
245 |
+
torch.min(
|
246 |
+
depth[i_bottom_left, j_bottom_left] > 0,
|
247 |
+
depth[i_bottom_right, j_bottom_right] > 0
|
248 |
+
)
|
249 |
+
)
|
250 |
+
|
251 |
+
i_top_left = i_top_left[valid_depth]
|
252 |
+
j_top_left = j_top_left[valid_depth]
|
253 |
+
|
254 |
+
i_top_right = i_top_right[valid_depth]
|
255 |
+
j_top_right = j_top_right[valid_depth]
|
256 |
+
|
257 |
+
i_bottom_left = i_bottom_left[valid_depth]
|
258 |
+
j_bottom_left = j_bottom_left[valid_depth]
|
259 |
+
|
260 |
+
i_bottom_right = i_bottom_right[valid_depth]
|
261 |
+
j_bottom_right = j_bottom_right[valid_depth]
|
262 |
+
|
263 |
+
ids = ids[valid_depth]
|
264 |
+
if ids.size(0) == 0:
|
265 |
+
raise EmptyTensorError
|
266 |
+
|
267 |
+
# Interpolation
|
268 |
+
i = i[ids]
|
269 |
+
j = j[ids]
|
270 |
+
dist_i_top_left = i - i_top_left.float()
|
271 |
+
dist_j_top_left = j - j_top_left.float()
|
272 |
+
w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left)
|
273 |
+
w_top_right = (1 - dist_i_top_left) * dist_j_top_left
|
274 |
+
w_bottom_left = dist_i_top_left * (1 - dist_j_top_left)
|
275 |
+
w_bottom_right = dist_i_top_left * dist_j_top_left
|
276 |
+
|
277 |
+
interpolated_depth = (
|
278 |
+
w_top_left * depth[i_top_left, j_top_left] +
|
279 |
+
w_top_right * depth[i_top_right, j_top_right] +
|
280 |
+
w_bottom_left * depth[i_bottom_left, j_bottom_left] +
|
281 |
+
w_bottom_right * depth[i_bottom_right, j_bottom_right]
|
282 |
+
)
|
283 |
+
|
284 |
+
pos = torch.cat([i.view(1, -1), j.view(1, -1)], dim=0)
|
285 |
+
|
286 |
+
return [interpolated_depth, pos, ids]
|
287 |
+
|
288 |
+
|
289 |
+
def uv_to_pos(uv):
|
290 |
+
return torch.cat([uv[1, :].view(1, -1), uv[0, :].view(1, -1)], dim=0)
|
291 |
+
|
292 |
+
|
293 |
+
def warp(
|
294 |
+
pos1,
|
295 |
+
depth1, intrinsics1, pose1, bbox1,
|
296 |
+
depth2, intrinsics2, pose2, bbox2
|
297 |
+
):
|
298 |
+
device = pos1.device
|
299 |
+
|
300 |
+
Z1, pos1, ids = interpolate_depth(pos1, depth1)
|
301 |
+
|
302 |
+
# COLMAP convention
|
303 |
+
u1 = pos1[1, :] + bbox1[1] + .5
|
304 |
+
v1 = pos1[0, :] + bbox1[0] + .5
|
305 |
+
|
306 |
+
X1 = (u1 - intrinsics1[0, 2]) * (Z1 / intrinsics1[0, 0])
|
307 |
+
Y1 = (v1 - intrinsics1[1, 2]) * (Z1 / intrinsics1[1, 1])
|
308 |
+
|
309 |
+
XYZ1_hom = torch.cat([
|
310 |
+
X1.view(1, -1),
|
311 |
+
Y1.view(1, -1),
|
312 |
+
Z1.view(1, -1),
|
313 |
+
torch.ones(1, Z1.size(0), device=device)
|
314 |
+
], dim=0)
|
315 |
+
XYZ2_hom = torch.chain_matmul(pose2, torch.inverse(pose1), XYZ1_hom)
|
316 |
+
XYZ2 = XYZ2_hom[: -1, :] / XYZ2_hom[-1, :].view(1, -1)
|
317 |
+
|
318 |
+
uv2_hom = torch.matmul(intrinsics2, XYZ2)
|
319 |
+
uv2 = uv2_hom[: -1, :] / uv2_hom[-1, :].view(1, -1)
|
320 |
+
|
321 |
+
u2 = uv2[0, :] - bbox2[1] - .5
|
322 |
+
v2 = uv2[1, :] - bbox2[0] - .5
|
323 |
+
uv2 = torch.cat([u2.view(1, -1), v2.view(1, -1)], dim=0)
|
324 |
+
|
325 |
+
annotated_depth, pos2, new_ids = interpolate_depth(uv_to_pos(uv2), depth2)
|
326 |
+
|
327 |
+
ids = ids[new_ids]
|
328 |
+
pos1 = pos1[:, new_ids]
|
329 |
+
estimated_depth = XYZ2[2, new_ids]
|
330 |
+
|
331 |
+
inlier_mask = torch.abs(estimated_depth - annotated_depth) < 0.05
|
332 |
+
|
333 |
+
ids = ids[inlier_mask]
|
334 |
+
if ids.size(0) == 0:
|
335 |
+
raise EmptyTensorError
|
336 |
+
|
337 |
+
pos2 = pos2[:, inlier_mask]
|
338 |
+
pos1 = pos1[:, inlier_mask]
|
339 |
+
|
340 |
+
return pos1, pos2, ids
|
third_party/d2net/lib/model.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
import torchvision.models as models
|
6 |
+
|
7 |
+
|
8 |
+
class DenseFeatureExtractionModule(nn.Module):
|
9 |
+
def __init__(self, finetune_feature_extraction=False, use_cuda=True):
|
10 |
+
super(DenseFeatureExtractionModule, self).__init__()
|
11 |
+
|
12 |
+
model = models.vgg16()
|
13 |
+
vgg16_layers = [
|
14 |
+
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2',
|
15 |
+
'pool1',
|
16 |
+
'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2',
|
17 |
+
'pool2',
|
18 |
+
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3',
|
19 |
+
'pool3',
|
20 |
+
'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3',
|
21 |
+
'pool4',
|
22 |
+
'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
|
23 |
+
'pool5'
|
24 |
+
]
|
25 |
+
conv4_3_idx = vgg16_layers.index('conv4_3')
|
26 |
+
|
27 |
+
self.model = nn.Sequential(
|
28 |
+
*list(model.features.children())[: conv4_3_idx + 1]
|
29 |
+
)
|
30 |
+
|
31 |
+
self.num_channels = 512
|
32 |
+
|
33 |
+
# Fix forward parameters
|
34 |
+
for param in self.model.parameters():
|
35 |
+
param.requires_grad = False
|
36 |
+
if finetune_feature_extraction:
|
37 |
+
# Unlock conv4_3
|
38 |
+
for param in list(self.model.parameters())[-2 :]:
|
39 |
+
param.requires_grad = True
|
40 |
+
|
41 |
+
if use_cuda:
|
42 |
+
self.model = self.model.cuda()
|
43 |
+
|
44 |
+
def forward(self, batch):
|
45 |
+
output = self.model(batch)
|
46 |
+
return output
|
47 |
+
|
48 |
+
|
49 |
+
class SoftDetectionModule(nn.Module):
|
50 |
+
def __init__(self, soft_local_max_size=3):
|
51 |
+
super(SoftDetectionModule, self).__init__()
|
52 |
+
|
53 |
+
self.soft_local_max_size = soft_local_max_size
|
54 |
+
|
55 |
+
self.pad = self.soft_local_max_size // 2
|
56 |
+
|
57 |
+
def forward(self, batch):
|
58 |
+
b = batch.size(0)
|
59 |
+
|
60 |
+
batch = F.relu(batch)
|
61 |
+
|
62 |
+
max_per_sample = torch.max(batch.view(b, -1), dim=1)[0]
|
63 |
+
exp = torch.exp(batch / max_per_sample.view(b, 1, 1, 1))
|
64 |
+
sum_exp = (
|
65 |
+
self.soft_local_max_size ** 2 *
|
66 |
+
F.avg_pool2d(
|
67 |
+
F.pad(exp, [self.pad] * 4, mode='constant', value=1.),
|
68 |
+
self.soft_local_max_size, stride=1
|
69 |
+
)
|
70 |
+
)
|
71 |
+
local_max_score = exp / sum_exp
|
72 |
+
|
73 |
+
depth_wise_max = torch.max(batch, dim=1)[0]
|
74 |
+
depth_wise_max_score = batch / depth_wise_max.unsqueeze(1)
|
75 |
+
|
76 |
+
all_scores = local_max_score * depth_wise_max_score
|
77 |
+
score = torch.max(all_scores, dim=1)[0]
|
78 |
+
|
79 |
+
score = score / torch.sum(score.view(b, -1), dim=1).view(b, 1, 1)
|
80 |
+
|
81 |
+
return score
|
82 |
+
|
83 |
+
|
84 |
+
class D2Net(nn.Module):
|
85 |
+
def __init__(self, model_file=None, use_cuda=True):
|
86 |
+
super(D2Net, self).__init__()
|
87 |
+
|
88 |
+
self.dense_feature_extraction = DenseFeatureExtractionModule(
|
89 |
+
finetune_feature_extraction=True,
|
90 |
+
use_cuda=use_cuda
|
91 |
+
)
|
92 |
+
|
93 |
+
self.detection = SoftDetectionModule()
|
94 |
+
|
95 |
+
if model_file is not None:
|
96 |
+
if use_cuda:
|
97 |
+
self.load_state_dict(torch.load(model_file)['model'])
|
98 |
+
else:
|
99 |
+
self.load_state_dict(torch.load(model_file, map_location='cpu')['model'])
|
100 |
+
|
101 |
+
def forward(self, batch):
|
102 |
+
b = batch['image1'].size(0)
|
103 |
+
|
104 |
+
dense_features = self.dense_feature_extraction(
|
105 |
+
torch.cat([batch['image1'], batch['image2']], dim=0)
|
106 |
+
)
|
107 |
+
|
108 |
+
scores = self.detection(dense_features)
|
109 |
+
|
110 |
+
dense_features1 = dense_features[: b, :, :, :]
|
111 |
+
dense_features2 = dense_features[b :, :, :, :]
|
112 |
+
|
113 |
+
scores1 = scores[: b, :, :]
|
114 |
+
scores2 = scores[b :, :, :]
|
115 |
+
|
116 |
+
return {
|
117 |
+
'dense_features1': dense_features1,
|
118 |
+
'scores1': scores1,
|
119 |
+
'dense_features2': dense_features2,
|
120 |
+
'scores2': scores2
|
121 |
+
}
|
third_party/d2net/lib/model_test.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class DenseFeatureExtractionModule(nn.Module):
|
7 |
+
def __init__(self, use_relu=True, use_cuda=True):
|
8 |
+
super(DenseFeatureExtractionModule, self).__init__()
|
9 |
+
|
10 |
+
self.model = nn.Sequential(
|
11 |
+
nn.Conv2d(3, 64, 3, padding=1),
|
12 |
+
nn.ReLU(inplace=True),
|
13 |
+
nn.Conv2d(64, 64, 3, padding=1),
|
14 |
+
nn.ReLU(inplace=True),
|
15 |
+
nn.MaxPool2d(2, stride=2),
|
16 |
+
nn.Conv2d(64, 128, 3, padding=1),
|
17 |
+
nn.ReLU(inplace=True),
|
18 |
+
nn.Conv2d(128, 128, 3, padding=1),
|
19 |
+
nn.ReLU(inplace=True),
|
20 |
+
nn.MaxPool2d(2, stride=2),
|
21 |
+
nn.Conv2d(128, 256, 3, padding=1),
|
22 |
+
nn.ReLU(inplace=True),
|
23 |
+
nn.Conv2d(256, 256, 3, padding=1),
|
24 |
+
nn.ReLU(inplace=True),
|
25 |
+
nn.Conv2d(256, 256, 3, padding=1),
|
26 |
+
nn.ReLU(inplace=True),
|
27 |
+
nn.AvgPool2d(2, stride=1),
|
28 |
+
nn.Conv2d(256, 512, 3, padding=2, dilation=2),
|
29 |
+
nn.ReLU(inplace=True),
|
30 |
+
nn.Conv2d(512, 512, 3, padding=2, dilation=2),
|
31 |
+
nn.ReLU(inplace=True),
|
32 |
+
nn.Conv2d(512, 512, 3, padding=2, dilation=2),
|
33 |
+
)
|
34 |
+
self.num_channels = 512
|
35 |
+
|
36 |
+
self.use_relu = use_relu
|
37 |
+
|
38 |
+
if use_cuda:
|
39 |
+
self.model = self.model.cuda()
|
40 |
+
|
41 |
+
def forward(self, batch):
|
42 |
+
output = self.model(batch)
|
43 |
+
if self.use_relu:
|
44 |
+
output = F.relu(output)
|
45 |
+
return output
|
46 |
+
|
47 |
+
|
48 |
+
class D2Net(nn.Module):
|
49 |
+
def __init__(self, model_file=None, use_relu=True, use_cuda=True):
|
50 |
+
super(D2Net, self).__init__()
|
51 |
+
|
52 |
+
self.dense_feature_extraction = DenseFeatureExtractionModule(
|
53 |
+
use_relu=use_relu, use_cuda=use_cuda
|
54 |
+
)
|
55 |
+
|
56 |
+
self.detection = HardDetectionModule()
|
57 |
+
|
58 |
+
self.localization = HandcraftedLocalizationModule()
|
59 |
+
|
60 |
+
if model_file is not None:
|
61 |
+
if use_cuda:
|
62 |
+
self.load_state_dict(torch.load(model_file)['model'])
|
63 |
+
else:
|
64 |
+
self.load_state_dict(torch.load(model_file, map_location='cpu')['model'])
|
65 |
+
|
66 |
+
def forward(self, batch):
|
67 |
+
_, _, h, w = batch.size()
|
68 |
+
dense_features = self.dense_feature_extraction(batch)
|
69 |
+
|
70 |
+
detections = self.detection(dense_features)
|
71 |
+
|
72 |
+
displacements = self.localization(dense_features)
|
73 |
+
|
74 |
+
return {
|
75 |
+
'dense_features': dense_features,
|
76 |
+
'detections': detections,
|
77 |
+
'displacements': displacements
|
78 |
+
}
|
79 |
+
|
80 |
+
|
81 |
+
class HardDetectionModule(nn.Module):
|
82 |
+
def __init__(self, edge_threshold=5):
|
83 |
+
super(HardDetectionModule, self).__init__()
|
84 |
+
|
85 |
+
self.edge_threshold = edge_threshold
|
86 |
+
|
87 |
+
self.dii_filter = torch.tensor(
|
88 |
+
[[0, 1., 0], [0, -2., 0], [0, 1., 0]]
|
89 |
+
).view(1, 1, 3, 3)
|
90 |
+
self.dij_filter = 0.25 * torch.tensor(
|
91 |
+
[[1., 0, -1.], [0, 0., 0], [-1., 0, 1.]]
|
92 |
+
).view(1, 1, 3, 3)
|
93 |
+
self.djj_filter = torch.tensor(
|
94 |
+
[[0, 0, 0], [1., -2., 1.], [0, 0, 0]]
|
95 |
+
).view(1, 1, 3, 3)
|
96 |
+
|
97 |
+
def forward(self, batch):
|
98 |
+
b, c, h, w = batch.size()
|
99 |
+
device = batch.device
|
100 |
+
|
101 |
+
depth_wise_max = torch.max(batch, dim=1)[0]
|
102 |
+
is_depth_wise_max = (batch == depth_wise_max)
|
103 |
+
del depth_wise_max
|
104 |
+
|
105 |
+
local_max = F.max_pool2d(batch, 3, stride=1, padding=1)
|
106 |
+
is_local_max = (batch == local_max)
|
107 |
+
del local_max
|
108 |
+
|
109 |
+
dii = F.conv2d(
|
110 |
+
batch.view(-1, 1, h, w), self.dii_filter.to(device), padding=1
|
111 |
+
).view(b, c, h, w)
|
112 |
+
dij = F.conv2d(
|
113 |
+
batch.view(-1, 1, h, w), self.dij_filter.to(device), padding=1
|
114 |
+
).view(b, c, h, w)
|
115 |
+
djj = F.conv2d(
|
116 |
+
batch.view(-1, 1, h, w), self.djj_filter.to(device), padding=1
|
117 |
+
).view(b, c, h, w)
|
118 |
+
|
119 |
+
det = dii * djj - dij * dij
|
120 |
+
tr = dii + djj
|
121 |
+
del dii, dij, djj
|
122 |
+
|
123 |
+
threshold = (self.edge_threshold + 1) ** 2 / self.edge_threshold
|
124 |
+
is_not_edge = torch.min(tr * tr / det <= threshold, det > 0)
|
125 |
+
|
126 |
+
detected = torch.min(
|
127 |
+
is_depth_wise_max,
|
128 |
+
torch.min(is_local_max, is_not_edge)
|
129 |
+
)
|
130 |
+
del is_depth_wise_max, is_local_max, is_not_edge
|
131 |
+
|
132 |
+
return detected
|
133 |
+
|
134 |
+
|
135 |
+
class HandcraftedLocalizationModule(nn.Module):
|
136 |
+
def __init__(self):
|
137 |
+
super(HandcraftedLocalizationModule, self).__init__()
|
138 |
+
|
139 |
+
self.di_filter = torch.tensor(
|
140 |
+
[[0, -0.5, 0], [0, 0, 0], [0, 0.5, 0]]
|
141 |
+
).view(1, 1, 3, 3)
|
142 |
+
self.dj_filter = torch.tensor(
|
143 |
+
[[0, 0, 0], [-0.5, 0, 0.5], [0, 0, 0]]
|
144 |
+
).view(1, 1, 3, 3)
|
145 |
+
|
146 |
+
self.dii_filter = torch.tensor(
|
147 |
+
[[0, 1., 0], [0, -2., 0], [0, 1., 0]]
|
148 |
+
).view(1, 1, 3, 3)
|
149 |
+
self.dij_filter = 0.25 * torch.tensor(
|
150 |
+
[[1., 0, -1.], [0, 0., 0], [-1., 0, 1.]]
|
151 |
+
).view(1, 1, 3, 3)
|
152 |
+
self.djj_filter = torch.tensor(
|
153 |
+
[[0, 0, 0], [1., -2., 1.], [0, 0, 0]]
|
154 |
+
).view(1, 1, 3, 3)
|
155 |
+
|
156 |
+
def forward(self, batch):
|
157 |
+
b, c, h, w = batch.size()
|
158 |
+
device = batch.device
|
159 |
+
|
160 |
+
dii = F.conv2d(
|
161 |
+
batch.view(-1, 1, h, w), self.dii_filter.to(device), padding=1
|
162 |
+
).view(b, c, h, w)
|
163 |
+
dij = F.conv2d(
|
164 |
+
batch.view(-1, 1, h, w), self.dij_filter.to(device), padding=1
|
165 |
+
).view(b, c, h, w)
|
166 |
+
djj = F.conv2d(
|
167 |
+
batch.view(-1, 1, h, w), self.djj_filter.to(device), padding=1
|
168 |
+
).view(b, c, h, w)
|
169 |
+
det = dii * djj - dij * dij
|
170 |
+
|
171 |
+
inv_hess_00 = djj / det
|
172 |
+
inv_hess_01 = -dij / det
|
173 |
+
inv_hess_11 = dii / det
|
174 |
+
del dii, dij, djj, det
|
175 |
+
|
176 |
+
di = F.conv2d(
|
177 |
+
batch.view(-1, 1, h, w), self.di_filter.to(device), padding=1
|
178 |
+
).view(b, c, h, w)
|
179 |
+
dj = F.conv2d(
|
180 |
+
batch.view(-1, 1, h, w), self.dj_filter.to(device), padding=1
|
181 |
+
).view(b, c, h, w)
|
182 |
+
|
183 |
+
step_i = -(inv_hess_00 * di + inv_hess_01 * dj)
|
184 |
+
step_j = -(inv_hess_01 * di + inv_hess_11 * dj)
|
185 |
+
del inv_hess_00, inv_hess_01, inv_hess_11, di, dj
|
186 |
+
|
187 |
+
return torch.stack([step_i, step_j], dim=1)
|
third_party/d2net/lib/pyramid.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from lib.exceptions import EmptyTensorError
|
6 |
+
from lib.utils import interpolate_dense_features, upscale_positions
|
7 |
+
|
8 |
+
|
9 |
+
def process_multiscale(image, model, scales=[.5, 1, 2]):
|
10 |
+
b, _, h_init, w_init = image.size()
|
11 |
+
device = image.device
|
12 |
+
assert(b == 1)
|
13 |
+
|
14 |
+
all_keypoints = torch.zeros([3, 0])
|
15 |
+
all_descriptors = torch.zeros([
|
16 |
+
model.dense_feature_extraction.num_channels, 0
|
17 |
+
])
|
18 |
+
all_scores = torch.zeros(0)
|
19 |
+
|
20 |
+
previous_dense_features = None
|
21 |
+
banned = None
|
22 |
+
for idx, scale in enumerate(scales):
|
23 |
+
current_image = F.interpolate(
|
24 |
+
image, scale_factor=scale,
|
25 |
+
mode='bilinear', align_corners=True
|
26 |
+
)
|
27 |
+
_, _, h_level, w_level = current_image.size()
|
28 |
+
|
29 |
+
dense_features = model.dense_feature_extraction(current_image)
|
30 |
+
del current_image
|
31 |
+
|
32 |
+
_, _, h, w = dense_features.size()
|
33 |
+
|
34 |
+
# Sum the feature maps.
|
35 |
+
if previous_dense_features is not None:
|
36 |
+
dense_features += F.interpolate(
|
37 |
+
previous_dense_features, size=[h, w],
|
38 |
+
mode='bilinear', align_corners=True
|
39 |
+
)
|
40 |
+
del previous_dense_features
|
41 |
+
|
42 |
+
# Recover detections.
|
43 |
+
detections = model.detection(dense_features)
|
44 |
+
if banned is not None:
|
45 |
+
banned = F.interpolate(banned.float(), size=[h, w]).bool()
|
46 |
+
detections = torch.min(detections, ~banned)
|
47 |
+
banned = torch.max(
|
48 |
+
torch.max(detections, dim=1)[0].unsqueeze(1), banned
|
49 |
+
)
|
50 |
+
else:
|
51 |
+
banned = torch.max(detections, dim=1)[0].unsqueeze(1)
|
52 |
+
fmap_pos = torch.nonzero(detections[0].cpu()).t()
|
53 |
+
del detections
|
54 |
+
|
55 |
+
# Recover displacements.
|
56 |
+
displacements = model.localization(dense_features)[0].cpu()
|
57 |
+
displacements_i = displacements[
|
58 |
+
0, fmap_pos[0, :], fmap_pos[1, :], fmap_pos[2, :]
|
59 |
+
]
|
60 |
+
displacements_j = displacements[
|
61 |
+
1, fmap_pos[0, :], fmap_pos[1, :], fmap_pos[2, :]
|
62 |
+
]
|
63 |
+
del displacements
|
64 |
+
|
65 |
+
mask = torch.min(
|
66 |
+
torch.abs(displacements_i) < 0.5,
|
67 |
+
torch.abs(displacements_j) < 0.5
|
68 |
+
)
|
69 |
+
fmap_pos = fmap_pos[:, mask]
|
70 |
+
valid_displacements = torch.stack([
|
71 |
+
displacements_i[mask],
|
72 |
+
displacements_j[mask]
|
73 |
+
], dim=0)
|
74 |
+
del mask, displacements_i, displacements_j
|
75 |
+
|
76 |
+
fmap_keypoints = fmap_pos[1 :, :].float() + valid_displacements
|
77 |
+
del valid_displacements
|
78 |
+
|
79 |
+
try:
|
80 |
+
raw_descriptors, _, ids = interpolate_dense_features(
|
81 |
+
fmap_keypoints.to(device),
|
82 |
+
dense_features[0]
|
83 |
+
)
|
84 |
+
except EmptyTensorError:
|
85 |
+
continue
|
86 |
+
fmap_pos = fmap_pos.to(device)
|
87 |
+
fmap_keypoints = fmap_keypoints.to(device)
|
88 |
+
fmap_pos = fmap_pos[:, ids]
|
89 |
+
fmap_keypoints = fmap_keypoints[:, ids]
|
90 |
+
del ids
|
91 |
+
|
92 |
+
keypoints = upscale_positions(fmap_keypoints, scaling_steps=2)
|
93 |
+
del fmap_keypoints
|
94 |
+
|
95 |
+
descriptors = F.normalize(raw_descriptors, dim=0).cpu()
|
96 |
+
del raw_descriptors
|
97 |
+
|
98 |
+
keypoints[0, :] *= h_init / h_level
|
99 |
+
keypoints[1, :] *= w_init / w_level
|
100 |
+
|
101 |
+
fmap_pos = fmap_pos.cpu()
|
102 |
+
keypoints = keypoints.cpu()
|
103 |
+
|
104 |
+
keypoints = torch.cat([
|
105 |
+
keypoints,
|
106 |
+
torch.ones([1, keypoints.size(1)]) * 1 / scale,
|
107 |
+
], dim=0)
|
108 |
+
|
109 |
+
scores = dense_features[
|
110 |
+
0, fmap_pos[0, :], fmap_pos[1, :], fmap_pos[2, :]
|
111 |
+
].cpu() / (idx + 1)
|
112 |
+
del fmap_pos
|
113 |
+
|
114 |
+
all_keypoints = torch.cat([all_keypoints, keypoints], dim=1)
|
115 |
+
all_descriptors = torch.cat([all_descriptors, descriptors], dim=1)
|
116 |
+
all_scores = torch.cat([all_scores, scores], dim=0)
|
117 |
+
del keypoints, descriptors
|
118 |
+
|
119 |
+
previous_dense_features = dense_features
|
120 |
+
del dense_features
|
121 |
+
del previous_dense_features, banned
|
122 |
+
|
123 |
+
keypoints = all_keypoints.t().detach().numpy()
|
124 |
+
del all_keypoints
|
125 |
+
scores = all_scores.detach().numpy()
|
126 |
+
del all_scores
|
127 |
+
descriptors = all_descriptors.t().detach().numpy()
|
128 |
+
del all_descriptors
|
129 |
+
return keypoints, scores, descriptors
|
third_party/d2net/lib/utils.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from lib.exceptions import EmptyTensorError
|
8 |
+
|
9 |
+
|
10 |
+
def preprocess_image(image, preprocessing=None):
|
11 |
+
image = image.astype(np.float32)
|
12 |
+
image = np.transpose(image, [2, 0, 1])
|
13 |
+
if preprocessing is None:
|
14 |
+
pass
|
15 |
+
elif preprocessing == 'caffe':
|
16 |
+
# RGB -> BGR
|
17 |
+
image = image[:: -1, :, :]
|
18 |
+
# Zero-center by mean pixel
|
19 |
+
mean = np.array([103.939, 116.779, 123.68])
|
20 |
+
image = image - mean.reshape([3, 1, 1])
|
21 |
+
elif preprocessing == 'torch':
|
22 |
+
image /= 255.0
|
23 |
+
mean = np.array([0.485, 0.456, 0.406])
|
24 |
+
std = np.array([0.229, 0.224, 0.225])
|
25 |
+
image = (image - mean.reshape([3, 1, 1])) / std.reshape([3, 1, 1])
|
26 |
+
else:
|
27 |
+
raise ValueError('Unknown preprocessing parameter.')
|
28 |
+
return image
|
29 |
+
|
30 |
+
|
31 |
+
def imshow_image(image, preprocessing=None):
|
32 |
+
if preprocessing is None:
|
33 |
+
pass
|
34 |
+
elif preprocessing == 'caffe':
|
35 |
+
mean = np.array([103.939, 116.779, 123.68])
|
36 |
+
image = image + mean.reshape([3, 1, 1])
|
37 |
+
# RGB -> BGR
|
38 |
+
image = image[:: -1, :, :]
|
39 |
+
elif preprocessing == 'torch':
|
40 |
+
mean = np.array([0.485, 0.456, 0.406])
|
41 |
+
std = np.array([0.229, 0.224, 0.225])
|
42 |
+
image = image * std.reshape([3, 1, 1]) + mean.reshape([3, 1, 1])
|
43 |
+
image *= 255.0
|
44 |
+
else:
|
45 |
+
raise ValueError('Unknown preprocessing parameter.')
|
46 |
+
image = np.transpose(image, [1, 2, 0])
|
47 |
+
image = np.round(image).astype(np.uint8)
|
48 |
+
return image
|
49 |
+
|
50 |
+
|
51 |
+
def grid_positions(h, w, device, matrix=False):
|
52 |
+
lines = torch.arange(
|
53 |
+
0, h, device=device
|
54 |
+
).view(-1, 1).float().repeat(1, w)
|
55 |
+
columns = torch.arange(
|
56 |
+
0, w, device=device
|
57 |
+
).view(1, -1).float().repeat(h, 1)
|
58 |
+
if matrix:
|
59 |
+
return torch.stack([lines, columns], dim=0)
|
60 |
+
else:
|
61 |
+
return torch.cat([lines.view(1, -1), columns.view(1, -1)], dim=0)
|
62 |
+
|
63 |
+
|
64 |
+
def upscale_positions(pos, scaling_steps=0):
|
65 |
+
for _ in range(scaling_steps):
|
66 |
+
pos = pos * 2 + 0.5
|
67 |
+
return pos
|
68 |
+
|
69 |
+
|
70 |
+
def downscale_positions(pos, scaling_steps=0):
|
71 |
+
for _ in range(scaling_steps):
|
72 |
+
pos = (pos - 0.5) / 2
|
73 |
+
return pos
|
74 |
+
|
75 |
+
|
76 |
+
def interpolate_dense_features(pos, dense_features, return_corners=False):
|
77 |
+
device = pos.device
|
78 |
+
|
79 |
+
ids = torch.arange(0, pos.size(1), device=device)
|
80 |
+
|
81 |
+
_, h, w = dense_features.size()
|
82 |
+
|
83 |
+
i = pos[0, :]
|
84 |
+
j = pos[1, :]
|
85 |
+
|
86 |
+
# Valid corners
|
87 |
+
i_top_left = torch.floor(i).long()
|
88 |
+
j_top_left = torch.floor(j).long()
|
89 |
+
valid_top_left = torch.min(i_top_left >= 0, j_top_left >= 0)
|
90 |
+
|
91 |
+
i_top_right = torch.floor(i).long()
|
92 |
+
j_top_right = torch.ceil(j).long()
|
93 |
+
valid_top_right = torch.min(i_top_right >= 0, j_top_right < w)
|
94 |
+
|
95 |
+
i_bottom_left = torch.ceil(i).long()
|
96 |
+
j_bottom_left = torch.floor(j).long()
|
97 |
+
valid_bottom_left = torch.min(i_bottom_left < h, j_bottom_left >= 0)
|
98 |
+
|
99 |
+
i_bottom_right = torch.ceil(i).long()
|
100 |
+
j_bottom_right = torch.ceil(j).long()
|
101 |
+
valid_bottom_right = torch.min(i_bottom_right < h, j_bottom_right < w)
|
102 |
+
|
103 |
+
valid_corners = torch.min(
|
104 |
+
torch.min(valid_top_left, valid_top_right),
|
105 |
+
torch.min(valid_bottom_left, valid_bottom_right)
|
106 |
+
)
|
107 |
+
|
108 |
+
i_top_left = i_top_left[valid_corners]
|
109 |
+
j_top_left = j_top_left[valid_corners]
|
110 |
+
|
111 |
+
i_top_right = i_top_right[valid_corners]
|
112 |
+
j_top_right = j_top_right[valid_corners]
|
113 |
+
|
114 |
+
i_bottom_left = i_bottom_left[valid_corners]
|
115 |
+
j_bottom_left = j_bottom_left[valid_corners]
|
116 |
+
|
117 |
+
i_bottom_right = i_bottom_right[valid_corners]
|
118 |
+
j_bottom_right = j_bottom_right[valid_corners]
|
119 |
+
|
120 |
+
ids = ids[valid_corners]
|
121 |
+
if ids.size(0) == 0:
|
122 |
+
raise EmptyTensorError
|
123 |
+
|
124 |
+
# Interpolation
|
125 |
+
i = i[ids]
|
126 |
+
j = j[ids]
|
127 |
+
dist_i_top_left = i - i_top_left.float()
|
128 |
+
dist_j_top_left = j - j_top_left.float()
|
129 |
+
w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left)
|
130 |
+
w_top_right = (1 - dist_i_top_left) * dist_j_top_left
|
131 |
+
w_bottom_left = dist_i_top_left * (1 - dist_j_top_left)
|
132 |
+
w_bottom_right = dist_i_top_left * dist_j_top_left
|
133 |
+
|
134 |
+
descriptors = (
|
135 |
+
w_top_left * dense_features[:, i_top_left, j_top_left] +
|
136 |
+
w_top_right * dense_features[:, i_top_right, j_top_right] +
|
137 |
+
w_bottom_left * dense_features[:, i_bottom_left, j_bottom_left] +
|
138 |
+
w_bottom_right * dense_features[:, i_bottom_right, j_bottom_right]
|
139 |
+
)
|
140 |
+
|
141 |
+
pos = torch.cat([i.view(1, -1), j.view(1, -1)], dim=0)
|
142 |
+
|
143 |
+
if not return_corners:
|
144 |
+
return [descriptors, pos, ids]
|
145 |
+
else:
|
146 |
+
corners = torch.stack([
|
147 |
+
torch.stack([i_top_left, j_top_left], dim=0),
|
148 |
+
torch.stack([i_top_right, j_top_right], dim=0),
|
149 |
+
torch.stack([i_bottom_left, j_bottom_left], dim=0),
|
150 |
+
torch.stack([i_bottom_right, j_bottom_right], dim=0)
|
151 |
+
], dim=0)
|
152 |
+
return [descriptors, pos, ids, corners]
|
153 |
+
|
154 |
+
|
155 |
+
def savefig(filepath, fig=None, dpi=None):
|
156 |
+
# TomNorway - https://stackoverflow.com/a/53516034
|
157 |
+
if not fig:
|
158 |
+
fig = plt.gcf()
|
159 |
+
|
160 |
+
plt.subplots_adjust(0, 0, 1, 1, 0, 0)
|
161 |
+
for ax in fig.axes:
|
162 |
+
ax.axis('off')
|
163 |
+
ax.margins(0, 0)
|
164 |
+
ax.xaxis.set_major_locator(plt.NullLocator())
|
165 |
+
ax.yaxis.set_major_locator(plt.NullLocator())
|
166 |
+
|
167 |
+
fig.savefig(filepath, pad_inches=0, bbox_inches='tight', dpi=dpi)
|