Upload restormer_arch.py
Browse files- restormer_arch.py +285 -0
restormer_arch.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Restormer: Efficient Transformer for High-Resolution Image Restoration
|
2 |
+
## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
|
3 |
+
## https://arxiv.org/abs/2111.09881
|
4 |
+
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from pdb import set_trace as stx
|
10 |
+
import numbers
|
11 |
+
|
12 |
+
from einops import rearrange
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
##########################################################################
|
17 |
+
## Layer Norm
|
18 |
+
|
19 |
+
def to_3d(x):
|
20 |
+
return rearrange(x, 'b c h w -> b (h w) c')
|
21 |
+
|
22 |
+
def to_4d(x,h,w):
|
23 |
+
return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w)
|
24 |
+
|
25 |
+
class BiasFree_LayerNorm(nn.Module):
|
26 |
+
def __init__(self, normalized_shape):
|
27 |
+
super(BiasFree_LayerNorm, self).__init__()
|
28 |
+
if isinstance(normalized_shape, numbers.Integral):
|
29 |
+
normalized_shape = (normalized_shape,)
|
30 |
+
normalized_shape = torch.Size(normalized_shape)
|
31 |
+
|
32 |
+
assert len(normalized_shape) == 1
|
33 |
+
|
34 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
35 |
+
self.normalized_shape = normalized_shape
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
sigma = x.var(-1, keepdim=True, unbiased=False)
|
39 |
+
return x / torch.sqrt(sigma+1e-5) * self.weight
|
40 |
+
|
41 |
+
class WithBias_LayerNorm(nn.Module):
|
42 |
+
def __init__(self, normalized_shape):
|
43 |
+
super(WithBias_LayerNorm, self).__init__()
|
44 |
+
if isinstance(normalized_shape, numbers.Integral):
|
45 |
+
normalized_shape = (normalized_shape,)
|
46 |
+
normalized_shape = torch.Size(normalized_shape)
|
47 |
+
|
48 |
+
assert len(normalized_shape) == 1
|
49 |
+
|
50 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
51 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
52 |
+
self.normalized_shape = normalized_shape
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
mu = x.mean(-1, keepdim=True)
|
56 |
+
sigma = x.var(-1, keepdim=True, unbiased=False)
|
57 |
+
return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias
|
58 |
+
|
59 |
+
|
60 |
+
class LayerNorm(nn.Module):
|
61 |
+
def __init__(self, dim, LayerNorm_type):
|
62 |
+
super(LayerNorm, self).__init__()
|
63 |
+
if LayerNorm_type =='BiasFree':
|
64 |
+
self.body = BiasFree_LayerNorm(dim)
|
65 |
+
else:
|
66 |
+
self.body = WithBias_LayerNorm(dim)
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
h, w = x.shape[-2:]
|
70 |
+
return to_4d(self.body(to_3d(x)), h, w)
|
71 |
+
|
72 |
+
|
73 |
+
|
74 |
+
##########################################################################
|
75 |
+
## Gated-Dconv Feed-Forward Network (GDFN)
|
76 |
+
class FeedForward(nn.Module):
|
77 |
+
def __init__(self, dim, ffn_expansion_factor, bias):
|
78 |
+
super(FeedForward, self).__init__()
|
79 |
+
|
80 |
+
hidden_features = int(dim*ffn_expansion_factor)
|
81 |
+
|
82 |
+
self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
|
83 |
+
|
84 |
+
self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)
|
85 |
+
|
86 |
+
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
|
87 |
+
|
88 |
+
def forward(self, x):
|
89 |
+
x = self.project_in(x)
|
90 |
+
x1, x2 = self.dwconv(x).chunk(2, dim=1)
|
91 |
+
x = F.gelu(x1) * x2
|
92 |
+
x = self.project_out(x)
|
93 |
+
return x
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
##########################################################################
|
98 |
+
## Multi-DConv Head Transposed Self-Attention (MDTA)
|
99 |
+
class Attention(nn.Module):
|
100 |
+
def __init__(self, dim, num_heads, bias):
|
101 |
+
super(Attention, self).__init__()
|
102 |
+
self.num_heads = num_heads
|
103 |
+
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
|
104 |
+
|
105 |
+
self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
|
106 |
+
self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
|
107 |
+
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
108 |
+
|
109 |
+
|
110 |
+
|
111 |
+
def forward(self, x):
|
112 |
+
b,c,h,w = x.shape
|
113 |
+
|
114 |
+
qkv = self.qkv_dwconv(self.qkv(x))
|
115 |
+
q,k,v = qkv.chunk(3, dim=1)
|
116 |
+
|
117 |
+
q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
|
118 |
+
k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
|
119 |
+
v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
|
120 |
+
|
121 |
+
q = torch.nn.functional.normalize(q, dim=-1)
|
122 |
+
k = torch.nn.functional.normalize(k, dim=-1)
|
123 |
+
|
124 |
+
attn = (q @ k.transpose(-2, -1)) * self.temperature
|
125 |
+
attn = attn.softmax(dim=-1)
|
126 |
+
|
127 |
+
out = (attn @ v)
|
128 |
+
|
129 |
+
out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
|
130 |
+
|
131 |
+
out = self.project_out(out)
|
132 |
+
return out
|
133 |
+
|
134 |
+
|
135 |
+
|
136 |
+
##########################################################################
|
137 |
+
class TransformerBlock(nn.Module):
|
138 |
+
def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):
|
139 |
+
super(TransformerBlock, self).__init__()
|
140 |
+
|
141 |
+
self.norm1 = LayerNorm(dim, LayerNorm_type)
|
142 |
+
self.attn = Attention(dim, num_heads, bias)
|
143 |
+
self.norm2 = LayerNorm(dim, LayerNorm_type)
|
144 |
+
self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
|
145 |
+
|
146 |
+
def forward(self, x):
|
147 |
+
x = x + self.attn(self.norm1(x))
|
148 |
+
x = x + self.ffn(self.norm2(x))
|
149 |
+
|
150 |
+
return x
|
151 |
+
|
152 |
+
|
153 |
+
|
154 |
+
##########################################################################
|
155 |
+
## Overlapped image patch embedding with 3x3 Conv
|
156 |
+
class OverlapPatchEmbed(nn.Module):
|
157 |
+
def __init__(self, in_c=3, embed_dim=48, bias=False):
|
158 |
+
super(OverlapPatchEmbed, self).__init__()
|
159 |
+
|
160 |
+
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)
|
161 |
+
|
162 |
+
def forward(self, x):
|
163 |
+
x = self.proj(x)
|
164 |
+
|
165 |
+
return x
|
166 |
+
|
167 |
+
|
168 |
+
|
169 |
+
##########################################################################
|
170 |
+
## Resizing modules
|
171 |
+
class Downsample(nn.Module):
|
172 |
+
def __init__(self, n_feat):
|
173 |
+
super(Downsample, self).__init__()
|
174 |
+
|
175 |
+
self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False),
|
176 |
+
nn.PixelUnshuffle(2))
|
177 |
+
|
178 |
+
def forward(self, x):
|
179 |
+
return self.body(x)
|
180 |
+
|
181 |
+
class Upsample(nn.Module):
|
182 |
+
def __init__(self, n_feat):
|
183 |
+
super(Upsample, self).__init__()
|
184 |
+
|
185 |
+
self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False),
|
186 |
+
nn.PixelShuffle(2))
|
187 |
+
|
188 |
+
def forward(self, x):
|
189 |
+
return self.body(x)
|
190 |
+
|
191 |
+
##########################################################################
|
192 |
+
##---------- Restormer -----------------------
|
193 |
+
class Restormer(nn.Module):
|
194 |
+
def __init__(self,
|
195 |
+
inp_channels=3,
|
196 |
+
out_channels=3,
|
197 |
+
dim = 48,
|
198 |
+
num_blocks = [4,6,6,8],
|
199 |
+
num_refinement_blocks = 4,
|
200 |
+
heads = [1,2,4,8],
|
201 |
+
ffn_expansion_factor = 2.66,
|
202 |
+
bias = False,
|
203 |
+
LayerNorm_type = 'WithBias', ## Other option 'BiasFree'
|
204 |
+
dual_pixel_task = False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
|
205 |
+
):
|
206 |
+
|
207 |
+
super(Restormer, self).__init__()
|
208 |
+
|
209 |
+
self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
|
210 |
+
|
211 |
+
self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
|
212 |
+
|
213 |
+
self.down1_2 = Downsample(dim) ## From Level 1 to Level 2
|
214 |
+
self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
|
215 |
+
|
216 |
+
self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3
|
217 |
+
self.encoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
|
218 |
+
|
219 |
+
self.down3_4 = Downsample(int(dim*2**2)) ## From Level 3 to Level 4
|
220 |
+
self.latent = nn.Sequential(*[TransformerBlock(dim=int(dim*2**3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])])
|
221 |
+
|
222 |
+
self.up4_3 = Upsample(int(dim*2**3)) ## From Level 4 to Level 3
|
223 |
+
self.reduce_chan_level3 = nn.Conv2d(int(dim*2**3), int(dim*2**2), kernel_size=1, bias=bias)
|
224 |
+
self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
|
225 |
+
|
226 |
+
|
227 |
+
self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2
|
228 |
+
self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias)
|
229 |
+
self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
|
230 |
+
|
231 |
+
self.up2_1 = Upsample(int(dim*2**1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels)
|
232 |
+
|
233 |
+
self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
|
234 |
+
|
235 |
+
self.refinement = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)])
|
236 |
+
|
237 |
+
#### For Dual-Pixel Defocus Deblurring Task ####
|
238 |
+
self.dual_pixel_task = dual_pixel_task
|
239 |
+
if self.dual_pixel_task:
|
240 |
+
self.skip_conv = nn.Conv2d(dim, int(dim*2**1), kernel_size=1, bias=bias)
|
241 |
+
###########################
|
242 |
+
|
243 |
+
self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)
|
244 |
+
|
245 |
+
def forward(self, inp_img):
|
246 |
+
|
247 |
+
inp_enc_level1 = self.patch_embed(inp_img)
|
248 |
+
out_enc_level1 = self.encoder_level1(inp_enc_level1)
|
249 |
+
|
250 |
+
inp_enc_level2 = self.down1_2(out_enc_level1)
|
251 |
+
out_enc_level2 = self.encoder_level2(inp_enc_level2)
|
252 |
+
|
253 |
+
inp_enc_level3 = self.down2_3(out_enc_level2)
|
254 |
+
out_enc_level3 = self.encoder_level3(inp_enc_level3)
|
255 |
+
|
256 |
+
inp_enc_level4 = self.down3_4(out_enc_level3)
|
257 |
+
latent = self.latent(inp_enc_level4)
|
258 |
+
|
259 |
+
inp_dec_level3 = self.up4_3(latent)
|
260 |
+
inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1)
|
261 |
+
inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)
|
262 |
+
out_dec_level3 = self.decoder_level3(inp_dec_level3)
|
263 |
+
|
264 |
+
inp_dec_level2 = self.up3_2(out_dec_level3)
|
265 |
+
inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1)
|
266 |
+
inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2)
|
267 |
+
out_dec_level2 = self.decoder_level2(inp_dec_level2)
|
268 |
+
|
269 |
+
inp_dec_level1 = self.up2_1(out_dec_level2)
|
270 |
+
inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1)
|
271 |
+
out_dec_level1 = self.decoder_level1(inp_dec_level1)
|
272 |
+
|
273 |
+
out_dec_level1 = self.refinement(out_dec_level1)
|
274 |
+
|
275 |
+
#### For Dual-Pixel Defocus Deblurring Task ####
|
276 |
+
if self.dual_pixel_task:
|
277 |
+
out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1)
|
278 |
+
out_dec_level1 = self.output(out_dec_level1)
|
279 |
+
###########################
|
280 |
+
else:
|
281 |
+
out_dec_level1 = self.output(out_dec_level1) + inp_img
|
282 |
+
|
283 |
+
|
284 |
+
return out_dec_level1
|
285 |
+
|