sunder-ali commited on
Commit
3d4805e
1 Parent(s): 9b43092

Upload team15_SAKDNNet.py

Browse files
Files changed (1) hide show
  1. models/team15_SAKDNNet.py +238 -0
models/team15_SAKDNNet.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from einops import rearrange
5
+ from einops.layers.torch import Rearrange
6
+ from timm.models.layers import trunc_normal_, DropPath
7
+
8
+
9
+ class SAST(nn.Module):
10
+
11
+ def __init__(self, input_dim, output_dim, head_dim, window_size, type):
12
+ super(SAST, self).__init__()
13
+ self.input_dim = input_dim
14
+ self.output_dim = output_dim
15
+ self.head_dim = head_dim
16
+ self.scale = self.head_dim ** -0.5
17
+ self.n_heads = input_dim//head_dim
18
+ self.window_size = window_size
19
+ self.type=type
20
+ self.embedding_layer = nn.Linear(self.input_dim, 3*self.input_dim, bias=True)
21
+
22
+ self.relative_position_params = nn.Parameter(torch.zeros((2 * window_size - 1)*(2 * window_size -1), self.n_heads))
23
+
24
+ self.linear = nn.Linear(self.input_dim, self.output_dim)
25
+
26
+ trunc_normal_(self.relative_position_params, std=.02)
27
+ self.relative_position_params = torch.nn.Parameter(self.relative_position_params.view(2*window_size-1, 2*window_size-1, self.n_heads).transpose(1,2).transpose(0,1))
28
+
29
+ def maskgen(self, h, w, p, shift):
30
+ maskatt = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device)
31
+ if self.type == 'W':
32
+ return maskatt
33
+
34
+ s = p - shift
35
+ maskatt[-1, :, :s, :, s:, :] = True
36
+ maskatt[-1, :, s:, :, :s, :] = True
37
+ maskatt[:, -1, :, :s, :, s:] = True
38
+ maskatt[:, -1, :, s:, :, :s] = True
39
+ maskatt = rearrange(maskatt, 'w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)')
40
+ return maskatt
41
+
42
+ def forward(self, x):
43
+
44
+ if self.type!='W': x = torch.roll(x, shifts=(-(self.window_size//2), -(self.window_size//2)), dims=(1,2))
45
+ x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
46
+ h_windows = x.size(1)
47
+ w_windows = x.size(2)
48
+
49
+
50
+ x = rearrange(x, 'b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c', p1=self.window_size, p2=self.window_size)
51
+ qkv = self.embedding_layer(x)
52
+ q, k, v = rearrange(qkv, 'b nw np (threeh c) -> threeh b nw np c', c=self.head_dim).chunk(3, dim=0)
53
+ sim = torch.einsum('hbwpc,hbwqc->hbwpq', q, k) * self.scale
54
+ sim = sim + rearrange(self.relative_embedding(), 'h p q -> h 1 1 p q')
55
+ if self.type != 'W':
56
+ maskatt = self.maskgen(h_windows, w_windows, self.window_size, shift=self.window_size//2)
57
+ sim = sim.masked_fill_(maskatt, float("-inf"))
58
+
59
+ probs = nn.functional.softmax(sim, dim=-1)
60
+ output = torch.einsum('hbwij,hbwjc->hbwic', probs, v)
61
+ output = rearrange(output, 'h b w p c -> b w p (h c)')
62
+ output = self.linear(output)
63
+ output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
64
+
65
+ if self.type!='W': output = torch.roll(output, shifts=(self.window_size//2, self.window_size//2), dims=(1,2))
66
+ return output
67
+
68
+ def relative_embedding(self):
69
+ cord = torch.tensor(np.array([[i, j] for i in range(self.window_size) for j in range(self.window_size)]))
70
+ relation = cord[:, None, :] - cord[None, :, :] + self.window_size -1
71
+ return self.relative_position_params[:, relation[:,:,0].long(), relation[:,:,1].long()]
72
+
73
+
74
+ class DRFE(nn.Module):
75
+ def __init__(self, input_dim, output_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
76
+
77
+ super(DRFE, self).__init__()
78
+ self.input_dim = input_dim
79
+ self.output_dim = output_dim
80
+ assert type in ['W', 'SW']
81
+ self.type = type
82
+ if input_resolution <= window_size:
83
+ self.type = 'W'
84
+
85
+ self.ln1 = nn.LayerNorm(input_dim)
86
+ self.msa = SAST(input_dim, input_dim, head_dim, window_size, self.type)
87
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
88
+ self.ln2 = nn.LayerNorm(input_dim)
89
+ self.mlp = nn.Sequential(
90
+ nn.Linear(input_dim, 4 * input_dim),
91
+ nn.GELU(),
92
+ nn.Linear(4 * input_dim, output_dim),
93
+ )
94
+
95
+ def forward(self, x):
96
+ x = x + self.drop_path(self.msa(self.ln1(x)))
97
+ x = x + self.drop_path(self.mlp(self.ln2(x)))
98
+ return x
99
+
100
+
101
+ class STCB(nn.Module):
102
+ def __init__(self, conv_dim, trans_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
103
+
104
+ super(STCB, self).__init__()
105
+ self.conv_dim = conv_dim
106
+ self.trans_dim = trans_dim
107
+ self.head_dim = head_dim
108
+ self.window_size = window_size
109
+ self.drop_path = drop_path
110
+ self.type = type
111
+ self.input_resolution = input_resolution
112
+
113
+ assert self.type in ['W', 'SW']
114
+ if self.input_resolution <= self.window_size:
115
+ self.type = 'W'
116
+
117
+ self.trans_block = DRFE(self.trans_dim, self.trans_dim, self.head_dim, self.window_size, self.drop_path, self.type, self.input_resolution)
118
+ self.conv1_1 = nn.Conv2d(self.conv_dim+self.trans_dim, self.conv_dim+self.trans_dim, 1, 1, 0, bias=True)
119
+ self.conv1_2 = nn.Conv2d(self.conv_dim+self.trans_dim, self.conv_dim+self.trans_dim, 1, 1, 0, bias=True)
120
+
121
+ self.conv_block = nn.Sequential(
122
+ nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False),
123
+ nn.ReLU(True),
124
+ nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False)
125
+ )
126
+
127
+ def forward(self, x):
128
+ conv_x, trans_x = torch.split(self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1)
129
+ conv_x = self.conv_block(conv_x) + conv_x
130
+ trans_x = Rearrange('b c h w -> b h w c')(trans_x)
131
+ trans_x = self.trans_block(trans_x)
132
+ trans_x = Rearrange('b h w c -> b c h w')(trans_x)
133
+ res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1))
134
+ x = x + res
135
+
136
+ return x
137
+
138
+
139
+ class SAKDNNet(nn.Module):
140
+
141
+ def __init__(self, in_nc=3, config=[2,2,2,2,2,2,2], dim=64, drop_path_rate=0.0, input_resolution=256):
142
+ super(SAKDNNet, self).__init__()
143
+ self.config = config
144
+ self.dim = dim
145
+ self.head_dim = 32
146
+ self.window_size = 8
147
+
148
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))]
149
+
150
+ self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)]
151
+
152
+ begin = 0
153
+ self.m_down1 = [STCB(dim//2, dim//2, self.head_dim, self.window_size, dpr[i+begin], 'W' if not i%2 else 'SW', input_resolution)
154
+ for i in range(config[0])] + \
155
+ [nn.Conv2d(dim, 2*dim, 2, 2, 0, bias=False)]
156
+
157
+ begin += config[0]
158
+ self.m_down2 = [STCB(dim, dim, self.head_dim, self.window_size, dpr[i+begin], 'W' if not i%2 else 'SW', input_resolution//2)
159
+ for i in range(config[1])] + \
160
+ [nn.Conv2d(2*dim, 4*dim, 2, 2, 0, bias=False)]
161
+
162
+ begin += config[1]
163
+ self.m_down3 = [STCB(2*dim, 2*dim, self.head_dim, self.window_size, dpr[i+begin], 'W' if not i%2 else 'SW',input_resolution//4)
164
+ for i in range(config[2])] + \
165
+ [nn.Conv2d(4*dim, 8*dim, 2, 2, 0, bias=False)]
166
+
167
+ begin += config[2]
168
+ self.m_body = [STCB(4*dim, 4*dim, self.head_dim, self.window_size, dpr[i+begin], 'W' if not i%2 else 'SW', input_resolution//8)
169
+ for i in range(config[3])]
170
+
171
+ begin += config[3]
172
+ self.m_up3 = [nn.ConvTranspose2d(8*dim, 4*dim, 2, 2, 0, bias=False),] + \
173
+ [STCB(2*dim, 2*dim, self.head_dim, self.window_size, dpr[i+begin], 'W' if not i%2 else 'SW',input_resolution//4)
174
+ for i in range(config[4])]
175
+
176
+ begin += config[4]
177
+ self.m_up2 = [nn.ConvTranspose2d(4*dim, 2*dim, 2, 2, 0, bias=False),] + \
178
+ [STCB(dim, dim, self.head_dim, self.window_size, dpr[i+begin], 'W' if not i%2 else 'SW', input_resolution//2)
179
+ for i in range(config[5])]
180
+
181
+ begin += config[5]
182
+ self.m_up1 = [nn.ConvTranspose2d(2*dim, dim, 2, 2, 0, bias=False),] + \
183
+ [STCB(dim//2, dim//2, self.head_dim, self.window_size, dpr[i+begin], 'W' if not i%2 else 'SW', input_resolution)
184
+ for i in range(config[6])]
185
+
186
+ self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)]
187
+
188
+ self.m_head = nn.Sequential(*self.m_head)
189
+ self.m_down1 = nn.Sequential(*self.m_down1)
190
+ self.m_down2 = nn.Sequential(*self.m_down2)
191
+ self.m_down3 = nn.Sequential(*self.m_down3)
192
+ self.m_body = nn.Sequential(*self.m_body)
193
+ self.m_up3 = nn.Sequential(*self.m_up3)
194
+ self.m_up2 = nn.Sequential(*self.m_up2)
195
+ self.m_up1 = nn.Sequential(*self.m_up1)
196
+ self.m_tail = nn.Sequential(*self.m_tail)
197
+
198
+ def forward(self, x0):
199
+
200
+ h, w = x0.size()[-2:]
201
+ paddingBottom = int(np.ceil(h/64)*64-h)
202
+ paddingRight = int(np.ceil(w/64)*64-w)
203
+ x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0)
204
+
205
+ x1 = self.m_head(x0)
206
+ x2 = self.m_down1(x1)
207
+ x3 = self.m_down2(x2)
208
+ x4 = self.m_down3(x3)
209
+ x = self.m_body(x4)
210
+ x = self.m_up3(x+x4)
211
+ x = self.m_up2(x+x3)
212
+ x = self.m_up1(x+x2)
213
+ x = self.m_tail(x+x1)
214
+
215
+ x = x[..., :h, :w]
216
+
217
+ return x
218
+
219
+
220
+ def _init_weights(self, m):
221
+ if isinstance(m, nn.Linear):
222
+ trunc_normal_(m.weight, std=.02)
223
+ if m.bias is not None:
224
+ nn.init.constant_(m.bias, 0)
225
+ elif isinstance(m, nn.LayerNorm):
226
+ nn.init.constant_(m.bias, 0)
227
+ nn.init.constant_(m.weight, 1.0)
228
+
229
+
230
+
231
+ if __name__ == '__main__':
232
+
233
+ # torch.cuda.empty_cache()
234
+ net = SAKDNNet()
235
+
236
+ x = torch.randn((2, 3, 64, 128))
237
+ x = net(x)
238
+ print(x.shape)