Upload model (#1)
Browse files- Upload model (edba1de555b3ac3fbc0ad8165f8fa3d84fbee170)
- attention.py +295 -0
- backbone.py +142 -0
- config.json +13 -0
- configuration_doduo.py +13 -0
- dino.py +420 -0
- geometry.py +229 -0
- matching.py +505 -0
- modeling_doduo.py +118 -0
- position.py +49 -0
- pytorch_model.bin +3 -0
- transformer.py +340 -0
- trident_conv.py +96 -0
- unimatch.py +213 -0
- utils.py +257 -0
attention.py
ADDED
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from .utils import merge_splits, merge_splits_1d, split_feature, split_feature_1d
|
6 |
+
|
7 |
+
|
8 |
+
def single_head_full_attention(q, k, v):
|
9 |
+
# q, k, v: [B, L, C]
|
10 |
+
assert q.dim() == k.dim() == v.dim() == 3
|
11 |
+
|
12 |
+
scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** 0.5) # [B, L, L]
|
13 |
+
attn = torch.softmax(scores, dim=2) # [B, L, L]
|
14 |
+
out = torch.matmul(attn, v) # [B, L, C]
|
15 |
+
|
16 |
+
return out
|
17 |
+
|
18 |
+
|
19 |
+
def single_head_full_attention_1d(
|
20 |
+
q,
|
21 |
+
k,
|
22 |
+
v,
|
23 |
+
h=None,
|
24 |
+
w=None,
|
25 |
+
):
|
26 |
+
# q, k, v: [B, L, C]
|
27 |
+
|
28 |
+
assert h is not None and w is not None
|
29 |
+
assert q.size(1) == h * w
|
30 |
+
|
31 |
+
b, _, c = q.size()
|
32 |
+
|
33 |
+
q = q.view(b, h, w, c) # [B, H, W, C]
|
34 |
+
k = k.view(b, h, w, c)
|
35 |
+
v = v.view(b, h, w, c)
|
36 |
+
|
37 |
+
scale_factor = c**0.5
|
38 |
+
|
39 |
+
scores = torch.matmul(q, k.permute(0, 1, 3, 2)) / scale_factor # [B, H, W, W]
|
40 |
+
|
41 |
+
attn = torch.softmax(scores, dim=-1)
|
42 |
+
|
43 |
+
out = torch.matmul(attn, v).view(b, -1, c) # [B, H*W, C]
|
44 |
+
|
45 |
+
return out
|
46 |
+
|
47 |
+
|
48 |
+
def single_head_split_window_attention(
|
49 |
+
q,
|
50 |
+
k,
|
51 |
+
v,
|
52 |
+
num_splits=1,
|
53 |
+
with_shift=False,
|
54 |
+
h=None,
|
55 |
+
w=None,
|
56 |
+
attn_mask=None,
|
57 |
+
):
|
58 |
+
# ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
|
59 |
+
# q, k, v: [B, L, C]
|
60 |
+
assert q.dim() == k.dim() == v.dim() == 3
|
61 |
+
|
62 |
+
assert h is not None and w is not None
|
63 |
+
assert q.size(1) == h * w
|
64 |
+
|
65 |
+
b, _, c = q.size()
|
66 |
+
|
67 |
+
b_new = b * num_splits * num_splits
|
68 |
+
|
69 |
+
window_size_h = h // num_splits
|
70 |
+
window_size_w = w // num_splits
|
71 |
+
|
72 |
+
q = q.view(b, h, w, c) # [B, H, W, C]
|
73 |
+
k = k.view(b, h, w, c)
|
74 |
+
v = v.view(b, h, w, c)
|
75 |
+
|
76 |
+
scale_factor = c**0.5
|
77 |
+
|
78 |
+
if with_shift:
|
79 |
+
assert attn_mask is not None # compute once
|
80 |
+
shift_size_h = window_size_h // 2
|
81 |
+
shift_size_w = window_size_w // 2
|
82 |
+
|
83 |
+
q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
|
84 |
+
k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
|
85 |
+
v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
|
86 |
+
|
87 |
+
q = split_feature(q, num_splits=num_splits, channel_last=True) # [B*K*K, H/K, W/K, C]
|
88 |
+
k = split_feature(k, num_splits=num_splits, channel_last=True)
|
89 |
+
v = split_feature(v, num_splits=num_splits, channel_last=True)
|
90 |
+
|
91 |
+
scores = (
|
92 |
+
torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1)) / scale_factor
|
93 |
+
) # [B*K*K, H/K*W/K, H/K*W/K]
|
94 |
+
|
95 |
+
if with_shift:
|
96 |
+
scores += attn_mask.repeat(b, 1, 1)
|
97 |
+
|
98 |
+
attn = torch.softmax(scores, dim=-1)
|
99 |
+
|
100 |
+
out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*K*K, H/K*W/K, C]
|
101 |
+
|
102 |
+
out = merge_splits(
|
103 |
+
out.view(b_new, h // num_splits, w // num_splits, c),
|
104 |
+
num_splits=num_splits,
|
105 |
+
channel_last=True,
|
106 |
+
) # [B, H, W, C]
|
107 |
+
|
108 |
+
# shift back
|
109 |
+
if with_shift:
|
110 |
+
out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2))
|
111 |
+
|
112 |
+
out = out.view(b, -1, c)
|
113 |
+
|
114 |
+
return out
|
115 |
+
|
116 |
+
|
117 |
+
def single_head_split_window_attention_1d(
|
118 |
+
q,
|
119 |
+
k,
|
120 |
+
v,
|
121 |
+
relative_position_bias=None,
|
122 |
+
num_splits=1,
|
123 |
+
with_shift=False,
|
124 |
+
h=None,
|
125 |
+
w=None,
|
126 |
+
attn_mask=None,
|
127 |
+
):
|
128 |
+
# q, k, v: [B, L, C]
|
129 |
+
|
130 |
+
assert h is not None and w is not None
|
131 |
+
assert q.size(1) == h * w
|
132 |
+
|
133 |
+
b, _, c = q.size()
|
134 |
+
|
135 |
+
b_new = b * num_splits * h
|
136 |
+
|
137 |
+
window_size_w = w // num_splits
|
138 |
+
|
139 |
+
q = q.view(b * h, w, c) # [B*H, W, C]
|
140 |
+
k = k.view(b * h, w, c)
|
141 |
+
v = v.view(b * h, w, c)
|
142 |
+
|
143 |
+
scale_factor = c**0.5
|
144 |
+
|
145 |
+
if with_shift:
|
146 |
+
assert attn_mask is not None # compute once
|
147 |
+
shift_size_w = window_size_w // 2
|
148 |
+
|
149 |
+
q = torch.roll(q, shifts=-shift_size_w, dims=1)
|
150 |
+
k = torch.roll(k, shifts=-shift_size_w, dims=1)
|
151 |
+
v = torch.roll(v, shifts=-shift_size_w, dims=1)
|
152 |
+
|
153 |
+
q = split_feature_1d(q, num_splits=num_splits) # [B*H*K, W/K, C]
|
154 |
+
k = split_feature_1d(k, num_splits=num_splits)
|
155 |
+
v = split_feature_1d(v, num_splits=num_splits)
|
156 |
+
|
157 |
+
scores = (
|
158 |
+
torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1)) / scale_factor
|
159 |
+
) # [B*H*K, W/K, W/K]
|
160 |
+
|
161 |
+
if with_shift:
|
162 |
+
# attn_mask: [K, W/K, W/K]
|
163 |
+
scores += attn_mask.repeat(b * h, 1, 1) # [B*H*K, W/K, W/K]
|
164 |
+
|
165 |
+
attn = torch.softmax(scores, dim=-1)
|
166 |
+
|
167 |
+
out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*H*K, W/K, C]
|
168 |
+
|
169 |
+
out = merge_splits_1d(out, h, num_splits=num_splits) # [B, H, W, C]
|
170 |
+
|
171 |
+
# shift back
|
172 |
+
if with_shift:
|
173 |
+
out = torch.roll(out, shifts=shift_size_w, dims=2)
|
174 |
+
|
175 |
+
out = out.view(b, -1, c)
|
176 |
+
|
177 |
+
return out
|
178 |
+
|
179 |
+
|
180 |
+
class SelfAttnPropagation(nn.Module):
|
181 |
+
"""
|
182 |
+
flow propagation with self-attention on feature
|
183 |
+
query: feature0, key: feature0, value: flow
|
184 |
+
"""
|
185 |
+
|
186 |
+
def __init__(
|
187 |
+
self,
|
188 |
+
in_channels,
|
189 |
+
**kwargs,
|
190 |
+
):
|
191 |
+
super().__init__()
|
192 |
+
|
193 |
+
self.q_proj = nn.Linear(in_channels, in_channels)
|
194 |
+
self.k_proj = nn.Linear(in_channels, in_channels)
|
195 |
+
|
196 |
+
for p in self.parameters():
|
197 |
+
if p.dim() > 1:
|
198 |
+
nn.init.xavier_uniform_(p)
|
199 |
+
|
200 |
+
def forward(
|
201 |
+
self,
|
202 |
+
feature0,
|
203 |
+
flow,
|
204 |
+
local_window_attn=False,
|
205 |
+
local_window_radius=1,
|
206 |
+
**kwargs,
|
207 |
+
):
|
208 |
+
# q, k: feature [B, C, H, W], v: flow [B, 2, H, W]
|
209 |
+
if local_window_attn:
|
210 |
+
return self.forward_local_window_attn(
|
211 |
+
feature0, flow, local_window_radius=local_window_radius
|
212 |
+
)
|
213 |
+
|
214 |
+
b, c, h, w = feature0.size()
|
215 |
+
|
216 |
+
query = feature0.view(b, c, h * w).permute(0, 2, 1) # [B, H*W, C]
|
217 |
+
|
218 |
+
# a note: the ``correct'' implementation should be:
|
219 |
+
# ``query = self.q_proj(query), key = self.k_proj(query)''
|
220 |
+
# this problem is observed while cleaning up the code
|
221 |
+
# however, this doesn't affect the performance since the projection is a linear operation,
|
222 |
+
# thus the two projection matrices for key can be merged
|
223 |
+
# so I just leave it as is in order to not re-train all models :)
|
224 |
+
query = self.q_proj(query) # [B, H*W, C]
|
225 |
+
key = self.k_proj(query) # [B, H*W, C]
|
226 |
+
|
227 |
+
value = flow.view(b, flow.size(1), h * w).permute(0, 2, 1) # [B, H*W, 2]
|
228 |
+
|
229 |
+
scores = torch.matmul(query, key.permute(0, 2, 1)) / (c**0.5) # [B, H*W, H*W]
|
230 |
+
prob = torch.softmax(scores, dim=-1)
|
231 |
+
|
232 |
+
out = torch.matmul(prob, value) # [B, H*W, 2]
|
233 |
+
out = out.view(b, h, w, value.size(-1)).permute(0, 3, 1, 2) # [B, 2, H, W]
|
234 |
+
|
235 |
+
return out
|
236 |
+
|
237 |
+
def forward_local_window_attn(
|
238 |
+
self,
|
239 |
+
feature0,
|
240 |
+
flow,
|
241 |
+
local_window_radius=1,
|
242 |
+
):
|
243 |
+
assert flow.size(1) == 2 or flow.size(1) == 1 # flow or disparity or depth
|
244 |
+
assert local_window_radius > 0
|
245 |
+
|
246 |
+
b, c, h, w = feature0.size()
|
247 |
+
|
248 |
+
value_channel = flow.size(1)
|
249 |
+
|
250 |
+
feature0_reshape = self.q_proj(feature0.view(b, c, -1).permute(0, 2, 1)).reshape(
|
251 |
+
b * h * w, 1, c
|
252 |
+
) # [B*H*W, 1, C]
|
253 |
+
|
254 |
+
kernel_size = 2 * local_window_radius + 1
|
255 |
+
|
256 |
+
feature0_proj = (
|
257 |
+
self.k_proj(feature0.view(b, c, -1).permute(0, 2, 1))
|
258 |
+
.permute(0, 2, 1)
|
259 |
+
.reshape(b, c, h, w)
|
260 |
+
)
|
261 |
+
|
262 |
+
feature0_window = F.unfold(
|
263 |
+
feature0_proj, kernel_size=kernel_size, padding=local_window_radius
|
264 |
+
) # [B, C*(2R+1)^2), H*W]
|
265 |
+
|
266 |
+
feature0_window = (
|
267 |
+
feature0_window.view(b, c, kernel_size**2, h, w)
|
268 |
+
.permute(0, 3, 4, 1, 2)
|
269 |
+
.reshape(b * h * w, c, kernel_size**2)
|
270 |
+
) # [B*H*W, C, (2R+1)^2]
|
271 |
+
|
272 |
+
flow_window = F.unfold(
|
273 |
+
flow, kernel_size=kernel_size, padding=local_window_radius
|
274 |
+
) # [B, 2*(2R+1)^2), H*W]
|
275 |
+
|
276 |
+
flow_window = (
|
277 |
+
flow_window.view(b, value_channel, kernel_size**2, h, w)
|
278 |
+
.permute(0, 3, 4, 2, 1)
|
279 |
+
.reshape(b * h * w, kernel_size**2, value_channel)
|
280 |
+
) # [B*H*W, (2R+1)^2, 2]
|
281 |
+
|
282 |
+
scores = torch.matmul(feature0_reshape, feature0_window) / (
|
283 |
+
c**0.5
|
284 |
+
) # [B*H*W, 1, (2R+1)^2]
|
285 |
+
|
286 |
+
prob = torch.softmax(scores, dim=-1)
|
287 |
+
|
288 |
+
out = (
|
289 |
+
torch.matmul(prob, flow_window)
|
290 |
+
.view(b, h, w, value_channel)
|
291 |
+
.permute(0, 3, 1, 2)
|
292 |
+
.contiguous()
|
293 |
+
) # [B, 2, H, W]
|
294 |
+
|
295 |
+
return out
|
backbone.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
from .trident_conv import MultiScaleTridentConv
|
4 |
+
|
5 |
+
|
6 |
+
class ResidualBlock(nn.Module):
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
in_planes,
|
10 |
+
planes,
|
11 |
+
norm_layer=nn.InstanceNorm2d,
|
12 |
+
stride=1,
|
13 |
+
dilation=1,
|
14 |
+
):
|
15 |
+
super().__init__()
|
16 |
+
|
17 |
+
self.conv1 = nn.Conv2d(
|
18 |
+
in_planes,
|
19 |
+
planes,
|
20 |
+
kernel_size=3,
|
21 |
+
dilation=dilation,
|
22 |
+
padding=dilation,
|
23 |
+
stride=stride,
|
24 |
+
bias=False,
|
25 |
+
)
|
26 |
+
self.conv2 = nn.Conv2d(
|
27 |
+
planes, planes, kernel_size=3, dilation=dilation, padding=dilation, bias=False
|
28 |
+
)
|
29 |
+
self.relu = nn.ReLU(inplace=True)
|
30 |
+
|
31 |
+
self.norm1 = norm_layer(planes)
|
32 |
+
self.norm2 = norm_layer(planes)
|
33 |
+
if not stride == 1 or in_planes != planes:
|
34 |
+
self.norm3 = norm_layer(planes)
|
35 |
+
|
36 |
+
if stride == 1 and in_planes == planes:
|
37 |
+
self.downsample = None
|
38 |
+
else:
|
39 |
+
self.downsample = nn.Sequential(
|
40 |
+
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
|
41 |
+
)
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
y = x
|
45 |
+
y = self.relu(self.norm1(self.conv1(y)))
|
46 |
+
y = self.relu(self.norm2(self.conv2(y)))
|
47 |
+
|
48 |
+
if self.downsample is not None:
|
49 |
+
x = self.downsample(x)
|
50 |
+
|
51 |
+
return self.relu(x + y)
|
52 |
+
|
53 |
+
|
54 |
+
class CNNEncoder(nn.Module):
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
output_dim=128,
|
58 |
+
norm_layer=nn.InstanceNorm2d,
|
59 |
+
num_output_scales=1,
|
60 |
+
**kwargs,
|
61 |
+
):
|
62 |
+
super().__init__()
|
63 |
+
self.num_branch = num_output_scales
|
64 |
+
|
65 |
+
feature_dims = [64, 96, 128]
|
66 |
+
|
67 |
+
self.conv1 = nn.Conv2d(
|
68 |
+
3, feature_dims[0], kernel_size=7, stride=2, padding=3, bias=False
|
69 |
+
) # 1/2
|
70 |
+
self.norm1 = norm_layer(feature_dims[0])
|
71 |
+
self.relu1 = nn.ReLU(inplace=True)
|
72 |
+
|
73 |
+
self.in_planes = feature_dims[0]
|
74 |
+
self.layer1 = self._make_layer(feature_dims[0], stride=1, norm_layer=norm_layer) # 1/2
|
75 |
+
self.layer2 = self._make_layer(feature_dims[1], stride=2, norm_layer=norm_layer) # 1/4
|
76 |
+
|
77 |
+
# highest resolution 1/4 or 1/8
|
78 |
+
stride = 2 if num_output_scales == 1 else 1
|
79 |
+
self.layer3 = self._make_layer(
|
80 |
+
feature_dims[2],
|
81 |
+
stride=stride,
|
82 |
+
norm_layer=norm_layer,
|
83 |
+
) # 1/4 or 1/8
|
84 |
+
|
85 |
+
self.conv2 = nn.Conv2d(feature_dims[2], output_dim, 1, 1, 0)
|
86 |
+
|
87 |
+
if self.num_branch > 1:
|
88 |
+
if self.num_branch == 4:
|
89 |
+
strides = (1, 2, 4, 8)
|
90 |
+
elif self.num_branch == 3:
|
91 |
+
strides = (1, 2, 4)
|
92 |
+
elif self.num_branch == 2:
|
93 |
+
strides = (1, 2)
|
94 |
+
else:
|
95 |
+
raise ValueError
|
96 |
+
|
97 |
+
self.trident_conv = MultiScaleTridentConv(
|
98 |
+
output_dim,
|
99 |
+
output_dim,
|
100 |
+
kernel_size=3,
|
101 |
+
strides=strides,
|
102 |
+
paddings=1,
|
103 |
+
num_branch=self.num_branch,
|
104 |
+
)
|
105 |
+
|
106 |
+
for m in self.modules():
|
107 |
+
if isinstance(m, nn.Conv2d):
|
108 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
109 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
110 |
+
if m.weight is not None:
|
111 |
+
nn.init.constant_(m.weight, 1)
|
112 |
+
if m.bias is not None:
|
113 |
+
nn.init.constant_(m.bias, 0)
|
114 |
+
|
115 |
+
def _make_layer(self, dim, stride=1, dilation=1, norm_layer=nn.InstanceNorm2d):
|
116 |
+
layer1 = ResidualBlock(
|
117 |
+
self.in_planes, dim, norm_layer=norm_layer, stride=stride, dilation=dilation
|
118 |
+
)
|
119 |
+
layer2 = ResidualBlock(dim, dim, norm_layer=norm_layer, stride=1, dilation=dilation)
|
120 |
+
|
121 |
+
layers = (layer1, layer2)
|
122 |
+
|
123 |
+
self.in_planes = dim
|
124 |
+
return nn.Sequential(*layers)
|
125 |
+
|
126 |
+
def forward(self, x):
|
127 |
+
x = self.conv1(x)
|
128 |
+
x = self.norm1(x)
|
129 |
+
x = self.relu1(x)
|
130 |
+
|
131 |
+
x = self.layer1(x) # 1/2
|
132 |
+
x = self.layer2(x) # 1/4
|
133 |
+
x = self.layer3(x) # 1/8 or 1/4
|
134 |
+
|
135 |
+
x = self.conv2(x)
|
136 |
+
|
137 |
+
if self.num_branch > 1:
|
138 |
+
out = self.trident_conv([x] * self.num_branch) # high to low res
|
139 |
+
else:
|
140 |
+
out = [x]
|
141 |
+
|
142 |
+
return out
|
config.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"DoduoModel"
|
4 |
+
],
|
5 |
+
"auto_map": {
|
6 |
+
"AutoConfig": "configuration_doduo.DoduoConfig",
|
7 |
+
"AutoModel": "modeling_doduo.DoduoModel"
|
8 |
+
},
|
9 |
+
"dino_corr_mask_ratio": 0.99,
|
10 |
+
"model_type": "doduo",
|
11 |
+
"torch_dtype": "float32",
|
12 |
+
"transformers_version": "4.30.2"
|
13 |
+
}
|
configuration_doduo.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
class DoduoConfig(PretrainedConfig):
|
5 |
+
model_type = "doduo"
|
6 |
+
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
dino_corr_mask_ratio: float = 0.99,
|
10 |
+
**kwargs,
|
11 |
+
):
|
12 |
+
self.dino_corr_mask_ratio = dino_corr_mask_ratio
|
13 |
+
super().__init__(**kwargs)
|
dino.py
ADDED
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Mostly copy-paste from timm library.
|
15 |
+
|
16 |
+
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
17 |
+
"""
|
18 |
+
import math
|
19 |
+
from functools import partial
|
20 |
+
import warnings
|
21 |
+
|
22 |
+
import torch
|
23 |
+
import torch.nn as nn
|
24 |
+
|
25 |
+
|
26 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
27 |
+
if drop_prob == 0.0 or not training:
|
28 |
+
return x
|
29 |
+
keep_prob = 1 - drop_prob
|
30 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
31 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
32 |
+
random_tensor.floor_() # binarize
|
33 |
+
output = x.div(keep_prob) * random_tensor
|
34 |
+
return output
|
35 |
+
|
36 |
+
|
37 |
+
class DropPath(nn.Module):
|
38 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
39 |
+
|
40 |
+
def __init__(self, drop_prob=None):
|
41 |
+
super().__init__()
|
42 |
+
self.drop_prob = drop_prob
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
return drop_path(x, self.drop_prob, self.training)
|
46 |
+
|
47 |
+
|
48 |
+
class Mlp(nn.Module):
|
49 |
+
def __init__(
|
50 |
+
self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0
|
51 |
+
):
|
52 |
+
super().__init__()
|
53 |
+
out_features = out_features or in_features
|
54 |
+
hidden_features = hidden_features or in_features
|
55 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
56 |
+
self.act = act_layer()
|
57 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
58 |
+
self.drop = nn.Dropout(drop)
|
59 |
+
|
60 |
+
def forward(self, x):
|
61 |
+
x = self.fc1(x)
|
62 |
+
x = self.act(x)
|
63 |
+
x = self.drop(x)
|
64 |
+
x = self.fc2(x)
|
65 |
+
x = self.drop(x)
|
66 |
+
return x
|
67 |
+
|
68 |
+
|
69 |
+
class Attention(nn.Module):
|
70 |
+
def __init__(
|
71 |
+
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0
|
72 |
+
):
|
73 |
+
super().__init__()
|
74 |
+
self.num_heads = num_heads
|
75 |
+
head_dim = dim // num_heads
|
76 |
+
self.scale = qk_scale or head_dim**-0.5
|
77 |
+
|
78 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
79 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
80 |
+
self.proj = nn.Linear(dim, dim)
|
81 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
82 |
+
|
83 |
+
def forward(self, x):
|
84 |
+
B, N, C = x.shape
|
85 |
+
qkv = (
|
86 |
+
self.qkv(x)
|
87 |
+
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
88 |
+
.permute(2, 0, 3, 1, 4)
|
89 |
+
)
|
90 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
91 |
+
|
92 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
93 |
+
attn = attn.softmax(dim=-1)
|
94 |
+
attn = self.attn_drop(attn)
|
95 |
+
|
96 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
97 |
+
x = self.proj(x)
|
98 |
+
x = self.proj_drop(x)
|
99 |
+
return x, attn
|
100 |
+
|
101 |
+
|
102 |
+
class Block(nn.Module):
|
103 |
+
def __init__(
|
104 |
+
self,
|
105 |
+
dim,
|
106 |
+
num_heads,
|
107 |
+
mlp_ratio=4.0,
|
108 |
+
qkv_bias=False,
|
109 |
+
qk_scale=None,
|
110 |
+
drop=0.0,
|
111 |
+
attn_drop=0.0,
|
112 |
+
drop_path=0.0,
|
113 |
+
act_layer=nn.GELU,
|
114 |
+
norm_layer=nn.LayerNorm,
|
115 |
+
):
|
116 |
+
super().__init__()
|
117 |
+
self.norm1 = norm_layer(dim)
|
118 |
+
self.attn = Attention(
|
119 |
+
dim,
|
120 |
+
num_heads=num_heads,
|
121 |
+
qkv_bias=qkv_bias,
|
122 |
+
qk_scale=qk_scale,
|
123 |
+
attn_drop=attn_drop,
|
124 |
+
proj_drop=drop,
|
125 |
+
)
|
126 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
127 |
+
self.norm2 = norm_layer(dim)
|
128 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
129 |
+
self.mlp = Mlp(
|
130 |
+
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop
|
131 |
+
)
|
132 |
+
|
133 |
+
def forward(self, x, return_attention=False):
|
134 |
+
y, attn = self.attn(self.norm1(x))
|
135 |
+
if return_attention:
|
136 |
+
return attn
|
137 |
+
x = x + self.drop_path(y)
|
138 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
139 |
+
return x
|
140 |
+
|
141 |
+
|
142 |
+
class PatchEmbed(nn.Module):
|
143 |
+
"""Image to Patch Embedding."""
|
144 |
+
|
145 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
146 |
+
super().__init__()
|
147 |
+
num_patches = (img_size // patch_size) * (img_size // patch_size)
|
148 |
+
self.img_size = img_size
|
149 |
+
self.patch_size = patch_size
|
150 |
+
self.num_patches = num_patches
|
151 |
+
|
152 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
153 |
+
|
154 |
+
def forward(self, x):
|
155 |
+
B, C, H, W = x.shape
|
156 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
157 |
+
return x
|
158 |
+
|
159 |
+
|
160 |
+
class VisionTransformer(nn.Module):
|
161 |
+
"""Vision Transformer."""
|
162 |
+
|
163 |
+
def __init__(
|
164 |
+
self,
|
165 |
+
img_size=[224],
|
166 |
+
patch_size=16,
|
167 |
+
in_chans=3,
|
168 |
+
num_classes=0,
|
169 |
+
embed_dim=768,
|
170 |
+
depth=12,
|
171 |
+
num_heads=12,
|
172 |
+
mlp_ratio=4.0,
|
173 |
+
qkv_bias=False,
|
174 |
+
qk_scale=None,
|
175 |
+
drop_rate=0.0,
|
176 |
+
attn_drop_rate=0.0,
|
177 |
+
drop_path_rate=0.0,
|
178 |
+
norm_layer=nn.LayerNorm,
|
179 |
+
**kwargs
|
180 |
+
):
|
181 |
+
super().__init__()
|
182 |
+
self.num_features = self.embed_dim = embed_dim
|
183 |
+
|
184 |
+
self.patch_embed = PatchEmbed(
|
185 |
+
img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim
|
186 |
+
)
|
187 |
+
num_patches = self.patch_embed.num_patches
|
188 |
+
|
189 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
190 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
191 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
192 |
+
|
193 |
+
dpr = [
|
194 |
+
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
195 |
+
] # stochastic depth decay rule
|
196 |
+
self.blocks = nn.ModuleList(
|
197 |
+
[
|
198 |
+
Block(
|
199 |
+
dim=embed_dim,
|
200 |
+
num_heads=num_heads,
|
201 |
+
mlp_ratio=mlp_ratio,
|
202 |
+
qkv_bias=qkv_bias,
|
203 |
+
qk_scale=qk_scale,
|
204 |
+
drop=drop_rate,
|
205 |
+
attn_drop=attn_drop_rate,
|
206 |
+
drop_path=dpr[i],
|
207 |
+
norm_layer=norm_layer,
|
208 |
+
)
|
209 |
+
for i in range(depth)
|
210 |
+
]
|
211 |
+
)
|
212 |
+
self.norm = norm_layer(embed_dim)
|
213 |
+
|
214 |
+
# Classifier head
|
215 |
+
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
216 |
+
|
217 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
218 |
+
trunc_normal_(self.cls_token, std=0.02)
|
219 |
+
self.apply(self._init_weights)
|
220 |
+
|
221 |
+
def _init_weights(self, m):
|
222 |
+
if isinstance(m, nn.Linear):
|
223 |
+
trunc_normal_(m.weight, std=0.02)
|
224 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
225 |
+
nn.init.constant_(m.bias, 0)
|
226 |
+
elif isinstance(m, nn.LayerNorm):
|
227 |
+
nn.init.constant_(m.bias, 0)
|
228 |
+
nn.init.constant_(m.weight, 1.0)
|
229 |
+
|
230 |
+
def interpolate_pos_encoding(self, x, w, h):
|
231 |
+
npatch = x.shape[1] - 1
|
232 |
+
N = self.pos_embed.shape[1] - 1
|
233 |
+
if npatch == N and w == h:
|
234 |
+
return self.pos_embed
|
235 |
+
class_pos_embed = self.pos_embed[:, 0]
|
236 |
+
patch_pos_embed = self.pos_embed[:, 1:]
|
237 |
+
dim = x.shape[-1]
|
238 |
+
w0 = w // self.patch_embed.patch_size
|
239 |
+
h0 = h // self.patch_embed.patch_size
|
240 |
+
# we add a small number to avoid floating point error in the interpolation
|
241 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
242 |
+
w0, h0 = w0 + 0.1, h0 + 0.1
|
243 |
+
patch_pos_embed = nn.functional.interpolate(
|
244 |
+
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(
|
245 |
+
0, 3, 1, 2
|
246 |
+
),
|
247 |
+
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
|
248 |
+
mode="bicubic",
|
249 |
+
)
|
250 |
+
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
|
251 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
252 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
253 |
+
|
254 |
+
def prepare_tokens(self, x):
|
255 |
+
B, nc, w, h = x.shape
|
256 |
+
x = self.patch_embed(x) # patch linear embedding
|
257 |
+
|
258 |
+
# add the [CLS] token to the embed patch tokens
|
259 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
260 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
261 |
+
|
262 |
+
# add positional encoding to each token
|
263 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
264 |
+
|
265 |
+
return self.pos_drop(x)
|
266 |
+
|
267 |
+
def forward(self, x):
|
268 |
+
x = self.prepare_tokens(x)
|
269 |
+
for blk in self.blocks:
|
270 |
+
x = blk(x)
|
271 |
+
x = self.norm(x)
|
272 |
+
return x[:, 0]
|
273 |
+
|
274 |
+
def get_last_selfattention(self, x):
|
275 |
+
x = self.prepare_tokens(x)
|
276 |
+
for i, blk in enumerate(self.blocks):
|
277 |
+
if i < len(self.blocks) - 1:
|
278 |
+
x = blk(x)
|
279 |
+
else:
|
280 |
+
# return attention of the last block
|
281 |
+
return blk(x, return_attention=True)
|
282 |
+
|
283 |
+
def get_intermediate_layers(self, x, n=1):
|
284 |
+
x = self.prepare_tokens(x)
|
285 |
+
# we return the output tokens from the `n` last blocks
|
286 |
+
output = []
|
287 |
+
for i, blk in enumerate(self.blocks):
|
288 |
+
x = blk(x)
|
289 |
+
if len(self.blocks) - i <= n:
|
290 |
+
output.append(self.norm(x))
|
291 |
+
return output
|
292 |
+
|
293 |
+
|
294 |
+
def vit_tiny(patch_size=16, **kwargs):
|
295 |
+
model = VisionTransformer(
|
296 |
+
patch_size=patch_size,
|
297 |
+
embed_dim=192,
|
298 |
+
depth=12,
|
299 |
+
num_heads=3,
|
300 |
+
mlp_ratio=4,
|
301 |
+
qkv_bias=True,
|
302 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
303 |
+
**kwargs
|
304 |
+
)
|
305 |
+
return model
|
306 |
+
|
307 |
+
|
308 |
+
def vit_small(patch_size=16, **kwargs):
|
309 |
+
model = VisionTransformer(
|
310 |
+
patch_size=patch_size,
|
311 |
+
embed_dim=384,
|
312 |
+
depth=12,
|
313 |
+
num_heads=6,
|
314 |
+
mlp_ratio=4,
|
315 |
+
qkv_bias=True,
|
316 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
317 |
+
**kwargs
|
318 |
+
)
|
319 |
+
return model
|
320 |
+
|
321 |
+
|
322 |
+
def vit_base(patch_size=16, **kwargs):
|
323 |
+
model = VisionTransformer(
|
324 |
+
patch_size=patch_size,
|
325 |
+
embed_dim=768,
|
326 |
+
depth=12,
|
327 |
+
num_heads=12,
|
328 |
+
mlp_ratio=4,
|
329 |
+
qkv_bias=True,
|
330 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
331 |
+
**kwargs
|
332 |
+
)
|
333 |
+
return model
|
334 |
+
|
335 |
+
|
336 |
+
class DINOHead(nn.Module):
|
337 |
+
def __init__(
|
338 |
+
self,
|
339 |
+
in_dim,
|
340 |
+
out_dim,
|
341 |
+
use_bn=False,
|
342 |
+
norm_last_layer=True,
|
343 |
+
nlayers=3,
|
344 |
+
hidden_dim=2048,
|
345 |
+
bottleneck_dim=256,
|
346 |
+
):
|
347 |
+
super().__init__()
|
348 |
+
nlayers = max(nlayers, 1)
|
349 |
+
if nlayers == 1:
|
350 |
+
self.mlp = nn.Linear(in_dim, bottleneck_dim)
|
351 |
+
else:
|
352 |
+
layers = [nn.Linear(in_dim, hidden_dim)]
|
353 |
+
if use_bn:
|
354 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
355 |
+
layers.append(nn.GELU())
|
356 |
+
for _ in range(nlayers - 2):
|
357 |
+
layers.append(nn.Linear(hidden_dim, hidden_dim))
|
358 |
+
if use_bn:
|
359 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
360 |
+
layers.append(nn.GELU())
|
361 |
+
layers.append(nn.Linear(hidden_dim, bottleneck_dim))
|
362 |
+
self.mlp = nn.Sequential(*layers)
|
363 |
+
self.apply(self._init_weights)
|
364 |
+
self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
|
365 |
+
self.last_layer.weight_g.data.fill_(1)
|
366 |
+
if norm_last_layer:
|
367 |
+
self.last_layer.weight_g.requires_grad = False
|
368 |
+
|
369 |
+
def _init_weights(self, m):
|
370 |
+
if isinstance(m, nn.Linear):
|
371 |
+
trunc_normal_(m.weight, std=0.02)
|
372 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
373 |
+
nn.init.constant_(m.bias, 0)
|
374 |
+
|
375 |
+
def forward(self, x):
|
376 |
+
x = self.mlp(x)
|
377 |
+
x = nn.functional.normalize(x, dim=-1, p=2)
|
378 |
+
x = self.last_layer(x)
|
379 |
+
return x
|
380 |
+
|
381 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
382 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
383 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
384 |
+
def norm_cdf(x):
|
385 |
+
# Computes standard normal cumulative distribution function
|
386 |
+
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
387 |
+
|
388 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
389 |
+
warnings.warn(
|
390 |
+
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
391 |
+
"The distribution of values may be incorrect.",
|
392 |
+
stacklevel=2,
|
393 |
+
)
|
394 |
+
|
395 |
+
with torch.no_grad():
|
396 |
+
# Values are generated by using a truncated uniform distribution and
|
397 |
+
# then using the inverse CDF for the normal distribution.
|
398 |
+
# Get upper and lower cdf values
|
399 |
+
lower = norm_cdf((a - mean) / std)
|
400 |
+
upper = norm_cdf((b - mean) / std)
|
401 |
+
|
402 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
403 |
+
# [2l-1, 2u-1].
|
404 |
+
tensor.uniform_(2 * lower - 1, 2 * upper - 1)
|
405 |
+
|
406 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
407 |
+
# standard normal
|
408 |
+
tensor.erfinv_()
|
409 |
+
|
410 |
+
# Transform to proper mean, std
|
411 |
+
tensor.mul_(std * math.sqrt(2.0))
|
412 |
+
tensor.add_(mean)
|
413 |
+
|
414 |
+
# Clamp to ensure it's in the proper range
|
415 |
+
tensor.clamp_(min=a, max=b)
|
416 |
+
return tensor
|
417 |
+
|
418 |
+
|
419 |
+
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
|
420 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
geometry.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
|
5 |
+
def coords_grid(b, h, w, homogeneous=False, device=None):
|
6 |
+
y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W]
|
7 |
+
|
8 |
+
stacks = [x, y]
|
9 |
+
|
10 |
+
if homogeneous:
|
11 |
+
ones = torch.ones_like(x) # [H, W]
|
12 |
+
stacks.append(ones)
|
13 |
+
|
14 |
+
grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W]
|
15 |
+
|
16 |
+
grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W]
|
17 |
+
|
18 |
+
if device is not None:
|
19 |
+
grid = grid.to(device)
|
20 |
+
|
21 |
+
return grid
|
22 |
+
|
23 |
+
|
24 |
+
def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None):
|
25 |
+
assert device is not None
|
26 |
+
|
27 |
+
x, y = torch.meshgrid(
|
28 |
+
[
|
29 |
+
torch.linspace(w_min, w_max, len_w, device=device),
|
30 |
+
torch.linspace(h_min, h_max, len_h, device=device),
|
31 |
+
],
|
32 |
+
)
|
33 |
+
grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2]
|
34 |
+
|
35 |
+
return grid
|
36 |
+
|
37 |
+
|
38 |
+
def normalize_coords(coords, h, w):
|
39 |
+
# coords: [B, H, W, 2]
|
40 |
+
c = torch.Tensor([(w - 1) / 2.0, (h - 1) / 2.0]).float().to(coords.device)
|
41 |
+
return (coords - c) / c # [-1, 1]
|
42 |
+
|
43 |
+
|
44 |
+
def bilinear_sample(img, sample_coords, mode="bilinear", padding_mode="zeros", return_mask=False):
|
45 |
+
# img: [B, C, H, W]
|
46 |
+
# sample_coords: [B, 2, H, W] in image scale
|
47 |
+
if sample_coords.size(1) != 2: # [B, H, W, 2]
|
48 |
+
sample_coords = sample_coords.permute(0, 3, 1, 2)
|
49 |
+
|
50 |
+
b, _, h, w = sample_coords.shape
|
51 |
+
|
52 |
+
# Normalize to [-1, 1]
|
53 |
+
x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1
|
54 |
+
y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1
|
55 |
+
|
56 |
+
grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2]
|
57 |
+
|
58 |
+
img = F.grid_sample(img, grid, mode=mode, padding_mode=padding_mode, align_corners=False)
|
59 |
+
|
60 |
+
if return_mask:
|
61 |
+
mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1) # [B, H, W]
|
62 |
+
|
63 |
+
return img, mask
|
64 |
+
|
65 |
+
return img
|
66 |
+
|
67 |
+
|
68 |
+
def flow_warp(feature, flow, mask=False, padding_mode="zeros"):
|
69 |
+
b, c, h, w = feature.size()
|
70 |
+
assert flow.size(1) == 2
|
71 |
+
|
72 |
+
grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W]
|
73 |
+
|
74 |
+
return bilinear_sample(feature, grid, padding_mode=padding_mode, return_mask=mask)
|
75 |
+
|
76 |
+
|
77 |
+
def forward_backward_consistency_check(fwd_flow, bwd_flow, alpha=0.01, beta=0.5):
|
78 |
+
# fwd_flow, bwd_flow: [B, 2, H, W]
|
79 |
+
# alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837)
|
80 |
+
assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4
|
81 |
+
assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2
|
82 |
+
flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W]
|
83 |
+
|
84 |
+
warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W]
|
85 |
+
warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W]
|
86 |
+
|
87 |
+
diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W]
|
88 |
+
diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1)
|
89 |
+
|
90 |
+
threshold = alpha * flow_mag + beta
|
91 |
+
|
92 |
+
fwd_occ = (diff_fwd > threshold).float() # [B, H, W]
|
93 |
+
bwd_occ = (diff_bwd > threshold).float()
|
94 |
+
|
95 |
+
return fwd_occ, bwd_occ
|
96 |
+
|
97 |
+
|
98 |
+
def back_project(depth, intrinsics):
|
99 |
+
# Back project 2D pixel coords to 3D points
|
100 |
+
# depth: [B, H, W]
|
101 |
+
# intrinsics: [B, 3, 3]
|
102 |
+
b, h, w = depth.shape
|
103 |
+
grid = coords_grid(b, h, w, homogeneous=True, device=depth.device) # [B, 3, H, W]
|
104 |
+
|
105 |
+
intrinsics_inv = torch.inverse(intrinsics) # [B, 3, 3]
|
106 |
+
|
107 |
+
points = intrinsics_inv.bmm(grid.view(b, 3, -1)).view(b, 3, h, w) * depth.unsqueeze(
|
108 |
+
1
|
109 |
+
) # [B, 3, H, W]
|
110 |
+
|
111 |
+
return points
|
112 |
+
|
113 |
+
|
114 |
+
def camera_transform(points_ref, extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None):
|
115 |
+
# Transform 3D points from reference camera to target camera
|
116 |
+
# points_ref: [B, 3, H, W]
|
117 |
+
# extrinsics_ref: [B, 4, 4]
|
118 |
+
# extrinsics_tgt: [B, 4, 4]
|
119 |
+
# extrinsics_rel: [B, 4, 4], relative pose transform
|
120 |
+
b, _, h, w = points_ref.shape
|
121 |
+
|
122 |
+
if extrinsics_rel is None:
|
123 |
+
extrinsics_rel = torch.bmm(extrinsics_tgt, torch.inverse(extrinsics_ref)) # [B, 4, 4]
|
124 |
+
|
125 |
+
points_tgt = (
|
126 |
+
torch.bmm(extrinsics_rel[:, :3, :3], points_ref.view(b, 3, -1))
|
127 |
+
+ extrinsics_rel[:, :3, -1:]
|
128 |
+
) # [B, 3, H*W]
|
129 |
+
|
130 |
+
points_tgt = points_tgt.view(b, 3, h, w) # [B, 3, H, W]
|
131 |
+
|
132 |
+
return points_tgt
|
133 |
+
|
134 |
+
|
135 |
+
def reproject(points_tgt, intrinsics, return_mask=False):
|
136 |
+
# reproject to target view
|
137 |
+
# points_tgt: [B, 3, H, W]
|
138 |
+
# intrinsics: [B, 3, 3]
|
139 |
+
|
140 |
+
b, _, h, w = points_tgt.shape
|
141 |
+
|
142 |
+
proj_points = torch.bmm(intrinsics, points_tgt.view(b, 3, -1)).view(b, 3, h, w) # [B, 3, H, W]
|
143 |
+
|
144 |
+
X = proj_points[:, 0]
|
145 |
+
Y = proj_points[:, 1]
|
146 |
+
Z = proj_points[:, 2].clamp(min=1e-3)
|
147 |
+
|
148 |
+
pixel_coords = torch.stack([X / Z, Y / Z], dim=1).view(
|
149 |
+
b, 2, h, w
|
150 |
+
) # [B, 2, H, W] in image scale
|
151 |
+
|
152 |
+
if return_mask:
|
153 |
+
# valid mask in pixel space
|
154 |
+
mask = (
|
155 |
+
(pixel_coords[:, 0] >= 0)
|
156 |
+
& (pixel_coords[:, 0] <= (w - 1))
|
157 |
+
& (pixel_coords[:, 1] >= 0)
|
158 |
+
& (pixel_coords[:, 1] <= (h - 1))
|
159 |
+
) # [B, H, W]
|
160 |
+
|
161 |
+
return pixel_coords, mask
|
162 |
+
|
163 |
+
return pixel_coords
|
164 |
+
|
165 |
+
|
166 |
+
def reproject_coords(
|
167 |
+
depth_ref,
|
168 |
+
intrinsics,
|
169 |
+
extrinsics_ref=None,
|
170 |
+
extrinsics_tgt=None,
|
171 |
+
extrinsics_rel=None,
|
172 |
+
return_mask=False,
|
173 |
+
):
|
174 |
+
# Compute reprojection sample coords
|
175 |
+
points_ref = back_project(depth_ref, intrinsics) # [B, 3, H, W]
|
176 |
+
points_tgt = camera_transform(
|
177 |
+
points_ref, extrinsics_ref, extrinsics_tgt, extrinsics_rel=extrinsics_rel
|
178 |
+
)
|
179 |
+
|
180 |
+
if return_mask:
|
181 |
+
reproj_coords, mask = reproject(
|
182 |
+
points_tgt, intrinsics, return_mask=return_mask
|
183 |
+
) # [B, 2, H, W] in image scale
|
184 |
+
|
185 |
+
return reproj_coords, mask
|
186 |
+
|
187 |
+
reproj_coords = reproject(
|
188 |
+
points_tgt, intrinsics, return_mask=return_mask
|
189 |
+
) # [B, 2, H, W] in image scale
|
190 |
+
|
191 |
+
return reproj_coords
|
192 |
+
|
193 |
+
|
194 |
+
def compute_flow_with_depth_pose(
|
195 |
+
depth_ref,
|
196 |
+
intrinsics,
|
197 |
+
extrinsics_ref=None,
|
198 |
+
extrinsics_tgt=None,
|
199 |
+
extrinsics_rel=None,
|
200 |
+
return_mask=False,
|
201 |
+
):
|
202 |
+
b, h, w = depth_ref.shape
|
203 |
+
coords_init = coords_grid(b, h, w, device=depth_ref.device) # [B, 2, H, W]
|
204 |
+
|
205 |
+
if return_mask:
|
206 |
+
reproj_coords, mask = reproject_coords(
|
207 |
+
depth_ref,
|
208 |
+
intrinsics,
|
209 |
+
extrinsics_ref,
|
210 |
+
extrinsics_tgt,
|
211 |
+
extrinsics_rel=extrinsics_rel,
|
212 |
+
return_mask=return_mask,
|
213 |
+
) # [B, 2, H, W]
|
214 |
+
rigid_flow = reproj_coords - coords_init
|
215 |
+
|
216 |
+
return rigid_flow, mask
|
217 |
+
|
218 |
+
reproj_coords = reproject_coords(
|
219 |
+
depth_ref,
|
220 |
+
intrinsics,
|
221 |
+
extrinsics_ref,
|
222 |
+
extrinsics_tgt,
|
223 |
+
extrinsics_rel=extrinsics_rel,
|
224 |
+
return_mask=return_mask,
|
225 |
+
) # [B, 2, H, W]
|
226 |
+
|
227 |
+
rigid_flow = reproj_coords - coords_init
|
228 |
+
|
229 |
+
return rigid_flow
|
matching.py
ADDED
@@ -0,0 +1,505 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
from .geometry import coords_grid, generate_window_grid, normalize_coords
|
5 |
+
|
6 |
+
|
7 |
+
def global_correlation_softmax_prototype(
|
8 |
+
feature0,
|
9 |
+
feature1,
|
10 |
+
value,
|
11 |
+
pred_bidir_flow=False,
|
12 |
+
corr_mask=None,
|
13 |
+
):
|
14 |
+
"""
|
15 |
+
feature0: [B, C, H, W]
|
16 |
+
feature1: [B, C, H, W]
|
17 |
+
value: [B, C1, H, W]
|
18 |
+
corr_mask: [B, H*W, H*W] or None, if not None, the value will be masked out
|
19 |
+
"""
|
20 |
+
# global correlation
|
21 |
+
b, c, h, w = feature0.shape
|
22 |
+
c_value = value.size(1)
|
23 |
+
value = value.view(b, c_value, -1).permute(0, 2, 1) # [B, H*W, C1]
|
24 |
+
|
25 |
+
feature0 = feature0.view(b, c, -1).permute(0, 2, 1) # [B, H*W, C]
|
26 |
+
feature1 = feature1.view(b, c, -1) # [B, C, H*W]
|
27 |
+
|
28 |
+
correlation = torch.matmul(feature0, feature1).view(b, h, w, h, w) / (
|
29 |
+
c**0.5
|
30 |
+
) # [B, H, W, H, W]
|
31 |
+
|
32 |
+
correlation = correlation.view(b, h * w, h * w) # [B, H*W, H*W]
|
33 |
+
|
34 |
+
if pred_bidir_flow:
|
35 |
+
correlation = torch.cat(
|
36 |
+
(correlation, correlation.permute(0, 2, 1)), dim=0
|
37 |
+
) # [2*B, H*W, H*W]
|
38 |
+
value = value.repeat(2, 1, 1) # [2*B, H*W, 2]
|
39 |
+
b = b * 2
|
40 |
+
|
41 |
+
if corr_mask is not None:
|
42 |
+
# mask out the correlation with corr_mask
|
43 |
+
if corr_mask.dtype == torch.bool:
|
44 |
+
# binary mask
|
45 |
+
correlation[corr_mask] = -65504.0
|
46 |
+
prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W]
|
47 |
+
else:
|
48 |
+
# float mask
|
49 |
+
# bin_mask = corr_mask < 0
|
50 |
+
# correlation[bin_mask] = -65504.0
|
51 |
+
prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W]
|
52 |
+
# soft_mask = torch.clamp(corr_mask, min=0)
|
53 |
+
prob = prob * corr_mask
|
54 |
+
# normalize
|
55 |
+
prob = prob / (prob.sum(dim=2, keepdim=True) + 1e-8)
|
56 |
+
else:
|
57 |
+
prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W]
|
58 |
+
# if corr_mask.dtype == torch.bool:
|
59 |
+
# # binary mask
|
60 |
+
# bin_mask = corr_mask
|
61 |
+
# num_true = bin_mask.sum(-1).unique().item()
|
62 |
+
# soft_mask = torch.ones((b, h * w, num_true)).to(corr_mask.device)
|
63 |
+
# else:
|
64 |
+
# # float mask
|
65 |
+
# bin_mask = corr_mask >= 0
|
66 |
+
# num_true = bin_mask.sum(-1).unique().item()
|
67 |
+
# soft_mask = corr_mask[bin_mask].view(b, h * w, num_true)
|
68 |
+
# assert soft_mask.min() >= 0
|
69 |
+
# correlation_seleted = correlation[bin_mask].view(b, h * w, num_true)
|
70 |
+
# prob_seleted = F.softmax(correlation_seleted, dim=-1) # [B, H*W, num_true]
|
71 |
+
# prob_seleted = soft_mask * prob_seleted
|
72 |
+
# prob_seleted = prob_seleted / (prob_seleted.sum(dim=2, keepdim=True) + 1e-8)
|
73 |
+
# prob = torch.zeros_like(correlation)
|
74 |
+
# prob[bin_mask] = prob_seleted.view(-1)
|
75 |
+
# else:
|
76 |
+
# prob = F.softmax(correlation, dim=-1)
|
77 |
+
|
78 |
+
result = torch.matmul(prob, value).view(b, h, w, c_value).permute(0, 3, 1, 2) # [B, 2, H, W]\
|
79 |
+
return result, correlation
|
80 |
+
|
81 |
+
|
82 |
+
def local_correlation_softmax_prototype(
|
83 |
+
feature0,
|
84 |
+
feature1,
|
85 |
+
value,
|
86 |
+
radius=5,
|
87 |
+
pred_bidir_flow=False,
|
88 |
+
corr_mask=None,
|
89 |
+
):
|
90 |
+
"""
|
91 |
+
softmax around argmax point
|
92 |
+
feature0: [B, C, H, W]
|
93 |
+
feature1: [B, C, H, W]
|
94 |
+
value: [B, C1, H, W]
|
95 |
+
corr_mask: [B, H*W, H*W] or None, if not None, the value will be masked out
|
96 |
+
"""
|
97 |
+
b, c, h, w = feature0.shape
|
98 |
+
c_value = value.size(1)
|
99 |
+
value = value.view(b, c_value, -1).permute(0, 2, 1) # [B, H*W, C1]
|
100 |
+
|
101 |
+
feature0 = feature0.view(b, c, -1).permute(0, 2, 1) # [B, H*W, C]
|
102 |
+
feature1 = feature1.view(b, c, -1) # [B, C, H*W]
|
103 |
+
|
104 |
+
correlation = torch.matmul(feature0, feature1).view(b, h, w, h, w) / (
|
105 |
+
c**0.5
|
106 |
+
) # [B, H, W, H, W]
|
107 |
+
|
108 |
+
correlation = correlation.view(b, h * w, h * w) # [B, H*W, H*W]
|
109 |
+
|
110 |
+
if pred_bidir_flow:
|
111 |
+
correlation = torch.cat(
|
112 |
+
(correlation, correlation.permute(0, 2, 1)), dim=0
|
113 |
+
) # [2*B, H*W, H*W]
|
114 |
+
value = value.repeat(2, 1, 1) # [2*B, H*W, 2]
|
115 |
+
b = b * 2
|
116 |
+
|
117 |
+
if corr_mask is not None:
|
118 |
+
# mask out the correlation with corr_mask
|
119 |
+
if corr_mask.dtype == torch.bool:
|
120 |
+
# binary mask
|
121 |
+
correlation[corr_mask] = -65504.0
|
122 |
+
prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W]
|
123 |
+
else:
|
124 |
+
# float mask
|
125 |
+
# bin_mask = corr_mask < 0
|
126 |
+
# correlation[bin_mask] = -65504.0
|
127 |
+
prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W]
|
128 |
+
# soft_mask = torch.clamp(corr_mask, min=0)
|
129 |
+
prob = prob * corr_mask
|
130 |
+
else:
|
131 |
+
prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W]
|
132 |
+
|
133 |
+
# get local prob
|
134 |
+
# B, H*W, 2
|
135 |
+
coords = coords_grid(b, h, w, device=feature0.device).flatten(2).permute(0, 2, 1)
|
136 |
+
# B, H*W, H*W, 2
|
137 |
+
coords = coords.unsqueeze(1).repeat(1, h * w, 1, 1)
|
138 |
+
# B, H*W
|
139 |
+
argmax_pos = torch.argmax(prob, dim=2)
|
140 |
+
# B, H*W, 1, 2
|
141 |
+
argmax_pos = argmax_pos.view(b, h * w, 1, 1).repeat(1, 1, 1, 2)
|
142 |
+
# B, H*W, 1, 2
|
143 |
+
pos = torch.gather(coords, 2, argmax_pos)
|
144 |
+
# B, H*W, H*W
|
145 |
+
valid = ((coords - pos).square().sum(dim=-1) < (radius**2)).float()
|
146 |
+
prob = prob * valid
|
147 |
+
|
148 |
+
# normalize
|
149 |
+
prob = prob / (prob.sum(dim=2, keepdim=True) + 1e-8)
|
150 |
+
result = torch.matmul(prob, value).view(b, h, w, c_value).permute(0, 3, 1, 2) # [B, 2, H, W]\
|
151 |
+
return result, correlation
|
152 |
+
|
153 |
+
|
154 |
+
def global_correlation_softmax(
|
155 |
+
feature0,
|
156 |
+
feature1,
|
157 |
+
pred_bidir_flow=False,
|
158 |
+
):
|
159 |
+
# global correlation
|
160 |
+
b, c, h, w = feature0.shape
|
161 |
+
feature0 = feature0.view(b, c, -1).permute(0, 2, 1) # [B, H*W, C]
|
162 |
+
feature1 = feature1.view(b, c, -1) # [B, C, H*W]
|
163 |
+
|
164 |
+
correlation = torch.matmul(feature0, feature1).view(b, h, w, h, w) / (
|
165 |
+
c**0.5
|
166 |
+
) # [B, H, W, H, W]
|
167 |
+
|
168 |
+
# flow from softmax
|
169 |
+
init_grid = coords_grid(b, h, w).to(correlation.device) # [B, 2, H, W]
|
170 |
+
grid = init_grid.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2]
|
171 |
+
|
172 |
+
correlation = correlation.view(b, h * w, h * w) # [B, H*W, H*W]
|
173 |
+
|
174 |
+
if pred_bidir_flow:
|
175 |
+
correlation = torch.cat(
|
176 |
+
(correlation, correlation.permute(0, 2, 1)), dim=0
|
177 |
+
) # [2*B, H*W, H*W]
|
178 |
+
init_grid = init_grid.repeat(2, 1, 1, 1) # [2*B, 2, H, W]
|
179 |
+
grid = grid.repeat(2, 1, 1) # [2*B, H*W, 2]
|
180 |
+
b = b * 2
|
181 |
+
|
182 |
+
prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W]
|
183 |
+
|
184 |
+
correspondence = torch.matmul(prob, grid).view(b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W]
|
185 |
+
|
186 |
+
# when predicting bidirectional flow, flow is the concatenation of forward flow and backward flow
|
187 |
+
flow = correspondence - init_grid
|
188 |
+
|
189 |
+
return flow, prob
|
190 |
+
|
191 |
+
|
192 |
+
def local_correlation_softmax(
|
193 |
+
feature0,
|
194 |
+
feature1,
|
195 |
+
local_radius,
|
196 |
+
padding_mode="zeros",
|
197 |
+
):
|
198 |
+
b, c, h, w = feature0.size()
|
199 |
+
coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W]
|
200 |
+
coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2]
|
201 |
+
|
202 |
+
local_h = 2 * local_radius + 1
|
203 |
+
local_w = 2 * local_radius + 1
|
204 |
+
|
205 |
+
window_grid = generate_window_grid(
|
206 |
+
-local_radius,
|
207 |
+
local_radius,
|
208 |
+
-local_radius,
|
209 |
+
local_radius,
|
210 |
+
local_h,
|
211 |
+
local_w,
|
212 |
+
device=feature0.device,
|
213 |
+
) # [2R+1, 2R+1, 2]
|
214 |
+
window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1)^2, 2]
|
215 |
+
sample_coords = coords.unsqueeze(-2) + window_grid # [B, H*W, (2R+1)^2, 2]
|
216 |
+
|
217 |
+
sample_coords_softmax = sample_coords
|
218 |
+
|
219 |
+
# exclude coords that are out of image space
|
220 |
+
valid_x = (sample_coords[:, :, :, 0] >= 0) & (
|
221 |
+
sample_coords[:, :, :, 0] < w
|
222 |
+
) # [B, H*W, (2R+1)^2]
|
223 |
+
valid_y = (sample_coords[:, :, :, 1] >= 0) & (
|
224 |
+
sample_coords[:, :, :, 1] < h
|
225 |
+
) # [B, H*W, (2R+1)^2]
|
226 |
+
|
227 |
+
valid = valid_x & valid_y # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax
|
228 |
+
|
229 |
+
# normalize coordinates to [-1, 1]
|
230 |
+
sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1]
|
231 |
+
window_feature = F.grid_sample(
|
232 |
+
feature1, sample_coords_norm, padding_mode=padding_mode, align_corners=False
|
233 |
+
).permute(
|
234 |
+
0, 2, 1, 3
|
235 |
+
) # [B, H*W, C, (2R+1)^2]
|
236 |
+
feature0_view = feature0.permute(0, 2, 3, 1).view(b, h * w, 1, c) # [B, H*W, 1, C]
|
237 |
+
|
238 |
+
corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (
|
239 |
+
c**0.5
|
240 |
+
) # [B, H*W, (2R+1)^2]
|
241 |
+
|
242 |
+
# mask invalid locations
|
243 |
+
corr[~valid] = -1e9
|
244 |
+
|
245 |
+
prob = F.softmax(corr, -1) # [B, H*W, (2R+1)^2]
|
246 |
+
|
247 |
+
correspondence = (
|
248 |
+
torch.matmul(prob.unsqueeze(-2), sample_coords_softmax)
|
249 |
+
.squeeze(-2)
|
250 |
+
.view(b, h, w, 2)
|
251 |
+
.permute(0, 3, 1, 2)
|
252 |
+
) # [B, 2, H, W]
|
253 |
+
|
254 |
+
flow = correspondence - coords_init
|
255 |
+
match_prob = prob
|
256 |
+
|
257 |
+
return flow, match_prob
|
258 |
+
|
259 |
+
|
260 |
+
def local_correlation_with_flow(
|
261 |
+
feature0,
|
262 |
+
feature1,
|
263 |
+
flow,
|
264 |
+
local_radius,
|
265 |
+
padding_mode="zeros",
|
266 |
+
dilation=1,
|
267 |
+
):
|
268 |
+
b, c, h, w = feature0.size()
|
269 |
+
coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W]
|
270 |
+
coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2]
|
271 |
+
|
272 |
+
local_h = 2 * local_radius + 1
|
273 |
+
local_w = 2 * local_radius + 1
|
274 |
+
|
275 |
+
window_grid = generate_window_grid(
|
276 |
+
-local_radius,
|
277 |
+
local_radius,
|
278 |
+
-local_radius,
|
279 |
+
local_radius,
|
280 |
+
local_h,
|
281 |
+
local_w,
|
282 |
+
device=feature0.device,
|
283 |
+
) # [2R+1, 2R+1, 2]
|
284 |
+
window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1)^2, 2]
|
285 |
+
sample_coords = coords.unsqueeze(-2) + window_grid * dilation # [B, H*W, (2R+1)^2, 2]
|
286 |
+
|
287 |
+
# flow can be zero when using features after transformer
|
288 |
+
if not isinstance(flow, float):
|
289 |
+
sample_coords = sample_coords + flow.view(b, 2, -1).permute(0, 2, 1).unsqueeze(
|
290 |
+
-2
|
291 |
+
) # [B, H*W, (2R+1)^2, 2]
|
292 |
+
else:
|
293 |
+
assert flow == 0.0
|
294 |
+
|
295 |
+
# normalize coordinates to [-1, 1]
|
296 |
+
sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1]
|
297 |
+
window_feature = F.grid_sample(
|
298 |
+
feature1, sample_coords_norm, padding_mode=padding_mode, align_corners=False
|
299 |
+
).permute(
|
300 |
+
0, 2, 1, 3
|
301 |
+
) # [B, H*W, C, (2R+1)^2]
|
302 |
+
feature0_view = feature0.permute(0, 2, 3, 1).view(b, h * w, 1, c) # [B, H*W, 1, C]
|
303 |
+
|
304 |
+
corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (
|
305 |
+
c**0.5
|
306 |
+
) # [B, H*W, (2R+1)^2]
|
307 |
+
|
308 |
+
corr = corr.view(b, h, w, -1).permute(0, 3, 1, 2).contiguous() # [B, (2R+1)^2, H, W]
|
309 |
+
|
310 |
+
return corr
|
311 |
+
|
312 |
+
|
313 |
+
def global_correlation_softmax_stereo(
|
314 |
+
feature0,
|
315 |
+
feature1,
|
316 |
+
):
|
317 |
+
# global correlation on horizontal direction
|
318 |
+
b, c, h, w = feature0.shape
|
319 |
+
|
320 |
+
x_grid = torch.linspace(0, w - 1, w, device=feature0.device) # [W]
|
321 |
+
|
322 |
+
feature0 = feature0.permute(0, 2, 3, 1) # [B, H, W, C]
|
323 |
+
feature1 = feature1.permute(0, 2, 1, 3) # [B, H, C, W]
|
324 |
+
|
325 |
+
correlation = torch.matmul(feature0, feature1) / (c**0.5) # [B, H, W, W]
|
326 |
+
|
327 |
+
# mask subsequent positions to make disparity positive
|
328 |
+
mask = torch.triu(torch.ones((w, w)), diagonal=1).type_as(feature0) # [W, W]
|
329 |
+
valid_mask = (mask == 0).unsqueeze(0).unsqueeze(0).repeat(b, h, 1, 1) # [B, H, W, W]
|
330 |
+
|
331 |
+
correlation[~valid_mask] = -1e9
|
332 |
+
|
333 |
+
prob = F.softmax(correlation, dim=-1) # [B, H, W, W]
|
334 |
+
|
335 |
+
correspondence = (x_grid.view(1, 1, 1, w) * prob).sum(-1) # [B, H, W]
|
336 |
+
|
337 |
+
# NOTE: unlike flow, disparity is typically positive
|
338 |
+
disparity = x_grid.view(1, 1, w).repeat(b, h, 1) - correspondence # [B, H, W]
|
339 |
+
|
340 |
+
return disparity.unsqueeze(1), prob # feature resolution
|
341 |
+
|
342 |
+
|
343 |
+
def local_correlation_softmax_stereo(
|
344 |
+
feature0,
|
345 |
+
feature1,
|
346 |
+
local_radius,
|
347 |
+
):
|
348 |
+
b, c, h, w = feature0.size()
|
349 |
+
coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W]
|
350 |
+
coords = coords_init.view(b, 2, -1).permute(0, 2, 1).contiguous() # [B, H*W, 2]
|
351 |
+
|
352 |
+
local_h = 1
|
353 |
+
local_w = 2 * local_radius + 1
|
354 |
+
|
355 |
+
window_grid = generate_window_grid(
|
356 |
+
0, 0, -local_radius, local_radius, local_h, local_w, device=feature0.device
|
357 |
+
) # [1, 2R+1, 2]
|
358 |
+
window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1), 2]
|
359 |
+
sample_coords = coords.unsqueeze(-2) + window_grid # [B, H*W, (2R+1), 2]
|
360 |
+
|
361 |
+
sample_coords_softmax = sample_coords
|
362 |
+
|
363 |
+
# exclude coords that are out of image space
|
364 |
+
valid_x = (sample_coords[:, :, :, 0] >= 0) & (
|
365 |
+
sample_coords[:, :, :, 0] < w
|
366 |
+
) # [B, H*W, (2R+1)^2]
|
367 |
+
valid_y = (sample_coords[:, :, :, 1] >= 0) & (
|
368 |
+
sample_coords[:, :, :, 1] < h
|
369 |
+
) # [B, H*W, (2R+1)^2]
|
370 |
+
|
371 |
+
valid = valid_x & valid_y # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax
|
372 |
+
|
373 |
+
# normalize coordinates to [-1, 1]
|
374 |
+
sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1]
|
375 |
+
window_feature = F.grid_sample(
|
376 |
+
feature1, sample_coords_norm, padding_mode="zeros", align_corners=False
|
377 |
+
).permute(
|
378 |
+
0, 2, 1, 3
|
379 |
+
) # [B, H*W, C, (2R+1)]
|
380 |
+
feature0_view = (
|
381 |
+
feature0.permute(0, 2, 3, 1).contiguous().view(b, h * w, 1, c)
|
382 |
+
) # [B, H*W, 1, C]
|
383 |
+
|
384 |
+
corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (
|
385 |
+
c**0.5
|
386 |
+
) # [B, H*W, (2R+1)]
|
387 |
+
|
388 |
+
# mask invalid locations
|
389 |
+
corr[~valid] = -1e9
|
390 |
+
|
391 |
+
prob = F.softmax(corr, -1) # [B, H*W, (2R+1)]
|
392 |
+
|
393 |
+
correspondence = (
|
394 |
+
torch.matmul(prob.unsqueeze(-2), sample_coords_softmax)
|
395 |
+
.squeeze(-2)
|
396 |
+
.view(b, h, w, 2)
|
397 |
+
.permute(0, 3, 1, 2)
|
398 |
+
.contiguous()
|
399 |
+
) # [B, 2, H, W]
|
400 |
+
|
401 |
+
flow = correspondence - coords_init # flow at feature resolution
|
402 |
+
match_prob = prob
|
403 |
+
|
404 |
+
flow_x = -flow[:, :1] # [B, 1, H, W]
|
405 |
+
|
406 |
+
return flow_x, match_prob
|
407 |
+
|
408 |
+
|
409 |
+
def correlation_softmax_depth(
|
410 |
+
feature0,
|
411 |
+
feature1,
|
412 |
+
intrinsics,
|
413 |
+
pose,
|
414 |
+
depth_candidates,
|
415 |
+
depth_from_argmax=False,
|
416 |
+
pred_bidir_depth=False,
|
417 |
+
):
|
418 |
+
b, c, h, w = feature0.size()
|
419 |
+
assert depth_candidates.dim() == 4 # [B, D, H, W]
|
420 |
+
scale_factor = c**0.5
|
421 |
+
|
422 |
+
if pred_bidir_depth:
|
423 |
+
feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat(
|
424 |
+
(feature1, feature0), dim=0
|
425 |
+
)
|
426 |
+
intrinsics = intrinsics.repeat(2, 1, 1)
|
427 |
+
pose = torch.cat((pose, torch.inverse(pose)), dim=0)
|
428 |
+
depth_candidates = depth_candidates.repeat(2, 1, 1, 1)
|
429 |
+
|
430 |
+
# depth candidates are actually inverse depth
|
431 |
+
warped_feature1 = warp_with_pose_depth_candidates(
|
432 |
+
feature1,
|
433 |
+
intrinsics,
|
434 |
+
pose,
|
435 |
+
1.0 / depth_candidates,
|
436 |
+
) # [B, C, D, H, W]
|
437 |
+
|
438 |
+
correlation = (feature0.unsqueeze(2) * warped_feature1).sum(1) / scale_factor # [B, D, H, W]
|
439 |
+
|
440 |
+
match_prob = F.softmax(correlation, dim=1) # [B, D, H, W]
|
441 |
+
|
442 |
+
# for cross-task transfer (flow -> depth), extract depth with argmax at test time
|
443 |
+
if depth_from_argmax:
|
444 |
+
index = torch.argmax(match_prob, dim=1, keepdim=True)
|
445 |
+
depth = torch.gather(depth_candidates, dim=1, index=index)
|
446 |
+
else:
|
447 |
+
depth = (match_prob * depth_candidates).sum(dim=1, keepdim=True) # [B, 1, H, W]
|
448 |
+
|
449 |
+
return depth, match_prob
|
450 |
+
|
451 |
+
|
452 |
+
def warp_with_pose_depth_candidates(
|
453 |
+
feature1,
|
454 |
+
intrinsics,
|
455 |
+
pose,
|
456 |
+
depth,
|
457 |
+
clamp_min_depth=1e-3,
|
458 |
+
):
|
459 |
+
"""
|
460 |
+
feature1: [B, C, H, W]
|
461 |
+
intrinsics: [B, 3, 3]
|
462 |
+
pose: [B, 4, 4]
|
463 |
+
depth: [B, D, H, W]
|
464 |
+
"""
|
465 |
+
|
466 |
+
assert intrinsics.size(1) == intrinsics.size(2) == 3
|
467 |
+
assert pose.size(1) == pose.size(2) == 4
|
468 |
+
assert depth.dim() == 4
|
469 |
+
|
470 |
+
b, d, h, w = depth.size()
|
471 |
+
c = feature1.size(1)
|
472 |
+
|
473 |
+
with torch.no_grad():
|
474 |
+
# pixel coordinates
|
475 |
+
grid = coords_grid(b, h, w, homogeneous=True, device=depth.device) # [B, 3, H, W]
|
476 |
+
# back project to 3D and transform viewpoint
|
477 |
+
points = torch.inverse(intrinsics).bmm(grid.view(b, 3, -1)) # [B, 3, H*W]
|
478 |
+
points = torch.bmm(pose[:, :3, :3], points).unsqueeze(2).repeat(1, 1, d, 1) * depth.view(
|
479 |
+
b, 1, d, h * w
|
480 |
+
) # [B, 3, D, H*W]
|
481 |
+
points = points + pose[:, :3, -1:].unsqueeze(-1) # [B, 3, D, H*W]
|
482 |
+
# reproject to 2D image plane
|
483 |
+
points = torch.bmm(intrinsics, points.view(b, 3, -1)).view(
|
484 |
+
b, 3, d, h * w
|
485 |
+
) # [B, 3, D, H*W]
|
486 |
+
pixel_coords = points[:, :2] / points[:, -1:].clamp(min=clamp_min_depth) # [B, 2, D, H*W]
|
487 |
+
|
488 |
+
# normalize to [-1, 1]
|
489 |
+
x_grid = 2 * pixel_coords[:, 0] / (w - 1) - 1
|
490 |
+
y_grid = 2 * pixel_coords[:, 1] / (h - 1) - 1
|
491 |
+
|
492 |
+
grid = torch.stack([x_grid, y_grid], dim=-1) # [B, D, H*W, 2]
|
493 |
+
|
494 |
+
# sample features
|
495 |
+
warped_feature = F.grid_sample(
|
496 |
+
feature1,
|
497 |
+
grid.view(b, d * h, w, 2),
|
498 |
+
mode="bilinear",
|
499 |
+
padding_mode="zeros",
|
500 |
+
align_corners=False,
|
501 |
+
).view(
|
502 |
+
b, c, d, h, w
|
503 |
+
) # [B, C, D, H, W]
|
504 |
+
|
505 |
+
return warped_feature
|
modeling_doduo.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Tuple
|
2 |
+
from PIL import Image
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from torchvision import transforms
|
9 |
+
from transformers import PreTrainedModel
|
10 |
+
|
11 |
+
from .dino import vit_small
|
12 |
+
from .unimatch import UniMatch
|
13 |
+
from .configuration_doduo import DoduoConfig
|
14 |
+
|
15 |
+
class DoduoModel(PreTrainedModel):
|
16 |
+
config_class = DoduoConfig
|
17 |
+
|
18 |
+
def __init__(self, config):
|
19 |
+
super().__init__(config)
|
20 |
+
self.model = CorrSegFlowNet(
|
21 |
+
dino_corr_mask_ratio=config.dino_corr_mask_ratio
|
22 |
+
)
|
23 |
+
|
24 |
+
def forward(self, frame_src, frame_dst):
|
25 |
+
if isinstance(frame_src, Image.Image):
|
26 |
+
frame_src = self.model.process_frame(frame_src)
|
27 |
+
frame_dst = self.model.process_frame(frame_dst)
|
28 |
+
assert frame_src.shape == frame_dst.shape
|
29 |
+
return self.model(frame_src, frame_dst)
|
30 |
+
|
31 |
+
class CorrSegFlowNet(nn.Module):
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
dino_corr_mask_ratio: float = 0.1,
|
35 |
+
):
|
36 |
+
super().__init__()
|
37 |
+
|
38 |
+
self.dino_corr_mask_ratio = dino_corr_mask_ratio
|
39 |
+
self.unimatch = UniMatch(bilinear_upsample=True)
|
40 |
+
self.dino = vit_small(patch_size=8, num_classes=0)
|
41 |
+
for k in self.dino.parameters():
|
42 |
+
k.requires_grad = False
|
43 |
+
|
44 |
+
self.transform = transforms.Compose(
|
45 |
+
[
|
46 |
+
lambda x: transforms.ToTensor()(x)[:3],
|
47 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
48 |
+
]
|
49 |
+
)
|
50 |
+
|
51 |
+
def process_frame(self, frame):
|
52 |
+
device = next(self.parameters()).device
|
53 |
+
frame = self.transform(frame)
|
54 |
+
frame = frame.unsqueeze(0).to(device)
|
55 |
+
return frame
|
56 |
+
|
57 |
+
def forward(
|
58 |
+
self,
|
59 |
+
frame_src,
|
60 |
+
frame_dst,
|
61 |
+
):
|
62 |
+
corr_mask = get_dino_corr_mask(
|
63 |
+
self.dino,
|
64 |
+
frame_src,
|
65 |
+
frame_dst,
|
66 |
+
mask_ratio=self.dino_corr_mask_ratio
|
67 |
+
)
|
68 |
+
|
69 |
+
flow, flow_low, correlation, feature0, feature1 = self.unimatch(
|
70 |
+
frame_src,
|
71 |
+
frame_dst,
|
72 |
+
return_feature=True,
|
73 |
+
bidirectional=False,
|
74 |
+
cycle_consistency=False,
|
75 |
+
corr_mask=corr_mask,
|
76 |
+
)
|
77 |
+
return flow
|
78 |
+
|
79 |
+
@torch.no_grad()
|
80 |
+
def extract_dino_feature(model, frame, return_h_w=False):
|
81 |
+
"""frame: B, C, H, W"""
|
82 |
+
B = frame.shape[0]
|
83 |
+
out = model.get_intermediate_layers(frame, n=1)[0]
|
84 |
+
out = out[:, 1:, :] # we discard the [CLS] token
|
85 |
+
h, w = int(frame.shape[2] / model.patch_embed.patch_size), int(
|
86 |
+
frame.shape[3] / model.patch_embed.patch_size
|
87 |
+
)
|
88 |
+
dim = out.shape[-1]
|
89 |
+
out = out.reshape(B, -1, dim)
|
90 |
+
if return_h_w:
|
91 |
+
return out, h, w
|
92 |
+
return out
|
93 |
+
|
94 |
+
@torch.no_grad()
|
95 |
+
def get_dino_corr_mask(
|
96 |
+
model, frame_src, frame_dst, mask_ratio
|
97 |
+
):
|
98 |
+
# frame_src: B x C x H x W
|
99 |
+
# frame_dst: B x C x H x W
|
100 |
+
# mask_ratio: ratio of pixels to be masked
|
101 |
+
# return: B x h*w x h*w
|
102 |
+
feat_1, h, w = extract_dino_feature(model, frame_src, return_h_w=True)
|
103 |
+
feat_2 = extract_dino_feature(model, frame_dst)
|
104 |
+
|
105 |
+
feat_1_norm = F.normalize(feat_1, dim=2, p=2)
|
106 |
+
feat_2_norm = F.normalize(feat_2, dim=2, p=2)
|
107 |
+
aff_raw = torch.einsum("bnc,bmc->bnm", [feat_1_norm, feat_2_norm])
|
108 |
+
|
109 |
+
if mask_ratio <= 0:
|
110 |
+
# no corr mask
|
111 |
+
corr_mask = None
|
112 |
+
else:
|
113 |
+
if aff_raw.dtype == torch.float16:
|
114 |
+
aff_raw = aff_raw.float()
|
115 |
+
aff_percentile = torch.quantile(aff_raw, mask_ratio, 2, keepdim=True)
|
116 |
+
# True for masked
|
117 |
+
corr_mask = aff_raw < aff_percentile
|
118 |
+
return corr_mask
|
position.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
# https://github.com/facebookresearch/detr/blob/main/models/position_encoding.py
|
3 |
+
|
4 |
+
import math
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
|
10 |
+
class PositionEmbeddingSine(nn.Module):
|
11 |
+
"""This is a more standard version of the position embedding, very similar to the one used by
|
12 |
+
the Attention is all you need paper, generalized to work on images."""
|
13 |
+
|
14 |
+
def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None):
|
15 |
+
super().__init__()
|
16 |
+
self.num_pos_feats = num_pos_feats
|
17 |
+
self.temperature = temperature
|
18 |
+
self.normalize = normalize
|
19 |
+
if scale is not None and normalize is False:
|
20 |
+
raise ValueError("normalize should be True if scale is passed")
|
21 |
+
if scale is None:
|
22 |
+
scale = 2 * math.pi
|
23 |
+
self.scale = scale
|
24 |
+
|
25 |
+
def forward(self, x):
|
26 |
+
# x = tensor_list.tensors # [B, C, H, W]
|
27 |
+
# mask = tensor_list.mask # [B, H, W], input with padding, valid as 0
|
28 |
+
b, c, h, w = x.size()
|
29 |
+
mask = torch.ones((b, h, w), device=x.device) # [B, H, W]
|
30 |
+
y_embed = mask.cumsum(1, dtype=torch.float32)
|
31 |
+
x_embed = mask.cumsum(2, dtype=torch.float32)
|
32 |
+
if self.normalize:
|
33 |
+
eps = 1e-6
|
34 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
35 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
36 |
+
|
37 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
38 |
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
39 |
+
|
40 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
41 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
42 |
+
pos_x = torch.stack(
|
43 |
+
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
44 |
+
).flatten(3)
|
45 |
+
pos_y = torch.stack(
|
46 |
+
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
47 |
+
).flatten(3)
|
48 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
49 |
+
return pos
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5340e547b103a76197fe408906930793aea54ed21a4c2f7845e73d1ca5810022
|
3 |
+
size 103562927
|
transformer.py
ADDED
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .attention import (
|
5 |
+
single_head_full_attention,
|
6 |
+
single_head_full_attention_1d,
|
7 |
+
single_head_split_window_attention,
|
8 |
+
single_head_split_window_attention_1d,
|
9 |
+
)
|
10 |
+
from .utils import generate_shift_window_attn_mask, generate_shift_window_attn_mask_1d
|
11 |
+
|
12 |
+
|
13 |
+
class TransformerLayer(nn.Module):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
d_model=128,
|
17 |
+
nhead=1,
|
18 |
+
no_ffn=False,
|
19 |
+
ffn_dim_expansion=4,
|
20 |
+
):
|
21 |
+
super().__init__()
|
22 |
+
|
23 |
+
self.dim = d_model
|
24 |
+
self.nhead = nhead
|
25 |
+
self.no_ffn = no_ffn
|
26 |
+
|
27 |
+
# multi-head attention
|
28 |
+
self.q_proj = nn.Linear(d_model, d_model, bias=False)
|
29 |
+
self.k_proj = nn.Linear(d_model, d_model, bias=False)
|
30 |
+
self.v_proj = nn.Linear(d_model, d_model, bias=False)
|
31 |
+
|
32 |
+
self.merge = nn.Linear(d_model, d_model, bias=False)
|
33 |
+
|
34 |
+
self.norm1 = nn.LayerNorm(d_model)
|
35 |
+
|
36 |
+
# no ffn after self-attn, with ffn after cross-attn
|
37 |
+
if not self.no_ffn:
|
38 |
+
in_channels = d_model * 2
|
39 |
+
self.mlp = nn.Sequential(
|
40 |
+
nn.Linear(in_channels, in_channels * ffn_dim_expansion, bias=False),
|
41 |
+
nn.GELU(),
|
42 |
+
nn.Linear(in_channels * ffn_dim_expansion, d_model, bias=False),
|
43 |
+
)
|
44 |
+
|
45 |
+
self.norm2 = nn.LayerNorm(d_model)
|
46 |
+
|
47 |
+
def forward(
|
48 |
+
self,
|
49 |
+
source,
|
50 |
+
target,
|
51 |
+
height=None,
|
52 |
+
width=None,
|
53 |
+
shifted_window_attn_mask=None,
|
54 |
+
shifted_window_attn_mask_1d=None,
|
55 |
+
attn_type="swin",
|
56 |
+
with_shift=False,
|
57 |
+
attn_num_splits=None,
|
58 |
+
):
|
59 |
+
# source, target: [B, L, C]
|
60 |
+
query, key, value = source, target, target
|
61 |
+
|
62 |
+
# for stereo: 2d attn in self-attn, 1d attn in cross-attn
|
63 |
+
is_self_attn = (query - key).abs().max() < 1e-6
|
64 |
+
|
65 |
+
# single-head attention
|
66 |
+
query = self.q_proj(query) # [B, L, C]
|
67 |
+
key = self.k_proj(key) # [B, L, C]
|
68 |
+
value = self.v_proj(value) # [B, L, C]
|
69 |
+
|
70 |
+
if attn_type == "swin" and attn_num_splits > 1: # self, cross-attn: both swin 2d
|
71 |
+
if self.nhead > 1:
|
72 |
+
# we observe that multihead attention slows down the speed and increases the memory consumption
|
73 |
+
# without bringing obvious performance gains and thus the implementation is removed
|
74 |
+
raise NotImplementedError
|
75 |
+
else:
|
76 |
+
message = single_head_split_window_attention(
|
77 |
+
query,
|
78 |
+
key,
|
79 |
+
value,
|
80 |
+
num_splits=attn_num_splits,
|
81 |
+
with_shift=with_shift,
|
82 |
+
h=height,
|
83 |
+
w=width,
|
84 |
+
attn_mask=shifted_window_attn_mask,
|
85 |
+
)
|
86 |
+
|
87 |
+
elif attn_type == "self_swin2d_cross_1d": # self-attn: swin 2d, cross-attn: full 1d
|
88 |
+
if self.nhead > 1:
|
89 |
+
raise NotImplementedError
|
90 |
+
else:
|
91 |
+
if is_self_attn:
|
92 |
+
if attn_num_splits > 1:
|
93 |
+
message = single_head_split_window_attention(
|
94 |
+
query,
|
95 |
+
key,
|
96 |
+
value,
|
97 |
+
num_splits=attn_num_splits,
|
98 |
+
with_shift=with_shift,
|
99 |
+
h=height,
|
100 |
+
w=width,
|
101 |
+
attn_mask=shifted_window_attn_mask,
|
102 |
+
)
|
103 |
+
else:
|
104 |
+
# full 2d attn
|
105 |
+
message = single_head_full_attention(query, key, value) # [N, L, C]
|
106 |
+
|
107 |
+
else:
|
108 |
+
# cross attn 1d
|
109 |
+
message = single_head_full_attention_1d(
|
110 |
+
query,
|
111 |
+
key,
|
112 |
+
value,
|
113 |
+
h=height,
|
114 |
+
w=width,
|
115 |
+
)
|
116 |
+
|
117 |
+
elif attn_type == "self_swin2d_cross_swin1d": # self-attn: swin 2d, cross-attn: swin 1d
|
118 |
+
if self.nhead > 1:
|
119 |
+
raise NotImplementedError
|
120 |
+
else:
|
121 |
+
if is_self_attn:
|
122 |
+
if attn_num_splits > 1:
|
123 |
+
# self attn shift window
|
124 |
+
message = single_head_split_window_attention(
|
125 |
+
query,
|
126 |
+
key,
|
127 |
+
value,
|
128 |
+
num_splits=attn_num_splits,
|
129 |
+
with_shift=with_shift,
|
130 |
+
h=height,
|
131 |
+
w=width,
|
132 |
+
attn_mask=shifted_window_attn_mask,
|
133 |
+
)
|
134 |
+
else:
|
135 |
+
# full 2d attn
|
136 |
+
message = single_head_full_attention(query, key, value) # [N, L, C]
|
137 |
+
else:
|
138 |
+
if attn_num_splits > 1:
|
139 |
+
assert shifted_window_attn_mask_1d is not None
|
140 |
+
# cross attn 1d shift
|
141 |
+
message = single_head_split_window_attention_1d(
|
142 |
+
query,
|
143 |
+
key,
|
144 |
+
value,
|
145 |
+
num_splits=attn_num_splits,
|
146 |
+
with_shift=with_shift,
|
147 |
+
h=height,
|
148 |
+
w=width,
|
149 |
+
attn_mask=shifted_window_attn_mask_1d,
|
150 |
+
)
|
151 |
+
else:
|
152 |
+
message = single_head_full_attention_1d(
|
153 |
+
query,
|
154 |
+
key,
|
155 |
+
value,
|
156 |
+
h=height,
|
157 |
+
w=width,
|
158 |
+
)
|
159 |
+
|
160 |
+
else:
|
161 |
+
message = single_head_full_attention(query, key, value) # [B, L, C]
|
162 |
+
|
163 |
+
message = self.merge(message) # [B, L, C]
|
164 |
+
message = self.norm1(message)
|
165 |
+
|
166 |
+
if not self.no_ffn:
|
167 |
+
message = self.mlp(torch.cat([source, message], dim=-1))
|
168 |
+
message = self.norm2(message)
|
169 |
+
|
170 |
+
return source + message
|
171 |
+
|
172 |
+
|
173 |
+
class TransformerBlock(nn.Module):
|
174 |
+
"""self attention + cross attention + FFN."""
|
175 |
+
|
176 |
+
def __init__(
|
177 |
+
self,
|
178 |
+
d_model=128,
|
179 |
+
nhead=1,
|
180 |
+
ffn_dim_expansion=4,
|
181 |
+
):
|
182 |
+
super().__init__()
|
183 |
+
|
184 |
+
self.self_attn = TransformerLayer(
|
185 |
+
d_model=d_model,
|
186 |
+
nhead=nhead,
|
187 |
+
no_ffn=True,
|
188 |
+
ffn_dim_expansion=ffn_dim_expansion,
|
189 |
+
)
|
190 |
+
|
191 |
+
self.cross_attn_ffn = TransformerLayer(
|
192 |
+
d_model=d_model,
|
193 |
+
nhead=nhead,
|
194 |
+
ffn_dim_expansion=ffn_dim_expansion,
|
195 |
+
)
|
196 |
+
|
197 |
+
def forward(
|
198 |
+
self,
|
199 |
+
source,
|
200 |
+
target,
|
201 |
+
height=None,
|
202 |
+
width=None,
|
203 |
+
shifted_window_attn_mask=None,
|
204 |
+
shifted_window_attn_mask_1d=None,
|
205 |
+
attn_type="swin",
|
206 |
+
with_shift=False,
|
207 |
+
attn_num_splits=None,
|
208 |
+
):
|
209 |
+
# source, target: [B, L, C]
|
210 |
+
|
211 |
+
# self attention
|
212 |
+
source = self.self_attn(
|
213 |
+
source,
|
214 |
+
source,
|
215 |
+
height=height,
|
216 |
+
width=width,
|
217 |
+
shifted_window_attn_mask=shifted_window_attn_mask,
|
218 |
+
attn_type=attn_type,
|
219 |
+
with_shift=with_shift,
|
220 |
+
attn_num_splits=attn_num_splits,
|
221 |
+
)
|
222 |
+
|
223 |
+
# cross attention and ffn
|
224 |
+
source = self.cross_attn_ffn(
|
225 |
+
source,
|
226 |
+
target,
|
227 |
+
height=height,
|
228 |
+
width=width,
|
229 |
+
shifted_window_attn_mask=shifted_window_attn_mask,
|
230 |
+
shifted_window_attn_mask_1d=shifted_window_attn_mask_1d,
|
231 |
+
attn_type=attn_type,
|
232 |
+
with_shift=with_shift,
|
233 |
+
attn_num_splits=attn_num_splits,
|
234 |
+
)
|
235 |
+
|
236 |
+
return source
|
237 |
+
|
238 |
+
|
239 |
+
class FeatureTransformer(nn.Module):
|
240 |
+
def __init__(
|
241 |
+
self,
|
242 |
+
num_layers=6,
|
243 |
+
d_model=128,
|
244 |
+
nhead=1,
|
245 |
+
ffn_dim_expansion=4,
|
246 |
+
):
|
247 |
+
super().__init__()
|
248 |
+
|
249 |
+
self.d_model = d_model
|
250 |
+
self.nhead = nhead
|
251 |
+
|
252 |
+
self.layers = nn.ModuleList(
|
253 |
+
[
|
254 |
+
TransformerBlock(
|
255 |
+
d_model=d_model,
|
256 |
+
nhead=nhead,
|
257 |
+
ffn_dim_expansion=ffn_dim_expansion,
|
258 |
+
)
|
259 |
+
for i in range(num_layers)
|
260 |
+
]
|
261 |
+
)
|
262 |
+
|
263 |
+
for p in self.parameters():
|
264 |
+
if p.dim() > 1:
|
265 |
+
nn.init.xavier_uniform_(p)
|
266 |
+
|
267 |
+
def forward(
|
268 |
+
self,
|
269 |
+
feature0,
|
270 |
+
feature1,
|
271 |
+
attn_type="swin",
|
272 |
+
attn_num_splits=None,
|
273 |
+
**kwargs,
|
274 |
+
):
|
275 |
+
|
276 |
+
b, c, h, w = feature0.shape
|
277 |
+
assert self.d_model == c
|
278 |
+
|
279 |
+
feature0 = feature0.flatten(-2).permute(0, 2, 1) # [B, H*W, C]
|
280 |
+
feature1 = feature1.flatten(-2).permute(0, 2, 1) # [B, H*W, C]
|
281 |
+
|
282 |
+
# 2d attention
|
283 |
+
if "swin" in attn_type and attn_num_splits > 1:
|
284 |
+
# global and refine use different number of splits
|
285 |
+
window_size_h = h // attn_num_splits
|
286 |
+
window_size_w = w // attn_num_splits
|
287 |
+
|
288 |
+
# compute attn mask once
|
289 |
+
shifted_window_attn_mask = generate_shift_window_attn_mask(
|
290 |
+
input_resolution=(h, w),
|
291 |
+
window_size_h=window_size_h,
|
292 |
+
window_size_w=window_size_w,
|
293 |
+
shift_size_h=window_size_h // 2,
|
294 |
+
shift_size_w=window_size_w // 2,
|
295 |
+
device=feature0.device,
|
296 |
+
) # [K*K, H/K*W/K, H/K*W/K]
|
297 |
+
else:
|
298 |
+
shifted_window_attn_mask = None
|
299 |
+
|
300 |
+
# 1d attention
|
301 |
+
if "swin1d" in attn_type and attn_num_splits > 1:
|
302 |
+
window_size_w = w // attn_num_splits
|
303 |
+
|
304 |
+
# compute attn mask once
|
305 |
+
shifted_window_attn_mask_1d = generate_shift_window_attn_mask_1d(
|
306 |
+
input_w=w,
|
307 |
+
window_size_w=window_size_w,
|
308 |
+
shift_size_w=window_size_w // 2,
|
309 |
+
device=feature0.device,
|
310 |
+
) # [K, W/K, W/K]
|
311 |
+
else:
|
312 |
+
shifted_window_attn_mask_1d = None
|
313 |
+
|
314 |
+
# concat feature0 and feature1 in batch dimension to compute in parallel
|
315 |
+
concat0 = torch.cat((feature0, feature1), dim=0) # [2B, H*W, C]
|
316 |
+
concat1 = torch.cat((feature1, feature0), dim=0) # [2B, H*W, C]
|
317 |
+
|
318 |
+
for i, layer in enumerate(self.layers):
|
319 |
+
concat0 = layer(
|
320 |
+
concat0,
|
321 |
+
concat1,
|
322 |
+
height=h,
|
323 |
+
width=w,
|
324 |
+
attn_type=attn_type,
|
325 |
+
with_shift="swin" in attn_type and attn_num_splits > 1 and i % 2 == 1,
|
326 |
+
attn_num_splits=attn_num_splits,
|
327 |
+
shifted_window_attn_mask=shifted_window_attn_mask,
|
328 |
+
shifted_window_attn_mask_1d=shifted_window_attn_mask_1d,
|
329 |
+
)
|
330 |
+
|
331 |
+
# update feature1
|
332 |
+
concat1 = torch.cat(concat0.chunk(chunks=2, dim=0)[::-1], dim=0)
|
333 |
+
|
334 |
+
feature0, feature1 = concat0.chunk(chunks=2, dim=0) # [B, H*W, C]
|
335 |
+
|
336 |
+
# reshape back
|
337 |
+
feature0 = feature0.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W]
|
338 |
+
feature1 = feature1.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W]
|
339 |
+
|
340 |
+
return feature0, feature1
|
trident_conv.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# https://github.com/facebookresearch/detectron2/blob/main/projects/TridentNet/tridentnet/trident_conv.py
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn import functional as F
|
7 |
+
from torch.nn.modules.utils import _pair
|
8 |
+
|
9 |
+
|
10 |
+
class MultiScaleTridentConv(nn.Module):
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
in_channels,
|
14 |
+
out_channels,
|
15 |
+
kernel_size,
|
16 |
+
stride=1,
|
17 |
+
strides=1,
|
18 |
+
paddings=0,
|
19 |
+
dilations=1,
|
20 |
+
dilation=1,
|
21 |
+
groups=1,
|
22 |
+
num_branch=1,
|
23 |
+
test_branch_idx=-1,
|
24 |
+
bias=False,
|
25 |
+
norm=None,
|
26 |
+
activation=None,
|
27 |
+
):
|
28 |
+
super().__init__()
|
29 |
+
self.in_channels = in_channels
|
30 |
+
self.out_channels = out_channels
|
31 |
+
self.kernel_size = _pair(kernel_size)
|
32 |
+
self.num_branch = num_branch
|
33 |
+
self.stride = _pair(stride)
|
34 |
+
self.groups = groups
|
35 |
+
self.with_bias = bias
|
36 |
+
self.dilation = dilation
|
37 |
+
if isinstance(paddings, int):
|
38 |
+
paddings = [paddings] * self.num_branch
|
39 |
+
if isinstance(dilations, int):
|
40 |
+
dilations = [dilations] * self.num_branch
|
41 |
+
if isinstance(strides, int):
|
42 |
+
strides = [strides] * self.num_branch
|
43 |
+
self.paddings = [_pair(padding) for padding in paddings]
|
44 |
+
self.dilations = [_pair(dilation) for dilation in dilations]
|
45 |
+
self.strides = [_pair(stride) for stride in strides]
|
46 |
+
self.test_branch_idx = test_branch_idx
|
47 |
+
self.norm = norm
|
48 |
+
self.activation = activation
|
49 |
+
|
50 |
+
assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1
|
51 |
+
|
52 |
+
self.weight = nn.Parameter(
|
53 |
+
torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)
|
54 |
+
)
|
55 |
+
if bias:
|
56 |
+
self.bias = nn.Parameter(torch.Tensor(out_channels))
|
57 |
+
else:
|
58 |
+
self.bias = None
|
59 |
+
|
60 |
+
nn.init.kaiming_uniform_(self.weight, nonlinearity="relu")
|
61 |
+
if self.bias is not None:
|
62 |
+
nn.init.constant_(self.bias, 0)
|
63 |
+
|
64 |
+
def forward(self, inputs):
|
65 |
+
num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1
|
66 |
+
assert len(inputs) == num_branch
|
67 |
+
|
68 |
+
if self.training or self.test_branch_idx == -1:
|
69 |
+
outputs = [
|
70 |
+
F.conv2d(
|
71 |
+
input, self.weight, self.bias, stride, padding, self.dilation, self.groups
|
72 |
+
)
|
73 |
+
for input, stride, padding in zip(inputs, self.strides, self.paddings)
|
74 |
+
]
|
75 |
+
else:
|
76 |
+
outputs = [
|
77 |
+
F.conv2d(
|
78 |
+
inputs[0],
|
79 |
+
self.weight,
|
80 |
+
self.bias,
|
81 |
+
self.strides[self.test_branch_idx]
|
82 |
+
if self.test_branch_idx == -1
|
83 |
+
else self.strides[-1],
|
84 |
+
self.paddings[self.test_branch_idx]
|
85 |
+
if self.test_branch_idx == -1
|
86 |
+
else self.paddings[-1],
|
87 |
+
self.dilation,
|
88 |
+
self.groups,
|
89 |
+
)
|
90 |
+
]
|
91 |
+
|
92 |
+
if self.norm is not None:
|
93 |
+
outputs = [self.norm(x) for x in outputs]
|
94 |
+
if self.activation is not None:
|
95 |
+
outputs = [self.activation(x) for x in outputs]
|
96 |
+
return outputs
|
unimatch.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from .backbone import CNNEncoder
|
6 |
+
from .geometry import coords_grid
|
7 |
+
from .matching import (
|
8 |
+
global_correlation_softmax_prototype,
|
9 |
+
local_correlation_softmax_prototype,
|
10 |
+
)
|
11 |
+
from .transformer import FeatureTransformer
|
12 |
+
from .utils import feature_add_position
|
13 |
+
|
14 |
+
|
15 |
+
class UniMatch(nn.Module):
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
num_scales=1,
|
19 |
+
feature_channels=128,
|
20 |
+
upsample_factor=8,
|
21 |
+
num_head=1,
|
22 |
+
ffn_dim_expansion=4,
|
23 |
+
num_transformer_layers=6,
|
24 |
+
bilinear_upsample=False,
|
25 |
+
corr_fn="global",
|
26 |
+
):
|
27 |
+
super().__init__()
|
28 |
+
|
29 |
+
self.feature_channels = feature_channels
|
30 |
+
self.num_scales = num_scales
|
31 |
+
self.upsample_factor = upsample_factor
|
32 |
+
self.bilinear_upsample = bilinear_upsample
|
33 |
+
if corr_fn == "global":
|
34 |
+
self.corr_fn = global_correlation_softmax_prototype
|
35 |
+
elif corr_fn == "local":
|
36 |
+
self.corr_fn = local_correlation_softmax_prototype
|
37 |
+
else:
|
38 |
+
raise NotImplementedError(f"Correlation function {corr_fn} not implemented")
|
39 |
+
|
40 |
+
# CNN
|
41 |
+
self.backbone = CNNEncoder(output_dim=feature_channels, num_output_scales=num_scales)
|
42 |
+
|
43 |
+
# Transformer
|
44 |
+
self.transformer = FeatureTransformer(
|
45 |
+
num_layers=num_transformer_layers,
|
46 |
+
d_model=feature_channels,
|
47 |
+
nhead=num_head,
|
48 |
+
ffn_dim_expansion=ffn_dim_expansion,
|
49 |
+
)
|
50 |
+
|
51 |
+
# convex upsampling similar to RAFT
|
52 |
+
# concat feature0 and low res flow as input
|
53 |
+
if not bilinear_upsample:
|
54 |
+
self.upsampler = nn.Sequential(
|
55 |
+
nn.Conv2d(2 + feature_channels, 256, 3, 1, 1),
|
56 |
+
nn.ReLU(inplace=True),
|
57 |
+
nn.Conv2d(256, upsample_factor**2 * 9, 1, 1, 0),
|
58 |
+
)
|
59 |
+
|
60 |
+
def extract_feature(self, img0, img1):
|
61 |
+
concat = torch.cat((img0, img1), dim=0) # [2B, C, H, W]
|
62 |
+
features = self.backbone(concat) # list of [2B, C, H, W], resolution from high to low
|
63 |
+
|
64 |
+
# reverse: resolution from low to high
|
65 |
+
features = features[::-1]
|
66 |
+
|
67 |
+
feature0, feature1 = [], []
|
68 |
+
|
69 |
+
for i in range(len(features)):
|
70 |
+
feature = features[i]
|
71 |
+
chunks = torch.chunk(feature, 2, 0) # tuple
|
72 |
+
feature0.append(chunks[0])
|
73 |
+
feature1.append(chunks[1])
|
74 |
+
|
75 |
+
return feature0, feature1
|
76 |
+
|
77 |
+
def correlate_feature(self, feature0, feature1, attn_splits=2, attn_type="swin"):
|
78 |
+
feature0, feature1 = feature_add_position(
|
79 |
+
feature0, feature1, attn_splits, self.feature_channels
|
80 |
+
)
|
81 |
+
feature0, feature1 = self.transformer(
|
82 |
+
feature0,
|
83 |
+
feature1,
|
84 |
+
attn_type=attn_type,
|
85 |
+
attn_num_splits=attn_splits,
|
86 |
+
)
|
87 |
+
b, c, h, w = feature0.shape
|
88 |
+
feature0 = feature0.view(b, c, -1).permute(0, 2, 1) # [B, H*W, C]
|
89 |
+
feature1 = feature1.view(b, c, -1) # [B, C, H*W]
|
90 |
+
correlation = torch.matmul(feature0, feature1).view(b, h, w, h, w) / (
|
91 |
+
c**0.5
|
92 |
+
) # [B, H, W, H, W]
|
93 |
+
correlation = correlation.view(b, h * w, h * w) # [B, H*W, H*W]
|
94 |
+
return correlation
|
95 |
+
|
96 |
+
def forward(
|
97 |
+
self,
|
98 |
+
img0,
|
99 |
+
img1,
|
100 |
+
attn_type="swin",
|
101 |
+
attn_splits=2,
|
102 |
+
return_feature=False,
|
103 |
+
bidirectional=False,
|
104 |
+
cycle_consistency=False,
|
105 |
+
corr_mask=None,
|
106 |
+
):
|
107 |
+
# list of features, resolution low to high
|
108 |
+
feature0_list, feature1_list = self.extract_feature(img0, img1) # list of features
|
109 |
+
assert self.num_scales == 1 # multi-scale depth model is not supported yet
|
110 |
+
scale_idx = 0
|
111 |
+
feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx]
|
112 |
+
|
113 |
+
if cycle_consistency:
|
114 |
+
# get both directions of features
|
115 |
+
feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat(
|
116 |
+
(feature1, feature0), dim=0
|
117 |
+
)
|
118 |
+
|
119 |
+
# add position to features
|
120 |
+
feature0, feature1 = feature_add_position(
|
121 |
+
feature0, feature1, attn_splits, self.feature_channels
|
122 |
+
)
|
123 |
+
|
124 |
+
# Transformer
|
125 |
+
feature0, feature1 = self.transformer(
|
126 |
+
feature0,
|
127 |
+
feature1,
|
128 |
+
attn_type=attn_type,
|
129 |
+
attn_num_splits=attn_splits,
|
130 |
+
)
|
131 |
+
b, c, h, w = feature0.shape
|
132 |
+
# downsampled_img0 = F.interpolate(img0, size=(h, w), mode="bilinear", align_corners=False)
|
133 |
+
flow_coords = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W]
|
134 |
+
# values = torch.cat((downsampled_img0, flow_coords), dim=1) # [B, 5, H, W]
|
135 |
+
# correlation and softmax
|
136 |
+
query_results, correlation = self.corr_fn(
|
137 |
+
feature0, feature1, flow_coords, pred_bidir_flow=bidirectional, corr_mask=corr_mask
|
138 |
+
)
|
139 |
+
if bidirectional:
|
140 |
+
flow_coords = torch.cat((flow_coords, flow_coords), dim=0)
|
141 |
+
up_feature = torch.cat((feature0, feature1), dim=0)
|
142 |
+
else:
|
143 |
+
up_feature = feature0
|
144 |
+
flow = query_results - flow_coords
|
145 |
+
flow_up = self.upsample_flow(flow, up_feature, bilinear=self.bilinear_upsample)
|
146 |
+
if return_feature:
|
147 |
+
return flow_up, flow, correlation, feature0, feature1
|
148 |
+
else:
|
149 |
+
return flow_up, flow, correlation
|
150 |
+
|
151 |
+
def forward_features(
|
152 |
+
self,
|
153 |
+
img0,
|
154 |
+
img1,
|
155 |
+
attn_type="swin",
|
156 |
+
attn_splits=2,
|
157 |
+
):
|
158 |
+
|
159 |
+
feature0_list, feature1_list = self.extract_feature(img0, img1) # list of features
|
160 |
+
assert self.num_scales == 1 # multi-scale depth model is not supported yet
|
161 |
+
scale_idx = 0
|
162 |
+
feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx]
|
163 |
+
# add position to features
|
164 |
+
feature0, feature1 = feature_add_position(
|
165 |
+
feature0, feature1, attn_splits, self.feature_channels
|
166 |
+
)
|
167 |
+
|
168 |
+
# Transformer
|
169 |
+
feature0, feature1 = self.transformer(
|
170 |
+
feature0,
|
171 |
+
feature1,
|
172 |
+
attn_type=attn_type,
|
173 |
+
attn_num_splits=attn_splits,
|
174 |
+
)
|
175 |
+
return feature0, feature1
|
176 |
+
|
177 |
+
def upsample_flow(self, flow, feature, bilinear=False, upsample_factor=8, is_depth=False):
|
178 |
+
if bilinear:
|
179 |
+
multiplier = 1 if is_depth else upsample_factor
|
180 |
+
up_flow = (
|
181 |
+
F.interpolate(
|
182 |
+
flow, scale_factor=upsample_factor, mode="bilinear", align_corners=False
|
183 |
+
)
|
184 |
+
* multiplier
|
185 |
+
)
|
186 |
+
else:
|
187 |
+
concat = torch.cat((flow, feature), dim=1)
|
188 |
+
mask = self.upsampler(concat)
|
189 |
+
up_flow = upsample_flow_with_mask(
|
190 |
+
flow, mask, upsample_factor=self.upsample_factor, is_depth=is_depth
|
191 |
+
)
|
192 |
+
return up_flow
|
193 |
+
|
194 |
+
|
195 |
+
def upsample_flow_with_mask(flow, up_mask, upsample_factor, is_depth=False):
|
196 |
+
# convex upsampling following raft
|
197 |
+
|
198 |
+
mask = up_mask
|
199 |
+
b, flow_channel, h, w = flow.shape
|
200 |
+
mask = mask.view(b, 1, 9, upsample_factor, upsample_factor, h, w) # [B, 1, 9, K, K, H, W]
|
201 |
+
mask = torch.softmax(mask, dim=2)
|
202 |
+
|
203 |
+
multiplier = 1 if is_depth else upsample_factor
|
204 |
+
up_flow = F.unfold(multiplier * flow, [3, 3], padding=1)
|
205 |
+
up_flow = up_flow.view(b, flow_channel, 9, 1, 1, h, w) # [B, 2, 9, 1, 1, H, W]
|
206 |
+
|
207 |
+
up_flow = torch.sum(mask * up_flow, dim=2) # [B, 2, K, K, H, W]
|
208 |
+
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) # [B, 2, K, H, K, W]
|
209 |
+
up_flow = up_flow.reshape(
|
210 |
+
b, flow_channel, upsample_factor * h, upsample_factor * w
|
211 |
+
) # [B, 2, K*H, K*W]
|
212 |
+
|
213 |
+
return up_flow
|
utils.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
from .position import PositionEmbeddingSine
|
5 |
+
|
6 |
+
|
7 |
+
def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None):
|
8 |
+
assert device is not None
|
9 |
+
|
10 |
+
x, y = torch.meshgrid(
|
11 |
+
[
|
12 |
+
torch.linspace(w_min, w_max, len_w, device=device),
|
13 |
+
torch.linspace(h_min, h_max, len_h, device=device),
|
14 |
+
],
|
15 |
+
)
|
16 |
+
grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2]
|
17 |
+
|
18 |
+
return grid
|
19 |
+
|
20 |
+
|
21 |
+
def normalize_coords(coords, h, w):
|
22 |
+
# coords: [B, H, W, 2]
|
23 |
+
c = torch.Tensor([(w - 1) / 2.0, (h - 1) / 2.0]).float().to(coords.device)
|
24 |
+
return (coords - c) / c # [-1, 1]
|
25 |
+
|
26 |
+
|
27 |
+
def normalize_img(img0, img1):
|
28 |
+
# loaded images are in [0, 255]
|
29 |
+
# normalize by ImageNet mean and std
|
30 |
+
mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(img1.device)
|
31 |
+
std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(img1.device)
|
32 |
+
img0 = (img0 / 255.0 - mean) / std
|
33 |
+
img1 = (img1 / 255.0 - mean) / std
|
34 |
+
|
35 |
+
return img0, img1
|
36 |
+
|
37 |
+
|
38 |
+
def split_feature(
|
39 |
+
feature,
|
40 |
+
num_splits=2,
|
41 |
+
channel_last=False,
|
42 |
+
):
|
43 |
+
if channel_last: # [B, H, W, C]
|
44 |
+
b, h, w, c = feature.size()
|
45 |
+
assert h % num_splits == 0 and w % num_splits == 0
|
46 |
+
|
47 |
+
b_new = b * num_splits * num_splits
|
48 |
+
h_new = h // num_splits
|
49 |
+
w_new = w // num_splits
|
50 |
+
|
51 |
+
feature = (
|
52 |
+
feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c)
|
53 |
+
.permute(0, 1, 3, 2, 4, 5)
|
54 |
+
.reshape(b_new, h_new, w_new, c)
|
55 |
+
) # [B*K*K, H/K, W/K, C]
|
56 |
+
else: # [B, C, H, W]
|
57 |
+
b, c, h, w = feature.size()
|
58 |
+
assert h % num_splits == 0 and w % num_splits == 0
|
59 |
+
|
60 |
+
b_new = b * num_splits * num_splits
|
61 |
+
h_new = h // num_splits
|
62 |
+
w_new = w // num_splits
|
63 |
+
|
64 |
+
feature = (
|
65 |
+
feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits)
|
66 |
+
.permute(0, 2, 4, 1, 3, 5)
|
67 |
+
.reshape(b_new, c, h_new, w_new)
|
68 |
+
) # [B*K*K, C, H/K, W/K]
|
69 |
+
|
70 |
+
return feature
|
71 |
+
|
72 |
+
|
73 |
+
def merge_splits(
|
74 |
+
splits,
|
75 |
+
num_splits=2,
|
76 |
+
channel_last=False,
|
77 |
+
):
|
78 |
+
if channel_last: # [B*K*K, H/K, W/K, C]
|
79 |
+
b, h, w, c = splits.size()
|
80 |
+
new_b = b // num_splits // num_splits
|
81 |
+
|
82 |
+
splits = splits.view(new_b, num_splits, num_splits, h, w, c)
|
83 |
+
merge = (
|
84 |
+
splits.permute(0, 1, 3, 2, 4, 5)
|
85 |
+
.contiguous()
|
86 |
+
.view(new_b, num_splits * h, num_splits * w, c)
|
87 |
+
) # [B, H, W, C]
|
88 |
+
else: # [B*K*K, C, H/K, W/K]
|
89 |
+
b, c, h, w = splits.size()
|
90 |
+
new_b = b // num_splits // num_splits
|
91 |
+
|
92 |
+
splits = splits.view(new_b, num_splits, num_splits, c, h, w)
|
93 |
+
merge = (
|
94 |
+
splits.permute(0, 3, 1, 4, 2, 5)
|
95 |
+
.contiguous()
|
96 |
+
.view(new_b, c, num_splits * h, num_splits * w)
|
97 |
+
) # [B, C, H, W]
|
98 |
+
|
99 |
+
return merge
|
100 |
+
|
101 |
+
|
102 |
+
def generate_shift_window_attn_mask(
|
103 |
+
input_resolution,
|
104 |
+
window_size_h,
|
105 |
+
window_size_w,
|
106 |
+
shift_size_h,
|
107 |
+
shift_size_w,
|
108 |
+
device=torch.device("cuda"),
|
109 |
+
):
|
110 |
+
# ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
|
111 |
+
# calculate attention mask for SW-MSA
|
112 |
+
h, w = input_resolution
|
113 |
+
img_mask = torch.zeros((1, h, w, 1)).to(device) # 1 H W 1
|
114 |
+
h_slices = (
|
115 |
+
slice(0, -window_size_h),
|
116 |
+
slice(-window_size_h, -shift_size_h),
|
117 |
+
slice(-shift_size_h, None),
|
118 |
+
)
|
119 |
+
w_slices = (
|
120 |
+
slice(0, -window_size_w),
|
121 |
+
slice(-window_size_w, -shift_size_w),
|
122 |
+
slice(-shift_size_w, None),
|
123 |
+
)
|
124 |
+
cnt = 0
|
125 |
+
for h in h_slices:
|
126 |
+
for w in w_slices:
|
127 |
+
img_mask[:, h, w, :] = cnt
|
128 |
+
cnt += 1
|
129 |
+
|
130 |
+
mask_windows = split_feature(
|
131 |
+
img_mask, num_splits=input_resolution[-1] // window_size_w, channel_last=True
|
132 |
+
)
|
133 |
+
|
134 |
+
mask_windows = mask_windows.view(-1, window_size_h * window_size_w)
|
135 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
136 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
|
137 |
+
attn_mask == 0, float(0.0)
|
138 |
+
)
|
139 |
+
|
140 |
+
return attn_mask
|
141 |
+
|
142 |
+
|
143 |
+
def feature_add_position(feature0, feature1, attn_splits, feature_channels):
|
144 |
+
pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2)
|
145 |
+
|
146 |
+
if attn_splits > 1: # add position in splited window
|
147 |
+
feature0_splits = split_feature(feature0, num_splits=attn_splits)
|
148 |
+
feature1_splits = split_feature(feature1, num_splits=attn_splits)
|
149 |
+
|
150 |
+
position = pos_enc(feature0_splits)
|
151 |
+
|
152 |
+
feature0_splits = feature0_splits + position
|
153 |
+
feature1_splits = feature1_splits + position
|
154 |
+
|
155 |
+
feature0 = merge_splits(feature0_splits, num_splits=attn_splits)
|
156 |
+
feature1 = merge_splits(feature1_splits, num_splits=attn_splits)
|
157 |
+
else:
|
158 |
+
position = pos_enc(feature0)
|
159 |
+
|
160 |
+
feature0 = feature0 + position
|
161 |
+
feature1 = feature1 + position
|
162 |
+
|
163 |
+
return feature0, feature1
|
164 |
+
|
165 |
+
|
166 |
+
def upsample_flow_with_mask(flow, up_mask, upsample_factor, is_depth=False):
|
167 |
+
# convex upsampling following raft
|
168 |
+
|
169 |
+
mask = up_mask
|
170 |
+
b, flow_channel, h, w = flow.shape
|
171 |
+
mask = mask.view(b, 1, 9, upsample_factor, upsample_factor, h, w) # [B, 1, 9, K, K, H, W]
|
172 |
+
mask = torch.softmax(mask, dim=2)
|
173 |
+
|
174 |
+
multiplier = 1 if is_depth else upsample_factor
|
175 |
+
up_flow = F.unfold(multiplier * flow, [3, 3], padding=1)
|
176 |
+
up_flow = up_flow.view(b, flow_channel, 9, 1, 1, h, w) # [B, 2, 9, 1, 1, H, W]
|
177 |
+
|
178 |
+
up_flow = torch.sum(mask * up_flow, dim=2) # [B, 2, K, K, H, W]
|
179 |
+
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) # [B, 2, K, H, K, W]
|
180 |
+
up_flow = up_flow.reshape(
|
181 |
+
b, flow_channel, upsample_factor * h, upsample_factor * w
|
182 |
+
) # [B, 2, K*H, K*W]
|
183 |
+
|
184 |
+
return up_flow
|
185 |
+
|
186 |
+
|
187 |
+
def split_feature_1d(
|
188 |
+
feature,
|
189 |
+
num_splits=2,
|
190 |
+
):
|
191 |
+
# feature: [B, W, C]
|
192 |
+
b, w, c = feature.size()
|
193 |
+
assert w % num_splits == 0
|
194 |
+
|
195 |
+
b_new = b * num_splits
|
196 |
+
w_new = w // num_splits
|
197 |
+
|
198 |
+
feature = feature.view(b, num_splits, w // num_splits, c).view(
|
199 |
+
b_new, w_new, c
|
200 |
+
) # [B*K, W/K, C]
|
201 |
+
|
202 |
+
return feature
|
203 |
+
|
204 |
+
|
205 |
+
def merge_splits_1d(
|
206 |
+
splits,
|
207 |
+
h,
|
208 |
+
num_splits=2,
|
209 |
+
):
|
210 |
+
b, w, c = splits.size()
|
211 |
+
new_b = b // num_splits // h
|
212 |
+
|
213 |
+
splits = splits.view(new_b, h, num_splits, w, c)
|
214 |
+
merge = splits.view(new_b, h, num_splits * w, c) # [B, H, W, C]
|
215 |
+
|
216 |
+
return merge
|
217 |
+
|
218 |
+
|
219 |
+
def window_partition_1d(x, window_size_w):
|
220 |
+
"""
|
221 |
+
Args:
|
222 |
+
x: (B, W, C)
|
223 |
+
window_size (int): window size
|
224 |
+
|
225 |
+
Returns:
|
226 |
+
windows: (num_windows*B, window_size, C)
|
227 |
+
"""
|
228 |
+
B, W, C = x.shape
|
229 |
+
x = x.view(B, W // window_size_w, window_size_w, C).view(-1, window_size_w, C)
|
230 |
+
return x
|
231 |
+
|
232 |
+
|
233 |
+
def generate_shift_window_attn_mask_1d(
|
234 |
+
input_w, window_size_w, shift_size_w, device=torch.device("cuda")
|
235 |
+
):
|
236 |
+
# calculate attention mask for SW-MSA
|
237 |
+
img_mask = torch.zeros((1, input_w, 1)).to(device) # 1 W 1
|
238 |
+
w_slices = (
|
239 |
+
slice(0, -window_size_w),
|
240 |
+
slice(-window_size_w, -shift_size_w),
|
241 |
+
slice(-shift_size_w, None),
|
242 |
+
)
|
243 |
+
cnt = 0
|
244 |
+
for w in w_slices:
|
245 |
+
img_mask[:, w, :] = cnt
|
246 |
+
cnt += 1
|
247 |
+
|
248 |
+
mask_windows = window_partition_1d(img_mask, window_size_w) # nW, window_size, 1
|
249 |
+
mask_windows = mask_windows.view(-1, window_size_w)
|
250 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(
|
251 |
+
2
|
252 |
+
) # nW, window_size, window_size
|
253 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
|
254 |
+
attn_mask == 0, float(0.0)
|
255 |
+
)
|
256 |
+
|
257 |
+
return attn_mask
|