linxy97 not-lain commited on
Commit
694e47c
0 Parent(s):

Duplicate from not-lain/CustomCodeForRMBG

Browse files

Co-authored-by: L_Ai_n <not-lain@users.noreply.huggingface.co>

Files changed (6) hide show
  1. .gitattributes +35 -0
  2. MyConfig.py +14 -0
  3. MyPipe.py +76 -0
  4. README.md +39 -0
  5. briarmbg.py +459 -0
  6. config.json +25 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
MyConfig.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from transformers import PretrainedConfig
3
+ from typing import List
4
+
5
+ class RMBGConfig(PretrainedConfig):
6
+ model_type = "SegformerForSemanticSegmentation"
7
+ def __init__(
8
+ self,
9
+ in_ch=3,
10
+ out_ch=1,
11
+ **kwargs):
12
+ self.in_ch = in_ch
13
+ self.out_ch = out_ch
14
+ super().__init__(**kwargs)
MyPipe.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch, os
3
+ import torch.nn.functional as F
4
+ from torchvision.transforms.functional import normalize
5
+ import numpy as np
6
+ from transformers import Pipeline
7
+ from skimage import io
8
+ from PIL import Image
9
+
10
+ class RMBGPipe(Pipeline):
11
+ def __init__(self,**kwargs):
12
+ Pipeline.__init__(self,**kwargs)
13
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ self.model.to(self.device)
15
+ self.model.eval()
16
+
17
+ def _sanitize_parameters(self, **kwargs):
18
+ # parse parameters
19
+ preprocess_kwargs = {}
20
+ postprocess_kwargs = {}
21
+ if "model_input_size" in kwargs :
22
+ preprocess_kwargs["model_input_size"] = kwargs["model_input_size"]
23
+ if "out_name" in kwargs:
24
+ postprocess_kwargs["out_name"] = kwargs["out_name"]
25
+ return preprocess_kwargs, {}, postprocess_kwargs
26
+
27
+ def preprocess(self,im_path:str,model_input_size: list=[1024,1024]):
28
+ # preprocess the input
29
+ orig_im = io.imread(im_path)
30
+ orig_im_size = orig_im.shape[0:2]
31
+ image = self.preprocess_image(orig_im, model_input_size).to(self.device)
32
+ inputs = {
33
+ "image":image,
34
+ "orig_im_size":orig_im_size,
35
+ "im_path" : im_path
36
+ }
37
+ return inputs
38
+
39
+ def _forward(self,inputs):
40
+ result = self.model(inputs.pop("image"))
41
+ inputs["result"] = result
42
+ return inputs
43
+ def postprocess(self,inputs,out_name = ""):
44
+ result = inputs.pop("result")
45
+ orig_im_size = inputs.pop("orig_im_size")
46
+ im_path = inputs.pop("im_path")
47
+ result_image = self.postprocess_image(result[0][0], orig_im_size)
48
+ if out_name != "" :
49
+ # if out_name is specified we save the image using that name
50
+ pil_im = Image.fromarray(result_image)
51
+ no_bg_image = Image.new("RGBA", pil_im.size, (0,0,0,0))
52
+ orig_image = Image.open(im_path)
53
+ no_bg_image.paste(orig_image, mask=pil_im)
54
+ no_bg_image.save(out_name)
55
+ else :
56
+ return result_image
57
+
58
+ # utilities functions
59
+ def preprocess_image(self,im: np.ndarray, model_input_size: list=[1024,1024]) -> torch.Tensor:
60
+ # same as utilities.py with minor modification
61
+ if len(im.shape) < 3:
62
+ im = im[:, :, np.newaxis]
63
+ # orig_im_size=im.shape[0:2]
64
+ im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
65
+ im_tensor = F.interpolate(torch.unsqueeze(im_tensor,0), size=model_input_size, mode='bilinear').type(torch.uint8)
66
+ image = torch.divide(im_tensor,255.0)
67
+ image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
68
+ return image
69
+ def postprocess_image(self,result: torch.Tensor, im_size: list)-> np.ndarray:
70
+ result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear') ,0)
71
+ ma = torch.max(result)
72
+ mi = torch.min(result)
73
+ result = (result-mi)/(ma-mi)
74
+ im_array = (result*255).permute(1,2,0).cpu().data.numpy().astype(np.uint8)
75
+ im_array = np.squeeze(im_array)
76
+ return im_array
README.md ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ pipeline_tag: image-segmentation
4
+ ---
5
+
6
+ # How to use
7
+ either load the model
8
+ ```python
9
+ from transformers import AutoModelForImageSegmentation
10
+ model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-1.4",revision ="refs/pr/9",trust_remote_code=True)
11
+ ```
12
+ or load the pipeline
13
+ ```python
14
+ from transformers import pipeline
15
+
16
+ pipe = pipeline("image-segmentation", model="briaai/RMBG-1.4",revision ="refs/pr/9", trust_remote_code=True)
17
+
18
+ numpy_mask = pipe("img_path") # outputs numpy mask
19
+
20
+ pipe("image_path",out_name="myout.png") # applies mask and saves the extracted image as `myout.png`
21
+
22
+ ```
23
+
24
+
25
+
26
+
27
+ # parameters :
28
+ for the pipeline you can use the following parameters :
29
+ * `model_input_size` : default to [1024,1024]
30
+ * `out_name` : if specified it will use the numpy mask to extract the image and save it using the `out_name`
31
+ * `preprocess_image` : original method created by briaai
32
+ * `postprocess_image` : original method created by briaai
33
+
34
+ # disclamer
35
+ I do not own, distribute or take credit for this model.
36
+
37
+ All rights belong to [briaai](https://huggingface.co/briaai/)
38
+
39
+ This repo is a temporary one to test out the custom architecture for [RMBG-1.4](https://huggingface.co/briaai/RMBG-1.4), please do refer to the original model.
briarmbg.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from transformers import PreTrainedModel
6
+ from .MyConfig import RMBGConfig
7
+
8
+ class REBNCONV(nn.Module):
9
+ def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
10
+ super(REBNCONV,self).__init__()
11
+
12
+ self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate,stride=stride)
13
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
14
+ self.relu_s1 = nn.ReLU(inplace=True)
15
+
16
+ def forward(self,x):
17
+
18
+ hx = x
19
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
20
+
21
+ return xout
22
+
23
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
24
+ def _upsample_like(src,tar):
25
+
26
+ src = F.interpolate(src,size=tar.shape[2:],mode='bilinear')
27
+
28
+ return src
29
+
30
+
31
+ ### RSU-7 ###
32
+ class RSU7(nn.Module):
33
+
34
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
35
+ super(RSU7,self).__init__()
36
+
37
+ self.in_ch = in_ch
38
+ self.mid_ch = mid_ch
39
+ self.out_ch = out_ch
40
+
41
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) ## 1 -> 1/2
42
+
43
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
44
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
45
+
46
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
47
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
48
+
49
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
50
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
51
+
52
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
53
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
54
+
55
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
56
+ self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
57
+
58
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
59
+
60
+ self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
61
+
62
+ self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
63
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
64
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
65
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
66
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
67
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
68
+
69
+ def forward(self,x):
70
+ b, c, h, w = x.shape
71
+
72
+ hx = x
73
+ hxin = self.rebnconvin(hx)
74
+
75
+ hx1 = self.rebnconv1(hxin)
76
+ hx = self.pool1(hx1)
77
+
78
+ hx2 = self.rebnconv2(hx)
79
+ hx = self.pool2(hx2)
80
+
81
+ hx3 = self.rebnconv3(hx)
82
+ hx = self.pool3(hx3)
83
+
84
+ hx4 = self.rebnconv4(hx)
85
+ hx = self.pool4(hx4)
86
+
87
+ hx5 = self.rebnconv5(hx)
88
+ hx = self.pool5(hx5)
89
+
90
+ hx6 = self.rebnconv6(hx)
91
+
92
+ hx7 = self.rebnconv7(hx6)
93
+
94
+ hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
95
+ hx6dup = _upsample_like(hx6d,hx5)
96
+
97
+ hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
98
+ hx5dup = _upsample_like(hx5d,hx4)
99
+
100
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
101
+ hx4dup = _upsample_like(hx4d,hx3)
102
+
103
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
104
+ hx3dup = _upsample_like(hx3d,hx2)
105
+
106
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
107
+ hx2dup = _upsample_like(hx2d,hx1)
108
+
109
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
110
+
111
+ return hx1d + hxin
112
+
113
+
114
+ ### RSU-6 ###
115
+ class RSU6(nn.Module):
116
+
117
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
118
+ super(RSU6,self).__init__()
119
+
120
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
121
+
122
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
123
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
124
+
125
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
126
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
127
+
128
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
129
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
130
+
131
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
132
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
133
+
134
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
135
+
136
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)
137
+
138
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
139
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
140
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
141
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
142
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
143
+
144
+ def forward(self,x):
145
+
146
+ hx = x
147
+
148
+ hxin = self.rebnconvin(hx)
149
+
150
+ hx1 = self.rebnconv1(hxin)
151
+ hx = self.pool1(hx1)
152
+
153
+ hx2 = self.rebnconv2(hx)
154
+ hx = self.pool2(hx2)
155
+
156
+ hx3 = self.rebnconv3(hx)
157
+ hx = self.pool3(hx3)
158
+
159
+ hx4 = self.rebnconv4(hx)
160
+ hx = self.pool4(hx4)
161
+
162
+ hx5 = self.rebnconv5(hx)
163
+
164
+ hx6 = self.rebnconv6(hx5)
165
+
166
+
167
+ hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
168
+ hx5dup = _upsample_like(hx5d,hx4)
169
+
170
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
171
+ hx4dup = _upsample_like(hx4d,hx3)
172
+
173
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
174
+ hx3dup = _upsample_like(hx3d,hx2)
175
+
176
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
177
+ hx2dup = _upsample_like(hx2d,hx1)
178
+
179
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
180
+
181
+ return hx1d + hxin
182
+
183
+ ### RSU-5 ###
184
+ class RSU5(nn.Module):
185
+
186
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
187
+ super(RSU5,self).__init__()
188
+
189
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
190
+
191
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
192
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
193
+
194
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
195
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
196
+
197
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
198
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
199
+
200
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
201
+
202
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)
203
+
204
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
205
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
206
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
207
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
208
+
209
+ def forward(self,x):
210
+
211
+ hx = x
212
+
213
+ hxin = self.rebnconvin(hx)
214
+
215
+ hx1 = self.rebnconv1(hxin)
216
+ hx = self.pool1(hx1)
217
+
218
+ hx2 = self.rebnconv2(hx)
219
+ hx = self.pool2(hx2)
220
+
221
+ hx3 = self.rebnconv3(hx)
222
+ hx = self.pool3(hx3)
223
+
224
+ hx4 = self.rebnconv4(hx)
225
+
226
+ hx5 = self.rebnconv5(hx4)
227
+
228
+ hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
229
+ hx4dup = _upsample_like(hx4d,hx3)
230
+
231
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
232
+ hx3dup = _upsample_like(hx3d,hx2)
233
+
234
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
235
+ hx2dup = _upsample_like(hx2d,hx1)
236
+
237
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
238
+
239
+ return hx1d + hxin
240
+
241
+ ### RSU-4 ###
242
+ class RSU4(nn.Module):
243
+
244
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
245
+ super(RSU4,self).__init__()
246
+
247
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
248
+
249
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
250
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
251
+
252
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
253
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
254
+
255
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
256
+
257
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)
258
+
259
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
260
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
261
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
262
+
263
+ def forward(self,x):
264
+
265
+ hx = x
266
+
267
+ hxin = self.rebnconvin(hx)
268
+
269
+ hx1 = self.rebnconv1(hxin)
270
+ hx = self.pool1(hx1)
271
+
272
+ hx2 = self.rebnconv2(hx)
273
+ hx = self.pool2(hx2)
274
+
275
+ hx3 = self.rebnconv3(hx)
276
+
277
+ hx4 = self.rebnconv4(hx3)
278
+
279
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
280
+ hx3dup = _upsample_like(hx3d,hx2)
281
+
282
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
283
+ hx2dup = _upsample_like(hx2d,hx1)
284
+
285
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
286
+
287
+ return hx1d + hxin
288
+
289
+ ### RSU-4F ###
290
+ class RSU4F(nn.Module):
291
+
292
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
293
+ super(RSU4F,self).__init__()
294
+
295
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
296
+
297
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
298
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
299
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
300
+
301
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)
302
+
303
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
304
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
305
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
306
+
307
+ def forward(self,x):
308
+
309
+ hx = x
310
+
311
+ hxin = self.rebnconvin(hx)
312
+
313
+ hx1 = self.rebnconv1(hxin)
314
+ hx2 = self.rebnconv2(hx1)
315
+ hx3 = self.rebnconv3(hx2)
316
+
317
+ hx4 = self.rebnconv4(hx3)
318
+
319
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
320
+ hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
321
+ hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
322
+
323
+ return hx1d + hxin
324
+
325
+
326
+ class myrebnconv(nn.Module):
327
+ def __init__(self, in_ch=3,
328
+ out_ch=1,
329
+ kernel_size=3,
330
+ stride=1,
331
+ padding=1,
332
+ dilation=1,
333
+ groups=1):
334
+ super(myrebnconv,self).__init__()
335
+
336
+ self.conv = nn.Conv2d(in_ch,
337
+ out_ch,
338
+ kernel_size=kernel_size,
339
+ stride=stride,
340
+ padding=padding,
341
+ dilation=dilation,
342
+ groups=groups)
343
+ self.bn = nn.BatchNorm2d(out_ch)
344
+ self.rl = nn.ReLU(inplace=True)
345
+
346
+ def forward(self,x):
347
+ return self.rl(self.bn(self.conv(x)))
348
+
349
+
350
+ class BriaRMBG(PreTrainedModel):
351
+ config_class = RMBGConfig
352
+ def __init__(self,config):
353
+ super().__init__(config)
354
+ in_ch = config.in_ch # 3
355
+ out_ch = config.out_ch # 1
356
+ self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1)
357
+ self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True)
358
+
359
+ self.stage1 = RSU7(64,32,64)
360
+ self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
361
+
362
+ self.stage2 = RSU6(64,32,128)
363
+ self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
364
+
365
+ self.stage3 = RSU5(128,64,256)
366
+ self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
367
+
368
+ self.stage4 = RSU4(256,128,512)
369
+ self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
370
+
371
+ self.stage5 = RSU4F(512,256,512)
372
+ self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
373
+
374
+ self.stage6 = RSU4F(512,256,512)
375
+
376
+ # decoder
377
+ self.stage5d = RSU4F(1024,256,512)
378
+ self.stage4d = RSU4(1024,128,256)
379
+ self.stage3d = RSU5(512,64,128)
380
+ self.stage2d = RSU6(256,32,64)
381
+ self.stage1d = RSU7(128,16,64)
382
+
383
+ self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
384
+ self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
385
+ self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
386
+ self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
387
+ self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
388
+ self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
389
+
390
+ # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
391
+
392
+ def forward(self,x):
393
+
394
+ hx = x
395
+
396
+ hxin = self.conv_in(hx)
397
+ #hx = self.pool_in(hxin)
398
+
399
+ #stage 1
400
+ hx1 = self.stage1(hxin)
401
+ hx = self.pool12(hx1)
402
+
403
+ #stage 2
404
+ hx2 = self.stage2(hx)
405
+ hx = self.pool23(hx2)
406
+
407
+ #stage 3
408
+ hx3 = self.stage3(hx)
409
+ hx = self.pool34(hx3)
410
+
411
+ #stage 4
412
+ hx4 = self.stage4(hx)
413
+ hx = self.pool45(hx4)
414
+
415
+ #stage 5
416
+ hx5 = self.stage5(hx)
417
+ hx = self.pool56(hx5)
418
+
419
+ #stage 6
420
+ hx6 = self.stage6(hx)
421
+ hx6up = _upsample_like(hx6,hx5)
422
+
423
+ #-------------------- decoder --------------------
424
+ hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
425
+ hx5dup = _upsample_like(hx5d,hx4)
426
+
427
+ hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
428
+ hx4dup = _upsample_like(hx4d,hx3)
429
+
430
+ hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
431
+ hx3dup = _upsample_like(hx3d,hx2)
432
+
433
+ hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
434
+ hx2dup = _upsample_like(hx2d,hx1)
435
+
436
+ hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
437
+
438
+
439
+ #side output
440
+ d1 = self.side1(hx1d)
441
+ d1 = _upsample_like(d1,x)
442
+
443
+ d2 = self.side2(hx2d)
444
+ d2 = _upsample_like(d2,x)
445
+
446
+ d3 = self.side3(hx3d)
447
+ d3 = _upsample_like(d3,x)
448
+
449
+ d4 = self.side4(hx4d)
450
+ d4 = _upsample_like(d4,x)
451
+
452
+ d5 = self.side5(hx5d)
453
+ d5 = _upsample_like(d5,x)
454
+
455
+ d6 = self.side6(hx6)
456
+ d6 = _upsample_like(d6,x)
457
+
458
+ return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)],[hx1d,hx2d,hx3d,hx4d,hx5d,hx6]
459
+
config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "not-lain/CustomCodeForRMBG",
3
+ "architectures": [
4
+ "BriaRMBG"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "not-lain/CustomCodeForRMBG--MyConfig.RMBGConfig",
8
+ "AutoModelForImageSegmentation": "not-lain/CustomCodeForRMBG--briarmbg.BriaRMBG"
9
+ },
10
+ "custom_pipelines": {
11
+ "image-segmentation": {
12
+ "impl": "MyPipe.RMBGPipe",
13
+ "pt": [
14
+ "AutoModelForImageSegmentation"
15
+ ],
16
+ "tf": [],
17
+ "type": "image"
18
+ }
19
+ },
20
+ "in_ch": 3,
21
+ "model_type": "SegformerForSemanticSegmentation",
22
+ "out_ch": 1,
23
+ "torch_dtype": "float32",
24
+ "transformers_version": "4.38.0.dev0"
25
+ }