Video Classification
vision
Andy1621 commited on
Commit
e8df2d2
1 Parent(s): 97b6ead

Upload uniformer.py

Browse files
Files changed (1) hide show
  1. uniformer.py +379 -0
uniformer.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ import torch
3
+ import torch.nn as nn
4
+ from functools import partial
5
+ from timm.models.layers import trunc_normal_, DropPath, to_2tuple
6
+
7
+
8
+ def conv_3xnxn(inp, oup, kernel_size=3, stride=3, groups=1):
9
+ return nn.Conv3d(inp, oup, (3, kernel_size, kernel_size), (2, stride, stride), (1, 0, 0), groups=groups)
10
+
11
+ def conv_1xnxn(inp, oup, kernel_size=3, stride=3, groups=1):
12
+ return nn.Conv3d(inp, oup, (1, kernel_size, kernel_size), (1, stride, stride), (0, 0, 0), groups=groups)
13
+
14
+ def conv_3xnxn_std(inp, oup, kernel_size=3, stride=3, groups=1):
15
+ return nn.Conv3d(inp, oup, (3, kernel_size, kernel_size), (1, stride, stride), (1, 0, 0), groups=groups)
16
+
17
+ def conv_1x1x1(inp, oup, groups=1):
18
+ return nn.Conv3d(inp, oup, (1, 1, 1), (1, 1, 1), (0, 0, 0), groups=groups)
19
+
20
+ def conv_3x3x3(inp, oup, groups=1):
21
+ return nn.Conv3d(inp, oup, (3, 3, 3), (1, 1, 1), (1, 1, 1), groups=groups)
22
+
23
+ def conv_5x5x5(inp, oup, groups=1):
24
+ return nn.Conv3d(inp, oup, (5, 5, 5), (1, 1, 1), (2, 2, 2), groups=groups)
25
+
26
+ def bn_3d(dim):
27
+ return nn.BatchNorm3d(dim)
28
+
29
+
30
+ class Mlp(nn.Module):
31
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
32
+ super().__init__()
33
+ out_features = out_features or in_features
34
+ hidden_features = hidden_features or in_features
35
+ self.fc1 = nn.Linear(in_features, hidden_features)
36
+ self.act = act_layer()
37
+ self.fc2 = nn.Linear(hidden_features, out_features)
38
+ self.drop = nn.Dropout(drop)
39
+
40
+ def forward(self, x):
41
+ x = self.fc1(x)
42
+ x = self.act(x)
43
+ x = self.drop(x)
44
+ x = self.fc2(x)
45
+ x = self.drop(x)
46
+ return x
47
+
48
+
49
+ class Attention(nn.Module):
50
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
51
+ super().__init__()
52
+ self.num_heads = num_heads
53
+ head_dim = dim // num_heads
54
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
55
+ self.scale = qk_scale or head_dim ** -0.5
56
+
57
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
58
+ self.attn_drop = nn.Dropout(attn_drop)
59
+ self.proj = nn.Linear(dim, dim)
60
+ self.proj_drop = nn.Dropout(proj_drop)
61
+
62
+ def forward(self, x):
63
+ B, N, C = x.shape
64
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
65
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
66
+
67
+ attn = (q @ k.transpose(-2, -1)) * self.scale
68
+ attn = attn.softmax(dim=-1)
69
+ attn = self.attn_drop(attn)
70
+
71
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
72
+ x = self.proj(x)
73
+ x = self.proj_drop(x)
74
+ return x
75
+
76
+
77
+ class CMlp(nn.Module):
78
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
79
+ super().__init__()
80
+ out_features = out_features or in_features
81
+ hidden_features = hidden_features or in_features
82
+ self.fc1 = conv_1x1x1(in_features, hidden_features)
83
+ self.act = act_layer()
84
+ self.fc2 = conv_1x1x1(hidden_features, out_features)
85
+ self.drop = nn.Dropout(drop)
86
+
87
+ def forward(self, x):
88
+ x = self.fc1(x)
89
+ x = self.act(x)
90
+ x = self.drop(x)
91
+ x = self.fc2(x)
92
+ x = self.drop(x)
93
+ return x
94
+
95
+
96
+ class CBlock(nn.Module):
97
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
98
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
99
+ super().__init__()
100
+ self.pos_embed = conv_3x3x3(dim, dim, groups=dim)
101
+ self.norm1 = bn_3d(dim)
102
+ self.conv1 = conv_1x1x1(dim, dim, 1)
103
+ self.conv2 = conv_1x1x1(dim, dim, 1)
104
+ self.attn = conv_5x5x5(dim, dim, groups=dim)
105
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
106
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
107
+ self.norm2 = bn_3d(dim)
108
+ mlp_hidden_dim = int(dim * mlp_ratio)
109
+ self.mlp = CMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
110
+
111
+ def forward(self, x):
112
+ x = x + self.pos_embed(x)
113
+ x = x + self.drop_path(self.conv2(self.attn(self.conv1(self.norm1(x)))))
114
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
115
+ return x
116
+
117
+
118
+ class SABlock(nn.Module):
119
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
120
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
121
+ super().__init__()
122
+ self.pos_embed = conv_3x3x3(dim, dim, groups=dim)
123
+ self.norm1 = norm_layer(dim)
124
+ self.attn = Attention(
125
+ dim,
126
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
127
+ attn_drop=attn_drop, proj_drop=drop)
128
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
129
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
130
+ self.norm2 = norm_layer(dim)
131
+ mlp_hidden_dim = int(dim * mlp_ratio)
132
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
133
+
134
+ def forward(self, x):
135
+ x = x + self.pos_embed(x)
136
+ B, C, T, H, W = x.shape
137
+ x = x.flatten(2).transpose(1, 2)
138
+ x = x + self.drop_path(self.attn(self.norm1(x)))
139
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
140
+ x = x.transpose(1, 2).reshape(B, C, T, H, W)
141
+ return x
142
+
143
+
144
+ class SplitSABlock(nn.Module):
145
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
146
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
147
+ super().__init__()
148
+ self.pos_embed = conv_3x3x3(dim, dim, groups=dim)
149
+ self.t_norm = norm_layer(dim)
150
+ self.t_attn = Attention(
151
+ dim,
152
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
153
+ attn_drop=attn_drop, proj_drop=drop)
154
+ self.norm1 = norm_layer(dim)
155
+ self.attn = Attention(
156
+ dim,
157
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
158
+ attn_drop=attn_drop, proj_drop=drop)
159
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
160
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
161
+ self.norm2 = norm_layer(dim)
162
+ mlp_hidden_dim = int(dim * mlp_ratio)
163
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
164
+
165
+ def forward(self, x):
166
+ x = x + self.pos_embed(x)
167
+ B, C, T, H, W = x.shape
168
+ attn = x.view(B, C, T, H * W).permute(0, 3, 2, 1).contiguous()
169
+ attn = attn.view(B * H * W, T, C)
170
+ attn = attn + self.drop_path(self.t_attn(self.t_norm(attn)))
171
+ attn = attn.view(B, H * W, T, C).permute(0, 2, 1, 3).contiguous()
172
+ attn = attn.view(B * T, H * W, C)
173
+ residual = x.view(B, C, T, H * W).permute(0, 2, 3, 1).contiguous()
174
+ residual = residual.view(B * T, H * W, C)
175
+ attn = residual + self.drop_path(self.attn(self.norm1(attn)))
176
+ attn = attn.view(B, T * H * W, C)
177
+ out = attn + self.drop_path(self.mlp(self.norm2(attn)))
178
+ out = out.transpose(1, 2).reshape(B, C, T, H, W)
179
+ return out
180
+
181
+
182
+ class SpeicalPatchEmbed(nn.Module):
183
+ """ Image to Patch Embedding
184
+ """
185
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
186
+ super().__init__()
187
+ img_size = to_2tuple(img_size)
188
+ patch_size = to_2tuple(patch_size)
189
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
190
+ self.img_size = img_size
191
+ self.patch_size = patch_size
192
+ self.num_patches = num_patches
193
+ self.norm = nn.LayerNorm(embed_dim)
194
+ self.proj = conv_3xnxn(in_chans, embed_dim, kernel_size=patch_size[0], stride=patch_size[0])
195
+
196
+ def forward(self, x):
197
+ B, C, T, H, W = x.shape
198
+ # FIXME look at relaxing size constraints
199
+ # assert H == self.img_size[0] and W == self.img_size[1], \
200
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
201
+ x = self.proj(x)
202
+ B, C, T, H, W = x.shape
203
+ x = x.flatten(2).transpose(1, 2)
204
+ x = self.norm(x)
205
+ x = x.reshape(B, T, H, W, -1).permute(0, 4, 1, 2, 3).contiguous()
206
+ return x
207
+
208
+
209
+ class PatchEmbed(nn.Module):
210
+ """ Image to Patch Embedding
211
+ """
212
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, std=False):
213
+ super().__init__()
214
+ img_size = to_2tuple(img_size)
215
+ patch_size = to_2tuple(patch_size)
216
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
217
+ self.img_size = img_size
218
+ self.patch_size = patch_size
219
+ self.num_patches = num_patches
220
+ self.norm = nn.LayerNorm(embed_dim)
221
+ if std:
222
+ self.proj = conv_3xnxn_std(in_chans, embed_dim, kernel_size=patch_size[0], stride=patch_size[0])
223
+ else:
224
+ self.proj = conv_1xnxn(in_chans, embed_dim, kernel_size=patch_size[0], stride=patch_size[0])
225
+
226
+ def forward(self, x):
227
+ B, C, T, H, W = x.shape
228
+ # FIXME look at relaxing size constraints
229
+ # assert H == self.img_size[0] and W == self.img_size[1], \
230
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
231
+ x = self.proj(x)
232
+ B, C, T, H, W = x.shape
233
+ x = x.flatten(2).transpose(1, 2)
234
+ x = self.norm(x)
235
+ x = x.reshape(B, T, H, W, -1).permute(0, 4, 1, 2, 3).contiguous()
236
+ return x
237
+
238
+
239
+ class Uniformer(nn.Module):
240
+ """ Vision Transformer
241
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
242
+ https://arxiv.org/abs/2010.11929
243
+ """
244
+ def __init__(self, depth=[5, 8, 20, 7], num_classes=400, img_size=224, in_chans=3, embed_dim=[64, 128, 320, 512],
245
+ head_dim=64, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
246
+ drop_rate=0.3, attn_drop_rate=0., drop_path_rate=0., norm_layer=None, split=False, std=False):
247
+ super().__init__()
248
+
249
+ self.num_classes = num_classes
250
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
251
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
252
+
253
+ self.patch_embed1 = SpeicalPatchEmbed(
254
+ img_size=img_size, patch_size=4, in_chans=in_chans, embed_dim=embed_dim[0])
255
+ self.patch_embed2 = PatchEmbed(
256
+ img_size=img_size // 4, patch_size=2, in_chans=embed_dim[0], embed_dim=embed_dim[1], std=std)
257
+ self.patch_embed3 = PatchEmbed(
258
+ img_size=img_size // 8, patch_size=2, in_chans=embed_dim[1], embed_dim=embed_dim[2], std=std)
259
+ self.patch_embed4 = PatchEmbed(
260
+ img_size=img_size // 16, patch_size=2, in_chans=embed_dim[2], embed_dim=embed_dim[3], std=std)
261
+
262
+ self.pos_drop = nn.Dropout(p=drop_rate)
263
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depth))] # stochastic depth decay rule
264
+ num_heads = [dim // head_dim for dim in embed_dim]
265
+ self.blocks1 = nn.ModuleList([
266
+ CBlock(
267
+ dim=embed_dim[0], num_heads=num_heads[0], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
268
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
269
+ for i in range(depth[0])])
270
+ self.blocks2 = nn.ModuleList([
271
+ CBlock(
272
+ dim=embed_dim[1], num_heads=num_heads[1], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
273
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+depth[0]], norm_layer=norm_layer)
274
+ for i in range(depth[1])])
275
+ if split:
276
+ self.blocks3 = nn.ModuleList([
277
+ SplitSABlock(
278
+ dim=embed_dim[2], num_heads=num_heads[2], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
279
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+depth[0]+depth[1]], norm_layer=norm_layer)
280
+ for i in range(depth[2])])
281
+ self.blocks4 = nn.ModuleList([
282
+ SplitSABlock(
283
+ dim=embed_dim[3], num_heads=num_heads[3], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
284
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+depth[0]+depth[1]+depth[2]], norm_layer=norm_layer)
285
+ for i in range(depth[3])])
286
+ else:
287
+ self.blocks3 = nn.ModuleList([
288
+ SABlock(
289
+ dim=embed_dim[2], num_heads=num_heads[2], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
290
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+depth[0]+depth[1]], norm_layer=norm_layer)
291
+ for i in range(depth[2])])
292
+ self.blocks4 = nn.ModuleList([
293
+ SABlock(
294
+ dim=embed_dim[3], num_heads=num_heads[3], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
295
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+depth[0]+depth[1]+depth[2]], norm_layer=norm_layer)
296
+ for i in range(depth[3])])
297
+ self.norm = bn_3d(embed_dim[-1])
298
+
299
+ # Representation layer
300
+ if representation_size:
301
+ self.num_features = representation_size
302
+ self.pre_logits = nn.Sequential(OrderedDict([
303
+ ('fc', nn.Linear(embed_dim, representation_size)),
304
+ ('act', nn.Tanh())
305
+ ]))
306
+ else:
307
+ self.pre_logits = nn.Identity()
308
+
309
+ # Classifier head
310
+ self.head = nn.Linear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
311
+
312
+ self.apply(self._init_weights)
313
+
314
+ for name, p in self.named_parameters():
315
+ # fill proj weight with 1 here to improve training dynamics. Otherwise temporal attention inputs
316
+ # are multiplied by 0*0, which is hard for the model to move out of.
317
+ if 't_attn.qkv.weight' in name:
318
+ nn.init.constant_(p, 0)
319
+ if 't_attn.qkv.bias' in name:
320
+ nn.init.constant_(p, 0)
321
+ if 't_attn.proj.weight' in name:
322
+ nn.init.constant_(p, 1)
323
+ if 't_attn.proj.bias' in name:
324
+ nn.init.constant_(p, 0)
325
+
326
+ def _init_weights(self, m):
327
+ if isinstance(m, nn.Linear):
328
+ trunc_normal_(m.weight, std=.02)
329
+ if isinstance(m, nn.Linear) and m.bias is not None:
330
+ nn.init.constant_(m.bias, 0)
331
+ elif isinstance(m, nn.LayerNorm):
332
+ nn.init.constant_(m.bias, 0)
333
+ nn.init.constant_(m.weight, 1.0)
334
+
335
+ @torch.jit.ignore
336
+ def no_weight_decay(self):
337
+ return {'pos_embed', 'cls_token'}
338
+
339
+ def get_classifier(self):
340
+ return self.head
341
+
342
+ def reset_classifier(self, num_classes, global_pool=''):
343
+ self.num_classes = num_classes
344
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
345
+
346
+ def forward_features(self, x):
347
+ x = self.patch_embed1(x)
348
+ x = self.pos_drop(x)
349
+ for blk in self.blocks1:
350
+ x = blk(x)
351
+ x = self.patch_embed2(x)
352
+ for blk in self.blocks2:
353
+ x = blk(x)
354
+ x = self.patch_embed3(x)
355
+ for blk in self.blocks3:
356
+ x = blk(x)
357
+ x = self.patch_embed4(x)
358
+ for blk in self.blocks4:
359
+ x = blk(x)
360
+ x = self.norm(x)
361
+ x = self.pre_logits(x)
362
+ return x
363
+
364
+ def forward(self, x):
365
+ x = self.forward_features(x)
366
+ x = x.flatten(2).mean(-1)
367
+ x = self.head(x)
368
+ return x
369
+
370
+
371
+ def uniformer_small():
372
+ return Uniformer(
373
+ depth=[3, 4, 8, 3], embed_dim=[64, 128, 320, 512],
374
+ head_dim=64, drop_rate=0.1)
375
+
376
+ def uniformer_base():
377
+ return Uniformer(
378
+ depth=[5, 8, 20, 7], embed_dim=[64, 128, 320, 512],
379
+ head_dim=64, drop_rate=0.3)