Spaces:
Running
Running
sunder-ali
commited on
Commit
•
3d4805e
1
Parent(s):
9b43092
Upload team15_SAKDNNet.py
Browse files- models/team15_SAKDNNet.py +238 -0
models/team15_SAKDNNet.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
from einops import rearrange
|
5 |
+
from einops.layers.torch import Rearrange
|
6 |
+
from timm.models.layers import trunc_normal_, DropPath
|
7 |
+
|
8 |
+
|
9 |
+
class SAST(nn.Module):
|
10 |
+
|
11 |
+
def __init__(self, input_dim, output_dim, head_dim, window_size, type):
|
12 |
+
super(SAST, self).__init__()
|
13 |
+
self.input_dim = input_dim
|
14 |
+
self.output_dim = output_dim
|
15 |
+
self.head_dim = head_dim
|
16 |
+
self.scale = self.head_dim ** -0.5
|
17 |
+
self.n_heads = input_dim//head_dim
|
18 |
+
self.window_size = window_size
|
19 |
+
self.type=type
|
20 |
+
self.embedding_layer = nn.Linear(self.input_dim, 3*self.input_dim, bias=True)
|
21 |
+
|
22 |
+
self.relative_position_params = nn.Parameter(torch.zeros((2 * window_size - 1)*(2 * window_size -1), self.n_heads))
|
23 |
+
|
24 |
+
self.linear = nn.Linear(self.input_dim, self.output_dim)
|
25 |
+
|
26 |
+
trunc_normal_(self.relative_position_params, std=.02)
|
27 |
+
self.relative_position_params = torch.nn.Parameter(self.relative_position_params.view(2*window_size-1, 2*window_size-1, self.n_heads).transpose(1,2).transpose(0,1))
|
28 |
+
|
29 |
+
def maskgen(self, h, w, p, shift):
|
30 |
+
maskatt = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device)
|
31 |
+
if self.type == 'W':
|
32 |
+
return maskatt
|
33 |
+
|
34 |
+
s = p - shift
|
35 |
+
maskatt[-1, :, :s, :, s:, :] = True
|
36 |
+
maskatt[-1, :, s:, :, :s, :] = True
|
37 |
+
maskatt[:, -1, :, :s, :, s:] = True
|
38 |
+
maskatt[:, -1, :, s:, :, :s] = True
|
39 |
+
maskatt = rearrange(maskatt, 'w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)')
|
40 |
+
return maskatt
|
41 |
+
|
42 |
+
def forward(self, x):
|
43 |
+
|
44 |
+
if self.type!='W': x = torch.roll(x, shifts=(-(self.window_size//2), -(self.window_size//2)), dims=(1,2))
|
45 |
+
x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
|
46 |
+
h_windows = x.size(1)
|
47 |
+
w_windows = x.size(2)
|
48 |
+
|
49 |
+
|
50 |
+
x = rearrange(x, 'b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c', p1=self.window_size, p2=self.window_size)
|
51 |
+
qkv = self.embedding_layer(x)
|
52 |
+
q, k, v = rearrange(qkv, 'b nw np (threeh c) -> threeh b nw np c', c=self.head_dim).chunk(3, dim=0)
|
53 |
+
sim = torch.einsum('hbwpc,hbwqc->hbwpq', q, k) * self.scale
|
54 |
+
sim = sim + rearrange(self.relative_embedding(), 'h p q -> h 1 1 p q')
|
55 |
+
if self.type != 'W':
|
56 |
+
maskatt = self.maskgen(h_windows, w_windows, self.window_size, shift=self.window_size//2)
|
57 |
+
sim = sim.masked_fill_(maskatt, float("-inf"))
|
58 |
+
|
59 |
+
probs = nn.functional.softmax(sim, dim=-1)
|
60 |
+
output = torch.einsum('hbwij,hbwjc->hbwic', probs, v)
|
61 |
+
output = rearrange(output, 'h b w p c -> b w p (h c)')
|
62 |
+
output = self.linear(output)
|
63 |
+
output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
|
64 |
+
|
65 |
+
if self.type!='W': output = torch.roll(output, shifts=(self.window_size//2, self.window_size//2), dims=(1,2))
|
66 |
+
return output
|
67 |
+
|
68 |
+
def relative_embedding(self):
|
69 |
+
cord = torch.tensor(np.array([[i, j] for i in range(self.window_size) for j in range(self.window_size)]))
|
70 |
+
relation = cord[:, None, :] - cord[None, :, :] + self.window_size -1
|
71 |
+
return self.relative_position_params[:, relation[:,:,0].long(), relation[:,:,1].long()]
|
72 |
+
|
73 |
+
|
74 |
+
class DRFE(nn.Module):
|
75 |
+
def __init__(self, input_dim, output_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
|
76 |
+
|
77 |
+
super(DRFE, self).__init__()
|
78 |
+
self.input_dim = input_dim
|
79 |
+
self.output_dim = output_dim
|
80 |
+
assert type in ['W', 'SW']
|
81 |
+
self.type = type
|
82 |
+
if input_resolution <= window_size:
|
83 |
+
self.type = 'W'
|
84 |
+
|
85 |
+
self.ln1 = nn.LayerNorm(input_dim)
|
86 |
+
self.msa = SAST(input_dim, input_dim, head_dim, window_size, self.type)
|
87 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
88 |
+
self.ln2 = nn.LayerNorm(input_dim)
|
89 |
+
self.mlp = nn.Sequential(
|
90 |
+
nn.Linear(input_dim, 4 * input_dim),
|
91 |
+
nn.GELU(),
|
92 |
+
nn.Linear(4 * input_dim, output_dim),
|
93 |
+
)
|
94 |
+
|
95 |
+
def forward(self, x):
|
96 |
+
x = x + self.drop_path(self.msa(self.ln1(x)))
|
97 |
+
x = x + self.drop_path(self.mlp(self.ln2(x)))
|
98 |
+
return x
|
99 |
+
|
100 |
+
|
101 |
+
class STCB(nn.Module):
|
102 |
+
def __init__(self, conv_dim, trans_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
|
103 |
+
|
104 |
+
super(STCB, self).__init__()
|
105 |
+
self.conv_dim = conv_dim
|
106 |
+
self.trans_dim = trans_dim
|
107 |
+
self.head_dim = head_dim
|
108 |
+
self.window_size = window_size
|
109 |
+
self.drop_path = drop_path
|
110 |
+
self.type = type
|
111 |
+
self.input_resolution = input_resolution
|
112 |
+
|
113 |
+
assert self.type in ['W', 'SW']
|
114 |
+
if self.input_resolution <= self.window_size:
|
115 |
+
self.type = 'W'
|
116 |
+
|
117 |
+
self.trans_block = DRFE(self.trans_dim, self.trans_dim, self.head_dim, self.window_size, self.drop_path, self.type, self.input_resolution)
|
118 |
+
self.conv1_1 = nn.Conv2d(self.conv_dim+self.trans_dim, self.conv_dim+self.trans_dim, 1, 1, 0, bias=True)
|
119 |
+
self.conv1_2 = nn.Conv2d(self.conv_dim+self.trans_dim, self.conv_dim+self.trans_dim, 1, 1, 0, bias=True)
|
120 |
+
|
121 |
+
self.conv_block = nn.Sequential(
|
122 |
+
nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False),
|
123 |
+
nn.ReLU(True),
|
124 |
+
nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False)
|
125 |
+
)
|
126 |
+
|
127 |
+
def forward(self, x):
|
128 |
+
conv_x, trans_x = torch.split(self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1)
|
129 |
+
conv_x = self.conv_block(conv_x) + conv_x
|
130 |
+
trans_x = Rearrange('b c h w -> b h w c')(trans_x)
|
131 |
+
trans_x = self.trans_block(trans_x)
|
132 |
+
trans_x = Rearrange('b h w c -> b c h w')(trans_x)
|
133 |
+
res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1))
|
134 |
+
x = x + res
|
135 |
+
|
136 |
+
return x
|
137 |
+
|
138 |
+
|
139 |
+
class SAKDNNet(nn.Module):
|
140 |
+
|
141 |
+
def __init__(self, in_nc=3, config=[2,2,2,2,2,2,2], dim=64, drop_path_rate=0.0, input_resolution=256):
|
142 |
+
super(SAKDNNet, self).__init__()
|
143 |
+
self.config = config
|
144 |
+
self.dim = dim
|
145 |
+
self.head_dim = 32
|
146 |
+
self.window_size = 8
|
147 |
+
|
148 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))]
|
149 |
+
|
150 |
+
self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)]
|
151 |
+
|
152 |
+
begin = 0
|
153 |
+
self.m_down1 = [STCB(dim//2, dim//2, self.head_dim, self.window_size, dpr[i+begin], 'W' if not i%2 else 'SW', input_resolution)
|
154 |
+
for i in range(config[0])] + \
|
155 |
+
[nn.Conv2d(dim, 2*dim, 2, 2, 0, bias=False)]
|
156 |
+
|
157 |
+
begin += config[0]
|
158 |
+
self.m_down2 = [STCB(dim, dim, self.head_dim, self.window_size, dpr[i+begin], 'W' if not i%2 else 'SW', input_resolution//2)
|
159 |
+
for i in range(config[1])] + \
|
160 |
+
[nn.Conv2d(2*dim, 4*dim, 2, 2, 0, bias=False)]
|
161 |
+
|
162 |
+
begin += config[1]
|
163 |
+
self.m_down3 = [STCB(2*dim, 2*dim, self.head_dim, self.window_size, dpr[i+begin], 'W' if not i%2 else 'SW',input_resolution//4)
|
164 |
+
for i in range(config[2])] + \
|
165 |
+
[nn.Conv2d(4*dim, 8*dim, 2, 2, 0, bias=False)]
|
166 |
+
|
167 |
+
begin += config[2]
|
168 |
+
self.m_body = [STCB(4*dim, 4*dim, self.head_dim, self.window_size, dpr[i+begin], 'W' if not i%2 else 'SW', input_resolution//8)
|
169 |
+
for i in range(config[3])]
|
170 |
+
|
171 |
+
begin += config[3]
|
172 |
+
self.m_up3 = [nn.ConvTranspose2d(8*dim, 4*dim, 2, 2, 0, bias=False),] + \
|
173 |
+
[STCB(2*dim, 2*dim, self.head_dim, self.window_size, dpr[i+begin], 'W' if not i%2 else 'SW',input_resolution//4)
|
174 |
+
for i in range(config[4])]
|
175 |
+
|
176 |
+
begin += config[4]
|
177 |
+
self.m_up2 = [nn.ConvTranspose2d(4*dim, 2*dim, 2, 2, 0, bias=False),] + \
|
178 |
+
[STCB(dim, dim, self.head_dim, self.window_size, dpr[i+begin], 'W' if not i%2 else 'SW', input_resolution//2)
|
179 |
+
for i in range(config[5])]
|
180 |
+
|
181 |
+
begin += config[5]
|
182 |
+
self.m_up1 = [nn.ConvTranspose2d(2*dim, dim, 2, 2, 0, bias=False),] + \
|
183 |
+
[STCB(dim//2, dim//2, self.head_dim, self.window_size, dpr[i+begin], 'W' if not i%2 else 'SW', input_resolution)
|
184 |
+
for i in range(config[6])]
|
185 |
+
|
186 |
+
self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)]
|
187 |
+
|
188 |
+
self.m_head = nn.Sequential(*self.m_head)
|
189 |
+
self.m_down1 = nn.Sequential(*self.m_down1)
|
190 |
+
self.m_down2 = nn.Sequential(*self.m_down2)
|
191 |
+
self.m_down3 = nn.Sequential(*self.m_down3)
|
192 |
+
self.m_body = nn.Sequential(*self.m_body)
|
193 |
+
self.m_up3 = nn.Sequential(*self.m_up3)
|
194 |
+
self.m_up2 = nn.Sequential(*self.m_up2)
|
195 |
+
self.m_up1 = nn.Sequential(*self.m_up1)
|
196 |
+
self.m_tail = nn.Sequential(*self.m_tail)
|
197 |
+
|
198 |
+
def forward(self, x0):
|
199 |
+
|
200 |
+
h, w = x0.size()[-2:]
|
201 |
+
paddingBottom = int(np.ceil(h/64)*64-h)
|
202 |
+
paddingRight = int(np.ceil(w/64)*64-w)
|
203 |
+
x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0)
|
204 |
+
|
205 |
+
x1 = self.m_head(x0)
|
206 |
+
x2 = self.m_down1(x1)
|
207 |
+
x3 = self.m_down2(x2)
|
208 |
+
x4 = self.m_down3(x3)
|
209 |
+
x = self.m_body(x4)
|
210 |
+
x = self.m_up3(x+x4)
|
211 |
+
x = self.m_up2(x+x3)
|
212 |
+
x = self.m_up1(x+x2)
|
213 |
+
x = self.m_tail(x+x1)
|
214 |
+
|
215 |
+
x = x[..., :h, :w]
|
216 |
+
|
217 |
+
return x
|
218 |
+
|
219 |
+
|
220 |
+
def _init_weights(self, m):
|
221 |
+
if isinstance(m, nn.Linear):
|
222 |
+
trunc_normal_(m.weight, std=.02)
|
223 |
+
if m.bias is not None:
|
224 |
+
nn.init.constant_(m.bias, 0)
|
225 |
+
elif isinstance(m, nn.LayerNorm):
|
226 |
+
nn.init.constant_(m.bias, 0)
|
227 |
+
nn.init.constant_(m.weight, 1.0)
|
228 |
+
|
229 |
+
|
230 |
+
|
231 |
+
if __name__ == '__main__':
|
232 |
+
|
233 |
+
# torch.cuda.empty_cache()
|
234 |
+
net = SAKDNNet()
|
235 |
+
|
236 |
+
x = torch.randn((2, 3, 64, 128))
|
237 |
+
x = net(x)
|
238 |
+
print(x.shape)
|