camenduru commited on
Commit
a11ab8b
β€’
1 Parent(s): 6bb3280

Create briarmbg.py

Browse files
Files changed (1) hide show
  1. briarmbg.py +462 -0
briarmbg.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RMBG1.4 (diffusers implementation)
2
+ # Found on huggingface space of several projects
3
+ # Not sure which project is the source of this file
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from huggingface_hub import PyTorchModelHubMixin
9
+
10
+
11
+ class REBNCONV(nn.Module):
12
+ def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
13
+ super(REBNCONV, self).__init__()
14
+
15
+ self.conv_s1 = nn.Conv2d(
16
+ in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride
17
+ )
18
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
19
+ self.relu_s1 = nn.ReLU(inplace=True)
20
+
21
+ def forward(self, x):
22
+ hx = x
23
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
24
+
25
+ return xout
26
+
27
+
28
+ def _upsample_like(src, tar):
29
+ src = F.interpolate(src, size=tar.shape[2:], mode="bilinear")
30
+ return src
31
+
32
+
33
+ ### RSU-7 ###
34
+ class RSU7(nn.Module):
35
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
36
+ super(RSU7, self).__init__()
37
+
38
+ self.in_ch = in_ch
39
+ self.mid_ch = mid_ch
40
+ self.out_ch = out_ch
41
+
42
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
43
+
44
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
45
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
46
+
47
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
48
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
49
+
50
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
51
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
52
+
53
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
54
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
55
+
56
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
57
+ self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
58
+
59
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
60
+
61
+ self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
62
+
63
+ self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
64
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
65
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
66
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
67
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
68
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
69
+
70
+ def forward(self, x):
71
+ b, c, h, w = x.shape
72
+
73
+ hx = x
74
+ hxin = self.rebnconvin(hx)
75
+
76
+ hx1 = self.rebnconv1(hxin)
77
+ hx = self.pool1(hx1)
78
+
79
+ hx2 = self.rebnconv2(hx)
80
+ hx = self.pool2(hx2)
81
+
82
+ hx3 = self.rebnconv3(hx)
83
+ hx = self.pool3(hx3)
84
+
85
+ hx4 = self.rebnconv4(hx)
86
+ hx = self.pool4(hx4)
87
+
88
+ hx5 = self.rebnconv5(hx)
89
+ hx = self.pool5(hx5)
90
+
91
+ hx6 = self.rebnconv6(hx)
92
+
93
+ hx7 = self.rebnconv7(hx6)
94
+
95
+ hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
96
+ hx6dup = _upsample_like(hx6d, hx5)
97
+
98
+ hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
99
+ hx5dup = _upsample_like(hx5d, hx4)
100
+
101
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
102
+ hx4dup = _upsample_like(hx4d, hx3)
103
+
104
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
105
+ hx3dup = _upsample_like(hx3d, hx2)
106
+
107
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
108
+ hx2dup = _upsample_like(hx2d, hx1)
109
+
110
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
111
+
112
+ return hx1d + hxin
113
+
114
+
115
+ ### RSU-6 ###
116
+ class RSU6(nn.Module):
117
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
118
+ super(RSU6, self).__init__()
119
+
120
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
121
+
122
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
123
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
124
+
125
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
126
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
127
+
128
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
129
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
130
+
131
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
132
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
133
+
134
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
135
+
136
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
137
+
138
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
139
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
140
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
141
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
142
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
143
+
144
+ def forward(self, x):
145
+ hx = x
146
+
147
+ hxin = self.rebnconvin(hx)
148
+
149
+ hx1 = self.rebnconv1(hxin)
150
+ hx = self.pool1(hx1)
151
+
152
+ hx2 = self.rebnconv2(hx)
153
+ hx = self.pool2(hx2)
154
+
155
+ hx3 = self.rebnconv3(hx)
156
+ hx = self.pool3(hx3)
157
+
158
+ hx4 = self.rebnconv4(hx)
159
+ hx = self.pool4(hx4)
160
+
161
+ hx5 = self.rebnconv5(hx)
162
+
163
+ hx6 = self.rebnconv6(hx5)
164
+
165
+ hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
166
+ hx5dup = _upsample_like(hx5d, hx4)
167
+
168
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
169
+ hx4dup = _upsample_like(hx4d, hx3)
170
+
171
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
172
+ hx3dup = _upsample_like(hx3d, hx2)
173
+
174
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
175
+ hx2dup = _upsample_like(hx2d, hx1)
176
+
177
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
178
+
179
+ return hx1d + hxin
180
+
181
+
182
+ ### RSU-5 ###
183
+ class RSU5(nn.Module):
184
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
185
+ super(RSU5, self).__init__()
186
+
187
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
188
+
189
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
190
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
191
+
192
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
193
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
194
+
195
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
196
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
197
+
198
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
199
+
200
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
201
+
202
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
203
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
204
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
205
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
206
+
207
+ def forward(self, x):
208
+ hx = x
209
+
210
+ hxin = self.rebnconvin(hx)
211
+
212
+ hx1 = self.rebnconv1(hxin)
213
+ hx = self.pool1(hx1)
214
+
215
+ hx2 = self.rebnconv2(hx)
216
+ hx = self.pool2(hx2)
217
+
218
+ hx3 = self.rebnconv3(hx)
219
+ hx = self.pool3(hx3)
220
+
221
+ hx4 = self.rebnconv4(hx)
222
+
223
+ hx5 = self.rebnconv5(hx4)
224
+
225
+ hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
226
+ hx4dup = _upsample_like(hx4d, hx3)
227
+
228
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
229
+ hx3dup = _upsample_like(hx3d, hx2)
230
+
231
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
232
+ hx2dup = _upsample_like(hx2d, hx1)
233
+
234
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
235
+
236
+ return hx1d + hxin
237
+
238
+
239
+ ### RSU-4 ###
240
+ class RSU4(nn.Module):
241
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
242
+ super(RSU4, self).__init__()
243
+
244
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
245
+
246
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
247
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
248
+
249
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
250
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
251
+
252
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
253
+
254
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
255
+
256
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
257
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
258
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
259
+
260
+ def forward(self, x):
261
+ hx = x
262
+
263
+ hxin = self.rebnconvin(hx)
264
+
265
+ hx1 = self.rebnconv1(hxin)
266
+ hx = self.pool1(hx1)
267
+
268
+ hx2 = self.rebnconv2(hx)
269
+ hx = self.pool2(hx2)
270
+
271
+ hx3 = self.rebnconv3(hx)
272
+
273
+ hx4 = self.rebnconv4(hx3)
274
+
275
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
276
+ hx3dup = _upsample_like(hx3d, hx2)
277
+
278
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
279
+ hx2dup = _upsample_like(hx2d, hx1)
280
+
281
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
282
+
283
+ return hx1d + hxin
284
+
285
+
286
+ ### RSU-4F ###
287
+ class RSU4F(nn.Module):
288
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
289
+ super(RSU4F, self).__init__()
290
+
291
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
292
+
293
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
294
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
295
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
296
+
297
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
298
+
299
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
300
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
301
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
302
+
303
+ def forward(self, x):
304
+ hx = x
305
+
306
+ hxin = self.rebnconvin(hx)
307
+
308
+ hx1 = self.rebnconv1(hxin)
309
+ hx2 = self.rebnconv2(hx1)
310
+ hx3 = self.rebnconv3(hx2)
311
+
312
+ hx4 = self.rebnconv4(hx3)
313
+
314
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
315
+ hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
316
+ hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
317
+
318
+ return hx1d + hxin
319
+
320
+
321
+ class myrebnconv(nn.Module):
322
+ def __init__(
323
+ self,
324
+ in_ch=3,
325
+ out_ch=1,
326
+ kernel_size=3,
327
+ stride=1,
328
+ padding=1,
329
+ dilation=1,
330
+ groups=1,
331
+ ):
332
+ super(myrebnconv, self).__init__()
333
+
334
+ self.conv = nn.Conv2d(
335
+ in_ch,
336
+ out_ch,
337
+ kernel_size=kernel_size,
338
+ stride=stride,
339
+ padding=padding,
340
+ dilation=dilation,
341
+ groups=groups,
342
+ )
343
+ self.bn = nn.BatchNorm2d(out_ch)
344
+ self.rl = nn.ReLU(inplace=True)
345
+
346
+ def forward(self, x):
347
+ return self.rl(self.bn(self.conv(x)))
348
+
349
+
350
+ class BriaRMBG(nn.Module, PyTorchModelHubMixin):
351
+ def __init__(self, config: dict = {"in_ch": 3, "out_ch": 1}):
352
+ super(BriaRMBG, self).__init__()
353
+ in_ch = config["in_ch"]
354
+ out_ch = config["out_ch"]
355
+ self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
356
+ self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
357
+
358
+ self.stage1 = RSU7(64, 32, 64)
359
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
360
+
361
+ self.stage2 = RSU6(64, 32, 128)
362
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
363
+
364
+ self.stage3 = RSU5(128, 64, 256)
365
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
366
+
367
+ self.stage4 = RSU4(256, 128, 512)
368
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
369
+
370
+ self.stage5 = RSU4F(512, 256, 512)
371
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
372
+
373
+ self.stage6 = RSU4F(512, 256, 512)
374
+
375
+ # decoder
376
+ self.stage5d = RSU4F(1024, 256, 512)
377
+ self.stage4d = RSU4(1024, 128, 256)
378
+ self.stage3d = RSU5(512, 64, 128)
379
+ self.stage2d = RSU6(256, 32, 64)
380
+ self.stage1d = RSU7(128, 16, 64)
381
+
382
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
383
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
384
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
385
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
386
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
387
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
388
+
389
+ # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
390
+
391
+ def forward(self, x):
392
+ hx = x
393
+
394
+ hxin = self.conv_in(hx)
395
+ # hx = self.pool_in(hxin)
396
+
397
+ # stage 1
398
+ hx1 = self.stage1(hxin)
399
+ hx = self.pool12(hx1)
400
+
401
+ # stage 2
402
+ hx2 = self.stage2(hx)
403
+ hx = self.pool23(hx2)
404
+
405
+ # stage 3
406
+ hx3 = self.stage3(hx)
407
+ hx = self.pool34(hx3)
408
+
409
+ # stage 4
410
+ hx4 = self.stage4(hx)
411
+ hx = self.pool45(hx4)
412
+
413
+ # stage 5
414
+ hx5 = self.stage5(hx)
415
+ hx = self.pool56(hx5)
416
+
417
+ # stage 6
418
+ hx6 = self.stage6(hx)
419
+ hx6up = _upsample_like(hx6, hx5)
420
+
421
+ # -------------------- decoder --------------------
422
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
423
+ hx5dup = _upsample_like(hx5d, hx4)
424
+
425
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
426
+ hx4dup = _upsample_like(hx4d, hx3)
427
+
428
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
429
+ hx3dup = _upsample_like(hx3d, hx2)
430
+
431
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
432
+ hx2dup = _upsample_like(hx2d, hx1)
433
+
434
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
435
+
436
+ # side output
437
+ d1 = self.side1(hx1d)
438
+ d1 = _upsample_like(d1, x)
439
+
440
+ d2 = self.side2(hx2d)
441
+ d2 = _upsample_like(d2, x)
442
+
443
+ d3 = self.side3(hx3d)
444
+ d3 = _upsample_like(d3, x)
445
+
446
+ d4 = self.side4(hx4d)
447
+ d4 = _upsample_like(d4, x)
448
+
449
+ d5 = self.side5(hx5d)
450
+ d5 = _upsample_like(d5, x)
451
+
452
+ d6 = self.side6(hx6)
453
+ d6 = _upsample_like(d6, x)
454
+
455
+ return [
456
+ F.sigmoid(d1),
457
+ F.sigmoid(d2),
458
+ F.sigmoid(d3),
459
+ F.sigmoid(d4),
460
+ F.sigmoid(d5),
461
+ F.sigmoid(d6),
462
+ ], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]