sidharthism commited on
Commit
2fc8f4b
1 Parent(s): 621050a
cloth_segmentation/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .generate_cloth_mask import *
cloth_segmentation/checkpoints/cloth_segm_u2net_latest.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f71fad2bc11789a996acc507d1a5a1602ae0edefc2b9aba1cd198be5cc9c1a44
3
+ size 176625341
cloth_segmentation/networks/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .u2net import U2NET
cloth_segmentation/networks/u2net.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class REBNCONV(nn.Module):
7
+ def __init__(self, in_ch=3, out_ch=3, dirate=1):
8
+ super(REBNCONV, self).__init__()
9
+
10
+ self.conv_s1 = nn.Conv2d(
11
+ in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate
12
+ )
13
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
14
+ self.relu_s1 = nn.ReLU(inplace=True)
15
+
16
+ def forward(self, x):
17
+
18
+ hx = x
19
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
20
+
21
+ return xout
22
+
23
+
24
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
25
+ def _upsample_like(src, tar):
26
+
27
+ src = F.upsample(src, size=tar.shape[2:], mode="bilinear")
28
+
29
+ return src
30
+
31
+
32
+ ### RSU-7 ###
33
+ class RSU7(nn.Module): # UNet07DRES(nn.Module):
34
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
35
+ super(RSU7, self).__init__()
36
+
37
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
38
+
39
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
40
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
41
+
42
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
43
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
44
+
45
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
46
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
47
+
48
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
49
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
50
+
51
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
52
+ self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
53
+
54
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
55
+
56
+ self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
57
+
58
+ self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
59
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
60
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
61
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
62
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
63
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
64
+
65
+ def forward(self, x):
66
+
67
+ hx = x
68
+ hxin = self.rebnconvin(hx)
69
+
70
+ hx1 = self.rebnconv1(hxin)
71
+ hx = self.pool1(hx1)
72
+
73
+ hx2 = self.rebnconv2(hx)
74
+ hx = self.pool2(hx2)
75
+
76
+ hx3 = self.rebnconv3(hx)
77
+ hx = self.pool3(hx3)
78
+
79
+ hx4 = self.rebnconv4(hx)
80
+ hx = self.pool4(hx4)
81
+
82
+ hx5 = self.rebnconv5(hx)
83
+ hx = self.pool5(hx5)
84
+
85
+ hx6 = self.rebnconv6(hx)
86
+
87
+ hx7 = self.rebnconv7(hx6)
88
+
89
+ hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
90
+ hx6dup = _upsample_like(hx6d, hx5)
91
+
92
+ hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
93
+ hx5dup = _upsample_like(hx5d, hx4)
94
+
95
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
96
+ hx4dup = _upsample_like(hx4d, hx3)
97
+
98
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
99
+ hx3dup = _upsample_like(hx3d, hx2)
100
+
101
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
102
+ hx2dup = _upsample_like(hx2d, hx1)
103
+
104
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
105
+
106
+ """
107
+ del hx1, hx2, hx3, hx4, hx5, hx6, hx7
108
+ del hx6d, hx5d, hx3d, hx2d
109
+ del hx2dup, hx3dup, hx4dup, hx5dup, hx6dup
110
+ """
111
+
112
+ return hx1d + hxin
113
+
114
+
115
+ ### RSU-6 ###
116
+ class RSU6(nn.Module): # UNet06DRES(nn.Module):
117
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
118
+ super(RSU6, self).__init__()
119
+
120
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
121
+
122
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
123
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
124
+
125
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
126
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
127
+
128
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
129
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
130
+
131
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
132
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
133
+
134
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
135
+
136
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
137
+
138
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
139
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
140
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
141
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
142
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
143
+
144
+ def forward(self, x):
145
+
146
+ hx = x
147
+
148
+ hxin = self.rebnconvin(hx)
149
+
150
+ hx1 = self.rebnconv1(hxin)
151
+ hx = self.pool1(hx1)
152
+
153
+ hx2 = self.rebnconv2(hx)
154
+ hx = self.pool2(hx2)
155
+
156
+ hx3 = self.rebnconv3(hx)
157
+ hx = self.pool3(hx3)
158
+
159
+ hx4 = self.rebnconv4(hx)
160
+ hx = self.pool4(hx4)
161
+
162
+ hx5 = self.rebnconv5(hx)
163
+
164
+ hx6 = self.rebnconv6(hx5)
165
+
166
+ hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
167
+ hx5dup = _upsample_like(hx5d, hx4)
168
+
169
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
170
+ hx4dup = _upsample_like(hx4d, hx3)
171
+
172
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
173
+ hx3dup = _upsample_like(hx3d, hx2)
174
+
175
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
176
+ hx2dup = _upsample_like(hx2d, hx1)
177
+
178
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
179
+
180
+ """
181
+ del hx1, hx2, hx3, hx4, hx5, hx6
182
+ del hx5d, hx4d, hx3d, hx2d
183
+ del hx2dup, hx3dup, hx4dup, hx5dup
184
+ """
185
+
186
+ return hx1d + hxin
187
+
188
+
189
+ ### RSU-5 ###
190
+ class RSU5(nn.Module): # UNet05DRES(nn.Module):
191
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
192
+ super(RSU5, self).__init__()
193
+
194
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
195
+
196
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
197
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
198
+
199
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
200
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
201
+
202
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
203
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
204
+
205
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
206
+
207
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
208
+
209
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
210
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
211
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
212
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
213
+
214
+ def forward(self, x):
215
+
216
+ hx = x
217
+
218
+ hxin = self.rebnconvin(hx)
219
+
220
+ hx1 = self.rebnconv1(hxin)
221
+ hx = self.pool1(hx1)
222
+
223
+ hx2 = self.rebnconv2(hx)
224
+ hx = self.pool2(hx2)
225
+
226
+ hx3 = self.rebnconv3(hx)
227
+ hx = self.pool3(hx3)
228
+
229
+ hx4 = self.rebnconv4(hx)
230
+
231
+ hx5 = self.rebnconv5(hx4)
232
+
233
+ hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
234
+ hx4dup = _upsample_like(hx4d, hx3)
235
+
236
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
237
+ hx3dup = _upsample_like(hx3d, hx2)
238
+
239
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
240
+ hx2dup = _upsample_like(hx2d, hx1)
241
+
242
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
243
+
244
+ """
245
+ del hx1, hx2, hx3, hx4, hx5
246
+ del hx4d, hx3d, hx2d
247
+ del hx2dup, hx3dup, hx4dup
248
+ """
249
+
250
+ return hx1d + hxin
251
+
252
+
253
+ ### RSU-4 ###
254
+ class RSU4(nn.Module): # UNet04DRES(nn.Module):
255
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
256
+ super(RSU4, self).__init__()
257
+
258
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
259
+
260
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
261
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
262
+
263
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
264
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
265
+
266
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
267
+
268
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
269
+
270
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
271
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
272
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
273
+
274
+ def forward(self, x):
275
+
276
+ hx = x
277
+
278
+ hxin = self.rebnconvin(hx)
279
+
280
+ hx1 = self.rebnconv1(hxin)
281
+ hx = self.pool1(hx1)
282
+
283
+ hx2 = self.rebnconv2(hx)
284
+ hx = self.pool2(hx2)
285
+
286
+ hx3 = self.rebnconv3(hx)
287
+
288
+ hx4 = self.rebnconv4(hx3)
289
+
290
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
291
+ hx3dup = _upsample_like(hx3d, hx2)
292
+
293
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
294
+ hx2dup = _upsample_like(hx2d, hx1)
295
+
296
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
297
+
298
+ """
299
+ del hx1, hx2, hx3, hx4
300
+ del hx3d, hx2d
301
+ del hx2dup, hx3dup
302
+ """
303
+
304
+ return hx1d + hxin
305
+
306
+
307
+ ### RSU-4F ###
308
+ class RSU4F(nn.Module): # UNet04FRES(nn.Module):
309
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
310
+ super(RSU4F, self).__init__()
311
+
312
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
313
+
314
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
315
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
316
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
317
+
318
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
319
+
320
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
321
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
322
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
323
+
324
+ def forward(self, x):
325
+
326
+ hx = x
327
+
328
+ hxin = self.rebnconvin(hx)
329
+
330
+ hx1 = self.rebnconv1(hxin)
331
+ hx2 = self.rebnconv2(hx1)
332
+ hx3 = self.rebnconv3(hx2)
333
+
334
+ hx4 = self.rebnconv4(hx3)
335
+
336
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
337
+ hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
338
+ hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
339
+
340
+ """
341
+ del hx1, hx2, hx3, hx4
342
+ del hx3d, hx2d
343
+ """
344
+
345
+ return hx1d + hxin
346
+
347
+
348
+ ##### U^2-Net ####
349
+ class U2NET(nn.Module):
350
+ def __init__(self, in_ch=3, out_ch=1):
351
+ super(U2NET, self).__init__()
352
+
353
+ self.stage1 = RSU7(in_ch, 32, 64)
354
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
355
+
356
+ self.stage2 = RSU6(64, 32, 128)
357
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
358
+
359
+ self.stage3 = RSU5(128, 64, 256)
360
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
361
+
362
+ self.stage4 = RSU4(256, 128, 512)
363
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
364
+
365
+ self.stage5 = RSU4F(512, 256, 512)
366
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
367
+
368
+ self.stage6 = RSU4F(512, 256, 512)
369
+
370
+ # decoder
371
+ self.stage5d = RSU4F(1024, 256, 512)
372
+ self.stage4d = RSU4(1024, 128, 256)
373
+ self.stage3d = RSU5(512, 64, 128)
374
+ self.stage2d = RSU6(256, 32, 64)
375
+ self.stage1d = RSU7(128, 16, 64)
376
+
377
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
378
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
379
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
380
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
381
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
382
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
383
+
384
+ self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1)
385
+
386
+ def forward(self, x):
387
+
388
+ hx = x
389
+
390
+ # stage 1
391
+ hx1 = self.stage1(hx)
392
+ hx = self.pool12(hx1)
393
+
394
+ # stage 2
395
+ hx2 = self.stage2(hx)
396
+ hx = self.pool23(hx2)
397
+
398
+ # stage 3
399
+ hx3 = self.stage3(hx)
400
+ hx = self.pool34(hx3)
401
+
402
+ # stage 4
403
+ hx4 = self.stage4(hx)
404
+ hx = self.pool45(hx4)
405
+
406
+ # stage 5
407
+ hx5 = self.stage5(hx)
408
+ hx = self.pool56(hx5)
409
+
410
+ # stage 6
411
+ hx6 = self.stage6(hx)
412
+ hx6up = _upsample_like(hx6, hx5)
413
+
414
+ # -------------------- decoder --------------------
415
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
416
+ hx5dup = _upsample_like(hx5d, hx4)
417
+
418
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
419
+ hx4dup = _upsample_like(hx4d, hx3)
420
+
421
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
422
+ hx3dup = _upsample_like(hx3d, hx2)
423
+
424
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
425
+ hx2dup = _upsample_like(hx2d, hx1)
426
+
427
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
428
+
429
+ # side output
430
+ d1 = self.side1(hx1d)
431
+
432
+ d2 = self.side2(hx2d)
433
+ d2 = _upsample_like(d2, d1)
434
+
435
+ d3 = self.side3(hx3d)
436
+ d3 = _upsample_like(d3, d1)
437
+
438
+ d4 = self.side4(hx4d)
439
+ d4 = _upsample_like(d4, d1)
440
+
441
+ d5 = self.side5(hx5d)
442
+ d5 = _upsample_like(d5, d1)
443
+
444
+ d6 = self.side6(hx6)
445
+ d6 = _upsample_like(d6, d1)
446
+
447
+ d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
448
+
449
+ """
450
+ del hx1, hx2, hx3, hx4, hx5, hx6
451
+ del hx5d, hx4d, hx3d, hx2d, hx1d
452
+ del hx6up, hx5dup, hx4dup, hx3dup, hx2dup
453
+ """
454
+
455
+ return d0, d1, d2, d3, d4, d5, d6
456
+
457
+
458
+ ### U^2-Net small ###
459
+ class U2NETP(nn.Module):
460
+ def __init__(self, in_ch=3, out_ch=1):
461
+ super(U2NETP, self).__init__()
462
+
463
+ self.stage1 = RSU7(in_ch, 16, 64)
464
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
465
+
466
+ self.stage2 = RSU6(64, 16, 64)
467
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
468
+
469
+ self.stage3 = RSU5(64, 16, 64)
470
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
471
+
472
+ self.stage4 = RSU4(64, 16, 64)
473
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
474
+
475
+ self.stage5 = RSU4F(64, 16, 64)
476
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
477
+
478
+ self.stage6 = RSU4F(64, 16, 64)
479
+
480
+ # decoder
481
+ self.stage5d = RSU4F(128, 16, 64)
482
+ self.stage4d = RSU4(128, 16, 64)
483
+ self.stage3d = RSU5(128, 16, 64)
484
+ self.stage2d = RSU6(128, 16, 64)
485
+ self.stage1d = RSU7(128, 16, 64)
486
+
487
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
488
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
489
+ self.side3 = nn.Conv2d(64, out_ch, 3, padding=1)
490
+ self.side4 = nn.Conv2d(64, out_ch, 3, padding=1)
491
+ self.side5 = nn.Conv2d(64, out_ch, 3, padding=1)
492
+ self.side6 = nn.Conv2d(64, out_ch, 3, padding=1)
493
+
494
+ self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1)
495
+
496
+ def forward(self, x):
497
+
498
+ hx = x
499
+
500
+ # stage 1
501
+ hx1 = self.stage1(hx)
502
+ hx = self.pool12(hx1)
503
+
504
+ # stage 2
505
+ hx2 = self.stage2(hx)
506
+ hx = self.pool23(hx2)
507
+
508
+ # stage 3
509
+ hx3 = self.stage3(hx)
510
+ hx = self.pool34(hx3)
511
+
512
+ # stage 4
513
+ hx4 = self.stage4(hx)
514
+ hx = self.pool45(hx4)
515
+
516
+ # stage 5
517
+ hx5 = self.stage5(hx)
518
+ hx = self.pool56(hx5)
519
+
520
+ # stage 6
521
+ hx6 = self.stage6(hx)
522
+ hx6up = _upsample_like(hx6, hx5)
523
+
524
+ # decoder
525
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
526
+ hx5dup = _upsample_like(hx5d, hx4)
527
+
528
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
529
+ hx4dup = _upsample_like(hx4d, hx3)
530
+
531
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
532
+ hx3dup = _upsample_like(hx3d, hx2)
533
+
534
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
535
+ hx2dup = _upsample_like(hx2d, hx1)
536
+
537
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
538
+
539
+ # side output
540
+ d1 = self.side1(hx1d)
541
+
542
+ d2 = self.side2(hx2d)
543
+ d2 = _upsample_like(d2, d1)
544
+
545
+ d3 = self.side3(hx3d)
546
+ d3 = _upsample_like(d3, d1)
547
+
548
+ d4 = self.side4(hx4d)
549
+ d4 = _upsample_like(d4, d1)
550
+
551
+ d5 = self.side5(hx5d)
552
+ d5 = _upsample_like(d5, d1)
553
+
554
+ d6 = self.side6(hx6)
555
+ d6 = _upsample_like(d6, d1)
556
+
557
+ d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
558
+
559
+ """
560
+ del hx1, hx2, hx3, hx4, hx5, hx6
561
+ del hx5d, hx4d, hx3d, hx2d, hx1d
562
+ del hx6up, hx5dup, hx4dup, hx3dup, hx2dup
563
+ """
564
+
565
+ return d0, d1, d2, d3, d4, d5, d6