AlexZou commited on
Commit
60e23e3
1 Parent(s): e9724ef

Upload SCET.py

Browse files
Files changed (1) hide show
  1. models/SCET.py +276 -0
models/SCET.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from einops import rearrange
6
+ from einops.layers.torch import Rearrange
7
+ import numbers
8
+
9
+ # LayerNorm
10
+
11
+ def to_3d(x):
12
+ return rearrange(x, 'b c h w -> b (h w) c')
13
+
14
+ def to_4d(x,h,w):
15
+ return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w)
16
+
17
+ class BiasFree_LayerNorm(nn.Module):
18
+ def __init__(self, normalized_shape):
19
+ super(BiasFree_LayerNorm, self).__init__()
20
+ if isinstance(normalized_shape, numbers.Integral):
21
+ normalized_shape = (normalized_shape,)
22
+ normalized_shape = torch.Size(normalized_shape)
23
+
24
+ assert len(normalized_shape) == 1
25
+
26
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
27
+ self.normalized_shape = normalized_shape
28
+
29
+ def forward(self, x):
30
+ sigma = x.var(-1, keepdim=True, unbiased=False)
31
+ return x / torch.sqrt(sigma+1e-5) * self.weight
32
+
33
+ class WithBias_LayerNorm(nn.Module):
34
+ def __init__(self, normalized_shape):
35
+ super(WithBias_LayerNorm, self).__init__()
36
+ if isinstance(normalized_shape, numbers.Integral):
37
+ normalized_shape = (normalized_shape,)
38
+ normalized_shape = torch.Size(normalized_shape)
39
+
40
+ assert len(normalized_shape) == 1
41
+
42
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
43
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
44
+ self.normalized_shape = normalized_shape
45
+
46
+ def forward(self, x):
47
+ mu = x.mean(-1, keepdim=True)
48
+ sigma = x.var(-1, keepdim=True, unbiased=False)
49
+ return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias
50
+
51
+ class LayerNorm(nn.Module):
52
+ def __init__(self, dim, LayerNorm_type):
53
+ super(LayerNorm, self).__init__()
54
+ if LayerNorm_type =='BiasFree':
55
+ self.body = BiasFree_LayerNorm(dim)
56
+ else:
57
+ self.body = WithBias_LayerNorm(dim)
58
+
59
+ def forward(self, x):
60
+ h, w = x.shape[-2:]
61
+ return to_4d(self.body(to_3d(x)), h, w)
62
+
63
+
64
+ ## Gated-Dconv Feed-Forward Network (GDFN)
65
+ class GFeedForward(nn.Module):
66
+ def __init__(self, dim, ffn_expansion_factor, bias):
67
+ super(GFeedForward, self).__init__()
68
+
69
+ hidden_features = int(dim * ffn_expansion_factor)
70
+
71
+ self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)
72
+
73
+ self.dwconv = nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3, stride=1, padding=1,
74
+ groups=hidden_features * 2, bias=bias)
75
+
76
+ self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
77
+
78
+ def forward(self, x):
79
+ x = self.project_in(x)
80
+ x1, x2 = self.dwconv(x).chunk(2, dim=1)
81
+ x = F.gelu(x1) * x2
82
+ x = self.project_out(x)
83
+ return x
84
+
85
+
86
+ ##########################################################################
87
+ ## Multi-DConv Head Transposed Self-Attention (MDTA)
88
+ class Attention(nn.Module):
89
+ def __init__(self, dim, num_heads, bias):
90
+ super(Attention, self).__init__()
91
+ self.num_heads = num_heads
92
+ self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
93
+
94
+ self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
95
+ self.qkv_dwconv = nn.Conv2d(dim * 3, dim * 3, kernel_size=3, stride=1, padding=1, groups=dim * 3, bias=bias)
96
+ self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
97
+
98
+ def forward(self, x):
99
+ b, c, h, w = x.shape
100
+
101
+ qkv = self.qkv_dwconv(self.qkv(x))
102
+ q, k, v = qkv.chunk(3, dim=1)
103
+
104
+ q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
105
+ k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
106
+ v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
107
+
108
+ q = torch.nn.functional.normalize(q, dim=-1)
109
+ k = torch.nn.functional.normalize(k, dim=-1)
110
+
111
+ attn = (q @ k.transpose(-2, -1)) * self.temperature
112
+ attn = attn.softmax(dim=-1)
113
+
114
+ out = (attn @ v)
115
+
116
+ out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
117
+
118
+ out = self.project_out(out)
119
+ return out
120
+
121
+
122
+ class TransformerBlock(nn.Module):
123
+ def __init__(self, dim=48, num_heads=8, ffn_expansion_factor=2.66, bias=False, LayerNorm_type=WithBias_LayerNorm):
124
+ super(TransformerBlock, self).__init__()
125
+
126
+ self.norm1 = LayerNorm(dim, LayerNorm_type)
127
+ self.attn = Attention(dim, num_heads, bias)
128
+ self.norm2 = LayerNorm(dim, LayerNorm_type)
129
+ self.ffn = GFeedForward(dim, ffn_expansion_factor, bias)
130
+
131
+ def forward(self, x):
132
+ x = x + self.attn(self.norm1(x))
133
+ x = x + self.ffn(self.norm2(x))
134
+
135
+ return x
136
+
137
+
138
+ class BackBoneBlock(nn.Module):
139
+ def __init__(self, num, fm, **args):
140
+ super().__init__()
141
+ self.arr = nn.ModuleList([])
142
+ for _ in range(num):
143
+ self.arr.append(fm(**args))
144
+
145
+ def forward(self, x):
146
+ for block in self.arr:
147
+ x = block(x)
148
+ return x
149
+
150
+
151
+ class PAConv(nn.Module):
152
+
153
+ def __init__(self, nf, k_size=3):
154
+ super(PAConv, self).__init__()
155
+ self.k2 = nn.Conv2d(nf, nf, 1) # 1x1 convolution nf->nf
156
+ self.sigmoid = nn.Sigmoid()
157
+ self.k3 = nn.Conv2d(nf, nf, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) # 3x3 convolution
158
+ self.k4 = nn.Conv2d(nf, nf, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) # 3x3 convolution
159
+
160
+ def forward(self, x):
161
+ y = self.k2(x)
162
+ y = self.sigmoid(y)
163
+
164
+ out = torch.mul(self.k3(x), y)
165
+ out = self.k4(out)
166
+
167
+ return out
168
+
169
+
170
+ class SCPA(nn.Module):
171
+ """SCPA is modified from SCNet (Jiang-Jiang Liu et al. Improving Convolutional Networks with Self-Calibrated Convolutions. In CVPR, 2020)
172
+ Github: https://github.com/MCG-NKU/SCNet
173
+ """
174
+
175
+ def __init__(self, nf, reduction=2, stride=1, dilation=1):
176
+ super(SCPA, self).__init__()
177
+ group_width = nf // reduction
178
+
179
+ self.conv1_a = nn.Conv2d(nf, group_width, kernel_size=1, bias=False)
180
+ self.conv1_b = nn.Conv2d(nf, group_width, kernel_size=1, bias=False)
181
+
182
+ self.k1 = nn.Sequential(
183
+ nn.Conv2d(
184
+ group_width, group_width, kernel_size=3, stride=stride,
185
+ padding=dilation, dilation=dilation,
186
+ bias=False)
187
+ )
188
+
189
+ self.PAConv = PAConv(group_width)
190
+
191
+ self.conv3 = nn.Conv2d(
192
+ group_width * reduction, nf, kernel_size=1, bias=False)
193
+
194
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
195
+
196
+ def forward(self, x):
197
+ residual = x
198
+
199
+ out_a = self.conv1_a(x)
200
+ out_b = self.conv1_b(x)
201
+ out_a = self.lrelu(out_a)
202
+ out_b = self.lrelu(out_b)
203
+
204
+ out_a = self.k1(out_a)
205
+ out_b = self.PAConv(out_b)
206
+ out_a = self.lrelu(out_a)
207
+ out_b = self.lrelu(out_b)
208
+
209
+ out = self.conv3(torch.cat([out_a, out_b], dim=1))
210
+ out += residual
211
+
212
+ return out
213
+
214
+
215
+ class SCET(nn.Module):
216
+ def __init__(self, hiddenDim=32, mlpDim=128, scaleFactor=2):
217
+ super().__init__()
218
+ self.conv3 = nn.Conv2d(3, hiddenDim,
219
+ kernel_size=3, padding=1)
220
+
221
+ lamRes = torch.nn.Parameter(torch.ones(1))
222
+ lamX = torch.nn.Parameter(torch.ones(1))
223
+ self.adaptiveWeight = (lamRes, lamX)
224
+ if scaleFactor == 3:
225
+ num_heads = 7
226
+ else:
227
+ num_heads = 8
228
+ self.path1 = nn.Sequential(
229
+ BackBoneBlock(16, SCPA, nf=hiddenDim, reduction=2, stride=1, dilation=1),
230
+ BackBoneBlock(1, TransformerBlock,
231
+ dim=hiddenDim, num_heads=num_heads, ffn_expansion_factor=2.66, bias=False, LayerNorm_type=WithBias_LayerNorm),
232
+ nn.Conv2d(hiddenDim, hiddenDim, kernel_size=3, padding=1),
233
+ nn.PixelShuffle(scaleFactor),
234
+ nn.Conv2d(hiddenDim // (scaleFactor ** 2),
235
+ 3, kernel_size=3, padding=1),
236
+ )
237
+
238
+ self.path2 = nn.Sequential(
239
+ nn.PixelShuffle(scaleFactor),
240
+ nn.Conv2d(hiddenDim // (scaleFactor ** 2),
241
+ 3, kernel_size=3, padding=1),
242
+ )
243
+
244
+ def forward(self, x):
245
+ x = self.conv3(x)
246
+ x1, x2 = self.path1(x), self.path2(x)
247
+ return x1 + x2
248
+
249
+
250
+ def init_weights(self, pretrained=None, strict=True):
251
+ """Init weights for models.
252
+ Args:
253
+ pretrained (str, optional): Path for pretrained weights. If given
254
+ None, pretrained weights will not be loaded. Defaults to None.
255
+ strict (boo, optional): Whether strictly load the pretrained model.
256
+ Defaults to True.
257
+ """
258
+ if isinstance(pretrained, str):
259
+ logger = get_root_logger()
260
+ load_checkpoint(self, pretrained, strict=strict, logger=logger)
261
+ elif pretrained is None:
262
+ pass # use default initialization
263
+ else:
264
+ raise TypeError('"pretrained" must be a str or None. '
265
+ f'But received {type(pretrained)}.')
266
+
267
+
268
+
269
+ if __name__ == '__main__':
270
+
271
+ from torchstat import stat
272
+ import time
273
+ import torchsummary
274
+
275
+ net = SCET(32, 128, 4).cuda()
276
+ torchsummary.summary(net, (3, 48, 48))