danielsapit commited on
Commit
34d1ef9
1 Parent(s): fe525fc

Upload network_fbcnn.py

Browse files
Files changed (1) hide show
  1. network_fbcnn.py +337 -0
network_fbcnn.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ import torch.nn.functional as F
6
+ import torchvision.models as models
7
+
8
+ '''
9
+ # --------------------------------------------
10
+ # Advanced nn.Sequential
11
+ # https://github.com/xinntao/BasicSR
12
+ # --------------------------------------------
13
+ '''
14
+
15
+
16
+ def sequential(*args):
17
+ """Advanced nn.Sequential.
18
+
19
+ Args:
20
+ nn.Sequential, nn.Module
21
+
22
+ Returns:
23
+ nn.Sequential
24
+ """
25
+ if len(args) == 1:
26
+ if isinstance(args[0], OrderedDict):
27
+ raise NotImplementedError('sequential does not support OrderedDict input.')
28
+ return args[0] # No sequential is needed.
29
+ modules = []
30
+ for module in args:
31
+ if isinstance(module, nn.Sequential):
32
+ for submodule in module.children():
33
+ modules.append(submodule)
34
+ elif isinstance(module, nn.Module):
35
+ modules.append(module)
36
+ return nn.Sequential(*modules)
37
+
38
+ # --------------------------------------------
39
+ # return nn.Sequantial of (Conv + BN + ReLU)
40
+ # --------------------------------------------
41
+ def conv(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='CBR', negative_slope=0.2):
42
+ L = []
43
+ for t in mode:
44
+ if t == 'C':
45
+ L.append(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias))
46
+ elif t == 'T':
47
+ L.append(nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias))
48
+ elif t == 'B':
49
+ L.append(nn.BatchNorm2d(out_channels, momentum=0.9, eps=1e-04, affine=True))
50
+ elif t == 'I':
51
+ L.append(nn.InstanceNorm2d(out_channels, affine=True))
52
+ elif t == 'R':
53
+ L.append(nn.ReLU(inplace=True))
54
+ elif t == 'r':
55
+ L.append(nn.ReLU(inplace=False))
56
+ elif t == 'L':
57
+ L.append(nn.LeakyReLU(negative_slope=negative_slope, inplace=True))
58
+ elif t == 'l':
59
+ L.append(nn.LeakyReLU(negative_slope=negative_slope, inplace=False))
60
+ elif t == '2':
61
+ L.append(nn.PixelShuffle(upscale_factor=2))
62
+ elif t == '3':
63
+ L.append(nn.PixelShuffle(upscale_factor=3))
64
+ elif t == '4':
65
+ L.append(nn.PixelShuffle(upscale_factor=4))
66
+ elif t == 'U':
67
+ L.append(nn.Upsample(scale_factor=2, mode='nearest'))
68
+ elif t == 'u':
69
+ L.append(nn.Upsample(scale_factor=3, mode='nearest'))
70
+ elif t == 'v':
71
+ L.append(nn.Upsample(scale_factor=4, mode='nearest'))
72
+ elif t == 'M':
73
+ L.append(nn.MaxPool2d(kernel_size=kernel_size, stride=stride, padding=0))
74
+ elif t == 'A':
75
+ L.append(nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=0))
76
+ else:
77
+ raise NotImplementedError('Undefined type: '.format(t))
78
+ return sequential(*L)
79
+
80
+ # --------------------------------------------
81
+ # Res Block: x + conv(relu(conv(x)))
82
+ # --------------------------------------------
83
+ class ResBlock(nn.Module):
84
+ def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='CRC', negative_slope=0.2):
85
+ super(ResBlock, self).__init__()
86
+
87
+ assert in_channels == out_channels, 'Only support in_channels==out_channels.'
88
+ if mode[0] in ['R', 'L']:
89
+ mode = mode[0].lower() + mode[1:]
90
+
91
+ self.res = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope)
92
+
93
+ def forward(self, x):
94
+ res = self.res(x)
95
+ return x + res
96
+
97
+ # --------------------------------------------
98
+ # conv + subp (+ relu)
99
+ # --------------------------------------------
100
+ def upsample_pixelshuffle(in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=1, bias=True, mode='2R', negative_slope=0.2):
101
+ assert len(mode)<4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR.'
102
+ up1 = conv(in_channels, out_channels * (int(mode[0]) ** 2), kernel_size, stride, padding, bias, mode='C'+mode, negative_slope=negative_slope)
103
+ return up1
104
+
105
+
106
+ # --------------------------------------------
107
+ # nearest_upsample + conv (+ R)
108
+ # --------------------------------------------
109
+ def upsample_upconv(in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=1, bias=True, mode='2R', negative_slope=0.2):
110
+ assert len(mode)<4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR'
111
+ if mode[0] == '2':
112
+ uc = 'UC'
113
+ elif mode[0] == '3':
114
+ uc = 'uC'
115
+ elif mode[0] == '4':
116
+ uc = 'vC'
117
+ mode = mode.replace(mode[0], uc)
118
+ up1 = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode=mode, negative_slope=negative_slope)
119
+ return up1
120
+
121
+
122
+ # --------------------------------------------
123
+ # convTranspose (+ relu)
124
+ # --------------------------------------------
125
+ def upsample_convtranspose(in_channels=64, out_channels=3, kernel_size=2, stride=2, padding=0, bias=True, mode='2R', negative_slope=0.2):
126
+ assert len(mode)<4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR.'
127
+ kernel_size = int(mode[0])
128
+ stride = int(mode[0])
129
+ mode = mode.replace(mode[0], 'T')
130
+ up1 = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope)
131
+ return up1
132
+
133
+
134
+ '''
135
+ # --------------------------------------------
136
+ # Downsampler
137
+ # Kai Zhang, https://github.com/cszn/KAIR
138
+ # --------------------------------------------
139
+ # downsample_strideconv
140
+ # downsample_maxpool
141
+ # downsample_avgpool
142
+ # --------------------------------------------
143
+ '''
144
+
145
+
146
+ # --------------------------------------------
147
+ # strideconv (+ relu)
148
+ # --------------------------------------------
149
+ def downsample_strideconv(in_channels=64, out_channels=64, kernel_size=2, stride=2, padding=0, bias=True, mode='2R', negative_slope=0.2):
150
+ assert len(mode)<4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR.'
151
+ kernel_size = int(mode[0])
152
+ stride = int(mode[0])
153
+ mode = mode.replace(mode[0], 'C')
154
+ down1 = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope)
155
+ return down1
156
+
157
+
158
+ # --------------------------------------------
159
+ # maxpooling + conv (+ relu)
160
+ # --------------------------------------------
161
+ def downsample_maxpool(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=0, bias=True, mode='2R', negative_slope=0.2):
162
+ assert len(mode)<4 and mode[0] in ['2', '3'], 'mode examples: 2, 2R, 2BR, 3, ..., 3BR.'
163
+ kernel_size_pool = int(mode[0])
164
+ stride_pool = int(mode[0])
165
+ mode = mode.replace(mode[0], 'MC')
166
+ pool = conv(kernel_size=kernel_size_pool, stride=stride_pool, mode=mode[0], negative_slope=negative_slope)
167
+ pool_tail = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode=mode[1:], negative_slope=negative_slope)
168
+ return sequential(pool, pool_tail)
169
+
170
+
171
+ # --------------------------------------------
172
+ # averagepooling + conv (+ relu)
173
+ # --------------------------------------------
174
+ def downsample_avgpool(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='2R', negative_slope=0.2):
175
+ assert len(mode)<4 and mode[0] in ['2', '3'], 'mode examples: 2, 2R, 2BR, 3, ..., 3BR.'
176
+ kernel_size_pool = int(mode[0])
177
+ stride_pool = int(mode[0])
178
+ mode = mode.replace(mode[0], 'AC')
179
+ pool = conv(kernel_size=kernel_size_pool, stride=stride_pool, mode=mode[0], negative_slope=negative_slope)
180
+ pool_tail = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode=mode[1:], negative_slope=negative_slope)
181
+ return sequential(pool, pool_tail)
182
+
183
+
184
+
185
+ class QFAttention(nn.Module):
186
+ def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='CRC', negative_slope=0.2):
187
+ super(QFAttention, self).__init__()
188
+
189
+ assert in_channels == out_channels, 'Only support in_channels==out_channels.'
190
+ if mode[0] in ['R', 'L']:
191
+ mode = mode[0].lower() + mode[1:]
192
+
193
+ self.res = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope)
194
+
195
+ def forward(self, x, gamma, beta):
196
+ gamma = gamma.unsqueeze(-1).unsqueeze(-1)
197
+ beta = beta.unsqueeze(-1).unsqueeze(-1)
198
+ res = (gamma)*self.res(x) + beta
199
+ return x + res
200
+
201
+
202
+ class FBCNN(nn.Module):
203
+ def __init__(self, in_nc=3, out_nc=3, nc=[64, 128, 256, 512], nb=4, act_mode='R', downsample_mode='strideconv',
204
+ upsample_mode='convtranspose'):
205
+ super(FBCNN, self).__init__()
206
+
207
+ self.m_head = conv(in_nc, nc[0], bias=True, mode='C')
208
+ self.nb = nb
209
+ self.nc = nc
210
+ # downsample
211
+ if downsample_mode == 'avgpool':
212
+ downsample_block = downsample_avgpool
213
+ elif downsample_mode == 'maxpool':
214
+ downsample_block = downsample_maxpool
215
+ elif downsample_mode == 'strideconv':
216
+ downsample_block = downsample_strideconv
217
+ else:
218
+ raise NotImplementedError('downsample mode [{:s}] is not found'.format(downsample_mode))
219
+
220
+ self.m_down1 = sequential(
221
+ *[ResBlock(nc[0], nc[0], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)],
222
+ downsample_block(nc[0], nc[1], bias=True, mode='2'))
223
+ self.m_down2 = sequential(
224
+ *[ResBlock(nc[1], nc[1], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)],
225
+ downsample_block(nc[1], nc[2], bias=True, mode='2'))
226
+ self.m_down3 = sequential(
227
+ *[ResBlock(nc[2], nc[2], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)],
228
+ downsample_block(nc[2], nc[3], bias=True, mode='2'))
229
+
230
+ self.m_body_encoder = sequential(
231
+ *[ResBlock(nc[3], nc[3], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)])
232
+
233
+ self.m_body_decoder = sequential(
234
+ *[ResBlock(nc[3], nc[3], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)])
235
+
236
+ # upsample
237
+ if upsample_mode == 'upconv':
238
+ upsample_block = upsample_upconv
239
+ elif upsample_mode == 'pixelshuffle':
240
+ upsample_block = upsample_pixelshuffle
241
+ elif upsample_mode == 'convtranspose':
242
+ upsample_block = upsample_convtranspose
243
+ else:
244
+ raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
245
+
246
+ self.m_up3 = nn.ModuleList([upsample_block(nc[3], nc[2], bias=True, mode='2'),
247
+ *[QFAttention(nc[2], nc[2], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)]])
248
+
249
+ self.m_up2 = nn.ModuleList([upsample_block(nc[2], nc[1], bias=True, mode='2'),
250
+ *[QFAttention(nc[1], nc[1], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)]])
251
+
252
+ self.m_up1 = nn.ModuleList([upsample_block(nc[1], nc[0], bias=True, mode='2'),
253
+ *[QFAttention(nc[0], nc[0], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)]])
254
+
255
+
256
+ self.m_tail = conv(nc[0], out_nc, bias=True, mode='C')
257
+
258
+
259
+ self.qf_pred = sequential(*[ResBlock(nc[3], nc[3], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)],
260
+ torch.nn.AdaptiveAvgPool2d((1,1)),
261
+ torch.nn.Flatten(),
262
+ torch.nn.Linear(512, 512),
263
+ nn.ReLU(),
264
+ torch.nn.Linear(512, 512),
265
+ nn.ReLU(),
266
+ torch.nn.Linear(512, 1),
267
+ nn.Sigmoid()
268
+ )
269
+
270
+ self.qf_embed = sequential(torch.nn.Linear(1, 512),
271
+ nn.ReLU(),
272
+ torch.nn.Linear(512, 512),
273
+ nn.ReLU(),
274
+ torch.nn.Linear(512, 512),
275
+ nn.ReLU()
276
+ )
277
+
278
+ self.to_gamma_3 = sequential(torch.nn.Linear(512, nc[2]),nn.Sigmoid())
279
+ self.to_beta_3 = sequential(torch.nn.Linear(512, nc[2]),nn.Tanh())
280
+ self.to_gamma_2 = sequential(torch.nn.Linear(512, nc[1]),nn.Sigmoid())
281
+ self.to_beta_2 = sequential(torch.nn.Linear(512, nc[1]),nn.Tanh())
282
+ self.to_gamma_1 = sequential(torch.nn.Linear(512, nc[0]),nn.Sigmoid())
283
+ self.to_beta_1 = sequential(torch.nn.Linear(512, nc[0]),nn.Tanh())
284
+
285
+
286
+ def forward(self, x, qf_input=None):
287
+
288
+ h, w = x.size()[-2:]
289
+ paddingBottom = int(np.ceil(h / 8) * 8 - h)
290
+ paddingRight = int(np.ceil(w / 8) * 8 - w)
291
+ x = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x)
292
+
293
+ x1 = self.m_head(x)
294
+ x2 = self.m_down1(x1)
295
+ x3 = self.m_down2(x2)
296
+ x4 = self.m_down3(x3)
297
+ x = self.m_body_encoder(x4)
298
+ qf = self.qf_pred(x)
299
+ x = self.m_body_decoder(x)
300
+ qf_embedding = self.qf_embed(qf_input) if qf_input is not None else self.qf_embed(qf)
301
+ gamma_3 = self.to_gamma_3(qf_embedding)
302
+ beta_3 = self.to_beta_3(qf_embedding)
303
+
304
+ gamma_2 = self.to_gamma_2(qf_embedding)
305
+ beta_2 = self.to_beta_2(qf_embedding)
306
+
307
+ gamma_1 = self.to_gamma_1(qf_embedding)
308
+ beta_1 = self.to_beta_1(qf_embedding)
309
+
310
+
311
+ x = x + x4
312
+ x = self.m_up3[0](x)
313
+ for i in range(self.nb):
314
+ x = self.m_up3[i+1](x, gamma_3,beta_3)
315
+
316
+ x = x + x3
317
+
318
+ x = self.m_up2[0](x)
319
+ for i in range(self.nb):
320
+ x = self.m_up2[i+1](x, gamma_2, beta_2)
321
+ x = x + x2
322
+
323
+ x = self.m_up1[0](x)
324
+ for i in range(self.nb):
325
+ x = self.m_up1[i+1](x, gamma_1, beta_1)
326
+
327
+ x = x + x1
328
+ x = self.m_tail(x)
329
+ x = x[..., :h, :w]
330
+
331
+ return x, qf
332
+
333
+ if __name__ == "__main__":
334
+ x = torch.randn(1, 3, 96, 96)#.cuda()#.to(torch.device('cuda'))
335
+ fbar=FBAR()
336
+ y,qf = fbar(x)
337
+ print(y.shape,qf.shape)