nikunjkdtechnoland commited on
Commit
e9d702e
1 Parent(s): d06defe

add bg remover

Browse files
Files changed (5) hide show
  1. app.py +1 -1
  2. bgremove/bg_remove_cnn.py +454 -0
  3. bgremover.py +53 -0
  4. only_gradio_server.py +8 -27
  5. requirements.txt +1 -2
app.py CHANGED
@@ -10,7 +10,7 @@ options_list = list(object_names.values())
10
  # Create Gradio interface
11
  iface = gr.Interface(fn=process_images,
12
  inputs=[gr.Image(type='filepath', label='Main Image where object identify', width="60%", height="60%"),
13
- gr.Image(type='filepath', label='Object Image which placed on Main Image (PNG file only RGBA Channel)', image_mode="RGBA", width="60%", height="60%"),
14
  gr.Dropdown(options_list, label='Replace Object Name (Default = chair)')],
15
  outputs=gr.Image(type='numpy', label='Final Result', width="60%", height="60%"),
16
  title="AI Based Image Processing",
 
10
  # Create Gradio interface
11
  iface = gr.Interface(fn=process_images,
12
  inputs=[gr.Image(type='filepath', label='Main Image where object identify', width="60%", height="60%"),
13
+ gr.Image(type='filepath', label='Object Image which placed on Main Image', image_mode="RGBA", width="60%", height="60%"),
14
  gr.Dropdown(options_list, label='Replace Object Name (Default = chair)')],
15
  outputs=gr.Image(type='numpy', label='Final Result', width="60%", height="60%"),
16
  title="AI Based Image Processing",
bgremove/bg_remove_cnn.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class REBNCONV(nn.Module):
6
+ def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
7
+ super(REBNCONV,self).__init__()
8
+
9
+ self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate,stride=stride)
10
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
11
+ self.relu_s1 = nn.ReLU(inplace=True)
12
+
13
+ def forward(self,x):
14
+
15
+ hx = x
16
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
17
+
18
+ return xout
19
+
20
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
21
+ def _upsample_like(src,tar):
22
+
23
+ src = F.interpolate(src,size=tar.shape[2:],mode='bilinear')
24
+
25
+ return src
26
+
27
+
28
+ ### RSU-7 ###
29
+ class RSU7(nn.Module):
30
+
31
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
32
+ super(RSU7,self).__init__()
33
+
34
+ self.in_ch = in_ch
35
+ self.mid_ch = mid_ch
36
+ self.out_ch = out_ch
37
+
38
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) ## 1 -> 1/2
39
+
40
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
41
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
42
+
43
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
44
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
45
+
46
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
47
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
48
+
49
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
50
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
51
+
52
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
53
+ self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
54
+
55
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
56
+
57
+ self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
58
+
59
+ self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
60
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
61
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
62
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
63
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
64
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
65
+
66
+ def forward(self,x):
67
+ b, c, h, w = x.shape
68
+
69
+ hx = x
70
+ hxin = self.rebnconvin(hx)
71
+
72
+ hx1 = self.rebnconv1(hxin)
73
+ hx = self.pool1(hx1)
74
+
75
+ hx2 = self.rebnconv2(hx)
76
+ hx = self.pool2(hx2)
77
+
78
+ hx3 = self.rebnconv3(hx)
79
+ hx = self.pool3(hx3)
80
+
81
+ hx4 = self.rebnconv4(hx)
82
+ hx = self.pool4(hx4)
83
+
84
+ hx5 = self.rebnconv5(hx)
85
+ hx = self.pool5(hx5)
86
+
87
+ hx6 = self.rebnconv6(hx)
88
+
89
+ hx7 = self.rebnconv7(hx6)
90
+
91
+ hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
92
+ hx6dup = _upsample_like(hx6d,hx5)
93
+
94
+ hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
95
+ hx5dup = _upsample_like(hx5d,hx4)
96
+
97
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
98
+ hx4dup = _upsample_like(hx4d,hx3)
99
+
100
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
101
+ hx3dup = _upsample_like(hx3d,hx2)
102
+
103
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
104
+ hx2dup = _upsample_like(hx2d,hx1)
105
+
106
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
107
+
108
+ return hx1d + hxin
109
+
110
+
111
+ ### RSU-6 ###
112
+ class RSU6(nn.Module):
113
+
114
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
115
+ super(RSU6,self).__init__()
116
+
117
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
118
+
119
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
120
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
121
+
122
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
123
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
124
+
125
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
126
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
127
+
128
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
129
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
130
+
131
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
132
+
133
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)
134
+
135
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
136
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
137
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
138
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
139
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
140
+
141
+ def forward(self,x):
142
+
143
+ hx = x
144
+
145
+ hxin = self.rebnconvin(hx)
146
+
147
+ hx1 = self.rebnconv1(hxin)
148
+ hx = self.pool1(hx1)
149
+
150
+ hx2 = self.rebnconv2(hx)
151
+ hx = self.pool2(hx2)
152
+
153
+ hx3 = self.rebnconv3(hx)
154
+ hx = self.pool3(hx3)
155
+
156
+ hx4 = self.rebnconv4(hx)
157
+ hx = self.pool4(hx4)
158
+
159
+ hx5 = self.rebnconv5(hx)
160
+
161
+ hx6 = self.rebnconv6(hx5)
162
+
163
+
164
+ hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
165
+ hx5dup = _upsample_like(hx5d,hx4)
166
+
167
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
168
+ hx4dup = _upsample_like(hx4d,hx3)
169
+
170
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
171
+ hx3dup = _upsample_like(hx3d,hx2)
172
+
173
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
174
+ hx2dup = _upsample_like(hx2d,hx1)
175
+
176
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
177
+
178
+ return hx1d + hxin
179
+
180
+ ### RSU-5 ###
181
+ class RSU5(nn.Module):
182
+
183
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
184
+ super(RSU5,self).__init__()
185
+
186
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
187
+
188
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
189
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
190
+
191
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
192
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
193
+
194
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
195
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
196
+
197
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
198
+
199
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)
200
+
201
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
202
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
203
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
204
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
205
+
206
+ def forward(self,x):
207
+
208
+ hx = x
209
+
210
+ hxin = self.rebnconvin(hx)
211
+
212
+ hx1 = self.rebnconv1(hxin)
213
+ hx = self.pool1(hx1)
214
+
215
+ hx2 = self.rebnconv2(hx)
216
+ hx = self.pool2(hx2)
217
+
218
+ hx3 = self.rebnconv3(hx)
219
+ hx = self.pool3(hx3)
220
+
221
+ hx4 = self.rebnconv4(hx)
222
+
223
+ hx5 = self.rebnconv5(hx4)
224
+
225
+ hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
226
+ hx4dup = _upsample_like(hx4d,hx3)
227
+
228
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
229
+ hx3dup = _upsample_like(hx3d,hx2)
230
+
231
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
232
+ hx2dup = _upsample_like(hx2d,hx1)
233
+
234
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
235
+
236
+ return hx1d + hxin
237
+
238
+ ### RSU-4 ###
239
+ class RSU4(nn.Module):
240
+
241
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
242
+ super(RSU4,self).__init__()
243
+
244
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
245
+
246
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
247
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
248
+
249
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
250
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
251
+
252
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
253
+
254
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)
255
+
256
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
257
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
258
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
259
+
260
+ def forward(self,x):
261
+
262
+ hx = x
263
+
264
+ hxin = self.rebnconvin(hx)
265
+
266
+ hx1 = self.rebnconv1(hxin)
267
+ hx = self.pool1(hx1)
268
+
269
+ hx2 = self.rebnconv2(hx)
270
+ hx = self.pool2(hx2)
271
+
272
+ hx3 = self.rebnconv3(hx)
273
+
274
+ hx4 = self.rebnconv4(hx3)
275
+
276
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
277
+ hx3dup = _upsample_like(hx3d,hx2)
278
+
279
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
280
+ hx2dup = _upsample_like(hx2d,hx1)
281
+
282
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
283
+
284
+ return hx1d + hxin
285
+
286
+ ### RSU-4F ###
287
+ class RSU4F(nn.Module):
288
+
289
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
290
+ super(RSU4F,self).__init__()
291
+
292
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
293
+
294
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
295
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
296
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
297
+
298
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)
299
+
300
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
301
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
302
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
303
+
304
+ def forward(self,x):
305
+
306
+ hx = x
307
+
308
+ hxin = self.rebnconvin(hx)
309
+
310
+ hx1 = self.rebnconv1(hxin)
311
+ hx2 = self.rebnconv2(hx1)
312
+ hx3 = self.rebnconv3(hx2)
313
+
314
+ hx4 = self.rebnconv4(hx3)
315
+
316
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
317
+ hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
318
+ hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
319
+
320
+ return hx1d + hxin
321
+
322
+
323
+ class myrebnconv(nn.Module):
324
+ def __init__(self, in_ch=3,
325
+ out_ch=1,
326
+ kernel_size=3,
327
+ stride=1,
328
+ padding=1,
329
+ dilation=1,
330
+ groups=1):
331
+ super(myrebnconv,self).__init__()
332
+
333
+ self.conv = nn.Conv2d(in_ch,
334
+ out_ch,
335
+ kernel_size=kernel_size,
336
+ stride=stride,
337
+ padding=padding,
338
+ dilation=dilation,
339
+ groups=groups)
340
+ self.bn = nn.BatchNorm2d(out_ch)
341
+ self.rl = nn.ReLU(inplace=True)
342
+
343
+ def forward(self,x):
344
+ return self.rl(self.bn(self.conv(x)))
345
+
346
+
347
+ class BriaRMBG(nn.Module):
348
+
349
+ def __init__(self,in_ch=3,out_ch=1):
350
+ super(BriaRMBG,self).__init__()
351
+
352
+ self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1)
353
+ self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True)
354
+
355
+ self.stage1 = RSU7(64,32,64)
356
+ self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
357
+
358
+ self.stage2 = RSU6(64,32,128)
359
+ self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
360
+
361
+ self.stage3 = RSU5(128,64,256)
362
+ self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
363
+
364
+ self.stage4 = RSU4(256,128,512)
365
+ self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
366
+
367
+ self.stage5 = RSU4F(512,256,512)
368
+ self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
369
+
370
+ self.stage6 = RSU4F(512,256,512)
371
+
372
+ # decoder
373
+ self.stage5d = RSU4F(1024,256,512)
374
+ self.stage4d = RSU4(1024,128,256)
375
+ self.stage3d = RSU5(512,64,128)
376
+ self.stage2d = RSU6(256,32,64)
377
+ self.stage1d = RSU7(128,16,64)
378
+
379
+ self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
380
+ self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
381
+ self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
382
+ self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
383
+ self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
384
+ self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
385
+
386
+ # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
387
+
388
+ def forward(self,x):
389
+
390
+ hx = x
391
+
392
+ hxin = self.conv_in(hx)
393
+ #hx = self.pool_in(hxin)
394
+
395
+ #stage 1
396
+ hx1 = self.stage1(hxin)
397
+ hx = self.pool12(hx1)
398
+
399
+ #stage 2
400
+ hx2 = self.stage2(hx)
401
+ hx = self.pool23(hx2)
402
+
403
+ #stage 3
404
+ hx3 = self.stage3(hx)
405
+ hx = self.pool34(hx3)
406
+
407
+ #stage 4
408
+ hx4 = self.stage4(hx)
409
+ hx = self.pool45(hx4)
410
+
411
+ #stage 5
412
+ hx5 = self.stage5(hx)
413
+ hx = self.pool56(hx5)
414
+
415
+ #stage 6
416
+ hx6 = self.stage6(hx)
417
+ hx6up = _upsample_like(hx6,hx5)
418
+
419
+ #-------------------- decoder --------------------
420
+ hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
421
+ hx5dup = _upsample_like(hx5d,hx4)
422
+
423
+ hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
424
+ hx4dup = _upsample_like(hx4d,hx3)
425
+
426
+ hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
427
+ hx3dup = _upsample_like(hx3d,hx2)
428
+
429
+ hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
430
+ hx2dup = _upsample_like(hx2d,hx1)
431
+
432
+ hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
433
+
434
+
435
+ #side output
436
+ d1 = self.side1(hx1d)
437
+ d1 = _upsample_like(d1,x)
438
+
439
+ d2 = self.side2(hx2d)
440
+ d2 = _upsample_like(d2,x)
441
+
442
+ d3 = self.side3(hx3d)
443
+ d3 = _upsample_like(d3,x)
444
+
445
+ d4 = self.side4(hx4d)
446
+ d4 = _upsample_like(d4,x)
447
+
448
+ d5 = self.side5(hx5d)
449
+ d5 = _upsample_like(d5,x)
450
+
451
+ d6 = self.side6(hx6)
452
+ d6 = _upsample_like(d6,x)
453
+
454
+ return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)],[hx1d,hx2d,hx3d,hx4d,hx5d,hx6]
bgremover.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torchvision.transforms.functional import normalize
5
+ from bgremove.bg_remove_cnn import BriaRMBG
6
+ from PIL import Image
7
+
8
+ net = BriaRMBG()
9
+ model_path = "./pretrained-model/bgremove.pth"
10
+
11
+ if torch.cuda.is_available():
12
+ net.load_state_dict(torch.load(model_path))
13
+ net = net.cuda()
14
+ else:
15
+ net.load_state_dict(torch.load(model_path, map_location="cpu"))
16
+ net.eval()
17
+
18
+
19
+ def resize_image(image):
20
+ image = image.convert('RGB')
21
+ model_input_size = (1024, 1024)
22
+ image = image.resize(model_input_size, Image.BILINEAR)
23
+ return image
24
+
25
+
26
+ def process(image):
27
+ # prepare input
28
+ orig_image = Image.fromarray(image)
29
+ w, h = orig_im_size = orig_image.size
30
+ image = resize_image(orig_image)
31
+ im_np = np.array(image)
32
+ im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
33
+ im_tensor = torch.unsqueeze(im_tensor, 0)
34
+ im_tensor = torch.divide(im_tensor, 255.0)
35
+ im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
36
+ if torch.cuda.is_available():
37
+ im_tensor = im_tensor.cuda()
38
+
39
+ # inference
40
+ result = net(im_tensor)
41
+ # post process
42
+ result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
43
+ ma = torch.max(result)
44
+ mi = torch.min(result)
45
+ result = (result - mi) / (ma - mi)
46
+ # image to pil
47
+ im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
48
+ pil_im = Image.fromarray(np.squeeze(im_array))
49
+ # paste the mask on the original image
50
+ new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
51
+ new_im.paste(orig_image, mask=pil_im)
52
+ # new_orig_image = orig_image.convert('RGBA')
53
+ return new_im
only_gradio_server.py CHANGED
@@ -1,19 +1,12 @@
1
  import os
2
- import base64
3
  import io
4
- import uuid
5
  from ultralytics import YOLO
6
  import cv2
7
- import torch
8
  import numpy as np
9
  from PIL import Image
10
- from torchvision import transforms
11
- import imageio.v2 as imageio
12
- from utils.tools import get_config
13
- import torch.nn.functional as F
14
  from iopaint.single_processing import batch_inpaint_cv2
15
- from pathlib import Path
16
  import gradio as gr
 
17
 
18
  # set current working directory cache instead of default
19
  os.environ["TORCH_HOME"] = "./pretrained-model"
@@ -60,18 +53,6 @@ def process_images(input_image, append_image, default_class="chair"):
60
  if not append_image:
61
  raise gr.Error("Please upload an object image.")
62
 
63
- # Check if the append_image is a PNG file with RGBA mode
64
- try:
65
- with Image.open(append_image) as img:
66
- if img.format != 'PNG' or img.mode != 'RGBA':
67
- raise gr.Error("Please upload a valid PNG file with RGBA mode for the object image.")
68
- except Exception as e:
69
- raise gr.Error("Failed to validate object image: Upload new image")
70
-
71
- # Static paths
72
- config_path = Path('configs/config.yaml')
73
- model_path = Path('pretrained-model/torch_model.p')
74
-
75
  # Resize input image and get base64 data of resized image
76
  img = resize_image(input_image)
77
 
@@ -121,7 +102,7 @@ def process_images(input_image, append_image, default_class="chair"):
121
  resized_mask = cv2.resize(dilated_mask, (img.shape[1], img.shape[0]))
122
 
123
  # call repainting and merge function
124
- output_numpy = repaitingAndMerge(append_image,str(model_path), str(config_path),width, height, x_point, y_point, img, resized_mask)
125
  # Return the output numpy image in the API response
126
  return output_numpy
127
 
@@ -129,10 +110,7 @@ def process_images(input_image, append_image, default_class="chair"):
129
  if not class_found:
130
  raise gr.Error(f'{default_class} object not found in the image')
131
 
132
- def repaitingAndMerge(append_image_path, model_path, config_path, width, height, xposition, yposition, input_base, mask_base):
133
- config = get_config(config_path)
134
- device = torch.device("cpu")
135
-
136
  # lama inpainting start
137
  print("lama inpainting start")
138
  inpaint_result_np = batch_inpaint_cv2('lama', 'cpu', input_base, mask_base)
@@ -148,14 +126,17 @@ def repaitingAndMerge(append_image_path, model_path, config_path, width, height,
148
  resized_image = cv2.resize(append_image, (width, height), interpolation=cv2.INTER_AREA)
149
  # Convert the resized image to RGBA format (assuming it's in BGRA format)
150
  resized_image = cv2.cvtColor(resized_image, cv2.COLOR_BGRA2RGBA)
 
151
  # Create a PIL Image from the resized image with transparent background
152
- append_image_pil = Image.fromarray(resized_image)
 
 
 
153
 
154
  # Paste the append image onto the final image
155
  final_image.paste(append_image_pil, (xposition, yposition), append_image_pil)
156
  # Save the resulting image
157
  print("merge end")
158
-
159
  # Convert the final image to base64
160
  with io.BytesIO() as output_buffer:
161
  final_image.save(output_buffer, format='PNG')
 
1
  import os
 
2
  import io
 
3
  from ultralytics import YOLO
4
  import cv2
 
5
  import numpy as np
6
  from PIL import Image
 
 
 
 
7
  from iopaint.single_processing import batch_inpaint_cv2
 
8
  import gradio as gr
9
+ from bgremover import process
10
 
11
  # set current working directory cache instead of default
12
  os.environ["TORCH_HOME"] = "./pretrained-model"
 
53
  if not append_image:
54
  raise gr.Error("Please upload an object image.")
55
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  # Resize input image and get base64 data of resized image
57
  img = resize_image(input_image)
58
 
 
102
  resized_mask = cv2.resize(dilated_mask, (img.shape[1], img.shape[0]))
103
 
104
  # call repainting and merge function
105
+ output_numpy = repaitingAndMerge(append_image,width, height, x_point, y_point, img, resized_mask)
106
  # Return the output numpy image in the API response
107
  return output_numpy
108
 
 
110
  if not class_found:
111
  raise gr.Error(f'{default_class} object not found in the image')
112
 
113
+ def repaitingAndMerge(append_image_path, width, height, xposition, yposition, input_base, mask_base):
 
 
 
114
  # lama inpainting start
115
  print("lama inpainting start")
116
  inpaint_result_np = batch_inpaint_cv2('lama', 'cpu', input_base, mask_base)
 
126
  resized_image = cv2.resize(append_image, (width, height), interpolation=cv2.INTER_AREA)
127
  # Convert the resized image to RGBA format (assuming it's in BGRA format)
128
  resized_image = cv2.cvtColor(resized_image, cv2.COLOR_BGRA2RGBA)
129
+
130
  # Create a PIL Image from the resized image with transparent background
131
+ #append_image_pil = Image.fromarray(resized_image)
132
+
133
+ # remove the bg from image
134
+ append_image_pil = process(resized_image)
135
 
136
  # Paste the append image onto the final image
137
  final_image.paste(append_image_pil, (xposition, yposition), append_image_pil)
138
  # Save the resulting image
139
  print("merge end")
 
140
  # Convert the final image to base64
141
  with io.BytesIO() as output_buffer:
142
  final_image.save(output_buffer, format='PNG')
requirements.txt CHANGED
@@ -23,5 +23,4 @@ typer-config==1.4.0
23
  Pillow==9.5.0
24
  ultralytics
25
  flask
26
- flask_cors
27
- trainer
 
23
  Pillow==9.5.0
24
  ultralytics
25
  flask
26
+ flask_cors