stevetod commited on
Commit
5189ac9
1 Parent(s): 60d3094
Files changed (14) hide show
  1. attention.py +295 -0
  2. backbone.py +142 -0
  3. config.json +13 -0
  4. configuration_doduo.py +13 -0
  5. dino.py +420 -0
  6. geometry.py +229 -0
  7. matching.py +505 -0
  8. modeling_doduo.py +118 -0
  9. position.py +49 -0
  10. pytorch_model.bin +3 -0
  11. transformer.py +340 -0
  12. trident_conv.py +96 -0
  13. unimatch.py +213 -0
  14. utils.py +257 -0
attention.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .utils import merge_splits, merge_splits_1d, split_feature, split_feature_1d
6
+
7
+
8
+ def single_head_full_attention(q, k, v):
9
+ # q, k, v: [B, L, C]
10
+ assert q.dim() == k.dim() == v.dim() == 3
11
+
12
+ scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** 0.5) # [B, L, L]
13
+ attn = torch.softmax(scores, dim=2) # [B, L, L]
14
+ out = torch.matmul(attn, v) # [B, L, C]
15
+
16
+ return out
17
+
18
+
19
+ def single_head_full_attention_1d(
20
+ q,
21
+ k,
22
+ v,
23
+ h=None,
24
+ w=None,
25
+ ):
26
+ # q, k, v: [B, L, C]
27
+
28
+ assert h is not None and w is not None
29
+ assert q.size(1) == h * w
30
+
31
+ b, _, c = q.size()
32
+
33
+ q = q.view(b, h, w, c) # [B, H, W, C]
34
+ k = k.view(b, h, w, c)
35
+ v = v.view(b, h, w, c)
36
+
37
+ scale_factor = c**0.5
38
+
39
+ scores = torch.matmul(q, k.permute(0, 1, 3, 2)) / scale_factor # [B, H, W, W]
40
+
41
+ attn = torch.softmax(scores, dim=-1)
42
+
43
+ out = torch.matmul(attn, v).view(b, -1, c) # [B, H*W, C]
44
+
45
+ return out
46
+
47
+
48
+ def single_head_split_window_attention(
49
+ q,
50
+ k,
51
+ v,
52
+ num_splits=1,
53
+ with_shift=False,
54
+ h=None,
55
+ w=None,
56
+ attn_mask=None,
57
+ ):
58
+ # ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
59
+ # q, k, v: [B, L, C]
60
+ assert q.dim() == k.dim() == v.dim() == 3
61
+
62
+ assert h is not None and w is not None
63
+ assert q.size(1) == h * w
64
+
65
+ b, _, c = q.size()
66
+
67
+ b_new = b * num_splits * num_splits
68
+
69
+ window_size_h = h // num_splits
70
+ window_size_w = w // num_splits
71
+
72
+ q = q.view(b, h, w, c) # [B, H, W, C]
73
+ k = k.view(b, h, w, c)
74
+ v = v.view(b, h, w, c)
75
+
76
+ scale_factor = c**0.5
77
+
78
+ if with_shift:
79
+ assert attn_mask is not None # compute once
80
+ shift_size_h = window_size_h // 2
81
+ shift_size_w = window_size_w // 2
82
+
83
+ q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
84
+ k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
85
+ v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
86
+
87
+ q = split_feature(q, num_splits=num_splits, channel_last=True) # [B*K*K, H/K, W/K, C]
88
+ k = split_feature(k, num_splits=num_splits, channel_last=True)
89
+ v = split_feature(v, num_splits=num_splits, channel_last=True)
90
+
91
+ scores = (
92
+ torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1)) / scale_factor
93
+ ) # [B*K*K, H/K*W/K, H/K*W/K]
94
+
95
+ if with_shift:
96
+ scores += attn_mask.repeat(b, 1, 1)
97
+
98
+ attn = torch.softmax(scores, dim=-1)
99
+
100
+ out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*K*K, H/K*W/K, C]
101
+
102
+ out = merge_splits(
103
+ out.view(b_new, h // num_splits, w // num_splits, c),
104
+ num_splits=num_splits,
105
+ channel_last=True,
106
+ ) # [B, H, W, C]
107
+
108
+ # shift back
109
+ if with_shift:
110
+ out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2))
111
+
112
+ out = out.view(b, -1, c)
113
+
114
+ return out
115
+
116
+
117
+ def single_head_split_window_attention_1d(
118
+ q,
119
+ k,
120
+ v,
121
+ relative_position_bias=None,
122
+ num_splits=1,
123
+ with_shift=False,
124
+ h=None,
125
+ w=None,
126
+ attn_mask=None,
127
+ ):
128
+ # q, k, v: [B, L, C]
129
+
130
+ assert h is not None and w is not None
131
+ assert q.size(1) == h * w
132
+
133
+ b, _, c = q.size()
134
+
135
+ b_new = b * num_splits * h
136
+
137
+ window_size_w = w // num_splits
138
+
139
+ q = q.view(b * h, w, c) # [B*H, W, C]
140
+ k = k.view(b * h, w, c)
141
+ v = v.view(b * h, w, c)
142
+
143
+ scale_factor = c**0.5
144
+
145
+ if with_shift:
146
+ assert attn_mask is not None # compute once
147
+ shift_size_w = window_size_w // 2
148
+
149
+ q = torch.roll(q, shifts=-shift_size_w, dims=1)
150
+ k = torch.roll(k, shifts=-shift_size_w, dims=1)
151
+ v = torch.roll(v, shifts=-shift_size_w, dims=1)
152
+
153
+ q = split_feature_1d(q, num_splits=num_splits) # [B*H*K, W/K, C]
154
+ k = split_feature_1d(k, num_splits=num_splits)
155
+ v = split_feature_1d(v, num_splits=num_splits)
156
+
157
+ scores = (
158
+ torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1)) / scale_factor
159
+ ) # [B*H*K, W/K, W/K]
160
+
161
+ if with_shift:
162
+ # attn_mask: [K, W/K, W/K]
163
+ scores += attn_mask.repeat(b * h, 1, 1) # [B*H*K, W/K, W/K]
164
+
165
+ attn = torch.softmax(scores, dim=-1)
166
+
167
+ out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*H*K, W/K, C]
168
+
169
+ out = merge_splits_1d(out, h, num_splits=num_splits) # [B, H, W, C]
170
+
171
+ # shift back
172
+ if with_shift:
173
+ out = torch.roll(out, shifts=shift_size_w, dims=2)
174
+
175
+ out = out.view(b, -1, c)
176
+
177
+ return out
178
+
179
+
180
+ class SelfAttnPropagation(nn.Module):
181
+ """
182
+ flow propagation with self-attention on feature
183
+ query: feature0, key: feature0, value: flow
184
+ """
185
+
186
+ def __init__(
187
+ self,
188
+ in_channels,
189
+ **kwargs,
190
+ ):
191
+ super().__init__()
192
+
193
+ self.q_proj = nn.Linear(in_channels, in_channels)
194
+ self.k_proj = nn.Linear(in_channels, in_channels)
195
+
196
+ for p in self.parameters():
197
+ if p.dim() > 1:
198
+ nn.init.xavier_uniform_(p)
199
+
200
+ def forward(
201
+ self,
202
+ feature0,
203
+ flow,
204
+ local_window_attn=False,
205
+ local_window_radius=1,
206
+ **kwargs,
207
+ ):
208
+ # q, k: feature [B, C, H, W], v: flow [B, 2, H, W]
209
+ if local_window_attn:
210
+ return self.forward_local_window_attn(
211
+ feature0, flow, local_window_radius=local_window_radius
212
+ )
213
+
214
+ b, c, h, w = feature0.size()
215
+
216
+ query = feature0.view(b, c, h * w).permute(0, 2, 1) # [B, H*W, C]
217
+
218
+ # a note: the ``correct'' implementation should be:
219
+ # ``query = self.q_proj(query), key = self.k_proj(query)''
220
+ # this problem is observed while cleaning up the code
221
+ # however, this doesn't affect the performance since the projection is a linear operation,
222
+ # thus the two projection matrices for key can be merged
223
+ # so I just leave it as is in order to not re-train all models :)
224
+ query = self.q_proj(query) # [B, H*W, C]
225
+ key = self.k_proj(query) # [B, H*W, C]
226
+
227
+ value = flow.view(b, flow.size(1), h * w).permute(0, 2, 1) # [B, H*W, 2]
228
+
229
+ scores = torch.matmul(query, key.permute(0, 2, 1)) / (c**0.5) # [B, H*W, H*W]
230
+ prob = torch.softmax(scores, dim=-1)
231
+
232
+ out = torch.matmul(prob, value) # [B, H*W, 2]
233
+ out = out.view(b, h, w, value.size(-1)).permute(0, 3, 1, 2) # [B, 2, H, W]
234
+
235
+ return out
236
+
237
+ def forward_local_window_attn(
238
+ self,
239
+ feature0,
240
+ flow,
241
+ local_window_radius=1,
242
+ ):
243
+ assert flow.size(1) == 2 or flow.size(1) == 1 # flow or disparity or depth
244
+ assert local_window_radius > 0
245
+
246
+ b, c, h, w = feature0.size()
247
+
248
+ value_channel = flow.size(1)
249
+
250
+ feature0_reshape = self.q_proj(feature0.view(b, c, -1).permute(0, 2, 1)).reshape(
251
+ b * h * w, 1, c
252
+ ) # [B*H*W, 1, C]
253
+
254
+ kernel_size = 2 * local_window_radius + 1
255
+
256
+ feature0_proj = (
257
+ self.k_proj(feature0.view(b, c, -1).permute(0, 2, 1))
258
+ .permute(0, 2, 1)
259
+ .reshape(b, c, h, w)
260
+ )
261
+
262
+ feature0_window = F.unfold(
263
+ feature0_proj, kernel_size=kernel_size, padding=local_window_radius
264
+ ) # [B, C*(2R+1)^2), H*W]
265
+
266
+ feature0_window = (
267
+ feature0_window.view(b, c, kernel_size**2, h, w)
268
+ .permute(0, 3, 4, 1, 2)
269
+ .reshape(b * h * w, c, kernel_size**2)
270
+ ) # [B*H*W, C, (2R+1)^2]
271
+
272
+ flow_window = F.unfold(
273
+ flow, kernel_size=kernel_size, padding=local_window_radius
274
+ ) # [B, 2*(2R+1)^2), H*W]
275
+
276
+ flow_window = (
277
+ flow_window.view(b, value_channel, kernel_size**2, h, w)
278
+ .permute(0, 3, 4, 2, 1)
279
+ .reshape(b * h * w, kernel_size**2, value_channel)
280
+ ) # [B*H*W, (2R+1)^2, 2]
281
+
282
+ scores = torch.matmul(feature0_reshape, feature0_window) / (
283
+ c**0.5
284
+ ) # [B*H*W, 1, (2R+1)^2]
285
+
286
+ prob = torch.softmax(scores, dim=-1)
287
+
288
+ out = (
289
+ torch.matmul(prob, flow_window)
290
+ .view(b, h, w, value_channel)
291
+ .permute(0, 3, 1, 2)
292
+ .contiguous()
293
+ ) # [B, 2, H, W]
294
+
295
+ return out
backbone.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from .trident_conv import MultiScaleTridentConv
4
+
5
+
6
+ class ResidualBlock(nn.Module):
7
+ def __init__(
8
+ self,
9
+ in_planes,
10
+ planes,
11
+ norm_layer=nn.InstanceNorm2d,
12
+ stride=1,
13
+ dilation=1,
14
+ ):
15
+ super().__init__()
16
+
17
+ self.conv1 = nn.Conv2d(
18
+ in_planes,
19
+ planes,
20
+ kernel_size=3,
21
+ dilation=dilation,
22
+ padding=dilation,
23
+ stride=stride,
24
+ bias=False,
25
+ )
26
+ self.conv2 = nn.Conv2d(
27
+ planes, planes, kernel_size=3, dilation=dilation, padding=dilation, bias=False
28
+ )
29
+ self.relu = nn.ReLU(inplace=True)
30
+
31
+ self.norm1 = norm_layer(planes)
32
+ self.norm2 = norm_layer(planes)
33
+ if not stride == 1 or in_planes != planes:
34
+ self.norm3 = norm_layer(planes)
35
+
36
+ if stride == 1 and in_planes == planes:
37
+ self.downsample = None
38
+ else:
39
+ self.downsample = nn.Sequential(
40
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
41
+ )
42
+
43
+ def forward(self, x):
44
+ y = x
45
+ y = self.relu(self.norm1(self.conv1(y)))
46
+ y = self.relu(self.norm2(self.conv2(y)))
47
+
48
+ if self.downsample is not None:
49
+ x = self.downsample(x)
50
+
51
+ return self.relu(x + y)
52
+
53
+
54
+ class CNNEncoder(nn.Module):
55
+ def __init__(
56
+ self,
57
+ output_dim=128,
58
+ norm_layer=nn.InstanceNorm2d,
59
+ num_output_scales=1,
60
+ **kwargs,
61
+ ):
62
+ super().__init__()
63
+ self.num_branch = num_output_scales
64
+
65
+ feature_dims = [64, 96, 128]
66
+
67
+ self.conv1 = nn.Conv2d(
68
+ 3, feature_dims[0], kernel_size=7, stride=2, padding=3, bias=False
69
+ ) # 1/2
70
+ self.norm1 = norm_layer(feature_dims[0])
71
+ self.relu1 = nn.ReLU(inplace=True)
72
+
73
+ self.in_planes = feature_dims[0]
74
+ self.layer1 = self._make_layer(feature_dims[0], stride=1, norm_layer=norm_layer) # 1/2
75
+ self.layer2 = self._make_layer(feature_dims[1], stride=2, norm_layer=norm_layer) # 1/4
76
+
77
+ # highest resolution 1/4 or 1/8
78
+ stride = 2 if num_output_scales == 1 else 1
79
+ self.layer3 = self._make_layer(
80
+ feature_dims[2],
81
+ stride=stride,
82
+ norm_layer=norm_layer,
83
+ ) # 1/4 or 1/8
84
+
85
+ self.conv2 = nn.Conv2d(feature_dims[2], output_dim, 1, 1, 0)
86
+
87
+ if self.num_branch > 1:
88
+ if self.num_branch == 4:
89
+ strides = (1, 2, 4, 8)
90
+ elif self.num_branch == 3:
91
+ strides = (1, 2, 4)
92
+ elif self.num_branch == 2:
93
+ strides = (1, 2)
94
+ else:
95
+ raise ValueError
96
+
97
+ self.trident_conv = MultiScaleTridentConv(
98
+ output_dim,
99
+ output_dim,
100
+ kernel_size=3,
101
+ strides=strides,
102
+ paddings=1,
103
+ num_branch=self.num_branch,
104
+ )
105
+
106
+ for m in self.modules():
107
+ if isinstance(m, nn.Conv2d):
108
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
109
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
110
+ if m.weight is not None:
111
+ nn.init.constant_(m.weight, 1)
112
+ if m.bias is not None:
113
+ nn.init.constant_(m.bias, 0)
114
+
115
+ def _make_layer(self, dim, stride=1, dilation=1, norm_layer=nn.InstanceNorm2d):
116
+ layer1 = ResidualBlock(
117
+ self.in_planes, dim, norm_layer=norm_layer, stride=stride, dilation=dilation
118
+ )
119
+ layer2 = ResidualBlock(dim, dim, norm_layer=norm_layer, stride=1, dilation=dilation)
120
+
121
+ layers = (layer1, layer2)
122
+
123
+ self.in_planes = dim
124
+ return nn.Sequential(*layers)
125
+
126
+ def forward(self, x):
127
+ x = self.conv1(x)
128
+ x = self.norm1(x)
129
+ x = self.relu1(x)
130
+
131
+ x = self.layer1(x) # 1/2
132
+ x = self.layer2(x) # 1/4
133
+ x = self.layer3(x) # 1/8 or 1/4
134
+
135
+ x = self.conv2(x)
136
+
137
+ if self.num_branch > 1:
138
+ out = self.trident_conv([x] * self.num_branch) # high to low res
139
+ else:
140
+ out = [x]
141
+
142
+ return out
config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "DoduoModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_doduo.DoduoConfig",
7
+ "AutoModel": "modeling_doduo.DoduoModel"
8
+ },
9
+ "dino_corr_mask_ratio": 0.99,
10
+ "model_type": "doduo",
11
+ "torch_dtype": "float32",
12
+ "transformers_version": "4.30.2"
13
+ }
configuration_doduo.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import List
3
+
4
+ class DoduoConfig(PretrainedConfig):
5
+ model_type = "doduo"
6
+
7
+ def __init__(
8
+ self,
9
+ dino_corr_mask_ratio: float = 0.99,
10
+ **kwargs,
11
+ ):
12
+ self.dino_corr_mask_ratio = dino_corr_mask_ratio
13
+ super().__init__(**kwargs)
dino.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Mostly copy-paste from timm library.
15
+
16
+ https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
17
+ """
18
+ import math
19
+ from functools import partial
20
+ import warnings
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+
25
+
26
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
27
+ if drop_prob == 0.0 or not training:
28
+ return x
29
+ keep_prob = 1 - drop_prob
30
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
31
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
32
+ random_tensor.floor_() # binarize
33
+ output = x.div(keep_prob) * random_tensor
34
+ return output
35
+
36
+
37
+ class DropPath(nn.Module):
38
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
39
+
40
+ def __init__(self, drop_prob=None):
41
+ super().__init__()
42
+ self.drop_prob = drop_prob
43
+
44
+ def forward(self, x):
45
+ return drop_path(x, self.drop_prob, self.training)
46
+
47
+
48
+ class Mlp(nn.Module):
49
+ def __init__(
50
+ self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0
51
+ ):
52
+ super().__init__()
53
+ out_features = out_features or in_features
54
+ hidden_features = hidden_features or in_features
55
+ self.fc1 = nn.Linear(in_features, hidden_features)
56
+ self.act = act_layer()
57
+ self.fc2 = nn.Linear(hidden_features, out_features)
58
+ self.drop = nn.Dropout(drop)
59
+
60
+ def forward(self, x):
61
+ x = self.fc1(x)
62
+ x = self.act(x)
63
+ x = self.drop(x)
64
+ x = self.fc2(x)
65
+ x = self.drop(x)
66
+ return x
67
+
68
+
69
+ class Attention(nn.Module):
70
+ def __init__(
71
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0
72
+ ):
73
+ super().__init__()
74
+ self.num_heads = num_heads
75
+ head_dim = dim // num_heads
76
+ self.scale = qk_scale or head_dim**-0.5
77
+
78
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
79
+ self.attn_drop = nn.Dropout(attn_drop)
80
+ self.proj = nn.Linear(dim, dim)
81
+ self.proj_drop = nn.Dropout(proj_drop)
82
+
83
+ def forward(self, x):
84
+ B, N, C = x.shape
85
+ qkv = (
86
+ self.qkv(x)
87
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
88
+ .permute(2, 0, 3, 1, 4)
89
+ )
90
+ q, k, v = qkv[0], qkv[1], qkv[2]
91
+
92
+ attn = (q @ k.transpose(-2, -1)) * self.scale
93
+ attn = attn.softmax(dim=-1)
94
+ attn = self.attn_drop(attn)
95
+
96
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
97
+ x = self.proj(x)
98
+ x = self.proj_drop(x)
99
+ return x, attn
100
+
101
+
102
+ class Block(nn.Module):
103
+ def __init__(
104
+ self,
105
+ dim,
106
+ num_heads,
107
+ mlp_ratio=4.0,
108
+ qkv_bias=False,
109
+ qk_scale=None,
110
+ drop=0.0,
111
+ attn_drop=0.0,
112
+ drop_path=0.0,
113
+ act_layer=nn.GELU,
114
+ norm_layer=nn.LayerNorm,
115
+ ):
116
+ super().__init__()
117
+ self.norm1 = norm_layer(dim)
118
+ self.attn = Attention(
119
+ dim,
120
+ num_heads=num_heads,
121
+ qkv_bias=qkv_bias,
122
+ qk_scale=qk_scale,
123
+ attn_drop=attn_drop,
124
+ proj_drop=drop,
125
+ )
126
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
127
+ self.norm2 = norm_layer(dim)
128
+ mlp_hidden_dim = int(dim * mlp_ratio)
129
+ self.mlp = Mlp(
130
+ in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop
131
+ )
132
+
133
+ def forward(self, x, return_attention=False):
134
+ y, attn = self.attn(self.norm1(x))
135
+ if return_attention:
136
+ return attn
137
+ x = x + self.drop_path(y)
138
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
139
+ return x
140
+
141
+
142
+ class PatchEmbed(nn.Module):
143
+ """Image to Patch Embedding."""
144
+
145
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
146
+ super().__init__()
147
+ num_patches = (img_size // patch_size) * (img_size // patch_size)
148
+ self.img_size = img_size
149
+ self.patch_size = patch_size
150
+ self.num_patches = num_patches
151
+
152
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
153
+
154
+ def forward(self, x):
155
+ B, C, H, W = x.shape
156
+ x = self.proj(x).flatten(2).transpose(1, 2)
157
+ return x
158
+
159
+
160
+ class VisionTransformer(nn.Module):
161
+ """Vision Transformer."""
162
+
163
+ def __init__(
164
+ self,
165
+ img_size=[224],
166
+ patch_size=16,
167
+ in_chans=3,
168
+ num_classes=0,
169
+ embed_dim=768,
170
+ depth=12,
171
+ num_heads=12,
172
+ mlp_ratio=4.0,
173
+ qkv_bias=False,
174
+ qk_scale=None,
175
+ drop_rate=0.0,
176
+ attn_drop_rate=0.0,
177
+ drop_path_rate=0.0,
178
+ norm_layer=nn.LayerNorm,
179
+ **kwargs
180
+ ):
181
+ super().__init__()
182
+ self.num_features = self.embed_dim = embed_dim
183
+
184
+ self.patch_embed = PatchEmbed(
185
+ img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim
186
+ )
187
+ num_patches = self.patch_embed.num_patches
188
+
189
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
190
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
191
+ self.pos_drop = nn.Dropout(p=drop_rate)
192
+
193
+ dpr = [
194
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
195
+ ] # stochastic depth decay rule
196
+ self.blocks = nn.ModuleList(
197
+ [
198
+ Block(
199
+ dim=embed_dim,
200
+ num_heads=num_heads,
201
+ mlp_ratio=mlp_ratio,
202
+ qkv_bias=qkv_bias,
203
+ qk_scale=qk_scale,
204
+ drop=drop_rate,
205
+ attn_drop=attn_drop_rate,
206
+ drop_path=dpr[i],
207
+ norm_layer=norm_layer,
208
+ )
209
+ for i in range(depth)
210
+ ]
211
+ )
212
+ self.norm = norm_layer(embed_dim)
213
+
214
+ # Classifier head
215
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
216
+
217
+ trunc_normal_(self.pos_embed, std=0.02)
218
+ trunc_normal_(self.cls_token, std=0.02)
219
+ self.apply(self._init_weights)
220
+
221
+ def _init_weights(self, m):
222
+ if isinstance(m, nn.Linear):
223
+ trunc_normal_(m.weight, std=0.02)
224
+ if isinstance(m, nn.Linear) and m.bias is not None:
225
+ nn.init.constant_(m.bias, 0)
226
+ elif isinstance(m, nn.LayerNorm):
227
+ nn.init.constant_(m.bias, 0)
228
+ nn.init.constant_(m.weight, 1.0)
229
+
230
+ def interpolate_pos_encoding(self, x, w, h):
231
+ npatch = x.shape[1] - 1
232
+ N = self.pos_embed.shape[1] - 1
233
+ if npatch == N and w == h:
234
+ return self.pos_embed
235
+ class_pos_embed = self.pos_embed[:, 0]
236
+ patch_pos_embed = self.pos_embed[:, 1:]
237
+ dim = x.shape[-1]
238
+ w0 = w // self.patch_embed.patch_size
239
+ h0 = h // self.patch_embed.patch_size
240
+ # we add a small number to avoid floating point error in the interpolation
241
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
242
+ w0, h0 = w0 + 0.1, h0 + 0.1
243
+ patch_pos_embed = nn.functional.interpolate(
244
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(
245
+ 0, 3, 1, 2
246
+ ),
247
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
248
+ mode="bicubic",
249
+ )
250
+ assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
251
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
252
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
253
+
254
+ def prepare_tokens(self, x):
255
+ B, nc, w, h = x.shape
256
+ x = self.patch_embed(x) # patch linear embedding
257
+
258
+ # add the [CLS] token to the embed patch tokens
259
+ cls_tokens = self.cls_token.expand(B, -1, -1)
260
+ x = torch.cat((cls_tokens, x), dim=1)
261
+
262
+ # add positional encoding to each token
263
+ x = x + self.interpolate_pos_encoding(x, w, h)
264
+
265
+ return self.pos_drop(x)
266
+
267
+ def forward(self, x):
268
+ x = self.prepare_tokens(x)
269
+ for blk in self.blocks:
270
+ x = blk(x)
271
+ x = self.norm(x)
272
+ return x[:, 0]
273
+
274
+ def get_last_selfattention(self, x):
275
+ x = self.prepare_tokens(x)
276
+ for i, blk in enumerate(self.blocks):
277
+ if i < len(self.blocks) - 1:
278
+ x = blk(x)
279
+ else:
280
+ # return attention of the last block
281
+ return blk(x, return_attention=True)
282
+
283
+ def get_intermediate_layers(self, x, n=1):
284
+ x = self.prepare_tokens(x)
285
+ # we return the output tokens from the `n` last blocks
286
+ output = []
287
+ for i, blk in enumerate(self.blocks):
288
+ x = blk(x)
289
+ if len(self.blocks) - i <= n:
290
+ output.append(self.norm(x))
291
+ return output
292
+
293
+
294
+ def vit_tiny(patch_size=16, **kwargs):
295
+ model = VisionTransformer(
296
+ patch_size=patch_size,
297
+ embed_dim=192,
298
+ depth=12,
299
+ num_heads=3,
300
+ mlp_ratio=4,
301
+ qkv_bias=True,
302
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
303
+ **kwargs
304
+ )
305
+ return model
306
+
307
+
308
+ def vit_small(patch_size=16, **kwargs):
309
+ model = VisionTransformer(
310
+ patch_size=patch_size,
311
+ embed_dim=384,
312
+ depth=12,
313
+ num_heads=6,
314
+ mlp_ratio=4,
315
+ qkv_bias=True,
316
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
317
+ **kwargs
318
+ )
319
+ return model
320
+
321
+
322
+ def vit_base(patch_size=16, **kwargs):
323
+ model = VisionTransformer(
324
+ patch_size=patch_size,
325
+ embed_dim=768,
326
+ depth=12,
327
+ num_heads=12,
328
+ mlp_ratio=4,
329
+ qkv_bias=True,
330
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
331
+ **kwargs
332
+ )
333
+ return model
334
+
335
+
336
+ class DINOHead(nn.Module):
337
+ def __init__(
338
+ self,
339
+ in_dim,
340
+ out_dim,
341
+ use_bn=False,
342
+ norm_last_layer=True,
343
+ nlayers=3,
344
+ hidden_dim=2048,
345
+ bottleneck_dim=256,
346
+ ):
347
+ super().__init__()
348
+ nlayers = max(nlayers, 1)
349
+ if nlayers == 1:
350
+ self.mlp = nn.Linear(in_dim, bottleneck_dim)
351
+ else:
352
+ layers = [nn.Linear(in_dim, hidden_dim)]
353
+ if use_bn:
354
+ layers.append(nn.BatchNorm1d(hidden_dim))
355
+ layers.append(nn.GELU())
356
+ for _ in range(nlayers - 2):
357
+ layers.append(nn.Linear(hidden_dim, hidden_dim))
358
+ if use_bn:
359
+ layers.append(nn.BatchNorm1d(hidden_dim))
360
+ layers.append(nn.GELU())
361
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim))
362
+ self.mlp = nn.Sequential(*layers)
363
+ self.apply(self._init_weights)
364
+ self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
365
+ self.last_layer.weight_g.data.fill_(1)
366
+ if norm_last_layer:
367
+ self.last_layer.weight_g.requires_grad = False
368
+
369
+ def _init_weights(self, m):
370
+ if isinstance(m, nn.Linear):
371
+ trunc_normal_(m.weight, std=0.02)
372
+ if isinstance(m, nn.Linear) and m.bias is not None:
373
+ nn.init.constant_(m.bias, 0)
374
+
375
+ def forward(self, x):
376
+ x = self.mlp(x)
377
+ x = nn.functional.normalize(x, dim=-1, p=2)
378
+ x = self.last_layer(x)
379
+ return x
380
+
381
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
382
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
383
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
384
+ def norm_cdf(x):
385
+ # Computes standard normal cumulative distribution function
386
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
387
+
388
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
389
+ warnings.warn(
390
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
391
+ "The distribution of values may be incorrect.",
392
+ stacklevel=2,
393
+ )
394
+
395
+ with torch.no_grad():
396
+ # Values are generated by using a truncated uniform distribution and
397
+ # then using the inverse CDF for the normal distribution.
398
+ # Get upper and lower cdf values
399
+ lower = norm_cdf((a - mean) / std)
400
+ upper = norm_cdf((b - mean) / std)
401
+
402
+ # Uniformly fill tensor with values from [l, u], then translate to
403
+ # [2l-1, 2u-1].
404
+ tensor.uniform_(2 * lower - 1, 2 * upper - 1)
405
+
406
+ # Use inverse cdf transform for normal distribution to get truncated
407
+ # standard normal
408
+ tensor.erfinv_()
409
+
410
+ # Transform to proper mean, std
411
+ tensor.mul_(std * math.sqrt(2.0))
412
+ tensor.add_(mean)
413
+
414
+ # Clamp to ensure it's in the proper range
415
+ tensor.clamp_(min=a, max=b)
416
+ return tensor
417
+
418
+
419
+ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
420
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
geometry.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ def coords_grid(b, h, w, homogeneous=False, device=None):
6
+ y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W]
7
+
8
+ stacks = [x, y]
9
+
10
+ if homogeneous:
11
+ ones = torch.ones_like(x) # [H, W]
12
+ stacks.append(ones)
13
+
14
+ grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W]
15
+
16
+ grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W]
17
+
18
+ if device is not None:
19
+ grid = grid.to(device)
20
+
21
+ return grid
22
+
23
+
24
+ def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None):
25
+ assert device is not None
26
+
27
+ x, y = torch.meshgrid(
28
+ [
29
+ torch.linspace(w_min, w_max, len_w, device=device),
30
+ torch.linspace(h_min, h_max, len_h, device=device),
31
+ ],
32
+ )
33
+ grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2]
34
+
35
+ return grid
36
+
37
+
38
+ def normalize_coords(coords, h, w):
39
+ # coords: [B, H, W, 2]
40
+ c = torch.Tensor([(w - 1) / 2.0, (h - 1) / 2.0]).float().to(coords.device)
41
+ return (coords - c) / c # [-1, 1]
42
+
43
+
44
+ def bilinear_sample(img, sample_coords, mode="bilinear", padding_mode="zeros", return_mask=False):
45
+ # img: [B, C, H, W]
46
+ # sample_coords: [B, 2, H, W] in image scale
47
+ if sample_coords.size(1) != 2: # [B, H, W, 2]
48
+ sample_coords = sample_coords.permute(0, 3, 1, 2)
49
+
50
+ b, _, h, w = sample_coords.shape
51
+
52
+ # Normalize to [-1, 1]
53
+ x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1
54
+ y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1
55
+
56
+ grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2]
57
+
58
+ img = F.grid_sample(img, grid, mode=mode, padding_mode=padding_mode, align_corners=False)
59
+
60
+ if return_mask:
61
+ mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1) # [B, H, W]
62
+
63
+ return img, mask
64
+
65
+ return img
66
+
67
+
68
+ def flow_warp(feature, flow, mask=False, padding_mode="zeros"):
69
+ b, c, h, w = feature.size()
70
+ assert flow.size(1) == 2
71
+
72
+ grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W]
73
+
74
+ return bilinear_sample(feature, grid, padding_mode=padding_mode, return_mask=mask)
75
+
76
+
77
+ def forward_backward_consistency_check(fwd_flow, bwd_flow, alpha=0.01, beta=0.5):
78
+ # fwd_flow, bwd_flow: [B, 2, H, W]
79
+ # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837)
80
+ assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4
81
+ assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2
82
+ flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W]
83
+
84
+ warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W]
85
+ warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W]
86
+
87
+ diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W]
88
+ diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1)
89
+
90
+ threshold = alpha * flow_mag + beta
91
+
92
+ fwd_occ = (diff_fwd > threshold).float() # [B, H, W]
93
+ bwd_occ = (diff_bwd > threshold).float()
94
+
95
+ return fwd_occ, bwd_occ
96
+
97
+
98
+ def back_project(depth, intrinsics):
99
+ # Back project 2D pixel coords to 3D points
100
+ # depth: [B, H, W]
101
+ # intrinsics: [B, 3, 3]
102
+ b, h, w = depth.shape
103
+ grid = coords_grid(b, h, w, homogeneous=True, device=depth.device) # [B, 3, H, W]
104
+
105
+ intrinsics_inv = torch.inverse(intrinsics) # [B, 3, 3]
106
+
107
+ points = intrinsics_inv.bmm(grid.view(b, 3, -1)).view(b, 3, h, w) * depth.unsqueeze(
108
+ 1
109
+ ) # [B, 3, H, W]
110
+
111
+ return points
112
+
113
+
114
+ def camera_transform(points_ref, extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None):
115
+ # Transform 3D points from reference camera to target camera
116
+ # points_ref: [B, 3, H, W]
117
+ # extrinsics_ref: [B, 4, 4]
118
+ # extrinsics_tgt: [B, 4, 4]
119
+ # extrinsics_rel: [B, 4, 4], relative pose transform
120
+ b, _, h, w = points_ref.shape
121
+
122
+ if extrinsics_rel is None:
123
+ extrinsics_rel = torch.bmm(extrinsics_tgt, torch.inverse(extrinsics_ref)) # [B, 4, 4]
124
+
125
+ points_tgt = (
126
+ torch.bmm(extrinsics_rel[:, :3, :3], points_ref.view(b, 3, -1))
127
+ + extrinsics_rel[:, :3, -1:]
128
+ ) # [B, 3, H*W]
129
+
130
+ points_tgt = points_tgt.view(b, 3, h, w) # [B, 3, H, W]
131
+
132
+ return points_tgt
133
+
134
+
135
+ def reproject(points_tgt, intrinsics, return_mask=False):
136
+ # reproject to target view
137
+ # points_tgt: [B, 3, H, W]
138
+ # intrinsics: [B, 3, 3]
139
+
140
+ b, _, h, w = points_tgt.shape
141
+
142
+ proj_points = torch.bmm(intrinsics, points_tgt.view(b, 3, -1)).view(b, 3, h, w) # [B, 3, H, W]
143
+
144
+ X = proj_points[:, 0]
145
+ Y = proj_points[:, 1]
146
+ Z = proj_points[:, 2].clamp(min=1e-3)
147
+
148
+ pixel_coords = torch.stack([X / Z, Y / Z], dim=1).view(
149
+ b, 2, h, w
150
+ ) # [B, 2, H, W] in image scale
151
+
152
+ if return_mask:
153
+ # valid mask in pixel space
154
+ mask = (
155
+ (pixel_coords[:, 0] >= 0)
156
+ & (pixel_coords[:, 0] <= (w - 1))
157
+ & (pixel_coords[:, 1] >= 0)
158
+ & (pixel_coords[:, 1] <= (h - 1))
159
+ ) # [B, H, W]
160
+
161
+ return pixel_coords, mask
162
+
163
+ return pixel_coords
164
+
165
+
166
+ def reproject_coords(
167
+ depth_ref,
168
+ intrinsics,
169
+ extrinsics_ref=None,
170
+ extrinsics_tgt=None,
171
+ extrinsics_rel=None,
172
+ return_mask=False,
173
+ ):
174
+ # Compute reprojection sample coords
175
+ points_ref = back_project(depth_ref, intrinsics) # [B, 3, H, W]
176
+ points_tgt = camera_transform(
177
+ points_ref, extrinsics_ref, extrinsics_tgt, extrinsics_rel=extrinsics_rel
178
+ )
179
+
180
+ if return_mask:
181
+ reproj_coords, mask = reproject(
182
+ points_tgt, intrinsics, return_mask=return_mask
183
+ ) # [B, 2, H, W] in image scale
184
+
185
+ return reproj_coords, mask
186
+
187
+ reproj_coords = reproject(
188
+ points_tgt, intrinsics, return_mask=return_mask
189
+ ) # [B, 2, H, W] in image scale
190
+
191
+ return reproj_coords
192
+
193
+
194
+ def compute_flow_with_depth_pose(
195
+ depth_ref,
196
+ intrinsics,
197
+ extrinsics_ref=None,
198
+ extrinsics_tgt=None,
199
+ extrinsics_rel=None,
200
+ return_mask=False,
201
+ ):
202
+ b, h, w = depth_ref.shape
203
+ coords_init = coords_grid(b, h, w, device=depth_ref.device) # [B, 2, H, W]
204
+
205
+ if return_mask:
206
+ reproj_coords, mask = reproject_coords(
207
+ depth_ref,
208
+ intrinsics,
209
+ extrinsics_ref,
210
+ extrinsics_tgt,
211
+ extrinsics_rel=extrinsics_rel,
212
+ return_mask=return_mask,
213
+ ) # [B, 2, H, W]
214
+ rigid_flow = reproj_coords - coords_init
215
+
216
+ return rigid_flow, mask
217
+
218
+ reproj_coords = reproject_coords(
219
+ depth_ref,
220
+ intrinsics,
221
+ extrinsics_ref,
222
+ extrinsics_tgt,
223
+ extrinsics_rel=extrinsics_rel,
224
+ return_mask=return_mask,
225
+ ) # [B, 2, H, W]
226
+
227
+ rigid_flow = reproj_coords - coords_init
228
+
229
+ return rigid_flow
matching.py ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from .geometry import coords_grid, generate_window_grid, normalize_coords
5
+
6
+
7
+ def global_correlation_softmax_prototype(
8
+ feature0,
9
+ feature1,
10
+ value,
11
+ pred_bidir_flow=False,
12
+ corr_mask=None,
13
+ ):
14
+ """
15
+ feature0: [B, C, H, W]
16
+ feature1: [B, C, H, W]
17
+ value: [B, C1, H, W]
18
+ corr_mask: [B, H*W, H*W] or None, if not None, the value will be masked out
19
+ """
20
+ # global correlation
21
+ b, c, h, w = feature0.shape
22
+ c_value = value.size(1)
23
+ value = value.view(b, c_value, -1).permute(0, 2, 1) # [B, H*W, C1]
24
+
25
+ feature0 = feature0.view(b, c, -1).permute(0, 2, 1) # [B, H*W, C]
26
+ feature1 = feature1.view(b, c, -1) # [B, C, H*W]
27
+
28
+ correlation = torch.matmul(feature0, feature1).view(b, h, w, h, w) / (
29
+ c**0.5
30
+ ) # [B, H, W, H, W]
31
+
32
+ correlation = correlation.view(b, h * w, h * w) # [B, H*W, H*W]
33
+
34
+ if pred_bidir_flow:
35
+ correlation = torch.cat(
36
+ (correlation, correlation.permute(0, 2, 1)), dim=0
37
+ ) # [2*B, H*W, H*W]
38
+ value = value.repeat(2, 1, 1) # [2*B, H*W, 2]
39
+ b = b * 2
40
+
41
+ if corr_mask is not None:
42
+ # mask out the correlation with corr_mask
43
+ if corr_mask.dtype == torch.bool:
44
+ # binary mask
45
+ correlation[corr_mask] = -65504.0
46
+ prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W]
47
+ else:
48
+ # float mask
49
+ # bin_mask = corr_mask < 0
50
+ # correlation[bin_mask] = -65504.0
51
+ prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W]
52
+ # soft_mask = torch.clamp(corr_mask, min=0)
53
+ prob = prob * corr_mask
54
+ # normalize
55
+ prob = prob / (prob.sum(dim=2, keepdim=True) + 1e-8)
56
+ else:
57
+ prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W]
58
+ # if corr_mask.dtype == torch.bool:
59
+ # # binary mask
60
+ # bin_mask = corr_mask
61
+ # num_true = bin_mask.sum(-1).unique().item()
62
+ # soft_mask = torch.ones((b, h * w, num_true)).to(corr_mask.device)
63
+ # else:
64
+ # # float mask
65
+ # bin_mask = corr_mask >= 0
66
+ # num_true = bin_mask.sum(-1).unique().item()
67
+ # soft_mask = corr_mask[bin_mask].view(b, h * w, num_true)
68
+ # assert soft_mask.min() >= 0
69
+ # correlation_seleted = correlation[bin_mask].view(b, h * w, num_true)
70
+ # prob_seleted = F.softmax(correlation_seleted, dim=-1) # [B, H*W, num_true]
71
+ # prob_seleted = soft_mask * prob_seleted
72
+ # prob_seleted = prob_seleted / (prob_seleted.sum(dim=2, keepdim=True) + 1e-8)
73
+ # prob = torch.zeros_like(correlation)
74
+ # prob[bin_mask] = prob_seleted.view(-1)
75
+ # else:
76
+ # prob = F.softmax(correlation, dim=-1)
77
+
78
+ result = torch.matmul(prob, value).view(b, h, w, c_value).permute(0, 3, 1, 2) # [B, 2, H, W]\
79
+ return result, correlation
80
+
81
+
82
+ def local_correlation_softmax_prototype(
83
+ feature0,
84
+ feature1,
85
+ value,
86
+ radius=5,
87
+ pred_bidir_flow=False,
88
+ corr_mask=None,
89
+ ):
90
+ """
91
+ softmax around argmax point
92
+ feature0: [B, C, H, W]
93
+ feature1: [B, C, H, W]
94
+ value: [B, C1, H, W]
95
+ corr_mask: [B, H*W, H*W] or None, if not None, the value will be masked out
96
+ """
97
+ b, c, h, w = feature0.shape
98
+ c_value = value.size(1)
99
+ value = value.view(b, c_value, -1).permute(0, 2, 1) # [B, H*W, C1]
100
+
101
+ feature0 = feature0.view(b, c, -1).permute(0, 2, 1) # [B, H*W, C]
102
+ feature1 = feature1.view(b, c, -1) # [B, C, H*W]
103
+
104
+ correlation = torch.matmul(feature0, feature1).view(b, h, w, h, w) / (
105
+ c**0.5
106
+ ) # [B, H, W, H, W]
107
+
108
+ correlation = correlation.view(b, h * w, h * w) # [B, H*W, H*W]
109
+
110
+ if pred_bidir_flow:
111
+ correlation = torch.cat(
112
+ (correlation, correlation.permute(0, 2, 1)), dim=0
113
+ ) # [2*B, H*W, H*W]
114
+ value = value.repeat(2, 1, 1) # [2*B, H*W, 2]
115
+ b = b * 2
116
+
117
+ if corr_mask is not None:
118
+ # mask out the correlation with corr_mask
119
+ if corr_mask.dtype == torch.bool:
120
+ # binary mask
121
+ correlation[corr_mask] = -65504.0
122
+ prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W]
123
+ else:
124
+ # float mask
125
+ # bin_mask = corr_mask < 0
126
+ # correlation[bin_mask] = -65504.0
127
+ prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W]
128
+ # soft_mask = torch.clamp(corr_mask, min=0)
129
+ prob = prob * corr_mask
130
+ else:
131
+ prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W]
132
+
133
+ # get local prob
134
+ # B, H*W, 2
135
+ coords = coords_grid(b, h, w, device=feature0.device).flatten(2).permute(0, 2, 1)
136
+ # B, H*W, H*W, 2
137
+ coords = coords.unsqueeze(1).repeat(1, h * w, 1, 1)
138
+ # B, H*W
139
+ argmax_pos = torch.argmax(prob, dim=2)
140
+ # B, H*W, 1, 2
141
+ argmax_pos = argmax_pos.view(b, h * w, 1, 1).repeat(1, 1, 1, 2)
142
+ # B, H*W, 1, 2
143
+ pos = torch.gather(coords, 2, argmax_pos)
144
+ # B, H*W, H*W
145
+ valid = ((coords - pos).square().sum(dim=-1) < (radius**2)).float()
146
+ prob = prob * valid
147
+
148
+ # normalize
149
+ prob = prob / (prob.sum(dim=2, keepdim=True) + 1e-8)
150
+ result = torch.matmul(prob, value).view(b, h, w, c_value).permute(0, 3, 1, 2) # [B, 2, H, W]\
151
+ return result, correlation
152
+
153
+
154
+ def global_correlation_softmax(
155
+ feature0,
156
+ feature1,
157
+ pred_bidir_flow=False,
158
+ ):
159
+ # global correlation
160
+ b, c, h, w = feature0.shape
161
+ feature0 = feature0.view(b, c, -1).permute(0, 2, 1) # [B, H*W, C]
162
+ feature1 = feature1.view(b, c, -1) # [B, C, H*W]
163
+
164
+ correlation = torch.matmul(feature0, feature1).view(b, h, w, h, w) / (
165
+ c**0.5
166
+ ) # [B, H, W, H, W]
167
+
168
+ # flow from softmax
169
+ init_grid = coords_grid(b, h, w).to(correlation.device) # [B, 2, H, W]
170
+ grid = init_grid.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2]
171
+
172
+ correlation = correlation.view(b, h * w, h * w) # [B, H*W, H*W]
173
+
174
+ if pred_bidir_flow:
175
+ correlation = torch.cat(
176
+ (correlation, correlation.permute(0, 2, 1)), dim=0
177
+ ) # [2*B, H*W, H*W]
178
+ init_grid = init_grid.repeat(2, 1, 1, 1) # [2*B, 2, H, W]
179
+ grid = grid.repeat(2, 1, 1) # [2*B, H*W, 2]
180
+ b = b * 2
181
+
182
+ prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W]
183
+
184
+ correspondence = torch.matmul(prob, grid).view(b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W]
185
+
186
+ # when predicting bidirectional flow, flow is the concatenation of forward flow and backward flow
187
+ flow = correspondence - init_grid
188
+
189
+ return flow, prob
190
+
191
+
192
+ def local_correlation_softmax(
193
+ feature0,
194
+ feature1,
195
+ local_radius,
196
+ padding_mode="zeros",
197
+ ):
198
+ b, c, h, w = feature0.size()
199
+ coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W]
200
+ coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2]
201
+
202
+ local_h = 2 * local_radius + 1
203
+ local_w = 2 * local_radius + 1
204
+
205
+ window_grid = generate_window_grid(
206
+ -local_radius,
207
+ local_radius,
208
+ -local_radius,
209
+ local_radius,
210
+ local_h,
211
+ local_w,
212
+ device=feature0.device,
213
+ ) # [2R+1, 2R+1, 2]
214
+ window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1)^2, 2]
215
+ sample_coords = coords.unsqueeze(-2) + window_grid # [B, H*W, (2R+1)^2, 2]
216
+
217
+ sample_coords_softmax = sample_coords
218
+
219
+ # exclude coords that are out of image space
220
+ valid_x = (sample_coords[:, :, :, 0] >= 0) & (
221
+ sample_coords[:, :, :, 0] < w
222
+ ) # [B, H*W, (2R+1)^2]
223
+ valid_y = (sample_coords[:, :, :, 1] >= 0) & (
224
+ sample_coords[:, :, :, 1] < h
225
+ ) # [B, H*W, (2R+1)^2]
226
+
227
+ valid = valid_x & valid_y # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax
228
+
229
+ # normalize coordinates to [-1, 1]
230
+ sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1]
231
+ window_feature = F.grid_sample(
232
+ feature1, sample_coords_norm, padding_mode=padding_mode, align_corners=False
233
+ ).permute(
234
+ 0, 2, 1, 3
235
+ ) # [B, H*W, C, (2R+1)^2]
236
+ feature0_view = feature0.permute(0, 2, 3, 1).view(b, h * w, 1, c) # [B, H*W, 1, C]
237
+
238
+ corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (
239
+ c**0.5
240
+ ) # [B, H*W, (2R+1)^2]
241
+
242
+ # mask invalid locations
243
+ corr[~valid] = -1e9
244
+
245
+ prob = F.softmax(corr, -1) # [B, H*W, (2R+1)^2]
246
+
247
+ correspondence = (
248
+ torch.matmul(prob.unsqueeze(-2), sample_coords_softmax)
249
+ .squeeze(-2)
250
+ .view(b, h, w, 2)
251
+ .permute(0, 3, 1, 2)
252
+ ) # [B, 2, H, W]
253
+
254
+ flow = correspondence - coords_init
255
+ match_prob = prob
256
+
257
+ return flow, match_prob
258
+
259
+
260
+ def local_correlation_with_flow(
261
+ feature0,
262
+ feature1,
263
+ flow,
264
+ local_radius,
265
+ padding_mode="zeros",
266
+ dilation=1,
267
+ ):
268
+ b, c, h, w = feature0.size()
269
+ coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W]
270
+ coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2]
271
+
272
+ local_h = 2 * local_radius + 1
273
+ local_w = 2 * local_radius + 1
274
+
275
+ window_grid = generate_window_grid(
276
+ -local_radius,
277
+ local_radius,
278
+ -local_radius,
279
+ local_radius,
280
+ local_h,
281
+ local_w,
282
+ device=feature0.device,
283
+ ) # [2R+1, 2R+1, 2]
284
+ window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1)^2, 2]
285
+ sample_coords = coords.unsqueeze(-2) + window_grid * dilation # [B, H*W, (2R+1)^2, 2]
286
+
287
+ # flow can be zero when using features after transformer
288
+ if not isinstance(flow, float):
289
+ sample_coords = sample_coords + flow.view(b, 2, -1).permute(0, 2, 1).unsqueeze(
290
+ -2
291
+ ) # [B, H*W, (2R+1)^2, 2]
292
+ else:
293
+ assert flow == 0.0
294
+
295
+ # normalize coordinates to [-1, 1]
296
+ sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1]
297
+ window_feature = F.grid_sample(
298
+ feature1, sample_coords_norm, padding_mode=padding_mode, align_corners=False
299
+ ).permute(
300
+ 0, 2, 1, 3
301
+ ) # [B, H*W, C, (2R+1)^2]
302
+ feature0_view = feature0.permute(0, 2, 3, 1).view(b, h * w, 1, c) # [B, H*W, 1, C]
303
+
304
+ corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (
305
+ c**0.5
306
+ ) # [B, H*W, (2R+1)^2]
307
+
308
+ corr = corr.view(b, h, w, -1).permute(0, 3, 1, 2).contiguous() # [B, (2R+1)^2, H, W]
309
+
310
+ return corr
311
+
312
+
313
+ def global_correlation_softmax_stereo(
314
+ feature0,
315
+ feature1,
316
+ ):
317
+ # global correlation on horizontal direction
318
+ b, c, h, w = feature0.shape
319
+
320
+ x_grid = torch.linspace(0, w - 1, w, device=feature0.device) # [W]
321
+
322
+ feature0 = feature0.permute(0, 2, 3, 1) # [B, H, W, C]
323
+ feature1 = feature1.permute(0, 2, 1, 3) # [B, H, C, W]
324
+
325
+ correlation = torch.matmul(feature0, feature1) / (c**0.5) # [B, H, W, W]
326
+
327
+ # mask subsequent positions to make disparity positive
328
+ mask = torch.triu(torch.ones((w, w)), diagonal=1).type_as(feature0) # [W, W]
329
+ valid_mask = (mask == 0).unsqueeze(0).unsqueeze(0).repeat(b, h, 1, 1) # [B, H, W, W]
330
+
331
+ correlation[~valid_mask] = -1e9
332
+
333
+ prob = F.softmax(correlation, dim=-1) # [B, H, W, W]
334
+
335
+ correspondence = (x_grid.view(1, 1, 1, w) * prob).sum(-1) # [B, H, W]
336
+
337
+ # NOTE: unlike flow, disparity is typically positive
338
+ disparity = x_grid.view(1, 1, w).repeat(b, h, 1) - correspondence # [B, H, W]
339
+
340
+ return disparity.unsqueeze(1), prob # feature resolution
341
+
342
+
343
+ def local_correlation_softmax_stereo(
344
+ feature0,
345
+ feature1,
346
+ local_radius,
347
+ ):
348
+ b, c, h, w = feature0.size()
349
+ coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W]
350
+ coords = coords_init.view(b, 2, -1).permute(0, 2, 1).contiguous() # [B, H*W, 2]
351
+
352
+ local_h = 1
353
+ local_w = 2 * local_radius + 1
354
+
355
+ window_grid = generate_window_grid(
356
+ 0, 0, -local_radius, local_radius, local_h, local_w, device=feature0.device
357
+ ) # [1, 2R+1, 2]
358
+ window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1), 2]
359
+ sample_coords = coords.unsqueeze(-2) + window_grid # [B, H*W, (2R+1), 2]
360
+
361
+ sample_coords_softmax = sample_coords
362
+
363
+ # exclude coords that are out of image space
364
+ valid_x = (sample_coords[:, :, :, 0] >= 0) & (
365
+ sample_coords[:, :, :, 0] < w
366
+ ) # [B, H*W, (2R+1)^2]
367
+ valid_y = (sample_coords[:, :, :, 1] >= 0) & (
368
+ sample_coords[:, :, :, 1] < h
369
+ ) # [B, H*W, (2R+1)^2]
370
+
371
+ valid = valid_x & valid_y # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax
372
+
373
+ # normalize coordinates to [-1, 1]
374
+ sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1]
375
+ window_feature = F.grid_sample(
376
+ feature1, sample_coords_norm, padding_mode="zeros", align_corners=False
377
+ ).permute(
378
+ 0, 2, 1, 3
379
+ ) # [B, H*W, C, (2R+1)]
380
+ feature0_view = (
381
+ feature0.permute(0, 2, 3, 1).contiguous().view(b, h * w, 1, c)
382
+ ) # [B, H*W, 1, C]
383
+
384
+ corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (
385
+ c**0.5
386
+ ) # [B, H*W, (2R+1)]
387
+
388
+ # mask invalid locations
389
+ corr[~valid] = -1e9
390
+
391
+ prob = F.softmax(corr, -1) # [B, H*W, (2R+1)]
392
+
393
+ correspondence = (
394
+ torch.matmul(prob.unsqueeze(-2), sample_coords_softmax)
395
+ .squeeze(-2)
396
+ .view(b, h, w, 2)
397
+ .permute(0, 3, 1, 2)
398
+ .contiguous()
399
+ ) # [B, 2, H, W]
400
+
401
+ flow = correspondence - coords_init # flow at feature resolution
402
+ match_prob = prob
403
+
404
+ flow_x = -flow[:, :1] # [B, 1, H, W]
405
+
406
+ return flow_x, match_prob
407
+
408
+
409
+ def correlation_softmax_depth(
410
+ feature0,
411
+ feature1,
412
+ intrinsics,
413
+ pose,
414
+ depth_candidates,
415
+ depth_from_argmax=False,
416
+ pred_bidir_depth=False,
417
+ ):
418
+ b, c, h, w = feature0.size()
419
+ assert depth_candidates.dim() == 4 # [B, D, H, W]
420
+ scale_factor = c**0.5
421
+
422
+ if pred_bidir_depth:
423
+ feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat(
424
+ (feature1, feature0), dim=0
425
+ )
426
+ intrinsics = intrinsics.repeat(2, 1, 1)
427
+ pose = torch.cat((pose, torch.inverse(pose)), dim=0)
428
+ depth_candidates = depth_candidates.repeat(2, 1, 1, 1)
429
+
430
+ # depth candidates are actually inverse depth
431
+ warped_feature1 = warp_with_pose_depth_candidates(
432
+ feature1,
433
+ intrinsics,
434
+ pose,
435
+ 1.0 / depth_candidates,
436
+ ) # [B, C, D, H, W]
437
+
438
+ correlation = (feature0.unsqueeze(2) * warped_feature1).sum(1) / scale_factor # [B, D, H, W]
439
+
440
+ match_prob = F.softmax(correlation, dim=1) # [B, D, H, W]
441
+
442
+ # for cross-task transfer (flow -> depth), extract depth with argmax at test time
443
+ if depth_from_argmax:
444
+ index = torch.argmax(match_prob, dim=1, keepdim=True)
445
+ depth = torch.gather(depth_candidates, dim=1, index=index)
446
+ else:
447
+ depth = (match_prob * depth_candidates).sum(dim=1, keepdim=True) # [B, 1, H, W]
448
+
449
+ return depth, match_prob
450
+
451
+
452
+ def warp_with_pose_depth_candidates(
453
+ feature1,
454
+ intrinsics,
455
+ pose,
456
+ depth,
457
+ clamp_min_depth=1e-3,
458
+ ):
459
+ """
460
+ feature1: [B, C, H, W]
461
+ intrinsics: [B, 3, 3]
462
+ pose: [B, 4, 4]
463
+ depth: [B, D, H, W]
464
+ """
465
+
466
+ assert intrinsics.size(1) == intrinsics.size(2) == 3
467
+ assert pose.size(1) == pose.size(2) == 4
468
+ assert depth.dim() == 4
469
+
470
+ b, d, h, w = depth.size()
471
+ c = feature1.size(1)
472
+
473
+ with torch.no_grad():
474
+ # pixel coordinates
475
+ grid = coords_grid(b, h, w, homogeneous=True, device=depth.device) # [B, 3, H, W]
476
+ # back project to 3D and transform viewpoint
477
+ points = torch.inverse(intrinsics).bmm(grid.view(b, 3, -1)) # [B, 3, H*W]
478
+ points = torch.bmm(pose[:, :3, :3], points).unsqueeze(2).repeat(1, 1, d, 1) * depth.view(
479
+ b, 1, d, h * w
480
+ ) # [B, 3, D, H*W]
481
+ points = points + pose[:, :3, -1:].unsqueeze(-1) # [B, 3, D, H*W]
482
+ # reproject to 2D image plane
483
+ points = torch.bmm(intrinsics, points.view(b, 3, -1)).view(
484
+ b, 3, d, h * w
485
+ ) # [B, 3, D, H*W]
486
+ pixel_coords = points[:, :2] / points[:, -1:].clamp(min=clamp_min_depth) # [B, 2, D, H*W]
487
+
488
+ # normalize to [-1, 1]
489
+ x_grid = 2 * pixel_coords[:, 0] / (w - 1) - 1
490
+ y_grid = 2 * pixel_coords[:, 1] / (h - 1) - 1
491
+
492
+ grid = torch.stack([x_grid, y_grid], dim=-1) # [B, D, H*W, 2]
493
+
494
+ # sample features
495
+ warped_feature = F.grid_sample(
496
+ feature1,
497
+ grid.view(b, d * h, w, 2),
498
+ mode="bilinear",
499
+ padding_mode="zeros",
500
+ align_corners=False,
501
+ ).view(
502
+ b, c, d, h, w
503
+ ) # [B, C, D, H, W]
504
+
505
+ return warped_feature
modeling_doduo.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Tuple
2
+ from PIL import Image
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from torchvision import transforms
9
+ from transformers import PreTrainedModel
10
+
11
+ from .dino import vit_small
12
+ from .unimatch import UniMatch
13
+ from .configuration_doduo import DoduoConfig
14
+
15
+ class DoduoModel(PreTrainedModel):
16
+ config_class = DoduoConfig
17
+
18
+ def __init__(self, config):
19
+ super().__init__(config)
20
+ self.model = CorrSegFlowNet(
21
+ dino_corr_mask_ratio=config.dino_corr_mask_ratio
22
+ )
23
+
24
+ def forward(self, frame_src, frame_dst):
25
+ if isinstance(frame_src, Image.Image):
26
+ frame_src = self.model.process_frame(frame_src)
27
+ frame_dst = self.model.process_frame(frame_dst)
28
+ assert frame_src.shape == frame_dst.shape
29
+ return self.model(frame_src, frame_dst)
30
+
31
+ class CorrSegFlowNet(nn.Module):
32
+ def __init__(
33
+ self,
34
+ dino_corr_mask_ratio: float = 0.1,
35
+ ):
36
+ super().__init__()
37
+
38
+ self.dino_corr_mask_ratio = dino_corr_mask_ratio
39
+ self.unimatch = UniMatch(bilinear_upsample=True)
40
+ self.dino = vit_small(patch_size=8, num_classes=0)
41
+ for k in self.dino.parameters():
42
+ k.requires_grad = False
43
+
44
+ self.transform = transforms.Compose(
45
+ [
46
+ lambda x: transforms.ToTensor()(x)[:3],
47
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
48
+ ]
49
+ )
50
+
51
+ def process_frame(self, frame):
52
+ device = next(self.parameters()).device
53
+ frame = self.transform(frame)
54
+ frame = frame.unsqueeze(0).to(device)
55
+ return frame
56
+
57
+ def forward(
58
+ self,
59
+ frame_src,
60
+ frame_dst,
61
+ ):
62
+ corr_mask = get_dino_corr_mask(
63
+ self.dino,
64
+ frame_src,
65
+ frame_dst,
66
+ mask_ratio=self.dino_corr_mask_ratio
67
+ )
68
+
69
+ flow, flow_low, correlation, feature0, feature1 = self.unimatch(
70
+ frame_src,
71
+ frame_dst,
72
+ return_feature=True,
73
+ bidirectional=False,
74
+ cycle_consistency=False,
75
+ corr_mask=corr_mask,
76
+ )
77
+ return flow
78
+
79
+ @torch.no_grad()
80
+ def extract_dino_feature(model, frame, return_h_w=False):
81
+ """frame: B, C, H, W"""
82
+ B = frame.shape[0]
83
+ out = model.get_intermediate_layers(frame, n=1)[0]
84
+ out = out[:, 1:, :] # we discard the [CLS] token
85
+ h, w = int(frame.shape[2] / model.patch_embed.patch_size), int(
86
+ frame.shape[3] / model.patch_embed.patch_size
87
+ )
88
+ dim = out.shape[-1]
89
+ out = out.reshape(B, -1, dim)
90
+ if return_h_w:
91
+ return out, h, w
92
+ return out
93
+
94
+ @torch.no_grad()
95
+ def get_dino_corr_mask(
96
+ model, frame_src, frame_dst, mask_ratio
97
+ ):
98
+ # frame_src: B x C x H x W
99
+ # frame_dst: B x C x H x W
100
+ # mask_ratio: ratio of pixels to be masked
101
+ # return: B x h*w x h*w
102
+ feat_1, h, w = extract_dino_feature(model, frame_src, return_h_w=True)
103
+ feat_2 = extract_dino_feature(model, frame_dst)
104
+
105
+ feat_1_norm = F.normalize(feat_1, dim=2, p=2)
106
+ feat_2_norm = F.normalize(feat_2, dim=2, p=2)
107
+ aff_raw = torch.einsum("bnc,bmc->bnm", [feat_1_norm, feat_2_norm])
108
+
109
+ if mask_ratio <= 0:
110
+ # no corr mask
111
+ corr_mask = None
112
+ else:
113
+ if aff_raw.dtype == torch.float16:
114
+ aff_raw = aff_raw.float()
115
+ aff_percentile = torch.quantile(aff_raw, mask_ratio, 2, keepdim=True)
116
+ # True for masked
117
+ corr_mask = aff_raw < aff_percentile
118
+ return corr_mask
position.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ # https://github.com/facebookresearch/detr/blob/main/models/position_encoding.py
3
+
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+
10
+ class PositionEmbeddingSine(nn.Module):
11
+ """This is a more standard version of the position embedding, very similar to the one used by
12
+ the Attention is all you need paper, generalized to work on images."""
13
+
14
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None):
15
+ super().__init__()
16
+ self.num_pos_feats = num_pos_feats
17
+ self.temperature = temperature
18
+ self.normalize = normalize
19
+ if scale is not None and normalize is False:
20
+ raise ValueError("normalize should be True if scale is passed")
21
+ if scale is None:
22
+ scale = 2 * math.pi
23
+ self.scale = scale
24
+
25
+ def forward(self, x):
26
+ # x = tensor_list.tensors # [B, C, H, W]
27
+ # mask = tensor_list.mask # [B, H, W], input with padding, valid as 0
28
+ b, c, h, w = x.size()
29
+ mask = torch.ones((b, h, w), device=x.device) # [B, H, W]
30
+ y_embed = mask.cumsum(1, dtype=torch.float32)
31
+ x_embed = mask.cumsum(2, dtype=torch.float32)
32
+ if self.normalize:
33
+ eps = 1e-6
34
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
35
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
36
+
37
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
38
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
39
+
40
+ pos_x = x_embed[:, :, :, None] / dim_t
41
+ pos_y = y_embed[:, :, :, None] / dim_t
42
+ pos_x = torch.stack(
43
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
44
+ ).flatten(3)
45
+ pos_y = torch.stack(
46
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
47
+ ).flatten(3)
48
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
49
+ return pos
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5340e547b103a76197fe408906930793aea54ed21a4c2f7845e73d1ca5810022
3
+ size 103562927
transformer.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .attention import (
5
+ single_head_full_attention,
6
+ single_head_full_attention_1d,
7
+ single_head_split_window_attention,
8
+ single_head_split_window_attention_1d,
9
+ )
10
+ from .utils import generate_shift_window_attn_mask, generate_shift_window_attn_mask_1d
11
+
12
+
13
+ class TransformerLayer(nn.Module):
14
+ def __init__(
15
+ self,
16
+ d_model=128,
17
+ nhead=1,
18
+ no_ffn=False,
19
+ ffn_dim_expansion=4,
20
+ ):
21
+ super().__init__()
22
+
23
+ self.dim = d_model
24
+ self.nhead = nhead
25
+ self.no_ffn = no_ffn
26
+
27
+ # multi-head attention
28
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
29
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
30
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
31
+
32
+ self.merge = nn.Linear(d_model, d_model, bias=False)
33
+
34
+ self.norm1 = nn.LayerNorm(d_model)
35
+
36
+ # no ffn after self-attn, with ffn after cross-attn
37
+ if not self.no_ffn:
38
+ in_channels = d_model * 2
39
+ self.mlp = nn.Sequential(
40
+ nn.Linear(in_channels, in_channels * ffn_dim_expansion, bias=False),
41
+ nn.GELU(),
42
+ nn.Linear(in_channels * ffn_dim_expansion, d_model, bias=False),
43
+ )
44
+
45
+ self.norm2 = nn.LayerNorm(d_model)
46
+
47
+ def forward(
48
+ self,
49
+ source,
50
+ target,
51
+ height=None,
52
+ width=None,
53
+ shifted_window_attn_mask=None,
54
+ shifted_window_attn_mask_1d=None,
55
+ attn_type="swin",
56
+ with_shift=False,
57
+ attn_num_splits=None,
58
+ ):
59
+ # source, target: [B, L, C]
60
+ query, key, value = source, target, target
61
+
62
+ # for stereo: 2d attn in self-attn, 1d attn in cross-attn
63
+ is_self_attn = (query - key).abs().max() < 1e-6
64
+
65
+ # single-head attention
66
+ query = self.q_proj(query) # [B, L, C]
67
+ key = self.k_proj(key) # [B, L, C]
68
+ value = self.v_proj(value) # [B, L, C]
69
+
70
+ if attn_type == "swin" and attn_num_splits > 1: # self, cross-attn: both swin 2d
71
+ if self.nhead > 1:
72
+ # we observe that multihead attention slows down the speed and increases the memory consumption
73
+ # without bringing obvious performance gains and thus the implementation is removed
74
+ raise NotImplementedError
75
+ else:
76
+ message = single_head_split_window_attention(
77
+ query,
78
+ key,
79
+ value,
80
+ num_splits=attn_num_splits,
81
+ with_shift=with_shift,
82
+ h=height,
83
+ w=width,
84
+ attn_mask=shifted_window_attn_mask,
85
+ )
86
+
87
+ elif attn_type == "self_swin2d_cross_1d": # self-attn: swin 2d, cross-attn: full 1d
88
+ if self.nhead > 1:
89
+ raise NotImplementedError
90
+ else:
91
+ if is_self_attn:
92
+ if attn_num_splits > 1:
93
+ message = single_head_split_window_attention(
94
+ query,
95
+ key,
96
+ value,
97
+ num_splits=attn_num_splits,
98
+ with_shift=with_shift,
99
+ h=height,
100
+ w=width,
101
+ attn_mask=shifted_window_attn_mask,
102
+ )
103
+ else:
104
+ # full 2d attn
105
+ message = single_head_full_attention(query, key, value) # [N, L, C]
106
+
107
+ else:
108
+ # cross attn 1d
109
+ message = single_head_full_attention_1d(
110
+ query,
111
+ key,
112
+ value,
113
+ h=height,
114
+ w=width,
115
+ )
116
+
117
+ elif attn_type == "self_swin2d_cross_swin1d": # self-attn: swin 2d, cross-attn: swin 1d
118
+ if self.nhead > 1:
119
+ raise NotImplementedError
120
+ else:
121
+ if is_self_attn:
122
+ if attn_num_splits > 1:
123
+ # self attn shift window
124
+ message = single_head_split_window_attention(
125
+ query,
126
+ key,
127
+ value,
128
+ num_splits=attn_num_splits,
129
+ with_shift=with_shift,
130
+ h=height,
131
+ w=width,
132
+ attn_mask=shifted_window_attn_mask,
133
+ )
134
+ else:
135
+ # full 2d attn
136
+ message = single_head_full_attention(query, key, value) # [N, L, C]
137
+ else:
138
+ if attn_num_splits > 1:
139
+ assert shifted_window_attn_mask_1d is not None
140
+ # cross attn 1d shift
141
+ message = single_head_split_window_attention_1d(
142
+ query,
143
+ key,
144
+ value,
145
+ num_splits=attn_num_splits,
146
+ with_shift=with_shift,
147
+ h=height,
148
+ w=width,
149
+ attn_mask=shifted_window_attn_mask_1d,
150
+ )
151
+ else:
152
+ message = single_head_full_attention_1d(
153
+ query,
154
+ key,
155
+ value,
156
+ h=height,
157
+ w=width,
158
+ )
159
+
160
+ else:
161
+ message = single_head_full_attention(query, key, value) # [B, L, C]
162
+
163
+ message = self.merge(message) # [B, L, C]
164
+ message = self.norm1(message)
165
+
166
+ if not self.no_ffn:
167
+ message = self.mlp(torch.cat([source, message], dim=-1))
168
+ message = self.norm2(message)
169
+
170
+ return source + message
171
+
172
+
173
+ class TransformerBlock(nn.Module):
174
+ """self attention + cross attention + FFN."""
175
+
176
+ def __init__(
177
+ self,
178
+ d_model=128,
179
+ nhead=1,
180
+ ffn_dim_expansion=4,
181
+ ):
182
+ super().__init__()
183
+
184
+ self.self_attn = TransformerLayer(
185
+ d_model=d_model,
186
+ nhead=nhead,
187
+ no_ffn=True,
188
+ ffn_dim_expansion=ffn_dim_expansion,
189
+ )
190
+
191
+ self.cross_attn_ffn = TransformerLayer(
192
+ d_model=d_model,
193
+ nhead=nhead,
194
+ ffn_dim_expansion=ffn_dim_expansion,
195
+ )
196
+
197
+ def forward(
198
+ self,
199
+ source,
200
+ target,
201
+ height=None,
202
+ width=None,
203
+ shifted_window_attn_mask=None,
204
+ shifted_window_attn_mask_1d=None,
205
+ attn_type="swin",
206
+ with_shift=False,
207
+ attn_num_splits=None,
208
+ ):
209
+ # source, target: [B, L, C]
210
+
211
+ # self attention
212
+ source = self.self_attn(
213
+ source,
214
+ source,
215
+ height=height,
216
+ width=width,
217
+ shifted_window_attn_mask=shifted_window_attn_mask,
218
+ attn_type=attn_type,
219
+ with_shift=with_shift,
220
+ attn_num_splits=attn_num_splits,
221
+ )
222
+
223
+ # cross attention and ffn
224
+ source = self.cross_attn_ffn(
225
+ source,
226
+ target,
227
+ height=height,
228
+ width=width,
229
+ shifted_window_attn_mask=shifted_window_attn_mask,
230
+ shifted_window_attn_mask_1d=shifted_window_attn_mask_1d,
231
+ attn_type=attn_type,
232
+ with_shift=with_shift,
233
+ attn_num_splits=attn_num_splits,
234
+ )
235
+
236
+ return source
237
+
238
+
239
+ class FeatureTransformer(nn.Module):
240
+ def __init__(
241
+ self,
242
+ num_layers=6,
243
+ d_model=128,
244
+ nhead=1,
245
+ ffn_dim_expansion=4,
246
+ ):
247
+ super().__init__()
248
+
249
+ self.d_model = d_model
250
+ self.nhead = nhead
251
+
252
+ self.layers = nn.ModuleList(
253
+ [
254
+ TransformerBlock(
255
+ d_model=d_model,
256
+ nhead=nhead,
257
+ ffn_dim_expansion=ffn_dim_expansion,
258
+ )
259
+ for i in range(num_layers)
260
+ ]
261
+ )
262
+
263
+ for p in self.parameters():
264
+ if p.dim() > 1:
265
+ nn.init.xavier_uniform_(p)
266
+
267
+ def forward(
268
+ self,
269
+ feature0,
270
+ feature1,
271
+ attn_type="swin",
272
+ attn_num_splits=None,
273
+ **kwargs,
274
+ ):
275
+
276
+ b, c, h, w = feature0.shape
277
+ assert self.d_model == c
278
+
279
+ feature0 = feature0.flatten(-2).permute(0, 2, 1) # [B, H*W, C]
280
+ feature1 = feature1.flatten(-2).permute(0, 2, 1) # [B, H*W, C]
281
+
282
+ # 2d attention
283
+ if "swin" in attn_type and attn_num_splits > 1:
284
+ # global and refine use different number of splits
285
+ window_size_h = h // attn_num_splits
286
+ window_size_w = w // attn_num_splits
287
+
288
+ # compute attn mask once
289
+ shifted_window_attn_mask = generate_shift_window_attn_mask(
290
+ input_resolution=(h, w),
291
+ window_size_h=window_size_h,
292
+ window_size_w=window_size_w,
293
+ shift_size_h=window_size_h // 2,
294
+ shift_size_w=window_size_w // 2,
295
+ device=feature0.device,
296
+ ) # [K*K, H/K*W/K, H/K*W/K]
297
+ else:
298
+ shifted_window_attn_mask = None
299
+
300
+ # 1d attention
301
+ if "swin1d" in attn_type and attn_num_splits > 1:
302
+ window_size_w = w // attn_num_splits
303
+
304
+ # compute attn mask once
305
+ shifted_window_attn_mask_1d = generate_shift_window_attn_mask_1d(
306
+ input_w=w,
307
+ window_size_w=window_size_w,
308
+ shift_size_w=window_size_w // 2,
309
+ device=feature0.device,
310
+ ) # [K, W/K, W/K]
311
+ else:
312
+ shifted_window_attn_mask_1d = None
313
+
314
+ # concat feature0 and feature1 in batch dimension to compute in parallel
315
+ concat0 = torch.cat((feature0, feature1), dim=0) # [2B, H*W, C]
316
+ concat1 = torch.cat((feature1, feature0), dim=0) # [2B, H*W, C]
317
+
318
+ for i, layer in enumerate(self.layers):
319
+ concat0 = layer(
320
+ concat0,
321
+ concat1,
322
+ height=h,
323
+ width=w,
324
+ attn_type=attn_type,
325
+ with_shift="swin" in attn_type and attn_num_splits > 1 and i % 2 == 1,
326
+ attn_num_splits=attn_num_splits,
327
+ shifted_window_attn_mask=shifted_window_attn_mask,
328
+ shifted_window_attn_mask_1d=shifted_window_attn_mask_1d,
329
+ )
330
+
331
+ # update feature1
332
+ concat1 = torch.cat(concat0.chunk(chunks=2, dim=0)[::-1], dim=0)
333
+
334
+ feature0, feature1 = concat0.chunk(chunks=2, dim=0) # [B, H*W, C]
335
+
336
+ # reshape back
337
+ feature0 = feature0.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W]
338
+ feature1 = feature1.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W]
339
+
340
+ return feature0, feature1
trident_conv.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # https://github.com/facebookresearch/detectron2/blob/main/projects/TridentNet/tridentnet/trident_conv.py
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+ from torch.nn.modules.utils import _pair
8
+
9
+
10
+ class MultiScaleTridentConv(nn.Module):
11
+ def __init__(
12
+ self,
13
+ in_channels,
14
+ out_channels,
15
+ kernel_size,
16
+ stride=1,
17
+ strides=1,
18
+ paddings=0,
19
+ dilations=1,
20
+ dilation=1,
21
+ groups=1,
22
+ num_branch=1,
23
+ test_branch_idx=-1,
24
+ bias=False,
25
+ norm=None,
26
+ activation=None,
27
+ ):
28
+ super().__init__()
29
+ self.in_channels = in_channels
30
+ self.out_channels = out_channels
31
+ self.kernel_size = _pair(kernel_size)
32
+ self.num_branch = num_branch
33
+ self.stride = _pair(stride)
34
+ self.groups = groups
35
+ self.with_bias = bias
36
+ self.dilation = dilation
37
+ if isinstance(paddings, int):
38
+ paddings = [paddings] * self.num_branch
39
+ if isinstance(dilations, int):
40
+ dilations = [dilations] * self.num_branch
41
+ if isinstance(strides, int):
42
+ strides = [strides] * self.num_branch
43
+ self.paddings = [_pair(padding) for padding in paddings]
44
+ self.dilations = [_pair(dilation) for dilation in dilations]
45
+ self.strides = [_pair(stride) for stride in strides]
46
+ self.test_branch_idx = test_branch_idx
47
+ self.norm = norm
48
+ self.activation = activation
49
+
50
+ assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1
51
+
52
+ self.weight = nn.Parameter(
53
+ torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)
54
+ )
55
+ if bias:
56
+ self.bias = nn.Parameter(torch.Tensor(out_channels))
57
+ else:
58
+ self.bias = None
59
+
60
+ nn.init.kaiming_uniform_(self.weight, nonlinearity="relu")
61
+ if self.bias is not None:
62
+ nn.init.constant_(self.bias, 0)
63
+
64
+ def forward(self, inputs):
65
+ num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1
66
+ assert len(inputs) == num_branch
67
+
68
+ if self.training or self.test_branch_idx == -1:
69
+ outputs = [
70
+ F.conv2d(
71
+ input, self.weight, self.bias, stride, padding, self.dilation, self.groups
72
+ )
73
+ for input, stride, padding in zip(inputs, self.strides, self.paddings)
74
+ ]
75
+ else:
76
+ outputs = [
77
+ F.conv2d(
78
+ inputs[0],
79
+ self.weight,
80
+ self.bias,
81
+ self.strides[self.test_branch_idx]
82
+ if self.test_branch_idx == -1
83
+ else self.strides[-1],
84
+ self.paddings[self.test_branch_idx]
85
+ if self.test_branch_idx == -1
86
+ else self.paddings[-1],
87
+ self.dilation,
88
+ self.groups,
89
+ )
90
+ ]
91
+
92
+ if self.norm is not None:
93
+ outputs = [self.norm(x) for x in outputs]
94
+ if self.activation is not None:
95
+ outputs = [self.activation(x) for x in outputs]
96
+ return outputs
unimatch.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .backbone import CNNEncoder
6
+ from .geometry import coords_grid
7
+ from .matching import (
8
+ global_correlation_softmax_prototype,
9
+ local_correlation_softmax_prototype,
10
+ )
11
+ from .transformer import FeatureTransformer
12
+ from .utils import feature_add_position
13
+
14
+
15
+ class UniMatch(nn.Module):
16
+ def __init__(
17
+ self,
18
+ num_scales=1,
19
+ feature_channels=128,
20
+ upsample_factor=8,
21
+ num_head=1,
22
+ ffn_dim_expansion=4,
23
+ num_transformer_layers=6,
24
+ bilinear_upsample=False,
25
+ corr_fn="global",
26
+ ):
27
+ super().__init__()
28
+
29
+ self.feature_channels = feature_channels
30
+ self.num_scales = num_scales
31
+ self.upsample_factor = upsample_factor
32
+ self.bilinear_upsample = bilinear_upsample
33
+ if corr_fn == "global":
34
+ self.corr_fn = global_correlation_softmax_prototype
35
+ elif corr_fn == "local":
36
+ self.corr_fn = local_correlation_softmax_prototype
37
+ else:
38
+ raise NotImplementedError(f"Correlation function {corr_fn} not implemented")
39
+
40
+ # CNN
41
+ self.backbone = CNNEncoder(output_dim=feature_channels, num_output_scales=num_scales)
42
+
43
+ # Transformer
44
+ self.transformer = FeatureTransformer(
45
+ num_layers=num_transformer_layers,
46
+ d_model=feature_channels,
47
+ nhead=num_head,
48
+ ffn_dim_expansion=ffn_dim_expansion,
49
+ )
50
+
51
+ # convex upsampling similar to RAFT
52
+ # concat feature0 and low res flow as input
53
+ if not bilinear_upsample:
54
+ self.upsampler = nn.Sequential(
55
+ nn.Conv2d(2 + feature_channels, 256, 3, 1, 1),
56
+ nn.ReLU(inplace=True),
57
+ nn.Conv2d(256, upsample_factor**2 * 9, 1, 1, 0),
58
+ )
59
+
60
+ def extract_feature(self, img0, img1):
61
+ concat = torch.cat((img0, img1), dim=0) # [2B, C, H, W]
62
+ features = self.backbone(concat) # list of [2B, C, H, W], resolution from high to low
63
+
64
+ # reverse: resolution from low to high
65
+ features = features[::-1]
66
+
67
+ feature0, feature1 = [], []
68
+
69
+ for i in range(len(features)):
70
+ feature = features[i]
71
+ chunks = torch.chunk(feature, 2, 0) # tuple
72
+ feature0.append(chunks[0])
73
+ feature1.append(chunks[1])
74
+
75
+ return feature0, feature1
76
+
77
+ def correlate_feature(self, feature0, feature1, attn_splits=2, attn_type="swin"):
78
+ feature0, feature1 = feature_add_position(
79
+ feature0, feature1, attn_splits, self.feature_channels
80
+ )
81
+ feature0, feature1 = self.transformer(
82
+ feature0,
83
+ feature1,
84
+ attn_type=attn_type,
85
+ attn_num_splits=attn_splits,
86
+ )
87
+ b, c, h, w = feature0.shape
88
+ feature0 = feature0.view(b, c, -1).permute(0, 2, 1) # [B, H*W, C]
89
+ feature1 = feature1.view(b, c, -1) # [B, C, H*W]
90
+ correlation = torch.matmul(feature0, feature1).view(b, h, w, h, w) / (
91
+ c**0.5
92
+ ) # [B, H, W, H, W]
93
+ correlation = correlation.view(b, h * w, h * w) # [B, H*W, H*W]
94
+ return correlation
95
+
96
+ def forward(
97
+ self,
98
+ img0,
99
+ img1,
100
+ attn_type="swin",
101
+ attn_splits=2,
102
+ return_feature=False,
103
+ bidirectional=False,
104
+ cycle_consistency=False,
105
+ corr_mask=None,
106
+ ):
107
+ # list of features, resolution low to high
108
+ feature0_list, feature1_list = self.extract_feature(img0, img1) # list of features
109
+ assert self.num_scales == 1 # multi-scale depth model is not supported yet
110
+ scale_idx = 0
111
+ feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx]
112
+
113
+ if cycle_consistency:
114
+ # get both directions of features
115
+ feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat(
116
+ (feature1, feature0), dim=0
117
+ )
118
+
119
+ # add position to features
120
+ feature0, feature1 = feature_add_position(
121
+ feature0, feature1, attn_splits, self.feature_channels
122
+ )
123
+
124
+ # Transformer
125
+ feature0, feature1 = self.transformer(
126
+ feature0,
127
+ feature1,
128
+ attn_type=attn_type,
129
+ attn_num_splits=attn_splits,
130
+ )
131
+ b, c, h, w = feature0.shape
132
+ # downsampled_img0 = F.interpolate(img0, size=(h, w), mode="bilinear", align_corners=False)
133
+ flow_coords = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W]
134
+ # values = torch.cat((downsampled_img0, flow_coords), dim=1) # [B, 5, H, W]
135
+ # correlation and softmax
136
+ query_results, correlation = self.corr_fn(
137
+ feature0, feature1, flow_coords, pred_bidir_flow=bidirectional, corr_mask=corr_mask
138
+ )
139
+ if bidirectional:
140
+ flow_coords = torch.cat((flow_coords, flow_coords), dim=0)
141
+ up_feature = torch.cat((feature0, feature1), dim=0)
142
+ else:
143
+ up_feature = feature0
144
+ flow = query_results - flow_coords
145
+ flow_up = self.upsample_flow(flow, up_feature, bilinear=self.bilinear_upsample)
146
+ if return_feature:
147
+ return flow_up, flow, correlation, feature0, feature1
148
+ else:
149
+ return flow_up, flow, correlation
150
+
151
+ def forward_features(
152
+ self,
153
+ img0,
154
+ img1,
155
+ attn_type="swin",
156
+ attn_splits=2,
157
+ ):
158
+
159
+ feature0_list, feature1_list = self.extract_feature(img0, img1) # list of features
160
+ assert self.num_scales == 1 # multi-scale depth model is not supported yet
161
+ scale_idx = 0
162
+ feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx]
163
+ # add position to features
164
+ feature0, feature1 = feature_add_position(
165
+ feature0, feature1, attn_splits, self.feature_channels
166
+ )
167
+
168
+ # Transformer
169
+ feature0, feature1 = self.transformer(
170
+ feature0,
171
+ feature1,
172
+ attn_type=attn_type,
173
+ attn_num_splits=attn_splits,
174
+ )
175
+ return feature0, feature1
176
+
177
+ def upsample_flow(self, flow, feature, bilinear=False, upsample_factor=8, is_depth=False):
178
+ if bilinear:
179
+ multiplier = 1 if is_depth else upsample_factor
180
+ up_flow = (
181
+ F.interpolate(
182
+ flow, scale_factor=upsample_factor, mode="bilinear", align_corners=False
183
+ )
184
+ * multiplier
185
+ )
186
+ else:
187
+ concat = torch.cat((flow, feature), dim=1)
188
+ mask = self.upsampler(concat)
189
+ up_flow = upsample_flow_with_mask(
190
+ flow, mask, upsample_factor=self.upsample_factor, is_depth=is_depth
191
+ )
192
+ return up_flow
193
+
194
+
195
+ def upsample_flow_with_mask(flow, up_mask, upsample_factor, is_depth=False):
196
+ # convex upsampling following raft
197
+
198
+ mask = up_mask
199
+ b, flow_channel, h, w = flow.shape
200
+ mask = mask.view(b, 1, 9, upsample_factor, upsample_factor, h, w) # [B, 1, 9, K, K, H, W]
201
+ mask = torch.softmax(mask, dim=2)
202
+
203
+ multiplier = 1 if is_depth else upsample_factor
204
+ up_flow = F.unfold(multiplier * flow, [3, 3], padding=1)
205
+ up_flow = up_flow.view(b, flow_channel, 9, 1, 1, h, w) # [B, 2, 9, 1, 1, H, W]
206
+
207
+ up_flow = torch.sum(mask * up_flow, dim=2) # [B, 2, K, K, H, W]
208
+ up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) # [B, 2, K, H, K, W]
209
+ up_flow = up_flow.reshape(
210
+ b, flow_channel, upsample_factor * h, upsample_factor * w
211
+ ) # [B, 2, K*H, K*W]
212
+
213
+ return up_flow
utils.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from .position import PositionEmbeddingSine
5
+
6
+
7
+ def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None):
8
+ assert device is not None
9
+
10
+ x, y = torch.meshgrid(
11
+ [
12
+ torch.linspace(w_min, w_max, len_w, device=device),
13
+ torch.linspace(h_min, h_max, len_h, device=device),
14
+ ],
15
+ )
16
+ grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2]
17
+
18
+ return grid
19
+
20
+
21
+ def normalize_coords(coords, h, w):
22
+ # coords: [B, H, W, 2]
23
+ c = torch.Tensor([(w - 1) / 2.0, (h - 1) / 2.0]).float().to(coords.device)
24
+ return (coords - c) / c # [-1, 1]
25
+
26
+
27
+ def normalize_img(img0, img1):
28
+ # loaded images are in [0, 255]
29
+ # normalize by ImageNet mean and std
30
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(img1.device)
31
+ std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(img1.device)
32
+ img0 = (img0 / 255.0 - mean) / std
33
+ img1 = (img1 / 255.0 - mean) / std
34
+
35
+ return img0, img1
36
+
37
+
38
+ def split_feature(
39
+ feature,
40
+ num_splits=2,
41
+ channel_last=False,
42
+ ):
43
+ if channel_last: # [B, H, W, C]
44
+ b, h, w, c = feature.size()
45
+ assert h % num_splits == 0 and w % num_splits == 0
46
+
47
+ b_new = b * num_splits * num_splits
48
+ h_new = h // num_splits
49
+ w_new = w // num_splits
50
+
51
+ feature = (
52
+ feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c)
53
+ .permute(0, 1, 3, 2, 4, 5)
54
+ .reshape(b_new, h_new, w_new, c)
55
+ ) # [B*K*K, H/K, W/K, C]
56
+ else: # [B, C, H, W]
57
+ b, c, h, w = feature.size()
58
+ assert h % num_splits == 0 and w % num_splits == 0
59
+
60
+ b_new = b * num_splits * num_splits
61
+ h_new = h // num_splits
62
+ w_new = w // num_splits
63
+
64
+ feature = (
65
+ feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits)
66
+ .permute(0, 2, 4, 1, 3, 5)
67
+ .reshape(b_new, c, h_new, w_new)
68
+ ) # [B*K*K, C, H/K, W/K]
69
+
70
+ return feature
71
+
72
+
73
+ def merge_splits(
74
+ splits,
75
+ num_splits=2,
76
+ channel_last=False,
77
+ ):
78
+ if channel_last: # [B*K*K, H/K, W/K, C]
79
+ b, h, w, c = splits.size()
80
+ new_b = b // num_splits // num_splits
81
+
82
+ splits = splits.view(new_b, num_splits, num_splits, h, w, c)
83
+ merge = (
84
+ splits.permute(0, 1, 3, 2, 4, 5)
85
+ .contiguous()
86
+ .view(new_b, num_splits * h, num_splits * w, c)
87
+ ) # [B, H, W, C]
88
+ else: # [B*K*K, C, H/K, W/K]
89
+ b, c, h, w = splits.size()
90
+ new_b = b // num_splits // num_splits
91
+
92
+ splits = splits.view(new_b, num_splits, num_splits, c, h, w)
93
+ merge = (
94
+ splits.permute(0, 3, 1, 4, 2, 5)
95
+ .contiguous()
96
+ .view(new_b, c, num_splits * h, num_splits * w)
97
+ ) # [B, C, H, W]
98
+
99
+ return merge
100
+
101
+
102
+ def generate_shift_window_attn_mask(
103
+ input_resolution,
104
+ window_size_h,
105
+ window_size_w,
106
+ shift_size_h,
107
+ shift_size_w,
108
+ device=torch.device("cuda"),
109
+ ):
110
+ # ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
111
+ # calculate attention mask for SW-MSA
112
+ h, w = input_resolution
113
+ img_mask = torch.zeros((1, h, w, 1)).to(device) # 1 H W 1
114
+ h_slices = (
115
+ slice(0, -window_size_h),
116
+ slice(-window_size_h, -shift_size_h),
117
+ slice(-shift_size_h, None),
118
+ )
119
+ w_slices = (
120
+ slice(0, -window_size_w),
121
+ slice(-window_size_w, -shift_size_w),
122
+ slice(-shift_size_w, None),
123
+ )
124
+ cnt = 0
125
+ for h in h_slices:
126
+ for w in w_slices:
127
+ img_mask[:, h, w, :] = cnt
128
+ cnt += 1
129
+
130
+ mask_windows = split_feature(
131
+ img_mask, num_splits=input_resolution[-1] // window_size_w, channel_last=True
132
+ )
133
+
134
+ mask_windows = mask_windows.view(-1, window_size_h * window_size_w)
135
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
136
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
137
+ attn_mask == 0, float(0.0)
138
+ )
139
+
140
+ return attn_mask
141
+
142
+
143
+ def feature_add_position(feature0, feature1, attn_splits, feature_channels):
144
+ pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2)
145
+
146
+ if attn_splits > 1: # add position in splited window
147
+ feature0_splits = split_feature(feature0, num_splits=attn_splits)
148
+ feature1_splits = split_feature(feature1, num_splits=attn_splits)
149
+
150
+ position = pos_enc(feature0_splits)
151
+
152
+ feature0_splits = feature0_splits + position
153
+ feature1_splits = feature1_splits + position
154
+
155
+ feature0 = merge_splits(feature0_splits, num_splits=attn_splits)
156
+ feature1 = merge_splits(feature1_splits, num_splits=attn_splits)
157
+ else:
158
+ position = pos_enc(feature0)
159
+
160
+ feature0 = feature0 + position
161
+ feature1 = feature1 + position
162
+
163
+ return feature0, feature1
164
+
165
+
166
+ def upsample_flow_with_mask(flow, up_mask, upsample_factor, is_depth=False):
167
+ # convex upsampling following raft
168
+
169
+ mask = up_mask
170
+ b, flow_channel, h, w = flow.shape
171
+ mask = mask.view(b, 1, 9, upsample_factor, upsample_factor, h, w) # [B, 1, 9, K, K, H, W]
172
+ mask = torch.softmax(mask, dim=2)
173
+
174
+ multiplier = 1 if is_depth else upsample_factor
175
+ up_flow = F.unfold(multiplier * flow, [3, 3], padding=1)
176
+ up_flow = up_flow.view(b, flow_channel, 9, 1, 1, h, w) # [B, 2, 9, 1, 1, H, W]
177
+
178
+ up_flow = torch.sum(mask * up_flow, dim=2) # [B, 2, K, K, H, W]
179
+ up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) # [B, 2, K, H, K, W]
180
+ up_flow = up_flow.reshape(
181
+ b, flow_channel, upsample_factor * h, upsample_factor * w
182
+ ) # [B, 2, K*H, K*W]
183
+
184
+ return up_flow
185
+
186
+
187
+ def split_feature_1d(
188
+ feature,
189
+ num_splits=2,
190
+ ):
191
+ # feature: [B, W, C]
192
+ b, w, c = feature.size()
193
+ assert w % num_splits == 0
194
+
195
+ b_new = b * num_splits
196
+ w_new = w // num_splits
197
+
198
+ feature = feature.view(b, num_splits, w // num_splits, c).view(
199
+ b_new, w_new, c
200
+ ) # [B*K, W/K, C]
201
+
202
+ return feature
203
+
204
+
205
+ def merge_splits_1d(
206
+ splits,
207
+ h,
208
+ num_splits=2,
209
+ ):
210
+ b, w, c = splits.size()
211
+ new_b = b // num_splits // h
212
+
213
+ splits = splits.view(new_b, h, num_splits, w, c)
214
+ merge = splits.view(new_b, h, num_splits * w, c) # [B, H, W, C]
215
+
216
+ return merge
217
+
218
+
219
+ def window_partition_1d(x, window_size_w):
220
+ """
221
+ Args:
222
+ x: (B, W, C)
223
+ window_size (int): window size
224
+
225
+ Returns:
226
+ windows: (num_windows*B, window_size, C)
227
+ """
228
+ B, W, C = x.shape
229
+ x = x.view(B, W // window_size_w, window_size_w, C).view(-1, window_size_w, C)
230
+ return x
231
+
232
+
233
+ def generate_shift_window_attn_mask_1d(
234
+ input_w, window_size_w, shift_size_w, device=torch.device("cuda")
235
+ ):
236
+ # calculate attention mask for SW-MSA
237
+ img_mask = torch.zeros((1, input_w, 1)).to(device) # 1 W 1
238
+ w_slices = (
239
+ slice(0, -window_size_w),
240
+ slice(-window_size_w, -shift_size_w),
241
+ slice(-shift_size_w, None),
242
+ )
243
+ cnt = 0
244
+ for w in w_slices:
245
+ img_mask[:, w, :] = cnt
246
+ cnt += 1
247
+
248
+ mask_windows = window_partition_1d(img_mask, window_size_w) # nW, window_size, 1
249
+ mask_windows = mask_windows.view(-1, window_size_w)
250
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(
251
+ 2
252
+ ) # nW, window_size, window_size
253
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
254
+ attn_mask == 0, float(0.0)
255
+ )
256
+
257
+ return attn_mask