schirrmacher commited on
Commit
02c46b0
1 Parent(s): fa28bd4

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -44,3 +44,6 @@ examples/image/example02.jpeg filter=lfs diff=lfs merge=lfs -text
44
  examples/image/example03.jpeg filter=lfs diff=lfs merge=lfs -text
45
  examples/image/image01.png filter=lfs diff=lfs merge=lfs -text
46
  examples/image/image01_no_background.png filter=lfs diff=lfs merge=lfs -text
 
 
 
 
44
  examples/image/example03.jpeg filter=lfs diff=lfs merge=lfs -text
45
  examples/image/image01.png filter=lfs diff=lfs merge=lfs -text
46
  examples/image/image01_no_background.png filter=lfs diff=lfs merge=lfs -text
47
+ hf_space/example01.jpeg filter=lfs diff=lfs merge=lfs -text
48
+ hf_space/example02.jpeg filter=lfs diff=lfs merge=lfs -text
49
+ hf_space/example03.jpeg filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -17,7 +17,7 @@ colorFrom: red
17
  colorTo: red
18
  sdk: gradio
19
  sdk_version: 4.29.0
20
- app_file: hf_space.py
21
  pinned: false
22
  ---
23
 
 
17
  colorTo: red
18
  sdk: gradio
19
  sdk_version: 4.29.0
20
+ app_file: hf_space/app.py
21
  pinned: false
22
  ---
23
 
hf_space/app.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import gradio as gr
6
+ from ormbg import ORMBG
7
+ from PIL import Image
8
+
9
+ model_path = "models/ormbg.pth"
10
+
11
+ # Load the model globally but don't send to device yet
12
+ net = ORMBG()
13
+ net.load_state_dict(torch.load(model_path, map_location="cpu"))
14
+ net.eval()
15
+
16
+
17
+ def resize_image(image):
18
+ image = image.convert("RGB")
19
+ model_input_size = (1024, 1024)
20
+ image = image.resize(model_input_size, Image.BILINEAR)
21
+ return image
22
+
23
+
24
+ @spaces.GPU
25
+ @torch.inference_mode()
26
+ def inference(image):
27
+ # Check for CUDA and set the device inside inference
28
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+ net.to(device)
30
+
31
+ # Prepare input
32
+ orig_image = Image.fromarray(image)
33
+ w, h = orig_image.size
34
+ image = resize_image(orig_image)
35
+ im_np = np.array(image)
36
+ im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
37
+ im_tensor = torch.unsqueeze(im_tensor, 0)
38
+ im_tensor = torch.divide(im_tensor, 255.0)
39
+
40
+ if torch.cuda.is_available():
41
+ im_tensor = im_tensor.to(device)
42
+
43
+ # Inference
44
+ result = net(im_tensor)
45
+ # Post process
46
+ result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode="bilinear"), 0)
47
+ ma = torch.max(result)
48
+ mi = torch.min(result)
49
+ result = (result - mi) / (ma - mi)
50
+ # Image to PIL
51
+ im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
52
+ pil_im = Image.fromarray(np.squeeze(im_array))
53
+ # Paste the mask on the original image
54
+ new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
55
+ new_im.paste(orig_image, mask=pil_im)
56
+
57
+ return new_im
58
+
59
+
60
+ # Gradio interface setup
61
+ title = "Open Remove Background Model (ormbg)"
62
+ description = r"""
63
+ This model is a <strong>fully open-source background remover</strong> optimized for images with humans. It is based on [Highly Accurate Dichotomous Image Segmentation research](https://github.com/xuebinqin/DIS). The model was trained with the synthetic <a href="https://huggingface.co/datasets/schirrmacher/humans">Human Segmentation Dataset</a>, <a href="https://paperswithcode.com/dataset/p3m-10k">P3M-10k</a> and <a href="https://paperswithcode.com/dataset/aim-500">AIM-500</a>.
64
+
65
+ If you identify cases where the model fails, <a href='https://huggingface.co/schirrmacher/ormbg/discussions' target='_blank'>upload your examples</a>!
66
+
67
+ - <a href='https://huggingface.co/schirrmacher/ormbg' target='_blank'>Model card</a>: find inference code, training information, tutorials
68
+ - <a href='https://huggingface.co/schirrmacher/ormbg' target='_blank'>Dataset</a>: see training images, segmentation data, backgrounds
69
+ - <a href='https://huggingface.co/schirrmacher/ormbg\#research' target='_blank'>Research</a>: see current approach for improvements
70
+ """
71
+
72
+ examples = [
73
+ "./examples/image/example1.jpeg",
74
+ "./examples/image/example2.jpeg",
75
+ "./examples/image/example3.jpeg",
76
+ ]
77
+
78
+ demo = gr.Interface(
79
+ fn=inference,
80
+ inputs="image",
81
+ outputs="image",
82
+ examples=examples,
83
+ title=title,
84
+ description=description,
85
+ )
86
+
87
+ if __name__ == "__main__":
88
+ demo.launch(share=False, root_path="hf_space", allowed_paths=["hf_space"])
hf_space/example01.jpeg ADDED

Git LFS Details

  • SHA256: 436f546cc1d7b2fd7021180299b028c0d379e48a9e9f05214a694b9c4eb8a7e3
  • Pointer size: 132 Bytes
  • Size of remote file: 7.63 MB
hf_space/example02.jpeg ADDED

Git LFS Details

  • SHA256: 1dad92b56723fd8ac1c3832844873ad297300d0e85f6e14764334687a70c8abc
  • Pointer size: 132 Bytes
  • Size of remote file: 4.32 MB
hf_space/example03.jpeg ADDED

Git LFS Details

  • SHA256: f392dc4716469f5367ce0e2ac788f284d1b8d70c39be109db7038c3306a1da16
  • Pointer size: 132 Bytes
  • Size of remote file: 1.09 MB
hf_space/ormbg.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ # https://github.com/xuebinqin/DIS/blob/main/IS-Net/models/isnet.py
6
+
7
+
8
+ class REBNCONV(nn.Module):
9
+ def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
10
+ super(REBNCONV, self).__init__()
11
+
12
+ self.conv_s1 = nn.Conv2d(
13
+ in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride
14
+ )
15
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
16
+ self.relu_s1 = nn.ReLU(inplace=True)
17
+
18
+ def forward(self, x):
19
+
20
+ hx = x
21
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
22
+
23
+ return xout
24
+
25
+
26
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
27
+ def _upsample_like(src, tar):
28
+
29
+ src = F.interpolate(src, size=tar.shape[2:], mode="bilinear")
30
+
31
+ return src
32
+
33
+
34
+ ### RSU-7 ###
35
+ class RSU7(nn.Module):
36
+
37
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
38
+ super(RSU7, self).__init__()
39
+
40
+ self.in_ch = in_ch
41
+ self.mid_ch = mid_ch
42
+ self.out_ch = out_ch
43
+
44
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
45
+
46
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
47
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
48
+
49
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
50
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
51
+
52
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
53
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
54
+
55
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
56
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
57
+
58
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
59
+ self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
60
+
61
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
62
+
63
+ self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
64
+
65
+ self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
66
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
67
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
68
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
69
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
70
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
71
+
72
+ def forward(self, x):
73
+ b, c, h, w = x.shape
74
+
75
+ hx = x
76
+ hxin = self.rebnconvin(hx)
77
+
78
+ hx1 = self.rebnconv1(hxin)
79
+ hx = self.pool1(hx1)
80
+
81
+ hx2 = self.rebnconv2(hx)
82
+ hx = self.pool2(hx2)
83
+
84
+ hx3 = self.rebnconv3(hx)
85
+ hx = self.pool3(hx3)
86
+
87
+ hx4 = self.rebnconv4(hx)
88
+ hx = self.pool4(hx4)
89
+
90
+ hx5 = self.rebnconv5(hx)
91
+ hx = self.pool5(hx5)
92
+
93
+ hx6 = self.rebnconv6(hx)
94
+
95
+ hx7 = self.rebnconv7(hx6)
96
+
97
+ hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
98
+ hx6dup = _upsample_like(hx6d, hx5)
99
+
100
+ hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
101
+ hx5dup = _upsample_like(hx5d, hx4)
102
+
103
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
104
+ hx4dup = _upsample_like(hx4d, hx3)
105
+
106
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
107
+ hx3dup = _upsample_like(hx3d, hx2)
108
+
109
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
110
+ hx2dup = _upsample_like(hx2d, hx1)
111
+
112
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
113
+
114
+ return hx1d + hxin
115
+
116
+
117
+ ### RSU-6 ###
118
+ class RSU6(nn.Module):
119
+
120
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
121
+ super(RSU6, self).__init__()
122
+
123
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
124
+
125
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
126
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
127
+
128
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
129
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
130
+
131
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
132
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
133
+
134
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
135
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
136
+
137
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
138
+
139
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
140
+
141
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
142
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
143
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
144
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
145
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
146
+
147
+ def forward(self, x):
148
+
149
+ hx = x
150
+
151
+ hxin = self.rebnconvin(hx)
152
+
153
+ hx1 = self.rebnconv1(hxin)
154
+ hx = self.pool1(hx1)
155
+
156
+ hx2 = self.rebnconv2(hx)
157
+ hx = self.pool2(hx2)
158
+
159
+ hx3 = self.rebnconv3(hx)
160
+ hx = self.pool3(hx3)
161
+
162
+ hx4 = self.rebnconv4(hx)
163
+ hx = self.pool4(hx4)
164
+
165
+ hx5 = self.rebnconv5(hx)
166
+
167
+ hx6 = self.rebnconv6(hx5)
168
+
169
+ hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
170
+ hx5dup = _upsample_like(hx5d, hx4)
171
+
172
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
173
+ hx4dup = _upsample_like(hx4d, hx3)
174
+
175
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
176
+ hx3dup = _upsample_like(hx3d, hx2)
177
+
178
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
179
+ hx2dup = _upsample_like(hx2d, hx1)
180
+
181
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
182
+
183
+ return hx1d + hxin
184
+
185
+
186
+ ### RSU-5 ###
187
+ class RSU5(nn.Module):
188
+
189
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
190
+ super(RSU5, self).__init__()
191
+
192
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
193
+
194
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
195
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
196
+
197
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
198
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
199
+
200
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
201
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
202
+
203
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
204
+
205
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
206
+
207
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
208
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
209
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
210
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
211
+
212
+ def forward(self, x):
213
+
214
+ hx = x
215
+
216
+ hxin = self.rebnconvin(hx)
217
+
218
+ hx1 = self.rebnconv1(hxin)
219
+ hx = self.pool1(hx1)
220
+
221
+ hx2 = self.rebnconv2(hx)
222
+ hx = self.pool2(hx2)
223
+
224
+ hx3 = self.rebnconv3(hx)
225
+ hx = self.pool3(hx3)
226
+
227
+ hx4 = self.rebnconv4(hx)
228
+
229
+ hx5 = self.rebnconv5(hx4)
230
+
231
+ hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
232
+ hx4dup = _upsample_like(hx4d, hx3)
233
+
234
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
235
+ hx3dup = _upsample_like(hx3d, hx2)
236
+
237
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
238
+ hx2dup = _upsample_like(hx2d, hx1)
239
+
240
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
241
+
242
+ return hx1d + hxin
243
+
244
+
245
+ ### RSU-4 ###
246
+ class RSU4(nn.Module):
247
+
248
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
249
+ super(RSU4, self).__init__()
250
+
251
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
252
+
253
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
254
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
255
+
256
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
257
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
258
+
259
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
260
+
261
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
262
+
263
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
264
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
265
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
266
+
267
+ def forward(self, x):
268
+
269
+ hx = x
270
+
271
+ hxin = self.rebnconvin(hx)
272
+
273
+ hx1 = self.rebnconv1(hxin)
274
+ hx = self.pool1(hx1)
275
+
276
+ hx2 = self.rebnconv2(hx)
277
+ hx = self.pool2(hx2)
278
+
279
+ hx3 = self.rebnconv3(hx)
280
+
281
+ hx4 = self.rebnconv4(hx3)
282
+
283
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
284
+ hx3dup = _upsample_like(hx3d, hx2)
285
+
286
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
287
+ hx2dup = _upsample_like(hx2d, hx1)
288
+
289
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
290
+
291
+ return hx1d + hxin
292
+
293
+
294
+ ### RSU-4F ###
295
+ class RSU4F(nn.Module):
296
+
297
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
298
+ super(RSU4F, self).__init__()
299
+
300
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
301
+
302
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
303
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
304
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
305
+
306
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
307
+
308
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
309
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
310
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
311
+
312
+ def forward(self, x):
313
+
314
+ hx = x
315
+
316
+ hxin = self.rebnconvin(hx)
317
+
318
+ hx1 = self.rebnconv1(hxin)
319
+ hx2 = self.rebnconv2(hx1)
320
+ hx3 = self.rebnconv3(hx2)
321
+
322
+ hx4 = self.rebnconv4(hx3)
323
+
324
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
325
+ hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
326
+ hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
327
+
328
+ return hx1d + hxin
329
+
330
+
331
+ class myrebnconv(nn.Module):
332
+ def __init__(
333
+ self,
334
+ in_ch=3,
335
+ out_ch=1,
336
+ kernel_size=3,
337
+ stride=1,
338
+ padding=1,
339
+ dilation=1,
340
+ groups=1,
341
+ ):
342
+ super(myrebnconv, self).__init__()
343
+
344
+ self.conv = nn.Conv2d(
345
+ in_ch,
346
+ out_ch,
347
+ kernel_size=kernel_size,
348
+ stride=stride,
349
+ padding=padding,
350
+ dilation=dilation,
351
+ groups=groups,
352
+ )
353
+ self.bn = nn.BatchNorm2d(out_ch)
354
+ self.rl = nn.ReLU(inplace=True)
355
+
356
+ def forward(self, x):
357
+ return self.rl(self.bn(self.conv(x)))
358
+
359
+
360
+ bce_loss = nn.BCELoss(size_average=True)
361
+
362
+
363
+ class ORMBG(nn.Module):
364
+
365
+ def __init__(self, in_ch=3, out_ch=1):
366
+ super(ORMBG, self).__init__()
367
+
368
+ self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
369
+ self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
370
+
371
+ self.stage1 = RSU7(64, 32, 64)
372
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
373
+
374
+ self.stage2 = RSU6(64, 32, 128)
375
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
376
+
377
+ self.stage3 = RSU5(128, 64, 256)
378
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
379
+
380
+ self.stage4 = RSU4(256, 128, 512)
381
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
382
+
383
+ self.stage5 = RSU4F(512, 256, 512)
384
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
385
+
386
+ self.stage6 = RSU4F(512, 256, 512)
387
+
388
+ # decoder
389
+ self.stage5d = RSU4F(1024, 256, 512)
390
+ self.stage4d = RSU4(1024, 128, 256)
391
+ self.stage3d = RSU5(512, 64, 128)
392
+ self.stage2d = RSU6(256, 32, 64)
393
+ self.stage1d = RSU7(128, 16, 64)
394
+
395
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
396
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
397
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
398
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
399
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
400
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
401
+
402
+ # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
403
+
404
+ def compute_loss(self, predictions, ground_truth):
405
+ loss0, loss = 0.0, 0.0
406
+ for i in range(0, len(predictions)):
407
+ loss = loss + bce_loss(predictions[i], ground_truth)
408
+ if i == 0:
409
+ loss0 = loss
410
+ return loss0, loss
411
+
412
+ def forward(self, x):
413
+
414
+ hx = x
415
+
416
+ hxin = self.conv_in(hx)
417
+ # hx = self.pool_in(hxin)
418
+
419
+ # stage 1
420
+ hx1 = self.stage1(hxin)
421
+ hx = self.pool12(hx1)
422
+
423
+ # stage 2
424
+ hx2 = self.stage2(hx)
425
+ hx = self.pool23(hx2)
426
+
427
+ # stage 3
428
+ hx3 = self.stage3(hx)
429
+ hx = self.pool34(hx3)
430
+
431
+ # stage 4
432
+ hx4 = self.stage4(hx)
433
+ hx = self.pool45(hx4)
434
+
435
+ # stage 5
436
+ hx5 = self.stage5(hx)
437
+ hx = self.pool56(hx5)
438
+
439
+ # stage 6
440
+ hx6 = self.stage6(hx)
441
+ hx6up = _upsample_like(hx6, hx5)
442
+
443
+ # -------------------- decoder --------------------
444
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
445
+ hx5dup = _upsample_like(hx5d, hx4)
446
+
447
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
448
+ hx4dup = _upsample_like(hx4d, hx3)
449
+
450
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
451
+ hx3dup = _upsample_like(hx3d, hx2)
452
+
453
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
454
+ hx2dup = _upsample_like(hx2d, hx1)
455
+
456
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
457
+
458
+ # side output
459
+ d1 = self.side1(hx1d)
460
+ d1 = _upsample_like(d1, x)
461
+
462
+ d2 = self.side2(hx2d)
463
+ d2 = _upsample_like(d2, x)
464
+
465
+ d3 = self.side3(hx3d)
466
+ d3 = _upsample_like(d3, x)
467
+
468
+ d4 = self.side4(hx4d)
469
+ d4 = _upsample_like(d4, x)
470
+
471
+ d5 = self.side5(hx5d)
472
+ d5 = _upsample_like(d5, x)
473
+
474
+ d6 = self.side6(hx6)
475
+ d6 = _upsample_like(d6, x)
476
+
477
+ return [
478
+ F.sigmoid(d1),
479
+ F.sigmoid(d2),
480
+ F.sigmoid(d3),
481
+ F.sigmoid(d4),
482
+ F.sigmoid(d5),
483
+ F.sigmoid(d6),
484
+ ], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]