RamAnanth1 commited on
Commit
015a3b5
1 Parent(s): 7da7768

Create model_edge.py

Browse files
Files changed (1) hide show
  1. model_edge.py +639 -0
model_edge.py ADDED
@@ -0,0 +1,639 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Author: Zhuo Su, Wenzhe Liu
3
+ Date: Feb 18, 2021
4
+ """
5
+
6
+ import math
7
+
8
+ import cv2
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from basicsr.utils import img2tensor
14
+
15
+ nets = {
16
+ 'baseline': {
17
+ 'layer0': 'cv',
18
+ 'layer1': 'cv',
19
+ 'layer2': 'cv',
20
+ 'layer3': 'cv',
21
+ 'layer4': 'cv',
22
+ 'layer5': 'cv',
23
+ 'layer6': 'cv',
24
+ 'layer7': 'cv',
25
+ 'layer8': 'cv',
26
+ 'layer9': 'cv',
27
+ 'layer10': 'cv',
28
+ 'layer11': 'cv',
29
+ 'layer12': 'cv',
30
+ 'layer13': 'cv',
31
+ 'layer14': 'cv',
32
+ 'layer15': 'cv',
33
+ },
34
+ 'c-v15': {
35
+ 'layer0': 'cd',
36
+ 'layer1': 'cv',
37
+ 'layer2': 'cv',
38
+ 'layer3': 'cv',
39
+ 'layer4': 'cv',
40
+ 'layer5': 'cv',
41
+ 'layer6': 'cv',
42
+ 'layer7': 'cv',
43
+ 'layer8': 'cv',
44
+ 'layer9': 'cv',
45
+ 'layer10': 'cv',
46
+ 'layer11': 'cv',
47
+ 'layer12': 'cv',
48
+ 'layer13': 'cv',
49
+ 'layer14': 'cv',
50
+ 'layer15': 'cv',
51
+ },
52
+ 'a-v15': {
53
+ 'layer0': 'ad',
54
+ 'layer1': 'cv',
55
+ 'layer2': 'cv',
56
+ 'layer3': 'cv',
57
+ 'layer4': 'cv',
58
+ 'layer5': 'cv',
59
+ 'layer6': 'cv',
60
+ 'layer7': 'cv',
61
+ 'layer8': 'cv',
62
+ 'layer9': 'cv',
63
+ 'layer10': 'cv',
64
+ 'layer11': 'cv',
65
+ 'layer12': 'cv',
66
+ 'layer13': 'cv',
67
+ 'layer14': 'cv',
68
+ 'layer15': 'cv',
69
+ },
70
+ 'r-v15': {
71
+ 'layer0': 'rd',
72
+ 'layer1': 'cv',
73
+ 'layer2': 'cv',
74
+ 'layer3': 'cv',
75
+ 'layer4': 'cv',
76
+ 'layer5': 'cv',
77
+ 'layer6': 'cv',
78
+ 'layer7': 'cv',
79
+ 'layer8': 'cv',
80
+ 'layer9': 'cv',
81
+ 'layer10': 'cv',
82
+ 'layer11': 'cv',
83
+ 'layer12': 'cv',
84
+ 'layer13': 'cv',
85
+ 'layer14': 'cv',
86
+ 'layer15': 'cv',
87
+ },
88
+ 'cvvv4': {
89
+ 'layer0': 'cd',
90
+ 'layer1': 'cv',
91
+ 'layer2': 'cv',
92
+ 'layer3': 'cv',
93
+ 'layer4': 'cd',
94
+ 'layer5': 'cv',
95
+ 'layer6': 'cv',
96
+ 'layer7': 'cv',
97
+ 'layer8': 'cd',
98
+ 'layer9': 'cv',
99
+ 'layer10': 'cv',
100
+ 'layer11': 'cv',
101
+ 'layer12': 'cd',
102
+ 'layer13': 'cv',
103
+ 'layer14': 'cv',
104
+ 'layer15': 'cv',
105
+ },
106
+ 'avvv4': {
107
+ 'layer0': 'ad',
108
+ 'layer1': 'cv',
109
+ 'layer2': 'cv',
110
+ 'layer3': 'cv',
111
+ 'layer4': 'ad',
112
+ 'layer5': 'cv',
113
+ 'layer6': 'cv',
114
+ 'layer7': 'cv',
115
+ 'layer8': 'ad',
116
+ 'layer9': 'cv',
117
+ 'layer10': 'cv',
118
+ 'layer11': 'cv',
119
+ 'layer12': 'ad',
120
+ 'layer13': 'cv',
121
+ 'layer14': 'cv',
122
+ 'layer15': 'cv',
123
+ },
124
+ 'rvvv4': {
125
+ 'layer0': 'rd',
126
+ 'layer1': 'cv',
127
+ 'layer2': 'cv',
128
+ 'layer3': 'cv',
129
+ 'layer4': 'rd',
130
+ 'layer5': 'cv',
131
+ 'layer6': 'cv',
132
+ 'layer7': 'cv',
133
+ 'layer8': 'rd',
134
+ 'layer9': 'cv',
135
+ 'layer10': 'cv',
136
+ 'layer11': 'cv',
137
+ 'layer12': 'rd',
138
+ 'layer13': 'cv',
139
+ 'layer14': 'cv',
140
+ 'layer15': 'cv',
141
+ },
142
+ 'cccv4': {
143
+ 'layer0': 'cd',
144
+ 'layer1': 'cd',
145
+ 'layer2': 'cd',
146
+ 'layer3': 'cv',
147
+ 'layer4': 'cd',
148
+ 'layer5': 'cd',
149
+ 'layer6': 'cd',
150
+ 'layer7': 'cv',
151
+ 'layer8': 'cd',
152
+ 'layer9': 'cd',
153
+ 'layer10': 'cd',
154
+ 'layer11': 'cv',
155
+ 'layer12': 'cd',
156
+ 'layer13': 'cd',
157
+ 'layer14': 'cd',
158
+ 'layer15': 'cv',
159
+ },
160
+ 'aaav4': {
161
+ 'layer0': 'ad',
162
+ 'layer1': 'ad',
163
+ 'layer2': 'ad',
164
+ 'layer3': 'cv',
165
+ 'layer4': 'ad',
166
+ 'layer5': 'ad',
167
+ 'layer6': 'ad',
168
+ 'layer7': 'cv',
169
+ 'layer8': 'ad',
170
+ 'layer9': 'ad',
171
+ 'layer10': 'ad',
172
+ 'layer11': 'cv',
173
+ 'layer12': 'ad',
174
+ 'layer13': 'ad',
175
+ 'layer14': 'ad',
176
+ 'layer15': 'cv',
177
+ },
178
+ 'rrrv4': {
179
+ 'layer0': 'rd',
180
+ 'layer1': 'rd',
181
+ 'layer2': 'rd',
182
+ 'layer3': 'cv',
183
+ 'layer4': 'rd',
184
+ 'layer5': 'rd',
185
+ 'layer6': 'rd',
186
+ 'layer7': 'cv',
187
+ 'layer8': 'rd',
188
+ 'layer9': 'rd',
189
+ 'layer10': 'rd',
190
+ 'layer11': 'cv',
191
+ 'layer12': 'rd',
192
+ 'layer13': 'rd',
193
+ 'layer14': 'rd',
194
+ 'layer15': 'cv',
195
+ },
196
+ 'c16': {
197
+ 'layer0': 'cd',
198
+ 'layer1': 'cd',
199
+ 'layer2': 'cd',
200
+ 'layer3': 'cd',
201
+ 'layer4': 'cd',
202
+ 'layer5': 'cd',
203
+ 'layer6': 'cd',
204
+ 'layer7': 'cd',
205
+ 'layer8': 'cd',
206
+ 'layer9': 'cd',
207
+ 'layer10': 'cd',
208
+ 'layer11': 'cd',
209
+ 'layer12': 'cd',
210
+ 'layer13': 'cd',
211
+ 'layer14': 'cd',
212
+ 'layer15': 'cd',
213
+ },
214
+ 'a16': {
215
+ 'layer0': 'ad',
216
+ 'layer1': 'ad',
217
+ 'layer2': 'ad',
218
+ 'layer3': 'ad',
219
+ 'layer4': 'ad',
220
+ 'layer5': 'ad',
221
+ 'layer6': 'ad',
222
+ 'layer7': 'ad',
223
+ 'layer8': 'ad',
224
+ 'layer9': 'ad',
225
+ 'layer10': 'ad',
226
+ 'layer11': 'ad',
227
+ 'layer12': 'ad',
228
+ 'layer13': 'ad',
229
+ 'layer14': 'ad',
230
+ 'layer15': 'ad',
231
+ },
232
+ 'r16': {
233
+ 'layer0': 'rd',
234
+ 'layer1': 'rd',
235
+ 'layer2': 'rd',
236
+ 'layer3': 'rd',
237
+ 'layer4': 'rd',
238
+ 'layer5': 'rd',
239
+ 'layer6': 'rd',
240
+ 'layer7': 'rd',
241
+ 'layer8': 'rd',
242
+ 'layer9': 'rd',
243
+ 'layer10': 'rd',
244
+ 'layer11': 'rd',
245
+ 'layer12': 'rd',
246
+ 'layer13': 'rd',
247
+ 'layer14': 'rd',
248
+ 'layer15': 'rd',
249
+ },
250
+ 'carv4': {
251
+ 'layer0': 'cd',
252
+ 'layer1': 'ad',
253
+ 'layer2': 'rd',
254
+ 'layer3': 'cv',
255
+ 'layer4': 'cd',
256
+ 'layer5': 'ad',
257
+ 'layer6': 'rd',
258
+ 'layer7': 'cv',
259
+ 'layer8': 'cd',
260
+ 'layer9': 'ad',
261
+ 'layer10': 'rd',
262
+ 'layer11': 'cv',
263
+ 'layer12': 'cd',
264
+ 'layer13': 'ad',
265
+ 'layer14': 'rd',
266
+ 'layer15': 'cv',
267
+ },
268
+ }
269
+
270
+ def createConvFunc(op_type):
271
+ assert op_type in ['cv', 'cd', 'ad', 'rd'], 'unknown op type: %s' % str(op_type)
272
+ if op_type == 'cv':
273
+ return F.conv2d
274
+
275
+ if op_type == 'cd':
276
+ def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1):
277
+ assert dilation in [1, 2], 'dilation for cd_conv should be in 1 or 2'
278
+ assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for cd_conv should be 3x3'
279
+ assert padding == dilation, 'padding for cd_conv set wrong'
280
+
281
+ weights_c = weights.sum(dim=[2, 3], keepdim=True)
282
+ yc = F.conv2d(x, weights_c, stride=stride, padding=0, groups=groups)
283
+ y = F.conv2d(x, weights, bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
284
+ return y - yc
285
+ return func
286
+ elif op_type == 'ad':
287
+ def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1):
288
+ assert dilation in [1, 2], 'dilation for ad_conv should be in 1 or 2'
289
+ assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for ad_conv should be 3x3'
290
+ assert padding == dilation, 'padding for ad_conv set wrong'
291
+
292
+ shape = weights.shape
293
+ weights = weights.view(shape[0], shape[1], -1)
294
+ weights_conv = (weights - weights[:, :, [3, 0, 1, 6, 4, 2, 7, 8, 5]]).view(shape) # clock-wise
295
+ y = F.conv2d(x, weights_conv, bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
296
+ return y
297
+ return func
298
+ elif op_type == 'rd':
299
+ def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1):
300
+ assert dilation in [1, 2], 'dilation for rd_conv should be in 1 or 2'
301
+ assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for rd_conv should be 3x3'
302
+ padding = 2 * dilation
303
+
304
+ shape = weights.shape
305
+ if weights.is_cuda:
306
+ buffer = torch.cuda.FloatTensor(shape[0], shape[1], 5 * 5).fill_(0)
307
+ else:
308
+ buffer = torch.zeros(shape[0], shape[1], 5 * 5)
309
+ weights = weights.view(shape[0], shape[1], -1)
310
+ buffer[:, :, [0, 2, 4, 10, 14, 20, 22, 24]] = weights[:, :, 1:]
311
+ buffer[:, :, [6, 7, 8, 11, 13, 16, 17, 18]] = -weights[:, :, 1:]
312
+ buffer[:, :, 12] = 0
313
+ buffer = buffer.view(shape[0], shape[1], 5, 5)
314
+ y = F.conv2d(x, buffer, bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
315
+ return y
316
+ return func
317
+ else:
318
+ print('impossible to be here unless you force that')
319
+ return None
320
+
321
+ class Conv2d(nn.Module):
322
+ def __init__(self, pdc, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False):
323
+ super(Conv2d, self).__init__()
324
+ if in_channels % groups != 0:
325
+ raise ValueError('in_channels must be divisible by groups')
326
+ if out_channels % groups != 0:
327
+ raise ValueError('out_channels must be divisible by groups')
328
+ self.in_channels = in_channels
329
+ self.out_channels = out_channels
330
+ self.kernel_size = kernel_size
331
+ self.stride = stride
332
+ self.padding = padding
333
+ self.dilation = dilation
334
+ self.groups = groups
335
+ self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, kernel_size, kernel_size))
336
+ if bias:
337
+ self.bias = nn.Parameter(torch.Tensor(out_channels))
338
+ else:
339
+ self.register_parameter('bias', None)
340
+ self.reset_parameters()
341
+ self.pdc = pdc
342
+
343
+ def reset_parameters(self):
344
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
345
+ if self.bias is not None:
346
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
347
+ bound = 1 / math.sqrt(fan_in)
348
+ nn.init.uniform_(self.bias, -bound, bound)
349
+
350
+ def forward(self, input):
351
+
352
+ return self.pdc(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
353
+
354
+ class CSAM(nn.Module):
355
+ """
356
+ Compact Spatial Attention Module
357
+ """
358
+ def __init__(self, channels):
359
+ super(CSAM, self).__init__()
360
+
361
+ mid_channels = 4
362
+ self.relu1 = nn.ReLU()
363
+ self.conv1 = nn.Conv2d(channels, mid_channels, kernel_size=1, padding=0)
364
+ self.conv2 = nn.Conv2d(mid_channels, 1, kernel_size=3, padding=1, bias=False)
365
+ self.sigmoid = nn.Sigmoid()
366
+ nn.init.constant_(self.conv1.bias, 0)
367
+
368
+ def forward(self, x):
369
+ y = self.relu1(x)
370
+ y = self.conv1(y)
371
+ y = self.conv2(y)
372
+ y = self.sigmoid(y)
373
+
374
+ return x * y
375
+
376
+ class CDCM(nn.Module):
377
+ """
378
+ Compact Dilation Convolution based Module
379
+ """
380
+ def __init__(self, in_channels, out_channels):
381
+ super(CDCM, self).__init__()
382
+
383
+ self.relu1 = nn.ReLU()
384
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
385
+ self.conv2_1 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=5, padding=5, bias=False)
386
+ self.conv2_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=7, padding=7, bias=False)
387
+ self.conv2_3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=9, padding=9, bias=False)
388
+ self.conv2_4 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=11, padding=11, bias=False)
389
+ nn.init.constant_(self.conv1.bias, 0)
390
+
391
+ def forward(self, x):
392
+ x = self.relu1(x)
393
+ x = self.conv1(x)
394
+ x1 = self.conv2_1(x)
395
+ x2 = self.conv2_2(x)
396
+ x3 = self.conv2_3(x)
397
+ x4 = self.conv2_4(x)
398
+ return x1 + x2 + x3 + x4
399
+
400
+
401
+ class MapReduce(nn.Module):
402
+ """
403
+ Reduce feature maps into a single edge map
404
+ """
405
+ def __init__(self, channels):
406
+ super(MapReduce, self).__init__()
407
+ self.conv = nn.Conv2d(channels, 1, kernel_size=1, padding=0)
408
+ nn.init.constant_(self.conv.bias, 0)
409
+
410
+ def forward(self, x):
411
+ return self.conv(x)
412
+
413
+
414
+ class PDCBlock(nn.Module):
415
+ def __init__(self, pdc, inplane, ouplane, stride=1):
416
+ super(PDCBlock, self).__init__()
417
+ self.stride=stride
418
+
419
+ self.stride=stride
420
+ if self.stride > 1:
421
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
422
+ self.shortcut = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0)
423
+ self.conv1 = Conv2d(pdc, inplane, inplane, kernel_size=3, padding=1, groups=inplane, bias=False)
424
+ self.relu2 = nn.ReLU()
425
+ self.conv2 = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0, bias=False)
426
+
427
+ def forward(self, x):
428
+ if self.stride > 1:
429
+ x = self.pool(x)
430
+ y = self.conv1(x)
431
+ y = self.relu2(y)
432
+ y = self.conv2(y)
433
+ if self.stride > 1:
434
+ x = self.shortcut(x)
435
+ y = y + x
436
+ return y
437
+
438
+ class PDCBlock_converted(nn.Module):
439
+ """
440
+ CPDC, APDC can be converted to vanilla 3x3 convolution
441
+ RPDC can be converted to vanilla 5x5 convolution
442
+ """
443
+ def __init__(self, pdc, inplane, ouplane, stride=1):
444
+ super(PDCBlock_converted, self).__init__()
445
+ self.stride=stride
446
+
447
+ if self.stride > 1:
448
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
449
+ self.shortcut = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0)
450
+ if pdc == 'rd':
451
+ self.conv1 = nn.Conv2d(inplane, inplane, kernel_size=5, padding=2, groups=inplane, bias=False)
452
+ else:
453
+ self.conv1 = nn.Conv2d(inplane, inplane, kernel_size=3, padding=1, groups=inplane, bias=False)
454
+ self.relu2 = nn.ReLU()
455
+ self.conv2 = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0, bias=False)
456
+
457
+ def forward(self, x):
458
+ if self.stride > 1:
459
+ x = self.pool(x)
460
+ y = self.conv1(x)
461
+ y = self.relu2(y)
462
+ y = self.conv2(y)
463
+ if self.stride > 1:
464
+ x = self.shortcut(x)
465
+ y = y + x
466
+ return y
467
+
468
+ class PiDiNet(nn.Module):
469
+ def __init__(self, inplane, pdcs, dil=None, sa=False, convert=False):
470
+ super(PiDiNet, self).__init__()
471
+ self.sa = sa
472
+ if dil is not None:
473
+ assert isinstance(dil, int), 'dil should be an int'
474
+ self.dil = dil
475
+
476
+ self.fuseplanes = []
477
+
478
+ self.inplane = inplane
479
+ if convert:
480
+ if pdcs[0] == 'rd':
481
+ init_kernel_size = 5
482
+ init_padding = 2
483
+ else:
484
+ init_kernel_size = 3
485
+ init_padding = 1
486
+ self.init_block = nn.Conv2d(3, self.inplane,
487
+ kernel_size=init_kernel_size, padding=init_padding, bias=False)
488
+ block_class = PDCBlock_converted
489
+ else:
490
+ self.init_block = Conv2d(pdcs[0], 3, self.inplane, kernel_size=3, padding=1)
491
+ block_class = PDCBlock
492
+
493
+ self.block1_1 = block_class(pdcs[1], self.inplane, self.inplane)
494
+ self.block1_2 = block_class(pdcs[2], self.inplane, self.inplane)
495
+ self.block1_3 = block_class(pdcs[3], self.inplane, self.inplane)
496
+ self.fuseplanes.append(self.inplane) # C
497
+
498
+ inplane = self.inplane
499
+ self.inplane = self.inplane * 2
500
+ self.block2_1 = block_class(pdcs[4], inplane, self.inplane, stride=2)
501
+ self.block2_2 = block_class(pdcs[5], self.inplane, self.inplane)
502
+ self.block2_3 = block_class(pdcs[6], self.inplane, self.inplane)
503
+ self.block2_4 = block_class(pdcs[7], self.inplane, self.inplane)
504
+ self.fuseplanes.append(self.inplane) # 2C
505
+
506
+ inplane = self.inplane
507
+ self.inplane = self.inplane * 2
508
+ self.block3_1 = block_class(pdcs[8], inplane, self.inplane, stride=2)
509
+ self.block3_2 = block_class(pdcs[9], self.inplane, self.inplane)
510
+ self.block3_3 = block_class(pdcs[10], self.inplane, self.inplane)
511
+ self.block3_4 = block_class(pdcs[11], self.inplane, self.inplane)
512
+ self.fuseplanes.append(self.inplane) # 4C
513
+
514
+ self.block4_1 = block_class(pdcs[12], self.inplane, self.inplane, stride=2)
515
+ self.block4_2 = block_class(pdcs[13], self.inplane, self.inplane)
516
+ self.block4_3 = block_class(pdcs[14], self.inplane, self.inplane)
517
+ self.block4_4 = block_class(pdcs[15], self.inplane, self.inplane)
518
+ self.fuseplanes.append(self.inplane) # 4C
519
+
520
+ self.conv_reduces = nn.ModuleList()
521
+ if self.sa and self.dil is not None:
522
+ self.attentions = nn.ModuleList()
523
+ self.dilations = nn.ModuleList()
524
+ for i in range(4):
525
+ self.dilations.append(CDCM(self.fuseplanes[i], self.dil))
526
+ self.attentions.append(CSAM(self.dil))
527
+ self.conv_reduces.append(MapReduce(self.dil))
528
+ elif self.sa:
529
+ self.attentions = nn.ModuleList()
530
+ for i in range(4):
531
+ self.attentions.append(CSAM(self.fuseplanes[i]))
532
+ self.conv_reduces.append(MapReduce(self.fuseplanes[i]))
533
+ elif self.dil is not None:
534
+ self.dilations = nn.ModuleList()
535
+ for i in range(4):
536
+ self.dilations.append(CDCM(self.fuseplanes[i], self.dil))
537
+ self.conv_reduces.append(MapReduce(self.dil))
538
+ else:
539
+ for i in range(4):
540
+ self.conv_reduces.append(MapReduce(self.fuseplanes[i]))
541
+
542
+ self.classifier = nn.Conv2d(4, 1, kernel_size=1) # has bias
543
+ nn.init.constant_(self.classifier.weight, 0.25)
544
+ nn.init.constant_(self.classifier.bias, 0)
545
+
546
+ # print('initialization done')
547
+
548
+ def get_weights(self):
549
+ conv_weights = []
550
+ bn_weights = []
551
+ relu_weights = []
552
+ for pname, p in self.named_parameters():
553
+ if 'bn' in pname:
554
+ bn_weights.append(p)
555
+ elif 'relu' in pname:
556
+ relu_weights.append(p)
557
+ else:
558
+ conv_weights.append(p)
559
+
560
+ return conv_weights, bn_weights, relu_weights
561
+
562
+ def forward(self, x):
563
+ H, W = x.size()[2:]
564
+
565
+ x = self.init_block(x)
566
+
567
+ x1 = self.block1_1(x)
568
+ x1 = self.block1_2(x1)
569
+ x1 = self.block1_3(x1)
570
+
571
+ x2 = self.block2_1(x1)
572
+ x2 = self.block2_2(x2)
573
+ x2 = self.block2_3(x2)
574
+ x2 = self.block2_4(x2)
575
+
576
+ x3 = self.block3_1(x2)
577
+ x3 = self.block3_2(x3)
578
+ x3 = self.block3_3(x3)
579
+ x3 = self.block3_4(x3)
580
+
581
+ x4 = self.block4_1(x3)
582
+ x4 = self.block4_2(x4)
583
+ x4 = self.block4_3(x4)
584
+ x4 = self.block4_4(x4)
585
+
586
+ x_fuses = []
587
+ if self.sa and self.dil is not None:
588
+ for i, xi in enumerate([x1, x2, x3, x4]):
589
+ x_fuses.append(self.attentions[i](self.dilations[i](xi)))
590
+ elif self.sa:
591
+ for i, xi in enumerate([x1, x2, x3, x4]):
592
+ x_fuses.append(self.attentions[i](xi))
593
+ elif self.dil is not None:
594
+ for i, xi in enumerate([x1, x2, x3, x4]):
595
+ x_fuses.append(self.dilations[i](xi))
596
+ else:
597
+ x_fuses = [x1, x2, x3, x4]
598
+
599
+ e1 = self.conv_reduces[0](x_fuses[0])
600
+ e1 = F.interpolate(e1, (H, W), mode="bilinear", align_corners=False)
601
+
602
+ e2 = self.conv_reduces[1](x_fuses[1])
603
+ e2 = F.interpolate(e2, (H, W), mode="bilinear", align_corners=False)
604
+
605
+ e3 = self.conv_reduces[2](x_fuses[2])
606
+ e3 = F.interpolate(e3, (H, W), mode="bilinear", align_corners=False)
607
+
608
+ e4 = self.conv_reduces[3](x_fuses[3])
609
+ e4 = F.interpolate(e4, (H, W), mode="bilinear", align_corners=False)
610
+
611
+ outputs = [e1, e2, e3, e4]
612
+
613
+ output = self.classifier(torch.cat(outputs, dim=1))
614
+ #if not self.training:
615
+ # return torch.sigmoid(output)
616
+
617
+ outputs.append(output)
618
+ outputs = [torch.sigmoid(r) for r in outputs]
619
+ return outputs
620
+
621
+ def config_model(model):
622
+ model_options = list(nets.keys())
623
+ assert model in model_options, \
624
+ 'unrecognized model, please choose from %s' % str(model_options)
625
+
626
+ # print(str(nets[model]))
627
+
628
+ pdcs = []
629
+ for i in range(16):
630
+ layer_name = 'layer%d' % i
631
+ op = nets[model][layer_name]
632
+ pdcs.append(createConvFunc(op))
633
+
634
+ return pdcs
635
+
636
+ def pidinet():
637
+ pdcs = config_model('carv4')
638
+ dil = 24 #if args.dil else None
639
+ return PiDiNet(60, pdcs, dil=dil, sa=True)