JCTN commited on
Commit
63fa184
1 Parent(s): 138576e

Upload 11 files

Browse files
Files changed (12) hide show
  1. .gitattributes +2 -0
  2. README.md +138 -0
  3. briarmbg.py +457 -0
  4. example_inference.py +39 -0
  5. example_input.jpg +0 -0
  6. gitattributes +41 -0
  7. model.pth +3 -0
  8. pytorch_model.bin +3 -0
  9. requirements.txt +7 -0
  10. results.png +3 -0
  11. t4.png +3 -0
  12. utilities.py +25 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ results.png filter=lfs diff=lfs merge=lfs -text
37
+ t4.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: other
3
+ license_name: bria-rmbg-1.4
4
+ license_link: https://bria.ai/bria-huggingface-model-license-agreement/
5
+ pipeline_tag: image-to-image
6
+ tags:
7
+ - remove background
8
+ - background
9
+ - background-removal
10
+ - Pytorch
11
+ - vision
12
+ - legal liability
13
+
14
+ extra_gated_prompt: This model weights by BRIA AI can be obtained after a commercial license is agreed upon. Fill in the form below and we reach out to you.
15
+ extra_gated_fields:
16
+ Name: text
17
+ Company/Org name: text
18
+ Org Type (Early/Growth Startup, Enterprise, Academy): text
19
+ Role: text
20
+ Country: text
21
+ Email: text
22
+ By submitting this form, I agree to BRIA’s Privacy policy and Terms & conditions, see links below: checkbox
23
+ ---
24
+
25
+ # BRIA Background Removal v1.4 Model Card
26
+
27
+ RMBG v1.4 is our state-of-the-art background removal model, designed to effectively separate foreground from background in a range of
28
+ categories and image types. This model has been trained on a carefully selected dataset, which includes:
29
+ general stock images, e-commerce, gaming, and advertising content, making it suitable for commercial use cases powering enterprise content creation at scale.
30
+ The accuracy, efficiency, and versatility currently rival leading source-available models.
31
+ It is ideal where content safety, legally licensed datasets, and bias mitigation are paramount.
32
+
33
+ Developed by BRIA AI, RMBG v1.4 is available as a source-available model for non-commercial use.
34
+
35
+ [CLICK HERE FOR A DEMO](https://huggingface.co/spaces/briaai/BRIA-RMBG-1.4)
36
+ ![examples](t4.png)
37
+
38
+ ### Model Description
39
+
40
+ - **Developed by:** [BRIA AI](https://bria.ai/)
41
+ - **Model type:** Background Removal
42
+ - **License:** [bria-rmbg-1.4](https://bria.ai/bria-huggingface-model-license-agreement/)
43
+ - The model is released under a Creative Commons license for non-commercial use.
44
+ - Commercial use is subject to a commercial agreement with BRIA. [Contact Us](https://bria.ai/contact-us) for more information.
45
+
46
+ - **Model Description:** BRIA RMBG 1.4 is a saliency segmentation model trained exclusively on a professional-grade dataset.
47
+ - **BRIA:** Resources for more information: [BRIA AI](https://bria.ai/)
48
+
49
+
50
+
51
+ ## Training data
52
+ Bria-RMBG model was trained with over 12,000 high-quality, high-resolution, manually labeled (pixel-wise accuracy), fully licensed images.
53
+ Our benchmark included balanced gender, balanced ethnicity, and people with different types of disabilities.
54
+ For clarity, we provide our data distribution according to different categories, demonstrating our model’s versatility.
55
+
56
+ ### Distribution of images:
57
+
58
+ | Category | Distribution |
59
+ | -----------------------------------| -----------------------------------:|
60
+ | Objects only | 45.11% |
61
+ | People with objects/animals | 25.24% |
62
+ | People only | 17.35% |
63
+ | people/objects/animals with text | 8.52% |
64
+ | Text only | 2.52% |
65
+ | Animals only | 1.89% |
66
+
67
+ | Category | Distribution |
68
+ | -----------------------------------| -----------------------------------------:|
69
+ | Photorealistic | 87.70% |
70
+ | Non-Photorealistic | 12.30% |
71
+
72
+
73
+ | Category | Distribution |
74
+ | -----------------------------------| -----------------------------------:|
75
+ | Non Solid Background | 52.05% |
76
+ | Solid Background | 47.95%
77
+
78
+
79
+ | Category | Distribution |
80
+ | -----------------------------------| -----------------------------------:|
81
+ | Single main foreground object | 51.42% |
82
+ | Multiple objects in the foreground | 48.58% |
83
+
84
+
85
+ ## Qualitative Evaluation
86
+
87
+ ![examples](results.png)
88
+
89
+
90
+ ## Architecture
91
+
92
+ RMBG v1.4 is developed on the [IS-Net](https://github.com/xuebinqin/DIS) enhanced with our unique training scheme and proprietary dataset.
93
+ These modifications significantly improve the model’s accuracy and effectiveness in diverse image-processing scenarios.
94
+
95
+ ## Installation
96
+ ```bash
97
+ git clone https://huggingface.co/briaai/RMBG-1.4
98
+ cd RMBG-1.4/
99
+ pip install -r requirements.txt
100
+ ```
101
+
102
+ ## Usage
103
+
104
+ ```python
105
+ from skimage import io
106
+ import torch, os
107
+ from PIL import Image
108
+ from briarmbg import BriaRMBG
109
+ from utilities import preprocess_image, postprocess_image
110
+ from huggingface_hub import hf_hub_download
111
+
112
+ im_path = f"{os.path.dirname(os.path.abspath(__file__))}/example_input.jpg"
113
+
114
+ net = BriaRMBG()
115
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
116
+ net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
117
+ net.to(device)
118
+ net.eval()
119
+
120
+ # prepare input
121
+ model_input_size = [1024,1024]
122
+ orig_im = io.imread(im_path)
123
+ orig_im_size = orig_im.shape[0:2]
124
+ image = preprocess_image(orig_im, model_input_size).to(device)
125
+
126
+ # inference
127
+ result=net(image)
128
+
129
+ # post process
130
+ result_image = postprocess_image(result[0][0], orig_im_size)
131
+
132
+ # save result
133
+ pil_im = Image.fromarray(result_image)
134
+ no_bg_image = Image.new("RGBA", pil_im.size, (0,0,0,0))
135
+ orig_image = Image.open(im_path)
136
+ no_bg_image.paste(orig_image, mask=pil_im)
137
+ no_bg_image.save("example_image_no_bg.png")
138
+ ```
briarmbg.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from huggingface_hub import PyTorchModelHubMixin
5
+
6
+ class REBNCONV(nn.Module):
7
+ def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
8
+ super(REBNCONV,self).__init__()
9
+
10
+ self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate,stride=stride)
11
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
12
+ self.relu_s1 = nn.ReLU(inplace=True)
13
+
14
+ def forward(self,x):
15
+
16
+ hx = x
17
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
18
+
19
+ return xout
20
+
21
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
22
+ def _upsample_like(src,tar):
23
+
24
+ src = F.interpolate(src,size=tar.shape[2:],mode='bilinear')
25
+
26
+ return src
27
+
28
+
29
+ ### RSU-7 ###
30
+ class RSU7(nn.Module):
31
+
32
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
33
+ super(RSU7,self).__init__()
34
+
35
+ self.in_ch = in_ch
36
+ self.mid_ch = mid_ch
37
+ self.out_ch = out_ch
38
+
39
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) ## 1 -> 1/2
40
+
41
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
42
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
43
+
44
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
45
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
46
+
47
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
48
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
49
+
50
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
51
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
52
+
53
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
54
+ self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
55
+
56
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
57
+
58
+ self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
59
+
60
+ self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
61
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
62
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
63
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
64
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
65
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
66
+
67
+ def forward(self,x):
68
+ b, c, h, w = x.shape
69
+
70
+ hx = x
71
+ hxin = self.rebnconvin(hx)
72
+
73
+ hx1 = self.rebnconv1(hxin)
74
+ hx = self.pool1(hx1)
75
+
76
+ hx2 = self.rebnconv2(hx)
77
+ hx = self.pool2(hx2)
78
+
79
+ hx3 = self.rebnconv3(hx)
80
+ hx = self.pool3(hx3)
81
+
82
+ hx4 = self.rebnconv4(hx)
83
+ hx = self.pool4(hx4)
84
+
85
+ hx5 = self.rebnconv5(hx)
86
+ hx = self.pool5(hx5)
87
+
88
+ hx6 = self.rebnconv6(hx)
89
+
90
+ hx7 = self.rebnconv7(hx6)
91
+
92
+ hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
93
+ hx6dup = _upsample_like(hx6d,hx5)
94
+
95
+ hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
96
+ hx5dup = _upsample_like(hx5d,hx4)
97
+
98
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
99
+ hx4dup = _upsample_like(hx4d,hx3)
100
+
101
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
102
+ hx3dup = _upsample_like(hx3d,hx2)
103
+
104
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
105
+ hx2dup = _upsample_like(hx2d,hx1)
106
+
107
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
108
+
109
+ return hx1d + hxin
110
+
111
+
112
+ ### RSU-6 ###
113
+ class RSU6(nn.Module):
114
+
115
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
116
+ super(RSU6,self).__init__()
117
+
118
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
119
+
120
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
121
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
122
+
123
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
124
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
125
+
126
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
127
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
128
+
129
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
130
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
131
+
132
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
133
+
134
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)
135
+
136
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
137
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
138
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
139
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
140
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
141
+
142
+ def forward(self,x):
143
+
144
+ hx = x
145
+
146
+ hxin = self.rebnconvin(hx)
147
+
148
+ hx1 = self.rebnconv1(hxin)
149
+ hx = self.pool1(hx1)
150
+
151
+ hx2 = self.rebnconv2(hx)
152
+ hx = self.pool2(hx2)
153
+
154
+ hx3 = self.rebnconv3(hx)
155
+ hx = self.pool3(hx3)
156
+
157
+ hx4 = self.rebnconv4(hx)
158
+ hx = self.pool4(hx4)
159
+
160
+ hx5 = self.rebnconv5(hx)
161
+
162
+ hx6 = self.rebnconv6(hx5)
163
+
164
+
165
+ hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
166
+ hx5dup = _upsample_like(hx5d,hx4)
167
+
168
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
169
+ hx4dup = _upsample_like(hx4d,hx3)
170
+
171
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
172
+ hx3dup = _upsample_like(hx3d,hx2)
173
+
174
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
175
+ hx2dup = _upsample_like(hx2d,hx1)
176
+
177
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
178
+
179
+ return hx1d + hxin
180
+
181
+ ### RSU-5 ###
182
+ class RSU5(nn.Module):
183
+
184
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
185
+ super(RSU5,self).__init__()
186
+
187
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
188
+
189
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
190
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
191
+
192
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
193
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
194
+
195
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
196
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
197
+
198
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
199
+
200
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)
201
+
202
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
203
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
204
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
205
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
206
+
207
+ def forward(self,x):
208
+
209
+ hx = x
210
+
211
+ hxin = self.rebnconvin(hx)
212
+
213
+ hx1 = self.rebnconv1(hxin)
214
+ hx = self.pool1(hx1)
215
+
216
+ hx2 = self.rebnconv2(hx)
217
+ hx = self.pool2(hx2)
218
+
219
+ hx3 = self.rebnconv3(hx)
220
+ hx = self.pool3(hx3)
221
+
222
+ hx4 = self.rebnconv4(hx)
223
+
224
+ hx5 = self.rebnconv5(hx4)
225
+
226
+ hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
227
+ hx4dup = _upsample_like(hx4d,hx3)
228
+
229
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
230
+ hx3dup = _upsample_like(hx3d,hx2)
231
+
232
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
233
+ hx2dup = _upsample_like(hx2d,hx1)
234
+
235
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
236
+
237
+ return hx1d + hxin
238
+
239
+ ### RSU-4 ###
240
+ class RSU4(nn.Module):
241
+
242
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
243
+ super(RSU4,self).__init__()
244
+
245
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
246
+
247
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
248
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
249
+
250
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
251
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
252
+
253
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
254
+
255
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)
256
+
257
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
258
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
259
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
260
+
261
+ def forward(self,x):
262
+
263
+ hx = x
264
+
265
+ hxin = self.rebnconvin(hx)
266
+
267
+ hx1 = self.rebnconv1(hxin)
268
+ hx = self.pool1(hx1)
269
+
270
+ hx2 = self.rebnconv2(hx)
271
+ hx = self.pool2(hx2)
272
+
273
+ hx3 = self.rebnconv3(hx)
274
+
275
+ hx4 = self.rebnconv4(hx3)
276
+
277
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
278
+ hx3dup = _upsample_like(hx3d,hx2)
279
+
280
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
281
+ hx2dup = _upsample_like(hx2d,hx1)
282
+
283
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
284
+
285
+ return hx1d + hxin
286
+
287
+ ### RSU-4F ###
288
+ class RSU4F(nn.Module):
289
+
290
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
291
+ super(RSU4F,self).__init__()
292
+
293
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
294
+
295
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
296
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
297
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
298
+
299
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)
300
+
301
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
302
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
303
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
304
+
305
+ def forward(self,x):
306
+
307
+ hx = x
308
+
309
+ hxin = self.rebnconvin(hx)
310
+
311
+ hx1 = self.rebnconv1(hxin)
312
+ hx2 = self.rebnconv2(hx1)
313
+ hx3 = self.rebnconv3(hx2)
314
+
315
+ hx4 = self.rebnconv4(hx3)
316
+
317
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
318
+ hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
319
+ hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
320
+
321
+ return hx1d + hxin
322
+
323
+
324
+ class myrebnconv(nn.Module):
325
+ def __init__(self, in_ch=3,
326
+ out_ch=1,
327
+ kernel_size=3,
328
+ stride=1,
329
+ padding=1,
330
+ dilation=1,
331
+ groups=1):
332
+ super(myrebnconv,self).__init__()
333
+
334
+ self.conv = nn.Conv2d(in_ch,
335
+ out_ch,
336
+ kernel_size=kernel_size,
337
+ stride=stride,
338
+ padding=padding,
339
+ dilation=dilation,
340
+ groups=groups)
341
+ self.bn = nn.BatchNorm2d(out_ch)
342
+ self.rl = nn.ReLU(inplace=True)
343
+
344
+ def forward(self,x):
345
+ return self.rl(self.bn(self.conv(x)))
346
+
347
+
348
+ class BriaRMBG(nn.Module, PyTorchModelHubMixin):
349
+
350
+ def __init__(self,config:dict={"in_ch":3,"out_ch":1}):
351
+ super(BriaRMBG,self).__init__()
352
+ in_ch=config["in_ch"]
353
+ out_ch=config["out_ch"]
354
+ self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1)
355
+ self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True)
356
+
357
+ self.stage1 = RSU7(64,32,64)
358
+ self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
359
+
360
+ self.stage2 = RSU6(64,32,128)
361
+ self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
362
+
363
+ self.stage3 = RSU5(128,64,256)
364
+ self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
365
+
366
+ self.stage4 = RSU4(256,128,512)
367
+ self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
368
+
369
+ self.stage5 = RSU4F(512,256,512)
370
+ self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
371
+
372
+ self.stage6 = RSU4F(512,256,512)
373
+
374
+ # decoder
375
+ self.stage5d = RSU4F(1024,256,512)
376
+ self.stage4d = RSU4(1024,128,256)
377
+ self.stage3d = RSU5(512,64,128)
378
+ self.stage2d = RSU6(256,32,64)
379
+ self.stage1d = RSU7(128,16,64)
380
+
381
+ self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
382
+ self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
383
+ self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
384
+ self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
385
+ self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
386
+ self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
387
+
388
+ # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
389
+
390
+ def forward(self,x):
391
+
392
+ hx = x
393
+
394
+ hxin = self.conv_in(hx)
395
+ #hx = self.pool_in(hxin)
396
+
397
+ #stage 1
398
+ hx1 = self.stage1(hxin)
399
+ hx = self.pool12(hx1)
400
+
401
+ #stage 2
402
+ hx2 = self.stage2(hx)
403
+ hx = self.pool23(hx2)
404
+
405
+ #stage 3
406
+ hx3 = self.stage3(hx)
407
+ hx = self.pool34(hx3)
408
+
409
+ #stage 4
410
+ hx4 = self.stage4(hx)
411
+ hx = self.pool45(hx4)
412
+
413
+ #stage 5
414
+ hx5 = self.stage5(hx)
415
+ hx = self.pool56(hx5)
416
+
417
+ #stage 6
418
+ hx6 = self.stage6(hx)
419
+ hx6up = _upsample_like(hx6,hx5)
420
+
421
+ #-------------------- decoder --------------------
422
+ hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
423
+ hx5dup = _upsample_like(hx5d,hx4)
424
+
425
+ hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
426
+ hx4dup = _upsample_like(hx4d,hx3)
427
+
428
+ hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
429
+ hx3dup = _upsample_like(hx3d,hx2)
430
+
431
+ hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
432
+ hx2dup = _upsample_like(hx2d,hx1)
433
+
434
+ hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
435
+
436
+
437
+ #side output
438
+ d1 = self.side1(hx1d)
439
+ d1 = _upsample_like(d1,x)
440
+
441
+ d2 = self.side2(hx2d)
442
+ d2 = _upsample_like(d2,x)
443
+
444
+ d3 = self.side3(hx3d)
445
+ d3 = _upsample_like(d3,x)
446
+
447
+ d4 = self.side4(hx4d)
448
+ d4 = _upsample_like(d4,x)
449
+
450
+ d5 = self.side5(hx5d)
451
+ d5 = _upsample_like(d5,x)
452
+
453
+ d6 = self.side6(hx6)
454
+ d6 = _upsample_like(d6,x)
455
+
456
+ return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)],[hx1d,hx2d,hx3d,hx4d,hx5d,hx6]
457
+
example_inference.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from skimage import io
2
+ import torch, os
3
+ from PIL import Image
4
+ from briarmbg import BriaRMBG
5
+ from utilities import preprocess_image, postprocess_image
6
+ from huggingface_hub import hf_hub_download
7
+
8
+ def example_inference():
9
+
10
+ im_path = f"{os.path.dirname(os.path.abspath(__file__))}/example_input.jpg"
11
+
12
+ net = BriaRMBG()
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
15
+ net.to(device)
16
+ net.eval()
17
+
18
+ # prepare input
19
+ model_input_size = [1024,1024]
20
+ orig_im = io.imread(im_path)
21
+ orig_im_size = orig_im.shape[0:2]
22
+ image = preprocess_image(orig_im, model_input_size).to(device)
23
+
24
+ # inference
25
+ result=net(image)
26
+
27
+ # post process
28
+ result_image = postprocess_image(result[0][0], orig_im_size)
29
+
30
+ # save result
31
+ pil_im = Image.fromarray(result_image)
32
+ no_bg_image = Image.new("RGBA", pil_im.size, (0,0,0,0))
33
+ orig_image = Image.open(im_path)
34
+ no_bg_image.paste(orig_image, mask=pil_im)
35
+ no_bg_image.save("example_image_no_bg.png")
36
+
37
+
38
+ if __name__ == "__main__":
39
+ example_inference()
example_input.jpg ADDED
gitattributes ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ example.png filter=lfs diff=lfs merge=lfs -text
37
+ results.png filter=lfs diff=lfs merge=lfs -text
38
+ Screenshot[[:space:]]2024-01-21[[:space:]]at[[:space:]]11.56.17.png filter=lfs diff=lfs merge=lfs -text
39
+ T1.png filter=lfs diff=lfs merge=lfs -text
40
+ T2.png filter=lfs diff=lfs merge=lfs -text
41
+ t4.png filter=lfs diff=lfs merge=lfs -text
model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:893c16c340b1ddafc93e78457a4d94190da9b7179149f8574284c83caebf5e8c
3
+ size 176718373
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:59569acdb281ac9fc9f78f9d33b6f9f17f68e25086b74f9025c35bb5f2848967
3
+ size 176574018
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ pillow
4
+ numpy
5
+ typing
6
+ scikit-image
7
+ huggingface_hub
results.png ADDED

Git LFS Details

  • SHA256: 2b7f08fc4c09db56b516186c0629f72523a5cbe328beaedda8b36349af4b04bc
  • Pointer size: 132 Bytes
  • Size of remote file: 1.25 MB
t4.png ADDED

Git LFS Details

  • SHA256: 43a9453f567d9bff7fe4481205575bbf302499379047ee6073247315452ba8fb
  • Pointer size: 132 Bytes
  • Size of remote file: 2.16 MB
utilities.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torchvision.transforms.functional import normalize
4
+ import numpy as np
5
+
6
+ def preprocess_image(im: np.ndarray, model_input_size: list) -> torch.Tensor:
7
+ if len(im.shape) < 3:
8
+ im = im[:, :, np.newaxis]
9
+ # orig_im_size=im.shape[0:2]
10
+ im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
11
+ im_tensor = F.interpolate(torch.unsqueeze(im_tensor,0), size=model_input_size, mode='bilinear').type(torch.uint8)
12
+ image = torch.divide(im_tensor,255.0)
13
+ image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
14
+ return image
15
+
16
+
17
+ def postprocess_image(result: torch.Tensor, im_size: list)-> np.ndarray:
18
+ result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear') ,0)
19
+ ma = torch.max(result)
20
+ mi = torch.min(result)
21
+ result = (result-mi)/(ma-mi)
22
+ im_array = (result*255).permute(1,2,0).cpu().data.numpy().astype(np.uint8)
23
+ im_array = np.squeeze(im_array)
24
+ return im_array
25
+