52Hz commited on
Commit
2f91f89
1 Parent(s): d625d36

Create CMFNet.py

Browse files
Files changed (1) hide show
  1. model/CMFNet.py +210 -0
model/CMFNet.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from utils import network_parameters
4
+ from thop import profile
5
+ from model.block import SAB, CAB, PAB, conv, SAM, conv3x3, conv_down
6
+
7
+ ##########################################################################
8
+ ## U-Net
9
+ bn = 2 # block number-1
10
+
11
+ class Encoder(nn.Module):
12
+ def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats, block):
13
+ super(Encoder, self).__init__()
14
+ if block == 'CAB':
15
+ self.encoder_level1 = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
16
+ self.encoder_level2 = [CAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
17
+ self.encoder_level3 = [CAB(n_feat + (scale_unetfeats * 2), kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
18
+ elif block == 'PAB':
19
+ self.encoder_level1 = [PAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
20
+ self.encoder_level2 = [PAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
21
+ self.encoder_level3 = [PAB(n_feat + (scale_unetfeats * 2), kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
22
+ elif block == 'SAB':
23
+ self.encoder_level1 = [SAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
24
+ self.encoder_level2 = [SAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
25
+ self.encoder_level3 = [SAB(n_feat + (scale_unetfeats * 2), kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
26
+ self.encoder_level1 = nn.Sequential(*self.encoder_level1)
27
+ self.encoder_level2 = nn.Sequential(*self.encoder_level2)
28
+ self.encoder_level3 = nn.Sequential(*self.encoder_level3)
29
+ self.down12 = DownSample(n_feat, scale_unetfeats)
30
+ self.down23 = DownSample(n_feat + scale_unetfeats, scale_unetfeats)
31
+
32
+ def forward(self, x):
33
+ enc1 = self.encoder_level1(x)
34
+ x = self.down12(enc1)
35
+ enc2 = self.encoder_level2(x)
36
+ x = self.down23(enc2)
37
+ enc3 = self.encoder_level3(x)
38
+ return [enc1, enc2, enc3]
39
+
40
+ class Decoder(nn.Module):
41
+ def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats, block):
42
+ super(Decoder, self).__init__()
43
+ if block == 'CAB':
44
+ self.decoder_level1 = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
45
+ self.decoder_level2 = [CAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
46
+ self.decoder_level3 = [CAB(n_feat + (scale_unetfeats * 2), kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
47
+ elif block == 'PAB':
48
+ self.decoder_level1 = [PAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
49
+ self.decoder_level2 = [PAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
50
+ self.decoder_level3 = [PAB(n_feat + (scale_unetfeats * 2), kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
51
+ elif block == 'SAB':
52
+ self.decoder_level1 = [SAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
53
+ self.decoder_level2 = [SAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
54
+ self.decoder_level3 = [SAB(n_feat + (scale_unetfeats * 2), kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
55
+ self.decoder_level1 = nn.Sequential(*self.decoder_level1)
56
+ self.decoder_level2 = nn.Sequential(*self.decoder_level2)
57
+ self.decoder_level3 = nn.Sequential(*self.decoder_level3)
58
+ if block == 'CAB':
59
+ self.skip_attn1 = CAB(n_feat, kernel_size, reduction, bias=bias, act=act)
60
+ self.skip_attn2 = CAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act)
61
+ if block == 'PAB':
62
+ self.skip_attn1 = PAB(n_feat, kernel_size, reduction, bias=bias, act=act)
63
+ self.skip_attn2 = PAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act)
64
+ if block == 'SAB':
65
+ self.skip_attn1 = SAB(n_feat, kernel_size, reduction, bias=bias, act=act)
66
+ self.skip_attn2 = SAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act)
67
+ self.up21 = SkipUpSample(n_feat, scale_unetfeats)
68
+ self.up32 = SkipUpSample(n_feat + scale_unetfeats, scale_unetfeats)
69
+
70
+ def forward(self, outs):
71
+ enc1, enc2, enc3 = outs
72
+ dec3 = self.decoder_level3(enc3)
73
+ x = self.up32(dec3, self.skip_attn2(enc2))
74
+ dec2 = self.decoder_level2(x)
75
+ x = self.up21(dec2, self.skip_attn1(enc1))
76
+ dec1 = self.decoder_level1(x)
77
+ return [dec1, dec2, dec3]
78
+
79
+ ##########################################################################
80
+ ##---------- Resizing Modules ----------
81
+ class DownSample(nn.Module):
82
+ def __init__(self, in_channels, s_factor):
83
+ super(DownSample, self).__init__()
84
+ self.down = nn.Sequential(nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False),
85
+ nn.Conv2d(in_channels, in_channels + s_factor, 1, stride=1, padding=0, bias=False))
86
+
87
+ def forward(self, x):
88
+ x = self.down(x)
89
+ return x
90
+
91
+ class UpSample(nn.Module):
92
+ def __init__(self, in_channels, s_factor):
93
+ super(UpSample, self).__init__()
94
+ self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
95
+ nn.Conv2d(in_channels + s_factor, in_channels, 1, stride=1, padding=0, bias=False))
96
+
97
+ def forward(self, x):
98
+ x = self.up(x)
99
+ return x
100
+
101
+ class SkipUpSample(nn.Module):
102
+ def __init__(self, in_channels, s_factor):
103
+ super(SkipUpSample, self).__init__()
104
+ self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
105
+ nn.Conv2d(in_channels + s_factor, in_channels, 1, stride=1, padding=0, bias=False))
106
+
107
+ def forward(self, x, y):
108
+ x = self.up(x)
109
+ x = x + y
110
+ return x
111
+
112
+ ##########################################################################
113
+ # Mixed Residual Module
114
+ class Mix(nn.Module):
115
+ def __init__(self, m=1):
116
+ super(Mix, self).__init__()
117
+ w = nn.Parameter(torch.FloatTensor([m]), requires_grad=True)
118
+ w = nn.Parameter(w, requires_grad=True)
119
+ self.w = w
120
+ self.mix_block = nn.Sigmoid()
121
+
122
+ def forward(self, fea1, fea2, feat3):
123
+ factor = self.mix_block(self.w)
124
+ other = (1 - factor)/2
125
+ output = fea1 * other.expand_as(fea1) + fea2 * factor.expand_as(fea2) + feat3 * other.expand_as(feat3)
126
+ return output, factor
127
+
128
+ ##########################################################################
129
+ # Architecture
130
+ class CMFNet(nn.Module):
131
+ def __init__(self, in_c=3, out_c=3, n_feat=96, scale_unetfeats=48, kernel_size=3, reduction=4, bias=False):
132
+ super(CMFNet, self).__init__()
133
+
134
+ p_act = nn.PReLU()
135
+ self.shallow_feat1 = nn.Sequential(conv(in_c, n_feat // 2, kernel_size, bias=bias), p_act,
136
+ conv(n_feat // 2, n_feat, kernel_size, bias=bias))
137
+ self.shallow_feat2 = nn.Sequential(conv(in_c, n_feat // 2, kernel_size, bias=bias), p_act,
138
+ conv(n_feat // 2, n_feat, kernel_size, bias=bias))
139
+ self.shallow_feat3 = nn.Sequential(conv(in_c, n_feat // 2, kernel_size, bias=bias), p_act,
140
+ conv(n_feat // 2, n_feat, kernel_size, bias=bias))
141
+
142
+ self.stage1_encoder = Encoder(n_feat, kernel_size, reduction, p_act, bias, scale_unetfeats, 'CAB')
143
+ self.stage1_decoder = Decoder(n_feat, kernel_size, reduction, p_act, bias, scale_unetfeats, 'CAB')
144
+
145
+ self.stage2_encoder = Encoder(n_feat, kernel_size, reduction, p_act, bias, scale_unetfeats, 'PAB')
146
+ self.stage2_decoder = Decoder(n_feat, kernel_size, reduction, p_act, bias, scale_unetfeats, 'PAB')
147
+
148
+ self.stage3_encoder = Encoder(n_feat, kernel_size, reduction, p_act, bias, scale_unetfeats, 'SAB')
149
+ self.stage3_decoder = Decoder(n_feat, kernel_size, reduction, p_act, bias, scale_unetfeats, 'SAB')
150
+
151
+ self.sam1o = SAM(n_feat, kernel_size=3, bias=bias)
152
+ self.sam2o = SAM(n_feat, kernel_size=3, bias=bias)
153
+ self.sam3o = SAM(n_feat, kernel_size=3, bias=bias)
154
+
155
+ self.mix = Mix(1)
156
+ self.add123 = conv(out_c, out_c, kernel_size, bias=bias)
157
+ self.concat123 = conv(n_feat*3, n_feat, kernel_size, bias=bias)
158
+ self.tail = conv(n_feat, out_c, kernel_size, bias=bias)
159
+
160
+
161
+ def forward(self, x):
162
+ ## Compute Shallow Features
163
+ shallow1 = self.shallow_feat1(x)
164
+ shallow2 = self.shallow_feat2(x)
165
+ shallow3 = self.shallow_feat3(x)
166
+
167
+ ## Enter the UNet-CAB
168
+ x1 = self.stage1_encoder(shallow1)
169
+ x1_D = self.stage1_decoder(x1)
170
+ ## Apply SAM
171
+ x1_out, x1_img = self.sam1o(x1_D[0], x)
172
+
173
+ ## Enter the UNet-PAB
174
+ x2 = self.stage2_encoder(shallow2)
175
+ x2_D = self.stage2_decoder(x2)
176
+ ## Apply SAM
177
+ x2_out, x2_img = self.sam2o(x2_D[0], x)
178
+
179
+ ## Enter the UNet-SAB
180
+ x3 = self.stage3_encoder(shallow3)
181
+ x3_D = self.stage3_decoder(x3)
182
+ ## Apply SAM
183
+ x3_out, x3_img = self.sam3o(x3_D[0], x)
184
+
185
+ ## Aggregate SAM features of Stage 1, Stage 2 and Stage 3
186
+ mix_r = self.mix(x1_img, x2_img, x3_img)
187
+ mixed_img = self.add123(mix_r[0])
188
+
189
+ ## Concat SAM features of Stage 1, Stage 2 and Stage 3
190
+ concat_feat = self.concat123(torch.cat([x1_out, x2_out, x3_out], 1))
191
+ x_final = self.tail(concat_feat)
192
+
193
+ return [x_final + mixed_img, mixed_img, mix_r[1], x1_img, x2_img, x3_img, x_final]
194
+
195
+
196
+ if __name__ == "__main__":
197
+ import time
198
+ model = CMFNet()
199
+
200
+ for idx, m in enumerate(model.modules()):
201
+ print(idx, "-", m)
202
+ s = time.time()
203
+
204
+ rgb = torch.ones(1, 3, 256, 256, dtype=torch.float, requires_grad=False)
205
+ out = model(rgb)
206
+ flops, params = profile(model, inputs=(rgb,))
207
+ print('parameters:', params)
208
+ print('flops', flops)
209
+ print('time: {:.4f}ms'.format((time.time()-s)*10))
210
+