Spaces:
Sleeping
Sleeping
amazinghaha
commited on
Upload 106 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- model/CoordAttention.py +110 -0
- model/Vision_Transformer_with_mask.py +990 -0
- model/__pycache__/CoordAttention.cpython-38.pyc +0 -0
- model/__pycache__/Vision_Transformer_with_mask.cpython-38.pyc +0 -0
- model/__pycache__/features.cpython-38.pyc +0 -0
- model/__pycache__/helpers.cpython-38.pyc +0 -0
- model/__pycache__/hub.cpython-38.pyc +0 -0
- model/__pycache__/registry.cpython-38.pyc +0 -0
- model/features.py +284 -0
- model/helpers.py +508 -0
- model/hub.py +96 -0
- model/layers/__init__.py +40 -0
- model/layers/__pycache__/__init__.cpython-38.pyc +0 -0
- model/layers/__pycache__/activations.cpython-38.pyc +0 -0
- model/layers/__pycache__/activations_jit.cpython-38.pyc +0 -0
- model/layers/__pycache__/activations_me.cpython-38.pyc +0 -0
- model/layers/__pycache__/adaptive_avgmax_pool.cpython-38.pyc +0 -0
- model/layers/__pycache__/blur_pool.cpython-38.pyc +0 -0
- model/layers/__pycache__/bottleneck_attn.cpython-38.pyc +0 -0
- model/layers/__pycache__/cbam.cpython-38.pyc +0 -0
- model/layers/__pycache__/classifier.cpython-38.pyc +0 -0
- model/layers/__pycache__/cond_conv2d.cpython-38.pyc +0 -0
- model/layers/__pycache__/config.cpython-38.pyc +0 -0
- model/layers/__pycache__/conv2d_same.cpython-38.pyc +0 -0
- model/layers/__pycache__/conv_bn_act.cpython-38.pyc +0 -0
- model/layers/__pycache__/create_act.cpython-38.pyc +0 -0
- model/layers/__pycache__/create_attn.cpython-38.pyc +0 -0
- model/layers/__pycache__/create_conv2d.cpython-38.pyc +0 -0
- model/layers/__pycache__/create_norm_act.cpython-38.pyc +0 -0
- model/layers/__pycache__/drop.cpython-38.pyc +0 -0
- model/layers/__pycache__/eca.cpython-38.pyc +0 -0
- model/layers/__pycache__/evo_norm.cpython-38.pyc +0 -0
- model/layers/__pycache__/gather_excite.cpython-38.pyc +0 -0
- model/layers/__pycache__/global_context.cpython-38.pyc +0 -0
- model/layers/__pycache__/halo_attn.cpython-38.pyc +0 -0
- model/layers/__pycache__/helpers.cpython-38.pyc +0 -0
- model/layers/__pycache__/inplace_abn.cpython-38.pyc +0 -0
- model/layers/__pycache__/involution.cpython-38.pyc +0 -0
- model/layers/__pycache__/lambda_layer.cpython-38.pyc +0 -0
- model/layers/__pycache__/linear.cpython-38.pyc +0 -0
- model/layers/__pycache__/mixed_conv2d.cpython-38.pyc +0 -0
- model/layers/__pycache__/mlp.cpython-38.pyc +0 -0
- model/layers/__pycache__/non_local_attn.cpython-38.pyc +0 -0
- model/layers/__pycache__/norm.cpython-38.pyc +0 -0
- model/layers/__pycache__/norm_act.cpython-38.pyc +0 -0
- model/layers/__pycache__/padding.cpython-38.pyc +0 -0
- model/layers/__pycache__/patch_embed.cpython-38.pyc +0 -0
- model/layers/__pycache__/pool2d_same.cpython-38.pyc +0 -0
- model/layers/__pycache__/selective_kernel.cpython-38.pyc +0 -0
- model/layers/__pycache__/separable_conv.cpython-38.pyc +0 -0
model/CoordAttention.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class h_sigmoid(nn.Module):
|
7 |
+
def __init__(self, inplace=True):
|
8 |
+
super(h_sigmoid, self).__init__()
|
9 |
+
self.relu = nn.ReLU6(inplace=inplace)
|
10 |
+
|
11 |
+
def forward(self, x):
|
12 |
+
return self.relu(x + 3) / 6
|
13 |
+
|
14 |
+
|
15 |
+
class h_swish(nn.Module):
|
16 |
+
def __init__(self, inplace=True):
|
17 |
+
super(h_swish, self).__init__()
|
18 |
+
self.sigmoid = h_sigmoid(inplace=inplace)
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
return x * self.sigmoid(x)
|
22 |
+
|
23 |
+
|
24 |
+
class CoordAtt(nn.Module):
|
25 |
+
def __init__(self, inp, oup, reduction=32):
|
26 |
+
super(CoordAtt, self).__init__()
|
27 |
+
self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
|
28 |
+
self.pool_w = nn.AdaptiveAvgPool2d((1, None))
|
29 |
+
|
30 |
+
mip = max(8, inp // reduction)
|
31 |
+
|
32 |
+
self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
|
33 |
+
self.bn1 = nn.BatchNorm2d(mip)
|
34 |
+
|
35 |
+
self.bn2 = nn.BatchNorm2d(1)
|
36 |
+
self.bn3 = nn.BatchNorm2d(1)
|
37 |
+
self.act = h_swish()
|
38 |
+
|
39 |
+
self.bn4 = nn.BatchNorm2d(mip)
|
40 |
+
self.bn5 = nn.BatchNorm2d(mip)
|
41 |
+
|
42 |
+
self.bn6 = nn.BatchNorm2d(1)
|
43 |
+
self.bn7 = nn.BatchNorm2d(1)
|
44 |
+
|
45 |
+
self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
|
46 |
+
self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
x = torch.unsqueeze(x, 1) #2 1 2304 196
|
50 |
+
identity = x
|
51 |
+
|
52 |
+
n, c, h, w = x.size()#2 1 2304 196
|
53 |
+
x_h = self.bn2(self.pool_h(x))#2 1 2304 1
|
54 |
+
x_w = self.bn3(self.pool_w(x).permute(0, 1, 3, 2)) #2 1 196 1
|
55 |
+
identity_x_w = x_w
|
56 |
+
identity_x_h = x_h
|
57 |
+
y = torch.cat([x_h, x_w], dim=2)
|
58 |
+
y = self.conv1(y) #2 8 2500 1
|
59 |
+
y = self.bn1(y)
|
60 |
+
y = self.act(y)
|
61 |
+
|
62 |
+
x_h, x_w = torch.split(y, [h, w], dim=2) #2 8 2304 1 | 2 8 196 1
|
63 |
+
x_h = self.bn4(x_h)+identity_x_h
|
64 |
+
x_w = self.bn5(x_w)+identity_x_w
|
65 |
+
x_w = x_w.permute(0, 1, 3, 2)
|
66 |
+
|
67 |
+
a_h = self.bn6(self.conv_h(x_h)).sigmoid() #2 1 2304 1
|
68 |
+
a_w = self.bn7(self.conv_w(x_w)).sigmoid() #24 1 1 196
|
69 |
+
|
70 |
+
out = identity * a_w * a_h #点×
|
71 |
+
out = torch.squeeze(out, 1)
|
72 |
+
return out
|
73 |
+
|
74 |
+
class CoordAtt_ori(nn.Module):
|
75 |
+
def __init__(self, inp, oup, reduction=32):
|
76 |
+
super(CoordAtt_ori, self).__init__()
|
77 |
+
self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
|
78 |
+
self.pool_w = nn.AdaptiveAvgPool2d((1, None))
|
79 |
+
|
80 |
+
mip = max(8, inp // reduction)
|
81 |
+
|
82 |
+
self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
|
83 |
+
self.bn1 = nn.BatchNorm2d(mip)
|
84 |
+
self.act = h_swish()
|
85 |
+
|
86 |
+
self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
|
87 |
+
self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
x = torch.unsqueeze(x, 1)
|
91 |
+
identity = x
|
92 |
+
|
93 |
+
n, c, h, w = x.size()
|
94 |
+
x_h = self.pool_h(x)
|
95 |
+
x_w = self.pool_w(x).permute(0, 1, 3, 2)
|
96 |
+
|
97 |
+
y = torch.cat([x_h, x_w], dim=2)
|
98 |
+
y = self.conv1(y)
|
99 |
+
y = self.bn1(y)
|
100 |
+
y = self.act(y)
|
101 |
+
|
102 |
+
x_h, x_w = torch.split(y, [h, w], dim=2)
|
103 |
+
x_w = x_w.permute(0, 1, 3, 2)
|
104 |
+
|
105 |
+
a_h = self.conv_h(x_h).sigmoid()
|
106 |
+
a_w = self.conv_w(x_w).sigmoid()
|
107 |
+
|
108 |
+
out = identity * a_w * a_h
|
109 |
+
out = torch.squeeze(out, 1)
|
110 |
+
return out
|
model/Vision_Transformer_with_mask.py
ADDED
@@ -0,0 +1,990 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Vision Transformer (ViT) in PyTorch
|
2 |
+
|
3 |
+
A PyTorch implement of Vision Transformers as described in:
|
4 |
+
|
5 |
+
'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale'
|
6 |
+
- https://arxiv.org/abs/2010.11929
|
7 |
+
|
8 |
+
`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
|
9 |
+
- https://arxiv.org/abs/2106.10270
|
10 |
+
|
11 |
+
The official jax code is released and available at https://github.com/google-research/vision_transformer
|
12 |
+
|
13 |
+
DeiT model defs and weights from https://github.com/facebookresearch/deit,
|
14 |
+
paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
|
15 |
+
|
16 |
+
Acknowledgments:
|
17 |
+
* The paper authors for releasing code and weights, thanks!
|
18 |
+
* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
|
19 |
+
for some einops/einsum fun
|
20 |
+
* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
|
21 |
+
* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
|
22 |
+
|
23 |
+
Hacked together by / Copyright 2021 Ross Wightman
|
24 |
+
"""
|
25 |
+
import math
|
26 |
+
import logging
|
27 |
+
from functools import partial
|
28 |
+
from collections import OrderedDict
|
29 |
+
from copy import deepcopy
|
30 |
+
|
31 |
+
import torch
|
32 |
+
import torch.nn as nn
|
33 |
+
import torch.nn.functional as F
|
34 |
+
from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
|
35 |
+
|
36 |
+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
37 |
+
from .helpers import build_model_with_cfg, named_apply, adapt_input_conv
|
38 |
+
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_
|
39 |
+
from .registry import register_model
|
40 |
+
|
41 |
+
_logger = logging.getLogger(__name__)
|
42 |
+
|
43 |
+
|
44 |
+
def _cfg(url='', **kwargs):
|
45 |
+
return {
|
46 |
+
'url': url,
|
47 |
+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
48 |
+
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
49 |
+
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
50 |
+
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
51 |
+
**kwargs
|
52 |
+
}
|
53 |
+
|
54 |
+
|
55 |
+
default_cfgs = {
|
56 |
+
# patch models (weights from official Google JAX impl)
|
57 |
+
'vit_tiny_patch16_224': _cfg(
|
58 |
+
url='https://storage.googleapis.com/vit_models/augreg/'
|
59 |
+
'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
|
60 |
+
'vit_tiny_patch16_384': _cfg(
|
61 |
+
url='https://storage.googleapis.com/vit_models/augreg/'
|
62 |
+
'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
|
63 |
+
input_size=(3, 384, 384), crop_pct=1.0),
|
64 |
+
'vit_small_patch32_224': _cfg(
|
65 |
+
url='https://storage.googleapis.com/vit_models/augreg/'
|
66 |
+
'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
|
67 |
+
'vit_small_patch32_384': _cfg(
|
68 |
+
url='https://storage.googleapis.com/vit_models/augreg/'
|
69 |
+
'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
|
70 |
+
input_size=(3, 384, 384), crop_pct=1.0),
|
71 |
+
'vit_small_patch16_224': _cfg(
|
72 |
+
url='https://storage.googleapis.com/vit_models/augreg/'
|
73 |
+
'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
|
74 |
+
'vit_small_patch16_384': _cfg(
|
75 |
+
url='https://storage.googleapis.com/vit_models/augreg/'
|
76 |
+
'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
|
77 |
+
input_size=(3, 384, 384), crop_pct=1.0),
|
78 |
+
'vit_base_patch32_224': _cfg(
|
79 |
+
url='https://storage.googleapis.com/vit_models/augreg/'
|
80 |
+
'B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
|
81 |
+
'vit_base_patch32_384': _cfg(
|
82 |
+
url='https://storage.googleapis.com/vit_models/augreg/'
|
83 |
+
'B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
|
84 |
+
input_size=(3, 384, 384), crop_pct=1.0),
|
85 |
+
'vit_base_patch16_224': _cfg(
|
86 |
+
url='https://storage.googleapis.com/vit_models/augreg/'
|
87 |
+
'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
|
88 |
+
'vit_base_patch16_384': _cfg(
|
89 |
+
url='https://storage.googleapis.com/vit_models/augreg/'
|
90 |
+
'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
|
91 |
+
input_size=(3, 384, 384), crop_pct=1.0),
|
92 |
+
'vit_large_patch32_224': _cfg(
|
93 |
+
url='', # no official model weights for this combo, only for in21k
|
94 |
+
),
|
95 |
+
'vit_large_patch32_384': _cfg(
|
96 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
|
97 |
+
input_size=(3, 384, 384), crop_pct=1.0),
|
98 |
+
'vit_large_patch16_224': _cfg(
|
99 |
+
url='https://storage.googleapis.com/vit_models/augreg/'
|
100 |
+
'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
|
101 |
+
'vit_large_patch16_384': _cfg(
|
102 |
+
url='https://storage.googleapis.com/vit_models/augreg/'
|
103 |
+
'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
|
104 |
+
input_size=(3, 384, 384), crop_pct=1.0),
|
105 |
+
|
106 |
+
# patch models, imagenet21k (weights from official Google JAX impl)
|
107 |
+
'vit_tiny_patch16_224_in21k': _cfg(
|
108 |
+
url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
|
109 |
+
num_classes=21843),
|
110 |
+
'vit_small_patch32_224_in21k': _cfg(
|
111 |
+
url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
|
112 |
+
num_classes=21843),
|
113 |
+
'vit_small_patch16_224_in21k': _cfg(
|
114 |
+
url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
|
115 |
+
num_classes=21843),
|
116 |
+
'vit_base_patch32_224_in21k': _cfg(
|
117 |
+
url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz',
|
118 |
+
num_classes=21843),
|
119 |
+
'vit_base_patch16_224_in21k': _cfg(
|
120 |
+
url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
|
121 |
+
num_classes=21843),
|
122 |
+
'vit_large_patch32_224_in21k': _cfg(
|
123 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
|
124 |
+
num_classes=21843),
|
125 |
+
'vit_large_patch16_224_in21k': _cfg(
|
126 |
+
url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz',
|
127 |
+
num_classes=21843),
|
128 |
+
'vit_huge_patch14_224_in21k': _cfg(
|
129 |
+
url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz',
|
130 |
+
hf_hub='timm/vit_huge_patch14_224_in21k',
|
131 |
+
num_classes=21843),
|
132 |
+
|
133 |
+
# deit models (FB weights)
|
134 |
+
'deit_tiny_patch16_224': _cfg(
|
135 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth',
|
136 |
+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
137 |
+
'deit_small_patch16_224': _cfg(
|
138 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth',
|
139 |
+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
140 |
+
'deit_base_patch16_224': _cfg(
|
141 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',
|
142 |
+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
143 |
+
'deit_base_patch16_384': _cfg(
|
144 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
|
145 |
+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0),
|
146 |
+
'deit_tiny_distilled_patch16_224': _cfg(
|
147 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth',
|
148 |
+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
|
149 |
+
'deit_small_distilled_patch16_224': _cfg(
|
150 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
|
151 |
+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
|
152 |
+
'deit_base_distilled_patch16_224': _cfg(
|
153 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth',
|
154 |
+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
|
155 |
+
'deit_base_distilled_patch16_384': _cfg(
|
156 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
|
157 |
+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0,
|
158 |
+
classifier=('head', 'head_dist')),
|
159 |
+
|
160 |
+
# ViT ImageNet-21K-P pretraining by MILL
|
161 |
+
'vit_base_patch16_224_miil_in21k': _cfg(
|
162 |
+
url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_in21k_miil.pth',
|
163 |
+
mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221,
|
164 |
+
),
|
165 |
+
'vit_base_patch16_224_miil': _cfg(
|
166 |
+
url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm'
|
167 |
+
'/vit_base_patch16_224_1k_miil_84_4.pth',
|
168 |
+
mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear',
|
169 |
+
),
|
170 |
+
}
|
171 |
+
|
172 |
+
|
173 |
+
class CrossAttention(nn.Module):
|
174 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
175 |
+
super().__init__()
|
176 |
+
self.num_heads = num_heads
|
177 |
+
head_dim = dim // num_heads
|
178 |
+
self.scale = qk_scale or head_dim ** -0.5 #这行多了个qk_scale #0.125
|
179 |
+
|
180 |
+
self.wq = nn.Linear(dim, dim, bias=qkv_bias)
|
181 |
+
self.wk = nn.Linear(dim, dim, bias=qkv_bias)
|
182 |
+
self.wv = nn.Linear(dim, dim, bias=qkv_bias)
|
183 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
184 |
+
self.proj = nn.Linear(dim, dim)
|
185 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
186 |
+
|
187 |
+
def forward(self, x):
|
188 |
+
|
189 |
+
B, N, C = x.shape #2 512 768
|
190 |
+
q = self.wq(x[:, 0:int(N/2), ...]).reshape(B, int(N/2), self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)#2 12 256 64
|
191 |
+
k = self.wk(x[:, (int(N/2)):, ...]).reshape(B, int(N/2), self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
192 |
+
v = self.wv(x[:, (int(N/2)):, ...]).reshape(B, int(N/2), self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
193 |
+
|
194 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
195 |
+
attn = attn.softmax(dim=-1)
|
196 |
+
attn = self.attn_drop(attn)
|
197 |
+
|
198 |
+
x = (attn @ v).transpose(1, 2).reshape(B, int(N/2), C) #变成了B/2 2 256 768
|
199 |
+
x = self.proj(x)
|
200 |
+
x = self.proj_drop(x)
|
201 |
+
return x
|
202 |
+
|
203 |
+
|
204 |
+
|
205 |
+
class Attention(nn.Module):
|
206 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None,attn_drop=0., proj_drop=0.):
|
207 |
+
super().__init__()
|
208 |
+
self.num_heads = num_heads
|
209 |
+
head_dim = dim // num_heads
|
210 |
+
self.scale = qk_scale or head_dim ** -0.5
|
211 |
+
|
212 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
213 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
214 |
+
self.proj = nn.Linear(dim, dim)
|
215 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
216 |
+
|
217 |
+
def forward(self, data):
|
218 |
+
b,c,h = data.shape
|
219 |
+
x,atten_mask = data[:,0:int(c/2),...],data[:,int(c/2):,...]
|
220 |
+
B, N, C = x.shape
|
221 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
222 |
+
q, k, v = qkv[0], qkv[1], qkv[2] #2,12,49,64 # make torchscript happy (cannot use tensor as tuple)
|
223 |
+
|
224 |
+
|
225 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale #2,12,49,49 #mask 2,1,49,49
|
226 |
+
if atten_mask.sum() != 0:
|
227 |
+
atten_mask = atten_mask.unsqueeze(1) # 2,1,49,49
|
228 |
+
attn = attn + atten_mask
|
229 |
+
attn = attn.softmax(dim=-1)
|
230 |
+
attn = self.attn_drop(attn)
|
231 |
+
|
232 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
233 |
+
x = self.proj(x)
|
234 |
+
x = self.proj_drop(x)
|
235 |
+
return x
|
236 |
+
|
237 |
+
class Attention_ori(nn.Module):
|
238 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None,attn_drop=0., proj_drop=0.):
|
239 |
+
super().__init__()
|
240 |
+
self.num_heads = num_heads
|
241 |
+
head_dim = dim // num_heads
|
242 |
+
self.scale = qk_scale or head_dim ** -0.5
|
243 |
+
|
244 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
245 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
246 |
+
self.proj = nn.Linear(dim, dim)
|
247 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
248 |
+
|
249 |
+
def forward(self, x):
|
250 |
+
|
251 |
+
B, N, C = x.shape
|
252 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
253 |
+
q, k, v = qkv[0], qkv[1], qkv[2] #2,12,49,64 # make torchscript happy (cannot use tensor as tuple)
|
254 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale #2,12,49,49 #mask 2,1,49,49
|
255 |
+
|
256 |
+
attn = attn.softmax(dim=-1)
|
257 |
+
attn = self.attn_drop(attn)
|
258 |
+
|
259 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
260 |
+
x = self.proj(x)
|
261 |
+
x = self.proj_drop(x)
|
262 |
+
return x
|
263 |
+
|
264 |
+
|
265 |
+
|
266 |
+
|
267 |
+
class Block(nn.Module):
|
268 |
+
|
269 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
|
270 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
271 |
+
super().__init__()
|
272 |
+
self.norm1 = norm_layer(dim)
|
273 |
+
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
274 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
275 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
276 |
+
self.norm2 = norm_layer(dim)
|
277 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
278 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
279 |
+
|
280 |
+
def forward(self, data):
|
281 |
+
b,c,h = data.shape
|
282 |
+
x,mask = data[:,0:int(c/2),...],data[:,int(c/2):,...]
|
283 |
+
x = x + self.drop_path(self.attn(torch.cat([self.norm1(x),mask],dim=1)))
|
284 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
285 |
+
return torch.cat([x,mask],dim=1)
|
286 |
+
|
287 |
+
class mask_PatchEmbed(nn.Module):
|
288 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, norm_layer=None, flatten=True):
|
289 |
+
super().__init__()
|
290 |
+
img_size = to_2tuple(img_size)
|
291 |
+
patch_size = to_2tuple(patch_size)
|
292 |
+
self.img_size = img_size
|
293 |
+
self.patch_size = patch_size
|
294 |
+
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
295 |
+
self.flatten = flatten
|
296 |
+
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
297 |
+
self.proj = nn.Conv2d(in_chans, 1, kernel_size=patch_size, stride=patch_size).requires_grad_(False)
|
298 |
+
nn.init.ones_(self.proj.weight)
|
299 |
+
nn.init.zeros_(self.proj.bias)
|
300 |
+
def forward(self, x):
|
301 |
+
B, C, H, W = x.shape
|
302 |
+
assert H == self.img_size[0] and W == self.img_size[1], \
|
303 |
+
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
304 |
+
x = self.proj(x)
|
305 |
+
if self.flatten:
|
306 |
+
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
307 |
+
return x
|
308 |
+
|
309 |
+
class VisionTransformer(nn.Module):
|
310 |
+
""" Vision Transformer
|
311 |
+
|
312 |
+
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
|
313 |
+
- https://arxiv.org/abs/2010.11929
|
314 |
+
|
315 |
+
Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
|
316 |
+
- https://arxiv.org/abs/2012.12877
|
317 |
+
"""
|
318 |
+
|
319 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
320 |
+
num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False,
|
321 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None,
|
322 |
+
act_layer=None,as_backbone=True, weight_init=''):
|
323 |
+
"""
|
324 |
+
Args:
|
325 |
+
img_size (int, tuple): input image size
|
326 |
+
patch_size (int, tuple): patch size
|
327 |
+
in_chans (int): number of input channels
|
328 |
+
num_classes (int): number of classes for classification head
|
329 |
+
embed_dim (int): embedding dimension
|
330 |
+
depth (int): depth of transformer
|
331 |
+
num_heads (int): number of attention heads
|
332 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
333 |
+
qkv_bias (bool): enable bias for qkv if True
|
334 |
+
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
|
335 |
+
distilled (bool): model includes a distillation token and head as in DeiT models
|
336 |
+
drop_rate (float): dropout rate
|
337 |
+
attn_drop_rate (float): attention dropout rate
|
338 |
+
drop_path_rate (float): stochastic depth rate
|
339 |
+
embed_layer (nn.Module): patch embedding layer
|
340 |
+
norm_layer: (nn.Module): normalization layer
|
341 |
+
weight_init: (str): weight init scheme
|
342 |
+
"""
|
343 |
+
super().__init__()
|
344 |
+
self.num_classes = num_classes
|
345 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
346 |
+
self.num_tokens = 2 if distilled else 1
|
347 |
+
self.num_heads = num_heads
|
348 |
+
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
349 |
+
act_layer = act_layer or nn.GELU
|
350 |
+
self.as_backbone = as_backbone #是否分类任务,如果不是,class不加上去
|
351 |
+
self.patch_embed = embed_layer(
|
352 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
353 |
+
self.mask_embed = mask_PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans)
|
354 |
+
num_patches = self.patch_embed.num_patches
|
355 |
+
|
356 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
357 |
+
self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
|
358 |
+
if not self.as_backbone:
|
359 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
360 |
+
else:
|
361 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
362 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
363 |
+
|
364 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
365 |
+
self.blocks = nn.Sequential(*[
|
366 |
+
Block(
|
367 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
|
368 |
+
attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
|
369 |
+
for i in range(depth)])
|
370 |
+
self.norm = norm_layer(embed_dim)
|
371 |
+
|
372 |
+
# Representation layer
|
373 |
+
if representation_size and not distilled:
|
374 |
+
self.num_features = representation_size
|
375 |
+
self.pre_logits = nn.Sequential(OrderedDict([
|
376 |
+
('fc', nn.Linear(embed_dim, representation_size)),
|
377 |
+
('act', nn.Tanh())
|
378 |
+
]))
|
379 |
+
else:
|
380 |
+
self.pre_logits = nn.Identity()
|
381 |
+
if not self.as_backbone:
|
382 |
+
self.avgpool = nn.AdaptiveAvgPool1d(1)
|
383 |
+
# Classifier head(s)
|
384 |
+
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
385 |
+
self.head_dist = None
|
386 |
+
if distilled:
|
387 |
+
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
388 |
+
|
389 |
+
self.init_weights(weight_init)
|
390 |
+
|
391 |
+
def init_weights(self, mode=''):
|
392 |
+
assert mode in ('jax', 'jax_nlhb', 'nlhb', '')
|
393 |
+
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
|
394 |
+
trunc_normal_(self.pos_embed, std=.02)
|
395 |
+
if self.dist_token is not None:
|
396 |
+
trunc_normal_(self.dist_token, std=.02)
|
397 |
+
if mode.startswith('jax'):
|
398 |
+
# leave cls token as zeros to match jax impl
|
399 |
+
named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl=True), self)
|
400 |
+
else:
|
401 |
+
trunc_normal_(self.cls_token, std=.02)
|
402 |
+
self.apply(_init_vit_weights)
|
403 |
+
|
404 |
+
def _init_weights(self, m):
|
405 |
+
# this fn left here for compat with downstream users
|
406 |
+
_init_vit_weights(m)
|
407 |
+
|
408 |
+
@torch.jit.ignore()
|
409 |
+
def load_pretrained(self, checkpoint_path, prefix=''):
|
410 |
+
_load_weights(self, checkpoint_path, prefix)
|
411 |
+
|
412 |
+
@torch.jit.ignore
|
413 |
+
def no_weight_decay(self):
|
414 |
+
return {'pos_embed', 'cls_token', 'dist_token'}
|
415 |
+
|
416 |
+
def get_classifier(self):
|
417 |
+
if self.dist_token is None:
|
418 |
+
return self.head
|
419 |
+
else:
|
420 |
+
return self.head, self.head_dist
|
421 |
+
|
422 |
+
def reset_classifier(self, num_classes, global_pool=''):
|
423 |
+
self.num_classes = num_classes
|
424 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
425 |
+
if self.num_tokens == 2:
|
426 |
+
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
427 |
+
|
428 |
+
def forward_features(self, data):
|
429 |
+
x,mask = data[:,0,:,:].unsqueeze(1),data[:,1,:,:].unsqueeze(1)
|
430 |
+
x = self.patch_embed(x)#B N C
|
431 |
+
atten_mask = torch.zeros_like(x) # 2 49 768
|
432 |
+
if mask.sum() != 0:
|
433 |
+
mask = self.mask_embed(mask) ###
|
434 |
+
mask.squeeze_(dim=2)
|
435 |
+
mask[mask != 0] = 1 ### H W数目token C编码长度
|
436 |
+
k1 = mask[:, None, :]
|
437 |
+
k2 = torch.ones_like(mask)[:, :, None]
|
438 |
+
k3 = k1 * k2
|
439 |
+
atten_mask = (1.0 - k3) * (-1e6)
|
440 |
+
atten_mask.requires_grad_(True)
|
441 |
+
self.atten_mask = atten_mask
|
442 |
+
cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
443 |
+
if not self.as_backbone:
|
444 |
+
if self.dist_token is None:
|
445 |
+
x = torch.cat((cls_token, x), dim=1)
|
446 |
+
else:
|
447 |
+
x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
|
448 |
+
x = self.pos_drop(x + self.pos_embed) #2 49 768
|
449 |
+
x = self.blocks(torch.cat([x,atten_mask],dim=1))
|
450 |
+
b,c,h = x.shape
|
451 |
+
x = x[:,0:int(c/2),...]
|
452 |
+
x = self.norm(x)
|
453 |
+
if self.as_backbone:
|
454 |
+
# x = self.avgpool(x.transpose(1, 2)) # B C 1
|
455 |
+
# x = torch.flatten(x, 1)
|
456 |
+
return x
|
457 |
+
if self.dist_token is None:
|
458 |
+
return self.pre_logits(x[:, 0])
|
459 |
+
else:
|
460 |
+
return x[:, 0], x[:, 1]
|
461 |
+
|
462 |
+
def forward(self, data):
|
463 |
+
x = self.forward_features(data) #2 49 768
|
464 |
+
if self.as_backbone:
|
465 |
+
return x
|
466 |
+
else:
|
467 |
+
if self.head_dist is not None:
|
468 |
+
x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple
|
469 |
+
if self.training and not torch.jit.is_scripting():
|
470 |
+
# during inference, return the average of both classifier predictions
|
471 |
+
return x, x_dist
|
472 |
+
else:
|
473 |
+
return (x + x_dist) / 2
|
474 |
+
else:
|
475 |
+
x = self.head(x)
|
476 |
+
return x
|
477 |
+
|
478 |
+
|
479 |
+
def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False):
|
480 |
+
""" ViT weight initialization
|
481 |
+
* When called without n, head_bias, jax_impl args it will behave exactly the same
|
482 |
+
as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).
|
483 |
+
* When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl
|
484 |
+
"""
|
485 |
+
if isinstance(module, nn.Linear):
|
486 |
+
if name.startswith('head'):
|
487 |
+
nn.init.zeros_(module.weight)
|
488 |
+
nn.init.constant_(module.bias, head_bias)
|
489 |
+
elif name.startswith('pre_logits'):
|
490 |
+
lecun_normal_(module.weight)
|
491 |
+
nn.init.zeros_(module.bias)
|
492 |
+
else:
|
493 |
+
if jax_impl:
|
494 |
+
nn.init.xavier_uniform_(module.weight)
|
495 |
+
if module.bias is not None:
|
496 |
+
if 'mlp' in name:
|
497 |
+
nn.init.normal_(module.bias, std=1e-6)
|
498 |
+
else:
|
499 |
+
nn.init.zeros_(module.bias)
|
500 |
+
else:
|
501 |
+
trunc_normal_(module.weight, std=.02)
|
502 |
+
if module.bias is not None:
|
503 |
+
nn.init.zeros_(module.bias)
|
504 |
+
elif jax_impl and isinstance(module, nn.Conv2d):
|
505 |
+
# NOTE conv was left to pytorch default in my original init
|
506 |
+
lecun_normal_(module.weight)
|
507 |
+
if module.bias is not None:
|
508 |
+
nn.init.zeros_(module.bias)
|
509 |
+
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
|
510 |
+
nn.init.zeros_(module.bias)
|
511 |
+
nn.init.ones_(module.weight)
|
512 |
+
|
513 |
+
|
514 |
+
@torch.no_grad()
|
515 |
+
def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
|
516 |
+
""" Load weights from .npz checkpoints for official Google Brain Flax implementation
|
517 |
+
"""
|
518 |
+
import numpy as np
|
519 |
+
|
520 |
+
def _n2p(w, t=True):
|
521 |
+
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
|
522 |
+
w = w.flatten()
|
523 |
+
if t:
|
524 |
+
if w.ndim == 4:
|
525 |
+
w = w.transpose([3, 2, 0, 1])
|
526 |
+
elif w.ndim == 3:
|
527 |
+
w = w.transpose([2, 0, 1])
|
528 |
+
elif w.ndim == 2:
|
529 |
+
w = w.transpose([1, 0])
|
530 |
+
return torch.from_numpy(w)
|
531 |
+
|
532 |
+
w = np.load(checkpoint_path)
|
533 |
+
if not prefix and 'opt/target/embedding/kernel' in w:
|
534 |
+
prefix = 'opt/target/'
|
535 |
+
|
536 |
+
if hasattr(model.patch_embed, 'backbone'):
|
537 |
+
# hybrid
|
538 |
+
backbone = model.patch_embed.backbone
|
539 |
+
stem_only = not hasattr(backbone, 'stem')
|
540 |
+
stem = backbone if stem_only else backbone.stem
|
541 |
+
stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
|
542 |
+
stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
|
543 |
+
stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
|
544 |
+
if not stem_only:
|
545 |
+
for i, stage in enumerate(backbone.stages):
|
546 |
+
for j, block in enumerate(stage.blocks):
|
547 |
+
bp = f'{prefix}block{i + 1}/unit{j + 1}/'
|
548 |
+
for r in range(3):
|
549 |
+
getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
|
550 |
+
getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
|
551 |
+
getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
|
552 |
+
if block.downsample is not None:
|
553 |
+
block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
|
554 |
+
block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
|
555 |
+
block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
|
556 |
+
embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
|
557 |
+
else:
|
558 |
+
embed_conv_w = adapt_input_conv(
|
559 |
+
model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
|
560 |
+
model.patch_embed.proj.weight.copy_(embed_conv_w)
|
561 |
+
model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
|
562 |
+
model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
|
563 |
+
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
|
564 |
+
if pos_embed_w.shape != model.pos_embed.shape:
|
565 |
+
pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
|
566 |
+
pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
|
567 |
+
model.pos_embed.copy_(pos_embed_w)
|
568 |
+
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
|
569 |
+
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
|
570 |
+
if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
|
571 |
+
model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
|
572 |
+
model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
|
573 |
+
if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
|
574 |
+
model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
|
575 |
+
model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
|
576 |
+
for i, block in enumerate(model.blocks.children()):
|
577 |
+
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
|
578 |
+
mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
|
579 |
+
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
|
580 |
+
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
|
581 |
+
block.attn.qkv.weight.copy_(torch.cat([
|
582 |
+
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
|
583 |
+
block.attn.qkv.bias.copy_(torch.cat([
|
584 |
+
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
|
585 |
+
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
|
586 |
+
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
|
587 |
+
for r in range(2):
|
588 |
+
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
|
589 |
+
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
|
590 |
+
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
|
591 |
+
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
|
592 |
+
|
593 |
+
|
594 |
+
def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
|
595 |
+
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
|
596 |
+
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
|
597 |
+
_logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
|
598 |
+
ntok_new = posemb_new.shape[1]
|
599 |
+
if num_tokens:
|
600 |
+
posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
|
601 |
+
ntok_new -= num_tokens
|
602 |
+
else:
|
603 |
+
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
|
604 |
+
gs_old = int(math.sqrt(len(posemb_grid)))
|
605 |
+
if not len(gs_new): # backwards compatibility
|
606 |
+
gs_new = [int(math.sqrt(ntok_new))] * 2
|
607 |
+
assert len(gs_new) >= 2
|
608 |
+
_logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new)
|
609 |
+
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
610 |
+
posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bilinear')
|
611 |
+
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
|
612 |
+
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
613 |
+
return posemb
|
614 |
+
|
615 |
+
|
616 |
+
def checkpoint_filter_fn(state_dict, model):
|
617 |
+
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
618 |
+
out_dict = {}
|
619 |
+
if 'model' in state_dict:
|
620 |
+
# For deit models
|
621 |
+
state_dict = state_dict['model']
|
622 |
+
for k, v in state_dict.items():
|
623 |
+
if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
|
624 |
+
# For old models that I trained prior to conv based patchification
|
625 |
+
O, I, H, W = model.patch_embed.proj.weight.shape
|
626 |
+
v = v.reshape(O, -1, H, W)
|
627 |
+
elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
|
628 |
+
# To resize pos embedding when using model at different size from pretrained weights
|
629 |
+
v = resize_pos_embed(
|
630 |
+
v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
|
631 |
+
out_dict[k] = v
|
632 |
+
return out_dict
|
633 |
+
|
634 |
+
|
635 |
+
def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs):
|
636 |
+
default_cfg = default_cfg or default_cfgs[variant]
|
637 |
+
if kwargs.get('features_only', None):
|
638 |
+
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
639 |
+
|
640 |
+
# NOTE this extra code to support handling of repr size for in21k pretrained models
|
641 |
+
default_num_classes = default_cfg['num_classes']
|
642 |
+
num_classes = kwargs.get('num_classes', default_num_classes)
|
643 |
+
repr_size = kwargs.pop('representation_size', None)
|
644 |
+
if repr_size is not None and num_classes != default_num_classes:
|
645 |
+
# Remove representation layer if fine-tuning. This may not always be the desired action,
|
646 |
+
# but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface?
|
647 |
+
_logger.warning("Removing representation layer for fine-tuning.")
|
648 |
+
repr_size = None
|
649 |
+
|
650 |
+
model = build_model_with_cfg(
|
651 |
+
VisionTransformer, variant, pretrained,
|
652 |
+
default_cfg=default_cfg,
|
653 |
+
representation_size=repr_size,
|
654 |
+
pretrained_filter_fn=checkpoint_filter_fn,
|
655 |
+
pretrained_custom_load='npz' in default_cfg['url'],
|
656 |
+
**kwargs)
|
657 |
+
return model
|
658 |
+
|
659 |
+
|
660 |
+
@register_model
|
661 |
+
def vit_tiny_patch16_224(pretrained=False, **kwargs):
|
662 |
+
""" ViT-Tiny (Vit-Ti/16)
|
663 |
+
"""
|
664 |
+
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
|
665 |
+
model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
|
666 |
+
return model
|
667 |
+
|
668 |
+
|
669 |
+
@register_model
|
670 |
+
def vit_tiny_patch16_384(pretrained=False, **kwargs):
|
671 |
+
""" ViT-Tiny (Vit-Ti/16) @ 384x384.
|
672 |
+
"""
|
673 |
+
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
|
674 |
+
model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **model_kwargs)
|
675 |
+
return model
|
676 |
+
|
677 |
+
|
678 |
+
@register_model
|
679 |
+
def vit_small_patch32_224(pretrained=False, **kwargs):
|
680 |
+
""" ViT-Small (ViT-S/32)
|
681 |
+
"""
|
682 |
+
model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
683 |
+
model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **model_kwargs)
|
684 |
+
return model
|
685 |
+
|
686 |
+
|
687 |
+
@register_model
|
688 |
+
def vit_small_patch32_384(pretrained=False, **kwargs):
|
689 |
+
""" ViT-Small (ViT-S/32) at 384x384.
|
690 |
+
"""
|
691 |
+
model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
692 |
+
model = _create_vision_transformer('vit_small_patch32_384', pretrained=pretrained, **model_kwargs)
|
693 |
+
return model
|
694 |
+
|
695 |
+
|
696 |
+
@register_model
|
697 |
+
def vit_small_patch16_224(pretrained=False, **kwargs):
|
698 |
+
""" ViT-Small (ViT-S/16)
|
699 |
+
NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper
|
700 |
+
"""
|
701 |
+
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
702 |
+
model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs)
|
703 |
+
return model
|
704 |
+
|
705 |
+
|
706 |
+
@register_model
|
707 |
+
def vit_small_patch16_384(pretrained=False, **kwargs):
|
708 |
+
""" ViT-Small (ViT-S/16)
|
709 |
+
NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper
|
710 |
+
"""
|
711 |
+
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
712 |
+
model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs)
|
713 |
+
return model
|
714 |
+
|
715 |
+
|
716 |
+
@register_model
|
717 |
+
def vit_base_patch32_224(pretrained=False, **kwargs):
|
718 |
+
""" ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
|
719 |
+
"""
|
720 |
+
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
721 |
+
model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs)
|
722 |
+
return model
|
723 |
+
|
724 |
+
|
725 |
+
@register_model
|
726 |
+
def vit_base_patch32_384(pretrained=False, **kwargs):
|
727 |
+
""" ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
|
728 |
+
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
729 |
+
"""
|
730 |
+
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
731 |
+
model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs)
|
732 |
+
return model
|
733 |
+
|
734 |
+
|
735 |
+
@register_model
|
736 |
+
def vit_base_patch16_224(pretrained=False, **kwargs):
|
737 |
+
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
738 |
+
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
|
739 |
+
"""
|
740 |
+
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
741 |
+
model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs)
|
742 |
+
return model
|
743 |
+
|
744 |
+
|
745 |
+
@register_model
|
746 |
+
def vit_base_patch16_384(pretrained=False, **kwargs):
|
747 |
+
""" ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
748 |
+
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
749 |
+
"""
|
750 |
+
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
751 |
+
model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs)
|
752 |
+
return model
|
753 |
+
|
754 |
+
|
755 |
+
@register_model
|
756 |
+
def vit_large_patch32_224(pretrained=False, **kwargs):
|
757 |
+
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
|
758 |
+
"""
|
759 |
+
model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
|
760 |
+
model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs)
|
761 |
+
return model
|
762 |
+
|
763 |
+
|
764 |
+
@register_model
|
765 |
+
def vit_large_patch32_384(pretrained=False, **kwargs):
|
766 |
+
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
|
767 |
+
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
768 |
+
"""
|
769 |
+
model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
|
770 |
+
model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs)
|
771 |
+
return model
|
772 |
+
|
773 |
+
|
774 |
+
@register_model
|
775 |
+
def vit_large_patch16_224(pretrained=False, **kwargs):
|
776 |
+
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
|
777 |
+
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
|
778 |
+
"""
|
779 |
+
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
|
780 |
+
model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs)
|
781 |
+
return model
|
782 |
+
|
783 |
+
|
784 |
+
@register_model
|
785 |
+
def vit_large_patch16_384(pretrained=False, **kwargs):
|
786 |
+
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
|
787 |
+
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
788 |
+
"""
|
789 |
+
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
|
790 |
+
model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs)
|
791 |
+
return model
|
792 |
+
|
793 |
+
|
794 |
+
@register_model
|
795 |
+
def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs):
|
796 |
+
""" ViT-Tiny (Vit-Ti/16).
|
797 |
+
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
798 |
+
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
|
799 |
+
"""
|
800 |
+
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
|
801 |
+
model = _create_vision_transformer('vit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
|
802 |
+
return model
|
803 |
+
|
804 |
+
|
805 |
+
@register_model
|
806 |
+
def vit_small_patch32_224_in21k(pretrained=False, **kwargs):
|
807 |
+
""" ViT-Small (ViT-S/16)
|
808 |
+
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
809 |
+
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
|
810 |
+
"""
|
811 |
+
model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
812 |
+
model = _create_vision_transformer('vit_small_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
|
813 |
+
return model
|
814 |
+
|
815 |
+
|
816 |
+
@register_model
|
817 |
+
def vit_small_patch16_224_in21k(pretrained=False, **kwargs):
|
818 |
+
""" ViT-Small (ViT-S/16)
|
819 |
+
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
820 |
+
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
|
821 |
+
"""
|
822 |
+
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
823 |
+
model = _create_vision_transformer('vit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
|
824 |
+
return model
|
825 |
+
|
826 |
+
|
827 |
+
@register_model
|
828 |
+
def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
|
829 |
+
""" ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
|
830 |
+
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
831 |
+
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
|
832 |
+
"""
|
833 |
+
model_kwargs = dict(
|
834 |
+
patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
835 |
+
model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
|
836 |
+
return model
|
837 |
+
|
838 |
+
|
839 |
+
@register_model
|
840 |
+
def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
|
841 |
+
""" ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
842 |
+
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
843 |
+
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
|
844 |
+
"""
|
845 |
+
model_kwargs = dict(
|
846 |
+
patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
847 |
+
model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
|
848 |
+
return model
|
849 |
+
|
850 |
+
|
851 |
+
@register_model
|
852 |
+
def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
|
853 |
+
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
|
854 |
+
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
855 |
+
NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights
|
856 |
+
"""
|
857 |
+
model_kwargs = dict(
|
858 |
+
patch_size=32, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
|
859 |
+
model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
|
860 |
+
return model
|
861 |
+
|
862 |
+
|
863 |
+
@register_model
|
864 |
+
def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
|
865 |
+
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
|
866 |
+
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
867 |
+
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
|
868 |
+
"""
|
869 |
+
model_kwargs = dict(
|
870 |
+
patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
|
871 |
+
model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
|
872 |
+
return model
|
873 |
+
|
874 |
+
|
875 |
+
@register_model
|
876 |
+
def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
|
877 |
+
""" ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
|
878 |
+
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
879 |
+
NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights
|
880 |
+
"""
|
881 |
+
model_kwargs = dict(
|
882 |
+
patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs)
|
883 |
+
model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs)
|
884 |
+
return model
|
885 |
+
|
886 |
+
|
887 |
+
@register_model
|
888 |
+
def deit_tiny_patch16_224(pretrained=False, **kwargs):
|
889 |
+
""" DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
890 |
+
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
891 |
+
"""
|
892 |
+
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
|
893 |
+
model = _create_vision_transformer('deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
|
894 |
+
return model
|
895 |
+
|
896 |
+
|
897 |
+
@register_model
|
898 |
+
def deit_small_patch16_224(pretrained=False, **kwargs):
|
899 |
+
""" DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
900 |
+
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
901 |
+
"""
|
902 |
+
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
903 |
+
model = _create_vision_transformer('deit_small_patch16_224', pretrained=pretrained, **model_kwargs)
|
904 |
+
return model
|
905 |
+
|
906 |
+
|
907 |
+
@register_model
|
908 |
+
def deit_base_patch16_224(pretrained=False, **kwargs):
|
909 |
+
""" DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
910 |
+
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
911 |
+
"""
|
912 |
+
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
913 |
+
model = _create_vision_transformer('deit_base_patch16_224', pretrained=pretrained, **model_kwargs)
|
914 |
+
return model
|
915 |
+
|
916 |
+
|
917 |
+
@register_model
|
918 |
+
def deit_base_patch16_384(pretrained=False, **kwargs):
|
919 |
+
""" DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
|
920 |
+
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
921 |
+
"""
|
922 |
+
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
923 |
+
model = _create_vision_transformer('deit_base_patch16_384', pretrained=pretrained, **model_kwargs)
|
924 |
+
return model
|
925 |
+
|
926 |
+
|
927 |
+
@register_model
|
928 |
+
def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
|
929 |
+
""" DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
930 |
+
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
931 |
+
"""
|
932 |
+
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
|
933 |
+
model = _create_vision_transformer(
|
934 |
+
'deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
|
935 |
+
return model
|
936 |
+
|
937 |
+
|
938 |
+
@register_model
|
939 |
+
def deit_small_distilled_patch16_224(pretrained=False, **kwargs):
|
940 |
+
""" DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
941 |
+
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
942 |
+
"""
|
943 |
+
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
944 |
+
model = _create_vision_transformer(
|
945 |
+
'deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
|
946 |
+
return model
|
947 |
+
|
948 |
+
|
949 |
+
@register_model
|
950 |
+
def deit_base_distilled_patch16_224(pretrained=False, **kwargs):
|
951 |
+
""" DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
952 |
+
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
953 |
+
"""
|
954 |
+
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
955 |
+
model = _create_vision_transformer(
|
956 |
+
'deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
|
957 |
+
return model
|
958 |
+
|
959 |
+
|
960 |
+
@register_model
|
961 |
+
def deit_base_distilled_patch16_384(pretrained=False, **kwargs):
|
962 |
+
""" DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
|
963 |
+
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
964 |
+
"""
|
965 |
+
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
966 |
+
model = _create_vision_transformer(
|
967 |
+
'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs)
|
968 |
+
return model
|
969 |
+
|
970 |
+
|
971 |
+
@register_model
|
972 |
+
def vit_base_patch16_224_miil_in21k(pretrained=False, **kwargs):
|
973 |
+
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
974 |
+
Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
|
975 |
+
"""
|
976 |
+
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs)
|
977 |
+
model = _create_vision_transformer('vit_base_patch16_224_miil_in21k', pretrained=pretrained, **model_kwargs)
|
978 |
+
return model
|
979 |
+
|
980 |
+
|
981 |
+
@register_model
|
982 |
+
def vit_base_patch16_224_miil(pretrained=False, **kwargs):
|
983 |
+
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
984 |
+
Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
|
985 |
+
"""
|
986 |
+
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs)
|
987 |
+
model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs)
|
988 |
+
return model
|
989 |
+
|
990 |
+
|
model/__pycache__/CoordAttention.cpython-38.pyc
ADDED
Binary file (3.58 kB). View file
|
|
model/__pycache__/Vision_Transformer_with_mask.cpython-38.pyc
ADDED
Binary file (37.2 kB). View file
|
|
model/__pycache__/features.cpython-38.pyc
ADDED
Binary file (12.4 kB). View file
|
|
model/__pycache__/helpers.cpython-38.pyc
ADDED
Binary file (14.8 kB). View file
|
|
model/__pycache__/hub.cpython-38.pyc
ADDED
Binary file (3.45 kB). View file
|
|
model/__pycache__/registry.cpython-38.pyc
ADDED
Binary file (4.77 kB). View file
|
|
model/features.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" PyTorch Feature Extraction Helpers
|
2 |
+
|
3 |
+
A collection of classes, functions, modules to help extract features from models
|
4 |
+
and provide a common interface for describing them.
|
5 |
+
|
6 |
+
The return_layers, module re-writing idea inspired by torchvision IntermediateLayerGetter
|
7 |
+
https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py
|
8 |
+
|
9 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
10 |
+
"""
|
11 |
+
from collections import OrderedDict, defaultdict
|
12 |
+
from copy import deepcopy
|
13 |
+
from functools import partial
|
14 |
+
from typing import Dict, List, Tuple
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
|
19 |
+
|
20 |
+
class FeatureInfo:
|
21 |
+
|
22 |
+
def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]):
|
23 |
+
prev_reduction = 1
|
24 |
+
for fi in feature_info:
|
25 |
+
# sanity check the mandatory fields, there may be additional fields depending on the model
|
26 |
+
assert 'num_chs' in fi and fi['num_chs'] > 0
|
27 |
+
assert 'reduction' in fi and fi['reduction'] >= prev_reduction
|
28 |
+
prev_reduction = fi['reduction']
|
29 |
+
assert 'module' in fi
|
30 |
+
self.out_indices = out_indices
|
31 |
+
self.info = feature_info
|
32 |
+
|
33 |
+
def from_other(self, out_indices: Tuple[int]):
|
34 |
+
return FeatureInfo(deepcopy(self.info), out_indices)
|
35 |
+
|
36 |
+
def get(self, key, idx=None):
|
37 |
+
""" Get value by key at specified index (indices)
|
38 |
+
if idx == None, returns value for key at each output index
|
39 |
+
if idx is an integer, return value for that feature module index (ignoring output indices)
|
40 |
+
if idx is a list/tupple, return value for each module index (ignoring output indices)
|
41 |
+
"""
|
42 |
+
if idx is None:
|
43 |
+
return [self.info[i][key] for i in self.out_indices]
|
44 |
+
if isinstance(idx, (tuple, list)):
|
45 |
+
return [self.info[i][key] for i in idx]
|
46 |
+
else:
|
47 |
+
return self.info[idx][key]
|
48 |
+
|
49 |
+
def get_dicts(self, keys=None, idx=None):
|
50 |
+
""" return info dicts for specified keys (or all if None) at specified indices (or out_indices if None)
|
51 |
+
"""
|
52 |
+
if idx is None:
|
53 |
+
if keys is None:
|
54 |
+
return [self.info[i] for i in self.out_indices]
|
55 |
+
else:
|
56 |
+
return [{k: self.info[i][k] for k in keys} for i in self.out_indices]
|
57 |
+
if isinstance(idx, (tuple, list)):
|
58 |
+
return [self.info[i] if keys is None else {k: self.info[i][k] for k in keys} for i in idx]
|
59 |
+
else:
|
60 |
+
return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys}
|
61 |
+
|
62 |
+
def channels(self, idx=None):
|
63 |
+
""" feature channels accessor
|
64 |
+
"""
|
65 |
+
return self.get('num_chs', idx)
|
66 |
+
|
67 |
+
def reduction(self, idx=None):
|
68 |
+
""" feature reduction (output stride) accessor
|
69 |
+
"""
|
70 |
+
return self.get('reduction', idx)
|
71 |
+
|
72 |
+
def module_name(self, idx=None):
|
73 |
+
""" feature module name accessor
|
74 |
+
"""
|
75 |
+
return self.get('module', idx)
|
76 |
+
|
77 |
+
def __getitem__(self, item):
|
78 |
+
return self.info[item]
|
79 |
+
|
80 |
+
def __len__(self):
|
81 |
+
return len(self.info)
|
82 |
+
|
83 |
+
|
84 |
+
class FeatureHooks:
|
85 |
+
""" Feature Hook Helper
|
86 |
+
|
87 |
+
This module helps with the setup and extraction of hooks for extracting features from
|
88 |
+
internal nodes in a model by node name. This works quite well in eager Python but needs
|
89 |
+
redesign for torcscript.
|
90 |
+
"""
|
91 |
+
|
92 |
+
def __init__(self, hooks, named_modules, out_map=None, default_hook_type='forward'):
|
93 |
+
# setup feature hooks
|
94 |
+
modules = {k: v for k, v in named_modules}
|
95 |
+
for i, h in enumerate(hooks):
|
96 |
+
hook_name = h['module']
|
97 |
+
m = modules[hook_name]
|
98 |
+
hook_id = out_map[i] if out_map else hook_name
|
99 |
+
hook_fn = partial(self._collect_output_hook, hook_id)
|
100 |
+
hook_type = h['hook_type'] if 'hook_type' in h else default_hook_type
|
101 |
+
if hook_type == 'forward_pre':
|
102 |
+
m.register_forward_pre_hook(hook_fn)
|
103 |
+
elif hook_type == 'forward':
|
104 |
+
m.register_forward_hook(hook_fn)
|
105 |
+
else:
|
106 |
+
assert False, "Unsupported hook type"
|
107 |
+
self._feature_outputs = defaultdict(OrderedDict)
|
108 |
+
|
109 |
+
def _collect_output_hook(self, hook_id, *args):
|
110 |
+
x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre
|
111 |
+
if isinstance(x, tuple):
|
112 |
+
x = x[0] # unwrap input tuple
|
113 |
+
self._feature_outputs[x.device][hook_id] = x
|
114 |
+
|
115 |
+
def get_output(self, device) -> Dict[str, torch.tensor]:
|
116 |
+
output = self._feature_outputs[device]
|
117 |
+
self._feature_outputs[device] = OrderedDict() # clear after reading
|
118 |
+
return output
|
119 |
+
|
120 |
+
|
121 |
+
def _module_list(module, flatten_sequential=False):
|
122 |
+
# a yield/iter would be better for this but wouldn't be compatible with torchscript
|
123 |
+
ml = []
|
124 |
+
for name, module in module.named_children():
|
125 |
+
if flatten_sequential and isinstance(module, nn.Sequential):
|
126 |
+
# first level of Sequential containers is flattened into containing model
|
127 |
+
for child_name, child_module in module.named_children():
|
128 |
+
combined = [name, child_name]
|
129 |
+
ml.append(('_'.join(combined), '.'.join(combined), child_module))
|
130 |
+
else:
|
131 |
+
ml.append((name, name, module))
|
132 |
+
return ml
|
133 |
+
|
134 |
+
|
135 |
+
def _get_feature_info(net, out_indices):
|
136 |
+
feature_info = getattr(net, 'feature_info')
|
137 |
+
if isinstance(feature_info, FeatureInfo):
|
138 |
+
return feature_info.from_other(out_indices)
|
139 |
+
elif isinstance(feature_info, (list, tuple)):
|
140 |
+
return FeatureInfo(net.feature_info, out_indices)
|
141 |
+
else:
|
142 |
+
assert False, "Provided feature_info is not valid"
|
143 |
+
|
144 |
+
|
145 |
+
def _get_return_layers(feature_info, out_map):
|
146 |
+
module_names = feature_info.module_name()
|
147 |
+
return_layers = {}
|
148 |
+
for i, name in enumerate(module_names):
|
149 |
+
return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i]
|
150 |
+
return return_layers
|
151 |
+
|
152 |
+
|
153 |
+
class FeatureDictNet(nn.ModuleDict):
|
154 |
+
""" Feature extractor with OrderedDict return
|
155 |
+
|
156 |
+
Wrap a model and extract features as specified by the out indices, the network is
|
157 |
+
partially re-built from contained modules.
|
158 |
+
|
159 |
+
There is a strong assumption that the modules have been registered into the model in the same
|
160 |
+
order as they are used. There should be no reuse of the same nn.Module more than once, including
|
161 |
+
trivial modules like `self.relu = nn.ReLU`.
|
162 |
+
|
163 |
+
Only submodules that are directly assigned to the model class (`model.feature1`) or at most
|
164 |
+
one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured.
|
165 |
+
All Sequential containers that are directly assigned to the original model will have their
|
166 |
+
modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
|
167 |
+
|
168 |
+
Arguments:
|
169 |
+
model (nn.Module): model from which we will extract the features
|
170 |
+
out_indices (tuple[int]): model output indices to extract features for
|
171 |
+
out_map (sequence): list or tuple specifying desired return id for each out index,
|
172 |
+
otherwise str(index) is used
|
173 |
+
feature_concat (bool): whether to concatenate intermediate features that are lists or tuples
|
174 |
+
vs select element [0]
|
175 |
+
flatten_sequential (bool): whether to flatten sequential modules assigned to model
|
176 |
+
"""
|
177 |
+
def __init__(
|
178 |
+
self, model,
|
179 |
+
out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False):
|
180 |
+
super(FeatureDictNet, self).__init__()
|
181 |
+
self.feature_info = _get_feature_info(model, out_indices)
|
182 |
+
self.concat = feature_concat
|
183 |
+
self.return_layers = {}
|
184 |
+
return_layers = _get_return_layers(self.feature_info, out_map)
|
185 |
+
modules = _module_list(model, flatten_sequential=flatten_sequential)
|
186 |
+
remaining = set(return_layers.keys())
|
187 |
+
layers = OrderedDict()
|
188 |
+
for new_name, old_name, module in modules:
|
189 |
+
layers[new_name] = module
|
190 |
+
if old_name in remaining:
|
191 |
+
# return id has to be consistently str type for torchscript
|
192 |
+
self.return_layers[new_name] = str(return_layers[old_name])
|
193 |
+
remaining.remove(old_name)
|
194 |
+
if not remaining:
|
195 |
+
break
|
196 |
+
assert not remaining and len(self.return_layers) == len(return_layers), \
|
197 |
+
f'Return layers ({remaining}) are not present in model'
|
198 |
+
self.update(layers)
|
199 |
+
|
200 |
+
def _collect(self, x) -> (Dict[str, torch.Tensor]):
|
201 |
+
out = OrderedDict()
|
202 |
+
for name, module in self.items():
|
203 |
+
x = module(x)
|
204 |
+
if name in self.return_layers:
|
205 |
+
out_id = self.return_layers[name]
|
206 |
+
if isinstance(x, (tuple, list)):
|
207 |
+
# If model tap is a tuple or list, concat or select first element
|
208 |
+
# FIXME this may need to be more generic / flexible for some nets
|
209 |
+
out[out_id] = torch.cat(x, 1) if self.concat else x[0]
|
210 |
+
else:
|
211 |
+
out[out_id] = x
|
212 |
+
return out
|
213 |
+
|
214 |
+
def forward(self, x) -> Dict[str, torch.Tensor]:
|
215 |
+
return self._collect(x)
|
216 |
+
|
217 |
+
|
218 |
+
class FeatureListNet(FeatureDictNet):
|
219 |
+
""" Feature extractor with list return
|
220 |
+
|
221 |
+
See docstring for FeatureDictNet above, this class exists only to appease Torchscript typing constraints.
|
222 |
+
In eager Python we could have returned List[Tensor] vs Dict[id, Tensor] based on a member bool.
|
223 |
+
"""
|
224 |
+
def __init__(
|
225 |
+
self, model,
|
226 |
+
out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False):
|
227 |
+
super(FeatureListNet, self).__init__(
|
228 |
+
model, out_indices=out_indices, out_map=out_map, feature_concat=feature_concat,
|
229 |
+
flatten_sequential=flatten_sequential)
|
230 |
+
|
231 |
+
def forward(self, x) -> (List[torch.Tensor]):
|
232 |
+
return list(self._collect(x).values())
|
233 |
+
|
234 |
+
|
235 |
+
class FeatureHookNet(nn.ModuleDict):
|
236 |
+
""" FeatureHookNet
|
237 |
+
|
238 |
+
Wrap a model and extract features specified by the out indices using forward/forward-pre hooks.
|
239 |
+
|
240 |
+
If `no_rewrite` is True, features are extracted via hooks without modifying the underlying
|
241 |
+
network in any way.
|
242 |
+
|
243 |
+
If `no_rewrite` is False, the model will be re-written as in the
|
244 |
+
FeatureList/FeatureDict case by folding first to second (Sequential only) level modules into this one.
|
245 |
+
|
246 |
+
FIXME this does not currently work with Torchscript, see FeatureHooks class
|
247 |
+
"""
|
248 |
+
def __init__(
|
249 |
+
self, model,
|
250 |
+
out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False, no_rewrite=False,
|
251 |
+
feature_concat=False, flatten_sequential=False, default_hook_type='forward'):
|
252 |
+
super(FeatureHookNet, self).__init__()
|
253 |
+
assert not torch.jit.is_scripting()
|
254 |
+
self.feature_info = _get_feature_info(model, out_indices)
|
255 |
+
self.out_as_dict = out_as_dict
|
256 |
+
layers = OrderedDict()
|
257 |
+
hooks = []
|
258 |
+
if no_rewrite:
|
259 |
+
assert not flatten_sequential
|
260 |
+
if hasattr(model, 'reset_classifier'): # make sure classifier is removed?
|
261 |
+
model.reset_classifier(0)
|
262 |
+
layers['body'] = model
|
263 |
+
hooks.extend(self.feature_info.get_dicts())
|
264 |
+
else:
|
265 |
+
modules = _module_list(model, flatten_sequential=flatten_sequential)
|
266 |
+
remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type
|
267 |
+
for f in self.feature_info.get_dicts()}
|
268 |
+
for new_name, old_name, module in modules:
|
269 |
+
layers[new_name] = module
|
270 |
+
for fn, fm in module.named_modules(prefix=old_name):
|
271 |
+
if fn in remaining:
|
272 |
+
hooks.append(dict(module=fn, hook_type=remaining[fn]))
|
273 |
+
del remaining[fn]
|
274 |
+
if not remaining:
|
275 |
+
break
|
276 |
+
assert not remaining, f'Return layers ({remaining}) are not present in model'
|
277 |
+
self.update(layers)
|
278 |
+
self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map)
|
279 |
+
|
280 |
+
def forward(self, x):
|
281 |
+
for name, module in self.items():
|
282 |
+
x = module(x)
|
283 |
+
out = self.hooks.get_output(x.device)
|
284 |
+
return out if self.out_as_dict else list(out.values())
|
model/helpers.py
ADDED
@@ -0,0 +1,508 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Model creation / weight loading / state_dict helpers
|
2 |
+
|
3 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
4 |
+
"""
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
import math
|
8 |
+
from collections import OrderedDict
|
9 |
+
from copy import deepcopy
|
10 |
+
from typing import Any, Callable, Optional, Tuple
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
|
15 |
+
|
16 |
+
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
|
17 |
+
from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf, load_state_dict_from_url
|
18 |
+
from .layers import Conv2dSame, Linear
|
19 |
+
|
20 |
+
|
21 |
+
_logger = logging.getLogger(__name__)
|
22 |
+
|
23 |
+
|
24 |
+
def load_state_dict(checkpoint_path, use_ema=False):
|
25 |
+
if checkpoint_path and os.path.isfile(checkpoint_path):
|
26 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
27 |
+
state_dict_key = 'state_dict'
|
28 |
+
if isinstance(checkpoint, dict):
|
29 |
+
if use_ema and 'state_dict_ema' in checkpoint:
|
30 |
+
state_dict_key = 'state_dict_ema'
|
31 |
+
if state_dict_key and state_dict_key in checkpoint:
|
32 |
+
new_state_dict = OrderedDict()
|
33 |
+
for k, v in checkpoint[state_dict_key].items():
|
34 |
+
# strip `module.` prefix
|
35 |
+
name = k[7:] if k.startswith('module') else k
|
36 |
+
new_state_dict[name] = v
|
37 |
+
state_dict = new_state_dict
|
38 |
+
else:
|
39 |
+
state_dict = checkpoint
|
40 |
+
_logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path))
|
41 |
+
return state_dict
|
42 |
+
else:
|
43 |
+
_logger.error("No checkpoint found at '{}'".format(checkpoint_path))
|
44 |
+
raise FileNotFoundError()
|
45 |
+
|
46 |
+
|
47 |
+
def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True):
|
48 |
+
if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'):
|
49 |
+
# numpy checkpoint, try to load via model specific load_pretrained fn
|
50 |
+
if hasattr(model, 'load_pretrained'):
|
51 |
+
model.load_pretrained(checkpoint_path)
|
52 |
+
else:
|
53 |
+
raise NotImplementedError('Model cannot load numpy checkpoint')
|
54 |
+
return
|
55 |
+
state_dict = load_state_dict(checkpoint_path, use_ema)
|
56 |
+
model.load_state_dict(state_dict, strict=strict)
|
57 |
+
|
58 |
+
|
59 |
+
def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True):
|
60 |
+
resume_epoch = None
|
61 |
+
if os.path.isfile(checkpoint_path):
|
62 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
63 |
+
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
64 |
+
if log_info:
|
65 |
+
_logger.info('Restoring model state from checkpoint...')
|
66 |
+
new_state_dict = OrderedDict()
|
67 |
+
for k, v in checkpoint['state_dict'].items():
|
68 |
+
name = k[7:] if k.startswith('module') else k
|
69 |
+
new_state_dict[name] = v
|
70 |
+
model.load_state_dict(new_state_dict)
|
71 |
+
|
72 |
+
if optimizer is not None and 'optimizer' in checkpoint:
|
73 |
+
if log_info:
|
74 |
+
_logger.info('Restoring optimizer state from checkpoint...')
|
75 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
76 |
+
|
77 |
+
if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint:
|
78 |
+
if log_info:
|
79 |
+
_logger.info('Restoring AMP loss scaler state from checkpoint...')
|
80 |
+
loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key])
|
81 |
+
|
82 |
+
if 'epoch' in checkpoint:
|
83 |
+
resume_epoch = checkpoint['epoch']
|
84 |
+
if 'version' in checkpoint and checkpoint['version'] > 1:
|
85 |
+
resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
|
86 |
+
|
87 |
+
if log_info:
|
88 |
+
_logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
|
89 |
+
else:
|
90 |
+
model.load_state_dict(checkpoint)
|
91 |
+
if log_info:
|
92 |
+
_logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
|
93 |
+
return resume_epoch
|
94 |
+
else:
|
95 |
+
_logger.error("No checkpoint found at '{}'".format(checkpoint_path))
|
96 |
+
raise FileNotFoundError()
|
97 |
+
|
98 |
+
|
99 |
+
def load_custom_pretrained(model, default_cfg=None, load_fn=None, progress=False, check_hash=False):
|
100 |
+
r"""Loads a custom (read non .pth) weight file
|
101 |
+
|
102 |
+
Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls
|
103 |
+
a passed in custom load fun, or the `load_pretrained` model member fn.
|
104 |
+
|
105 |
+
If the object is already present in `model_dir`, it's deserialized and returned.
|
106 |
+
The default value of `model_dir` is ``<hub_dir>/checkpoints`` where
|
107 |
+
`hub_dir` is the directory returned by :func:`~torch.hub.get_dir`.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
model: The instantiated model to load weights into
|
111 |
+
default_cfg (dict): Default pretrained model cfg
|
112 |
+
load_fn: An external stand alone fn that loads weights into provided model, otherwise a fn named
|
113 |
+
'laod_pretrained' on the model will be called if it exists
|
114 |
+
progress (bool, optional): whether or not to display a progress bar to stderr. Default: False
|
115 |
+
check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention
|
116 |
+
``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
|
117 |
+
digits of the SHA256 hash of the contents of the file. The hash is used to
|
118 |
+
ensure unique names and to verify the contents of the file. Default: False
|
119 |
+
"""
|
120 |
+
default_cfg = default_cfg or getattr(model, 'default_cfg', None) or {}
|
121 |
+
pretrained_url = default_cfg.get('url', None)
|
122 |
+
if not pretrained_url:
|
123 |
+
_logger.warning("No pretrained weights exist for this model. Using random initialization.")
|
124 |
+
return
|
125 |
+
cached_file = download_cached_file(default_cfg['url'], check_hash=check_hash, progress=progress)
|
126 |
+
|
127 |
+
if load_fn is not None:
|
128 |
+
load_fn(model, cached_file)
|
129 |
+
elif hasattr(model, 'load_pretrained'):
|
130 |
+
model.load_pretrained(cached_file)
|
131 |
+
else:
|
132 |
+
_logger.warning("Valid function to load pretrained weights is not available, using random initialization.")
|
133 |
+
|
134 |
+
|
135 |
+
def adapt_input_conv(in_chans, conv_weight):
|
136 |
+
conv_type = conv_weight.dtype
|
137 |
+
conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU
|
138 |
+
O, I, J, K = conv_weight.shape
|
139 |
+
if in_chans == 1:
|
140 |
+
if I > 3:
|
141 |
+
assert conv_weight.shape[1] % 3 == 0
|
142 |
+
# For models with space2depth stems
|
143 |
+
conv_weight = conv_weight.reshape(O, I // 3, 3, J, K)
|
144 |
+
conv_weight = conv_weight.sum(dim=2, keepdim=False)
|
145 |
+
else:
|
146 |
+
conv_weight = conv_weight.sum(dim=1, keepdim=True)
|
147 |
+
elif in_chans != 3:
|
148 |
+
if I != 3:
|
149 |
+
raise NotImplementedError('Weight format not supported by conversion.')
|
150 |
+
else:
|
151 |
+
# NOTE this strategy should be better than random init, but there could be other combinations of
|
152 |
+
# the original RGB input layer weights that'd work better for specific cases.
|
153 |
+
repeat = int(math.ceil(in_chans / 3))
|
154 |
+
conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
|
155 |
+
conv_weight *= (3 / float(in_chans))
|
156 |
+
conv_weight = conv_weight.to(conv_type)
|
157 |
+
return conv_weight
|
158 |
+
|
159 |
+
|
160 |
+
def load_pretrained(model, default_cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False):
|
161 |
+
""" Load pretrained checkpoint
|
162 |
+
|
163 |
+
Args:
|
164 |
+
model (nn.Module) : PyTorch model module
|
165 |
+
default_cfg (Optional[Dict]): default configuration for pretrained weights / target dataset
|
166 |
+
num_classes (int): num_classes for model
|
167 |
+
in_chans (int): in_chans for model
|
168 |
+
filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args)
|
169 |
+
strict (bool): strict load of checkpoint
|
170 |
+
progress (bool): enable progress bar for weight download
|
171 |
+
|
172 |
+
"""
|
173 |
+
default_cfg = default_cfg or getattr(model, 'default_cfg', None) or {}
|
174 |
+
pretrained_url = default_cfg.get('url', None)
|
175 |
+
hf_hub_id = default_cfg.get('hf_hub', None)
|
176 |
+
if not pretrained_url and not hf_hub_id:
|
177 |
+
_logger.warning("No pretrained weights exist for this model. Using random initialization.")
|
178 |
+
return
|
179 |
+
if hf_hub_id and has_hf_hub(necessary=not pretrained_url):
|
180 |
+
_logger.info(f'Loading pretrained weights from Hugging Face hub ({hf_hub_id})')
|
181 |
+
state_dict = load_state_dict_from_hf(hf_hub_id)
|
182 |
+
else:
|
183 |
+
_logger.info(f'Loading pretrained weights from url ({pretrained_url})')
|
184 |
+
state_dict = load_state_dict_from_url(pretrained_url, progress=progress, map_location='cpu')
|
185 |
+
if filter_fn is not None:
|
186 |
+
# for backwards compat with filter fn that take one arg, try one first, the two
|
187 |
+
try:
|
188 |
+
state_dict = filter_fn(state_dict)
|
189 |
+
except TypeError:
|
190 |
+
state_dict = filter_fn(state_dict, model)
|
191 |
+
|
192 |
+
input_convs = default_cfg.get('first_conv', None)
|
193 |
+
if input_convs is not None and in_chans != 3:
|
194 |
+
if isinstance(input_convs, str):
|
195 |
+
input_convs = (input_convs,)
|
196 |
+
for input_conv_name in input_convs:
|
197 |
+
weight_name = input_conv_name + '.weight'
|
198 |
+
try:
|
199 |
+
state_dict[weight_name] = adapt_input_conv(in_chans, state_dict[weight_name])
|
200 |
+
_logger.info(
|
201 |
+
f'Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)')
|
202 |
+
except NotImplementedError as e:
|
203 |
+
del state_dict[weight_name]
|
204 |
+
strict = False
|
205 |
+
_logger.warning(
|
206 |
+
f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.')
|
207 |
+
|
208 |
+
classifiers = default_cfg.get('classifier', None)
|
209 |
+
label_offset = default_cfg.get('label_offset', 0)
|
210 |
+
if classifiers is not None:
|
211 |
+
if isinstance(classifiers, str):
|
212 |
+
classifiers = (classifiers,)
|
213 |
+
if num_classes != default_cfg['num_classes']:
|
214 |
+
for classifier_name in classifiers:
|
215 |
+
# completely discard fully connected if model num_classes doesn't match pretrained weights
|
216 |
+
del state_dict[classifier_name + '.weight']
|
217 |
+
del state_dict[classifier_name + '.bias']
|
218 |
+
strict = False
|
219 |
+
elif label_offset > 0:
|
220 |
+
for classifier_name in classifiers:
|
221 |
+
# special case for pretrained weights with an extra background class in pretrained weights
|
222 |
+
classifier_weight = state_dict[classifier_name + '.weight']
|
223 |
+
state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:]
|
224 |
+
classifier_bias = state_dict[classifier_name + '.bias']
|
225 |
+
state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:]
|
226 |
+
|
227 |
+
model.load_state_dict(state_dict, strict=strict)
|
228 |
+
|
229 |
+
|
230 |
+
def extract_layer(model, layer):
|
231 |
+
layer = layer.split('.')
|
232 |
+
module = model
|
233 |
+
if hasattr(model, 'module') and layer[0] != 'module':
|
234 |
+
module = model.module
|
235 |
+
if not hasattr(model, 'module') and layer[0] == 'module':
|
236 |
+
layer = layer[1:]
|
237 |
+
for l in layer:
|
238 |
+
if hasattr(module, l):
|
239 |
+
if not l.isdigit():
|
240 |
+
module = getattr(module, l)
|
241 |
+
else:
|
242 |
+
module = module[int(l)]
|
243 |
+
else:
|
244 |
+
return module
|
245 |
+
return module
|
246 |
+
|
247 |
+
|
248 |
+
def set_layer(model, layer, val):
|
249 |
+
layer = layer.split('.')
|
250 |
+
module = model
|
251 |
+
if hasattr(model, 'module') and layer[0] != 'module':
|
252 |
+
module = model.module
|
253 |
+
lst_index = 0
|
254 |
+
module2 = module
|
255 |
+
for l in layer:
|
256 |
+
if hasattr(module2, l):
|
257 |
+
if not l.isdigit():
|
258 |
+
module2 = getattr(module2, l)
|
259 |
+
else:
|
260 |
+
module2 = module2[int(l)]
|
261 |
+
lst_index += 1
|
262 |
+
lst_index -= 1
|
263 |
+
for l in layer[:lst_index]:
|
264 |
+
if not l.isdigit():
|
265 |
+
module = getattr(module, l)
|
266 |
+
else:
|
267 |
+
module = module[int(l)]
|
268 |
+
l = layer[lst_index]
|
269 |
+
setattr(module, l, val)
|
270 |
+
|
271 |
+
|
272 |
+
def adapt_model_from_string(parent_module, model_string):
|
273 |
+
separator = '***'
|
274 |
+
state_dict = {}
|
275 |
+
lst_shape = model_string.split(separator)
|
276 |
+
for k in lst_shape:
|
277 |
+
k = k.split(':')
|
278 |
+
key = k[0]
|
279 |
+
shape = k[1][1:-1].split(',')
|
280 |
+
if shape[0] != '':
|
281 |
+
state_dict[key] = [int(i) for i in shape]
|
282 |
+
|
283 |
+
new_module = deepcopy(parent_module)
|
284 |
+
for n, m in parent_module.named_modules():
|
285 |
+
old_module = extract_layer(parent_module, n)
|
286 |
+
if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame):
|
287 |
+
if isinstance(old_module, Conv2dSame):
|
288 |
+
conv = Conv2dSame
|
289 |
+
else:
|
290 |
+
conv = nn.Conv2d
|
291 |
+
s = state_dict[n + '.weight']
|
292 |
+
in_channels = s[1]
|
293 |
+
out_channels = s[0]
|
294 |
+
g = 1
|
295 |
+
if old_module.groups > 1:
|
296 |
+
in_channels = out_channels
|
297 |
+
g = in_channels
|
298 |
+
new_conv = conv(
|
299 |
+
in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size,
|
300 |
+
bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation,
|
301 |
+
groups=g, stride=old_module.stride)
|
302 |
+
set_layer(new_module, n, new_conv)
|
303 |
+
if isinstance(old_module, nn.BatchNorm2d):
|
304 |
+
new_bn = nn.BatchNorm2d(
|
305 |
+
num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
|
306 |
+
affine=old_module.affine, track_running_stats=True)
|
307 |
+
set_layer(new_module, n, new_bn)
|
308 |
+
if isinstance(old_module, nn.Linear):
|
309 |
+
# FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer?
|
310 |
+
num_features = state_dict[n + '.weight'][1]
|
311 |
+
new_fc = Linear(
|
312 |
+
in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None)
|
313 |
+
set_layer(new_module, n, new_fc)
|
314 |
+
if hasattr(new_module, 'num_features'):
|
315 |
+
new_module.num_features = num_features
|
316 |
+
new_module.eval()
|
317 |
+
parent_module.eval()
|
318 |
+
|
319 |
+
return new_module
|
320 |
+
|
321 |
+
|
322 |
+
def adapt_model_from_file(parent_module, model_variant):
|
323 |
+
adapt_file = os.path.join(os.path.dirname(__file__), 'pruned', model_variant + '.txt')
|
324 |
+
with open(adapt_file, 'r') as f:
|
325 |
+
return adapt_model_from_string(parent_module, f.read().strip())
|
326 |
+
|
327 |
+
|
328 |
+
def default_cfg_for_features(default_cfg):
|
329 |
+
default_cfg = deepcopy(default_cfg)
|
330 |
+
# remove default pretrained cfg fields that don't have much relevance for feature backbone
|
331 |
+
to_remove = ('num_classes', 'crop_pct', 'classifier', 'global_pool') # add default final pool size?
|
332 |
+
for tr in to_remove:
|
333 |
+
default_cfg.pop(tr, None)
|
334 |
+
return default_cfg
|
335 |
+
|
336 |
+
|
337 |
+
def overlay_external_default_cfg(default_cfg, kwargs):
|
338 |
+
""" Overlay 'external_default_cfg' in kwargs on top of default_cfg arg.
|
339 |
+
"""
|
340 |
+
external_default_cfg = kwargs.pop('external_default_cfg', None)
|
341 |
+
if external_default_cfg:
|
342 |
+
default_cfg.pop('url', None) # url should come from external cfg
|
343 |
+
default_cfg.pop('hf_hub', None) # hf hub id should come from external cfg
|
344 |
+
default_cfg.update(external_default_cfg)
|
345 |
+
|
346 |
+
|
347 |
+
def set_default_kwargs(kwargs, names, default_cfg):
|
348 |
+
for n in names:
|
349 |
+
# for legacy reasons, model __init__args uses img_size + in_chans as separate args while
|
350 |
+
# default_cfg has one input_size=(C, H ,W) entry
|
351 |
+
if n == 'img_size':
|
352 |
+
input_size = default_cfg.get('input_size', None)
|
353 |
+
if input_size is not None:
|
354 |
+
assert len(input_size) == 3
|
355 |
+
kwargs.setdefault(n, input_size[-2:])
|
356 |
+
elif n == 'in_chans':
|
357 |
+
input_size = default_cfg.get('input_size', None)
|
358 |
+
if input_size is not None:
|
359 |
+
assert len(input_size) == 3
|
360 |
+
kwargs.setdefault(n, input_size[0])
|
361 |
+
else:
|
362 |
+
default_val = default_cfg.get(n, None)
|
363 |
+
if default_val is not None:
|
364 |
+
kwargs.setdefault(n, default_cfg[n])
|
365 |
+
|
366 |
+
|
367 |
+
def filter_kwargs(kwargs, names):
|
368 |
+
if not kwargs or not names:
|
369 |
+
return
|
370 |
+
for n in names:
|
371 |
+
kwargs.pop(n, None)
|
372 |
+
|
373 |
+
|
374 |
+
def update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter):
|
375 |
+
""" Update the default_cfg and kwargs before passing to model
|
376 |
+
|
377 |
+
FIXME this sequence of overlay default_cfg, set default kwargs, filter kwargs
|
378 |
+
could/should be replaced by an improved configuration mechanism
|
379 |
+
|
380 |
+
Args:
|
381 |
+
default_cfg: input default_cfg (updated in-place)
|
382 |
+
kwargs: keyword args passed to model build fn (updated in-place)
|
383 |
+
kwargs_filter: keyword arg keys that must be removed before model __init__
|
384 |
+
"""
|
385 |
+
# Overlay default cfg values from `external_default_cfg` if it exists in kwargs
|
386 |
+
overlay_external_default_cfg(default_cfg, kwargs)
|
387 |
+
# Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs)
|
388 |
+
default_kwarg_names = ('num_classes', 'global_pool', 'in_chans')
|
389 |
+
if default_cfg.get('fixed_input_size', False):
|
390 |
+
# if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size
|
391 |
+
default_kwarg_names += ('img_size',)
|
392 |
+
set_default_kwargs(kwargs, names=default_kwarg_names, default_cfg=default_cfg)
|
393 |
+
# Filter keyword args for task specific model variants (some 'features only' models, etc.)
|
394 |
+
filter_kwargs(kwargs, names=kwargs_filter)
|
395 |
+
|
396 |
+
|
397 |
+
def build_model_with_cfg(
|
398 |
+
model_cls: Callable,
|
399 |
+
variant: str,
|
400 |
+
pretrained: bool,
|
401 |
+
default_cfg: dict,
|
402 |
+
model_cfg: Optional[Any] = None,
|
403 |
+
feature_cfg: Optional[dict] = None,
|
404 |
+
pretrained_strict: bool = True,
|
405 |
+
pretrained_filter_fn: Optional[Callable] = None,
|
406 |
+
pretrained_custom_load: bool = False,
|
407 |
+
kwargs_filter: Optional[Tuple[str]] = None,
|
408 |
+
**kwargs):
|
409 |
+
""" Build model with specified default_cfg and optional model_cfg
|
410 |
+
|
411 |
+
This helper fn aids in the construction of a model including:
|
412 |
+
* handling default_cfg and associated pretained weight loading
|
413 |
+
* passing through optional model_cfg for models with config based arch spec
|
414 |
+
* features_only model adaptation
|
415 |
+
* pruning config / model adaptation
|
416 |
+
|
417 |
+
Args:
|
418 |
+
model_cls (nn.Module): model class
|
419 |
+
variant (str): model variant name
|
420 |
+
pretrained (bool): load pretrained weights
|
421 |
+
default_cfg (dict): model's default pretrained/task config
|
422 |
+
model_cfg (Optional[Dict]): model's architecture config
|
423 |
+
feature_cfg (Optional[Dict]: feature extraction adapter config
|
424 |
+
pretrained_strict (bool): load pretrained weights strictly
|
425 |
+
pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights
|
426 |
+
pretrained_custom_load (bool): use custom load fn, to load numpy or other non PyTorch weights
|
427 |
+
kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model
|
428 |
+
**kwargs: model args passed through to model __init__
|
429 |
+
"""
|
430 |
+
pruned = kwargs.pop('pruned', False)
|
431 |
+
features = False
|
432 |
+
feature_cfg = feature_cfg or {}
|
433 |
+
default_cfg = deepcopy(default_cfg) if default_cfg else {}
|
434 |
+
update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter)
|
435 |
+
default_cfg.setdefault('architecture', variant)
|
436 |
+
|
437 |
+
# Setup for feature extraction wrapper done at end of this fn
|
438 |
+
if kwargs.pop('features_only', False):
|
439 |
+
features = True
|
440 |
+
feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
|
441 |
+
if 'out_indices' in kwargs:
|
442 |
+
feature_cfg['out_indices'] = kwargs.pop('out_indices')
|
443 |
+
|
444 |
+
# Build the model
|
445 |
+
model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs)
|
446 |
+
model.default_cfg = default_cfg
|
447 |
+
|
448 |
+
if pruned:
|
449 |
+
model = adapt_model_from_file(model, variant)
|
450 |
+
|
451 |
+
# For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
|
452 |
+
num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
|
453 |
+
if pretrained:
|
454 |
+
if pretrained_custom_load:
|
455 |
+
load_custom_pretrained(model)
|
456 |
+
else:
|
457 |
+
load_pretrained(
|
458 |
+
model,
|
459 |
+
num_classes=num_classes_pretrained,
|
460 |
+
in_chans=kwargs.get('in_chans', 3),
|
461 |
+
filter_fn=pretrained_filter_fn,
|
462 |
+
strict=pretrained_strict)
|
463 |
+
|
464 |
+
# Wrap the model in a feature extraction module if enabled
|
465 |
+
if features:
|
466 |
+
feature_cls = FeatureListNet
|
467 |
+
if 'feature_cls' in feature_cfg:
|
468 |
+
feature_cls = feature_cfg.pop('feature_cls')
|
469 |
+
if isinstance(feature_cls, str):
|
470 |
+
feature_cls = feature_cls.lower()
|
471 |
+
if 'hook' in feature_cls:
|
472 |
+
feature_cls = FeatureHookNet
|
473 |
+
else:
|
474 |
+
assert False, f'Unknown feature class {feature_cls}'
|
475 |
+
model = feature_cls(model, **feature_cfg)
|
476 |
+
model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg
|
477 |
+
|
478 |
+
return model
|
479 |
+
|
480 |
+
|
481 |
+
def model_parameters(model, exclude_head=False):
|
482 |
+
if exclude_head:
|
483 |
+
# FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering
|
484 |
+
return [p for p in model.parameters()][:-2]
|
485 |
+
else:
|
486 |
+
return model.parameters()
|
487 |
+
|
488 |
+
|
489 |
+
def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, include_root=False) -> nn.Module:
|
490 |
+
if not depth_first and include_root:
|
491 |
+
fn(module=module, name=name)
|
492 |
+
for child_name, child_module in module.named_children():
|
493 |
+
child_name = '.'.join((name, child_name)) if name else child_name
|
494 |
+
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
495 |
+
if depth_first and include_root:
|
496 |
+
fn(module=module, name=name)
|
497 |
+
return module
|
498 |
+
|
499 |
+
|
500 |
+
def named_modules(module: nn.Module, name='', depth_first=True, include_root=False):
|
501 |
+
if not depth_first and include_root:
|
502 |
+
yield name, module
|
503 |
+
for child_name, child_module in module.named_children():
|
504 |
+
child_name = '.'.join((name, child_name)) if name else child_name
|
505 |
+
yield from named_modules(
|
506 |
+
module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
507 |
+
if depth_first and include_root:
|
508 |
+
yield name, module
|
model/hub.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
from functools import partial
|
5 |
+
from typing import Union, Optional
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch.hub import load_state_dict_from_url, download_url_to_file, urlparse, HASH_REGEX
|
9 |
+
try:
|
10 |
+
from torch.hub import get_dir
|
11 |
+
except ImportError:
|
12 |
+
from torch.hub import _get_torch_home as get_dir
|
13 |
+
|
14 |
+
from timm import __version__
|
15 |
+
try:
|
16 |
+
from huggingface_hub import hf_hub_url
|
17 |
+
from huggingface_hub import cached_download
|
18 |
+
cached_download = partial(cached_download, library_name="timm", library_version=__version__)
|
19 |
+
except ImportError:
|
20 |
+
hf_hub_url = None
|
21 |
+
cached_download = None
|
22 |
+
|
23 |
+
_logger = logging.getLogger(__name__)
|
24 |
+
|
25 |
+
|
26 |
+
def get_cache_dir(child_dir=''):
|
27 |
+
"""
|
28 |
+
Returns the location of the directory where models are cached (and creates it if necessary).
|
29 |
+
"""
|
30 |
+
# Issue warning to move data if old env is set
|
31 |
+
if os.getenv('TORCH_MODEL_ZOO'):
|
32 |
+
_logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
|
33 |
+
|
34 |
+
hub_dir = get_dir()
|
35 |
+
child_dir = () if not child_dir else (child_dir,)
|
36 |
+
model_dir = os.path.join(hub_dir, 'checkpoints', *child_dir)
|
37 |
+
os.makedirs(model_dir, exist_ok=True)
|
38 |
+
return model_dir
|
39 |
+
|
40 |
+
|
41 |
+
def download_cached_file(url, check_hash=True, progress=False):
|
42 |
+
parts = urlparse(url)
|
43 |
+
filename = os.path.basename(parts.path)
|
44 |
+
cached_file = os.path.join(get_cache_dir(), filename)
|
45 |
+
if not os.path.exists(cached_file):
|
46 |
+
_logger.info('Downloading: "{}" to {}\n'.format(url, cached_file))
|
47 |
+
hash_prefix = None
|
48 |
+
if check_hash:
|
49 |
+
r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
|
50 |
+
hash_prefix = r.group(1) if r else None
|
51 |
+
download_url_to_file(url, cached_file, hash_prefix, progress=progress)
|
52 |
+
return cached_file
|
53 |
+
|
54 |
+
|
55 |
+
def has_hf_hub(necessary=False):
|
56 |
+
if hf_hub_url is None and necessary:
|
57 |
+
# if no HF Hub module installed and it is necessary to continue, raise error
|
58 |
+
raise RuntimeError(
|
59 |
+
'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
|
60 |
+
return hf_hub_url is not None
|
61 |
+
|
62 |
+
|
63 |
+
def hf_split(hf_id):
|
64 |
+
rev_split = hf_id.split('@')
|
65 |
+
assert 0 < len(rev_split) <= 2, 'hf_hub id should only contain one @ character to identify revision.'
|
66 |
+
hf_model_id = rev_split[0]
|
67 |
+
hf_revision = rev_split[-1] if len(rev_split) > 1 else None
|
68 |
+
return hf_model_id, hf_revision
|
69 |
+
|
70 |
+
|
71 |
+
def load_cfg_from_json(json_file: Union[str, os.PathLike]):
|
72 |
+
with open(json_file, "r", encoding="utf-8") as reader:
|
73 |
+
text = reader.read()
|
74 |
+
return json.loads(text)
|
75 |
+
|
76 |
+
|
77 |
+
def _download_from_hf(model_id: str, filename: str):
|
78 |
+
hf_model_id, hf_revision = hf_split(model_id)
|
79 |
+
url = hf_hub_url(hf_model_id, filename, revision=hf_revision)
|
80 |
+
return cached_download(url, cache_dir=get_cache_dir('hf'))
|
81 |
+
|
82 |
+
|
83 |
+
def load_model_config_from_hf(model_id: str):
|
84 |
+
assert has_hf_hub(True)
|
85 |
+
cached_file = _download_from_hf(model_id, 'config.json')
|
86 |
+
default_cfg = load_cfg_from_json(cached_file)
|
87 |
+
default_cfg['hf_hub'] = model_id # insert hf_hub id for pretrained weight load during model creation
|
88 |
+
model_name = default_cfg.get('architecture')
|
89 |
+
return default_cfg, model_name
|
90 |
+
|
91 |
+
|
92 |
+
def load_state_dict_from_hf(model_id: str):
|
93 |
+
assert has_hf_hub(True)
|
94 |
+
cached_file = _download_from_hf(model_id, 'pytorch_model.bin')
|
95 |
+
state_dict = torch.load(cached_file, map_location='cpu')
|
96 |
+
return state_dict
|
model/layers/__init__.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .activations import *
|
2 |
+
from .adaptive_avgmax_pool import \
|
3 |
+
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
|
4 |
+
from .blur_pool import BlurPool2d
|
5 |
+
from .classifier import ClassifierHead, create_classifier
|
6 |
+
from .cond_conv2d import CondConv2d, get_condconv_initializer
|
7 |
+
from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\
|
8 |
+
set_layer_config
|
9 |
+
from .conv2d_same import Conv2dSame, conv2d_same
|
10 |
+
from .conv_bn_act import ConvBnAct
|
11 |
+
from .create_act import create_act_layer, get_act_layer, get_act_fn
|
12 |
+
from .create_attn import get_attn, create_attn
|
13 |
+
from .create_conv2d import create_conv2d
|
14 |
+
from .create_norm_act import get_norm_act_layer, create_norm_act, convert_norm_act
|
15 |
+
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
|
16 |
+
from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn
|
17 |
+
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
|
18 |
+
from .gather_excite import GatherExcite
|
19 |
+
from .global_context import GlobalContext
|
20 |
+
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible
|
21 |
+
from .inplace_abn import InplaceAbn
|
22 |
+
from .involution import Involution
|
23 |
+
from .linear import Linear
|
24 |
+
from .mixed_conv2d import MixedConv2d
|
25 |
+
from .mlp import Mlp, GluMlp, GatedMlp
|
26 |
+
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
|
27 |
+
from .norm import GroupNorm, LayerNorm2d
|
28 |
+
from .norm_act import BatchNormAct2d, GroupNormAct
|
29 |
+
from .padding import get_padding, get_same_padding, pad_same
|
30 |
+
from .patch_embed import PatchEmbed
|
31 |
+
from .pool2d_same import AvgPool2dSame, create_pool2d
|
32 |
+
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
|
33 |
+
from .selective_kernel import SelectiveKernel
|
34 |
+
from .separable_conv import SeparableConv2d, SeparableConvBnAct
|
35 |
+
from .space_to_depth import SpaceToDepthModule
|
36 |
+
from .split_attn import SplitAttn
|
37 |
+
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
|
38 |
+
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
|
39 |
+
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
|
40 |
+
from .weight_init import trunc_normal_, variance_scaling_, lecun_normal_
|
model/layers/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (3.1 kB). View file
|
|
model/layers/__pycache__/activations.cpython-38.pyc
ADDED
Binary file (6.66 kB). View file
|
|
model/layers/__pycache__/activations_jit.cpython-38.pyc
ADDED
Binary file (4.14 kB). View file
|
|
model/layers/__pycache__/activations_me.cpython-38.pyc
ADDED
Binary file (8.92 kB). View file
|
|
model/layers/__pycache__/adaptive_avgmax_pool.cpython-38.pyc
ADDED
Binary file (4.78 kB). View file
|
|
model/layers/__pycache__/blur_pool.cpython-38.pyc
ADDED
Binary file (2.1 kB). View file
|
|
model/layers/__pycache__/bottleneck_attn.cpython-38.pyc
ADDED
Binary file (4.76 kB). View file
|
|
model/layers/__pycache__/cbam.cpython-38.pyc
ADDED
Binary file (5.17 kB). View file
|
|
model/layers/__pycache__/classifier.cpython-38.pyc
ADDED
Binary file (2.24 kB). View file
|
|
model/layers/__pycache__/cond_conv2d.cpython-38.pyc
ADDED
Binary file (3.84 kB). View file
|
|
model/layers/__pycache__/config.cpython-38.pyc
ADDED
Binary file (3.44 kB). View file
|
|
model/layers/__pycache__/conv2d_same.cpython-38.pyc
ADDED
Binary file (1.96 kB). View file
|
|
model/layers/__pycache__/conv_bn_act.cpython-38.pyc
ADDED
Binary file (1.66 kB). View file
|
|
model/layers/__pycache__/create_act.cpython-38.pyc
ADDED
Binary file (3.65 kB). View file
|
|
model/layers/__pycache__/create_attn.cpython-38.pyc
ADDED
Binary file (2.09 kB). View file
|
|
model/layers/__pycache__/create_conv2d.cpython-38.pyc
ADDED
Binary file (1.09 kB). View file
|
|
model/layers/__pycache__/create_norm_act.cpython-38.pyc
ADDED
Binary file (2.33 kB). View file
|
|
model/layers/__pycache__/drop.cpython-38.pyc
ADDED
Binary file (5.74 kB). View file
|
|
model/layers/__pycache__/eca.cpython-38.pyc
ADDED
Binary file (6.15 kB). View file
|
|
model/layers/__pycache__/evo_norm.cpython-38.pyc
ADDED
Binary file (3.39 kB). View file
|
|
model/layers/__pycache__/gather_excite.cpython-38.pyc
ADDED
Binary file (3.11 kB). View file
|
|
model/layers/__pycache__/global_context.cpython-38.pyc
ADDED
Binary file (2.43 kB). View file
|
|
model/layers/__pycache__/halo_attn.cpython-38.pyc
ADDED
Binary file (5.59 kB). View file
|
|
model/layers/__pycache__/helpers.cpython-38.pyc
ADDED
Binary file (1.03 kB). View file
|
|
model/layers/__pycache__/inplace_abn.cpython-38.pyc
ADDED
Binary file (3.18 kB). View file
|
|
model/layers/__pycache__/involution.cpython-38.pyc
ADDED
Binary file (1.83 kB). View file
|
|
model/layers/__pycache__/lambda_layer.cpython-38.pyc
ADDED
Binary file (3.01 kB). View file
|
|
model/layers/__pycache__/linear.cpython-38.pyc
ADDED
Binary file (1.08 kB). View file
|
|
model/layers/__pycache__/mixed_conv2d.cpython-38.pyc
ADDED
Binary file (2.29 kB). View file
|
|
model/layers/__pycache__/mlp.cpython-38.pyc
ADDED
Binary file (3.81 kB). View file
|
|
model/layers/__pycache__/non_local_attn.cpython-38.pyc
ADDED
Binary file (5.64 kB). View file
|
|
model/layers/__pycache__/norm.cpython-38.pyc
ADDED
Binary file (1.52 kB). View file
|
|
model/layers/__pycache__/norm_act.cpython-38.pyc
ADDED
Binary file (3.07 kB). View file
|
|
model/layers/__pycache__/padding.cpython-38.pyc
ADDED
Binary file (1.8 kB). View file
|
|
model/layers/__pycache__/patch_embed.cpython-38.pyc
ADDED
Binary file (1.68 kB). View file
|
|
model/layers/__pycache__/pool2d_same.cpython-38.pyc
ADDED
Binary file (3.12 kB). View file
|
|
model/layers/__pycache__/selective_kernel.cpython-38.pyc
ADDED
Binary file (5.51 kB). View file
|
|
model/layers/__pycache__/separable_conv.cpython-38.pyc
ADDED
Binary file (2.98 kB). View file
|
|