hylee commited on
Commit
de51c6d
1 Parent(s): ada0e20
U-2-Net/__pycache__/data_loader.cpython-38.pyc ADDED
Binary file (8.75 kB). View file
 
U-2-Net/data_loader.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # data loader
2
+ from __future__ import print_function, division
3
+ import glob
4
+ import torch
5
+ from skimage import io, transform, color
6
+ import numpy as np
7
+ import random
8
+ import math
9
+ import matplotlib.pyplot as plt
10
+ from torch.utils.data import Dataset, DataLoader
11
+ from torchvision import transforms, utils
12
+ from PIL import Image
13
+
14
+ #==========================dataset load==========================
15
+ class RescaleT(object):
16
+
17
+ def __init__(self,output_size):
18
+ assert isinstance(output_size,(int,tuple))
19
+ self.output_size = output_size
20
+
21
+ def __call__(self,sample):
22
+ imidx, image, label = sample['imidx'], sample['image'],sample['label']
23
+
24
+ h, w = image.shape[:2]
25
+
26
+ if isinstance(self.output_size,int):
27
+ if h > w:
28
+ new_h, new_w = self.output_size*h/w,self.output_size
29
+ else:
30
+ new_h, new_w = self.output_size,self.output_size*w/h
31
+ else:
32
+ new_h, new_w = self.output_size
33
+
34
+ new_h, new_w = int(new_h), int(new_w)
35
+
36
+ # #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
37
+ # img = transform.resize(image,(new_h,new_w),mode='constant')
38
+ # lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)
39
+
40
+ img = transform.resize(image,(self.output_size,self.output_size),mode='constant')
41
+ lbl = transform.resize(label,(self.output_size,self.output_size),mode='constant', order=0, preserve_range=True)
42
+
43
+ return {'imidx':imidx, 'image':img,'label':lbl}
44
+
45
+ class Rescale(object):
46
+
47
+ def __init__(self,output_size):
48
+ assert isinstance(output_size,(int,tuple))
49
+ self.output_size = output_size
50
+
51
+ def __call__(self,sample):
52
+ imidx, image, label = sample['imidx'], sample['image'],sample['label']
53
+
54
+ if random.random() >= 0.5:
55
+ image = image[::-1]
56
+ label = label[::-1]
57
+
58
+ h, w = image.shape[:2]
59
+
60
+ if isinstance(self.output_size,int):
61
+ if h > w:
62
+ new_h, new_w = self.output_size*h/w,self.output_size
63
+ else:
64
+ new_h, new_w = self.output_size,self.output_size*w/h
65
+ else:
66
+ new_h, new_w = self.output_size
67
+
68
+ new_h, new_w = int(new_h), int(new_w)
69
+
70
+ # #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
71
+ img = transform.resize(image,(new_h,new_w),mode='constant')
72
+ lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)
73
+
74
+ return {'imidx':imidx, 'image':img,'label':lbl}
75
+
76
+ class RandomCrop(object):
77
+
78
+ def __init__(self,output_size):
79
+ assert isinstance(output_size, (int, tuple))
80
+ if isinstance(output_size, int):
81
+ self.output_size = (output_size, output_size)
82
+ else:
83
+ assert len(output_size) == 2
84
+ self.output_size = output_size
85
+ def __call__(self,sample):
86
+ imidx, image, label = sample['imidx'], sample['image'], sample['label']
87
+
88
+ if random.random() >= 0.5:
89
+ image = image[::-1]
90
+ label = label[::-1]
91
+
92
+ h, w = image.shape[:2]
93
+ new_h, new_w = self.output_size
94
+
95
+ top = np.random.randint(0, h - new_h)
96
+ left = np.random.randint(0, w - new_w)
97
+
98
+ image = image[top: top + new_h, left: left + new_w]
99
+ label = label[top: top + new_h, left: left + new_w]
100
+
101
+ return {'imidx':imidx,'image':image, 'label':label}
102
+
103
+ class ToTensor(object):
104
+ """Convert ndarrays in sample to Tensors."""
105
+
106
+ def __call__(self, sample):
107
+
108
+ imidx, image, label = sample['imidx'], sample['image'], sample['label']
109
+
110
+ tmpImg = np.zeros((image.shape[0],image.shape[1],3))
111
+ tmpLbl = np.zeros(label.shape)
112
+
113
+ image = image/np.max(image)
114
+ if(np.max(label)<1e-6):
115
+ label = label
116
+ else:
117
+ label = label/np.max(label)
118
+
119
+ if image.shape[2]==1:
120
+ tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
121
+ tmpImg[:,:,1] = (image[:,:,0]-0.485)/0.229
122
+ tmpImg[:,:,2] = (image[:,:,0]-0.485)/0.229
123
+ else:
124
+ tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
125
+ tmpImg[:,:,1] = (image[:,:,1]-0.456)/0.224
126
+ tmpImg[:,:,2] = (image[:,:,2]-0.406)/0.225
127
+
128
+ tmpLbl[:,:,0] = label[:,:,0]
129
+
130
+
131
+ tmpImg = tmpImg.transpose((2, 0, 1))
132
+ tmpLbl = label.transpose((2, 0, 1))
133
+
134
+ return {'imidx':torch.from_numpy(imidx), 'image': torch.from_numpy(tmpImg), 'label': torch.from_numpy(tmpLbl)}
135
+
136
+ class ToTensorLab(object):
137
+ """Convert ndarrays in sample to Tensors."""
138
+ def __init__(self,flag=0):
139
+ self.flag = flag
140
+
141
+ def __call__(self, sample):
142
+
143
+ imidx, image, label =sample['imidx'], sample['image'], sample['label']
144
+
145
+ tmpLbl = np.zeros(label.shape)
146
+
147
+ if(np.max(label)<1e-6):
148
+ label = label
149
+ else:
150
+ label = label/np.max(label)
151
+
152
+ # change the color space
153
+ if self.flag == 2: # with rgb and Lab colors
154
+ tmpImg = np.zeros((image.shape[0],image.shape[1],6))
155
+ tmpImgt = np.zeros((image.shape[0],image.shape[1],3))
156
+ if image.shape[2]==1:
157
+ tmpImgt[:,:,0] = image[:,:,0]
158
+ tmpImgt[:,:,1] = image[:,:,0]
159
+ tmpImgt[:,:,2] = image[:,:,0]
160
+ else:
161
+ tmpImgt = image
162
+ tmpImgtl = color.rgb2lab(tmpImgt)
163
+
164
+ # nomalize image to range [0,1]
165
+ tmpImg[:,:,0] = (tmpImgt[:,:,0]-np.min(tmpImgt[:,:,0]))/(np.max(tmpImgt[:,:,0])-np.min(tmpImgt[:,:,0]))
166
+ tmpImg[:,:,1] = (tmpImgt[:,:,1]-np.min(tmpImgt[:,:,1]))/(np.max(tmpImgt[:,:,1])-np.min(tmpImgt[:,:,1]))
167
+ tmpImg[:,:,2] = (tmpImgt[:,:,2]-np.min(tmpImgt[:,:,2]))/(np.max(tmpImgt[:,:,2])-np.min(tmpImgt[:,:,2]))
168
+ tmpImg[:,:,3] = (tmpImgtl[:,:,0]-np.min(tmpImgtl[:,:,0]))/(np.max(tmpImgtl[:,:,0])-np.min(tmpImgtl[:,:,0]))
169
+ tmpImg[:,:,4] = (tmpImgtl[:,:,1]-np.min(tmpImgtl[:,:,1]))/(np.max(tmpImgtl[:,:,1])-np.min(tmpImgtl[:,:,1]))
170
+ tmpImg[:,:,5] = (tmpImgtl[:,:,2]-np.min(tmpImgtl[:,:,2]))/(np.max(tmpImgtl[:,:,2])-np.min(tmpImgtl[:,:,2]))
171
+
172
+ # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
173
+
174
+ tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0])
175
+ tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1])
176
+ tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2])
177
+ tmpImg[:,:,3] = (tmpImg[:,:,3]-np.mean(tmpImg[:,:,3]))/np.std(tmpImg[:,:,3])
178
+ tmpImg[:,:,4] = (tmpImg[:,:,4]-np.mean(tmpImg[:,:,4]))/np.std(tmpImg[:,:,4])
179
+ tmpImg[:,:,5] = (tmpImg[:,:,5]-np.mean(tmpImg[:,:,5]))/np.std(tmpImg[:,:,5])
180
+
181
+ elif self.flag == 1: #with Lab color
182
+ tmpImg = np.zeros((image.shape[0],image.shape[1],3))
183
+
184
+ if image.shape[2]==1:
185
+ tmpImg[:,:,0] = image[:,:,0]
186
+ tmpImg[:,:,1] = image[:,:,0]
187
+ tmpImg[:,:,2] = image[:,:,0]
188
+ else:
189
+ tmpImg = image
190
+
191
+ tmpImg = color.rgb2lab(tmpImg)
192
+
193
+ # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
194
+
195
+ tmpImg[:,:,0] = (tmpImg[:,:,0]-np.min(tmpImg[:,:,0]))/(np.max(tmpImg[:,:,0])-np.min(tmpImg[:,:,0]))
196
+ tmpImg[:,:,1] = (tmpImg[:,:,1]-np.min(tmpImg[:,:,1]))/(np.max(tmpImg[:,:,1])-np.min(tmpImg[:,:,1]))
197
+ tmpImg[:,:,2] = (tmpImg[:,:,2]-np.min(tmpImg[:,:,2]))/(np.max(tmpImg[:,:,2])-np.min(tmpImg[:,:,2]))
198
+
199
+ tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0])
200
+ tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1])
201
+ tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2])
202
+
203
+ else: # with rgb color
204
+ tmpImg = np.zeros((image.shape[0],image.shape[1],3))
205
+ image = image/np.max(image)
206
+ if image.shape[2]==1:
207
+ tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
208
+ tmpImg[:,:,1] = (image[:,:,0]-0.485)/0.229
209
+ tmpImg[:,:,2] = (image[:,:,0]-0.485)/0.229
210
+ else:
211
+ tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
212
+ tmpImg[:,:,1] = (image[:,:,1]-0.456)/0.224
213
+ tmpImg[:,:,2] = (image[:,:,2]-0.406)/0.225
214
+
215
+ tmpLbl[:,:,0] = label[:,:,0]
216
+
217
+
218
+ tmpImg = tmpImg.transpose((2, 0, 1))
219
+ tmpLbl = label.transpose((2, 0, 1))
220
+
221
+ return {'imidx':torch.from_numpy(imidx), 'image': torch.from_numpy(tmpImg), 'label': torch.from_numpy(tmpLbl)}
222
+
223
+ class SalObjDataset(Dataset):
224
+ def __init__(self,img_name_list,lbl_name_list,transform=None):
225
+ # self.root_dir = root_dir
226
+ # self.image_name_list = glob.glob(image_dir+'*.png')
227
+ # self.label_name_list = glob.glob(label_dir+'*.png')
228
+ self.image_name_list = img_name_list
229
+ self.label_name_list = lbl_name_list
230
+ self.transform = transform
231
+
232
+ def __len__(self):
233
+ return len(self.image_name_list)
234
+
235
+ def __getitem__(self,idx):
236
+
237
+ # image = Image.open(self.image_name_list[idx])#io.imread(self.image_name_list[idx])
238
+ # label = Image.open(self.label_name_list[idx])#io.imread(self.label_name_list[idx])
239
+
240
+ image = io.imread(self.image_name_list[idx])
241
+ imname = self.image_name_list[idx]
242
+ imidx = np.array([idx])
243
+
244
+ if(0==len(self.label_name_list)):
245
+ label_3 = np.zeros(image.shape)
246
+ else:
247
+ label_3 = io.imread(self.label_name_list[idx])
248
+
249
+ label = np.zeros(label_3.shape[0:2])
250
+ if(3==len(label_3.shape)):
251
+ label = label_3[:,:,0]
252
+ elif(2==len(label_3.shape)):
253
+ label = label_3
254
+
255
+ if(3==len(image.shape) and 2==len(label.shape)):
256
+ label = label[:,:,np.newaxis]
257
+ elif(2==len(image.shape) and 2==len(label.shape)):
258
+ image = image[:,:,np.newaxis]
259
+ label = label[:,:,np.newaxis]
260
+
261
+ sample = {'imidx':imidx, 'image':image, 'label':label}
262
+
263
+ if self.transform:
264
+ sample = self.transform(sample)
265
+
266
+ return sample
U-2-Net/gradio/demo.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import paddlehub as hub
3
+ import gradio as gr
4
+ import torch
5
+
6
+ # Images
7
+ torch.hub.download_url_to_file('https://cdn.pixabay.com/photo/2018/08/12/16/59/ara-3601194_1280.jpg', 'parrot.jpg')
8
+ torch.hub.download_url_to_file('https://cdn.pixabay.com/photo/2016/10/21/14/46/fox-1758183_1280.jpg', 'fox.jpg')
9
+
10
+ model = hub.Module(name='U2Net')
11
+
12
+ def infer(img):
13
+ result = model.Segmentation(
14
+ images=[cv2.imread(img.name)],
15
+ paths=None,
16
+ batch_size=1,
17
+ input_size=320,
18
+ output_dir='output',
19
+ visualization=True)
20
+ return result[0]['front'][:,:,::-1], result[0]['mask']
21
+
22
+ inputs = gr.inputs.Image(type='file', label="Original Image")
23
+ outputs = [
24
+ gr.outputs.Image(type="numpy",label="Front"),
25
+ gr.outputs.Image(type="numpy",label="Mask")
26
+ ]
27
+
28
+ title = "U^2-Net"
29
+ description = "demo for U^2-Net. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
30
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2005.09007'>U^2-Net: Going Deeper with Nested U-Structure for Salient Object Detection</a> | <a href='https://github.com/xuebinqin/U-2-Net'>Github Repo</a></p>"
31
+
32
+ examples = [
33
+ ['fox.jpg'],
34
+ ['parrot.jpg']
35
+ ]
36
+
37
+ gr.Interface(infer, inputs, outputs, title=title, description=description, article=article, examples=examples).launch()
U-2-Net/model/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .u2net import U2NET
2
+ from .u2net import U2NETP
U-2-Net/model/__pycache__/__init__.cpython-36.pyc ADDED
Binary file (257 Bytes). View file
 
U-2-Net/model/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (190 Bytes). View file
 
U-2-Net/model/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (203 Bytes). View file
 
U-2-Net/model/__pycache__/u2net.cpython-36.pyc ADDED
Binary file (11.6 kB). View file
 
U-2-Net/model/__pycache__/u2net.cpython-37.pyc ADDED
Binary file (11.1 kB). View file
 
U-2-Net/model/__pycache__/u2net.cpython-38.pyc ADDED
Binary file (10.5 kB). View file
 
U-2-Net/model/u2net.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class REBNCONV(nn.Module):
6
+ def __init__(self,in_ch=3,out_ch=3,dirate=1):
7
+ super(REBNCONV,self).__init__()
8
+
9
+ self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate)
10
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
11
+ self.relu_s1 = nn.ReLU(inplace=True)
12
+
13
+ def forward(self,x):
14
+
15
+ hx = x
16
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
17
+
18
+ return xout
19
+
20
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
21
+ def _upsample_like(src,tar):
22
+
23
+ src = F.upsample(src,size=tar.shape[2:],mode='bilinear')
24
+
25
+ return src
26
+
27
+
28
+ ### RSU-7 ###
29
+ class RSU7(nn.Module):#UNet07DRES(nn.Module):
30
+
31
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
32
+ super(RSU7,self).__init__()
33
+
34
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
35
+
36
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
37
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
38
+
39
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
40
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
41
+
42
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
43
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
44
+
45
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
46
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
47
+
48
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
49
+ self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
50
+
51
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
52
+
53
+ self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
54
+
55
+ self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
56
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
57
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
58
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
59
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
60
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
61
+
62
+ def forward(self,x):
63
+
64
+ hx = x
65
+ hxin = self.rebnconvin(hx)
66
+
67
+ hx1 = self.rebnconv1(hxin)
68
+ hx = self.pool1(hx1)
69
+
70
+ hx2 = self.rebnconv2(hx)
71
+ hx = self.pool2(hx2)
72
+
73
+ hx3 = self.rebnconv3(hx)
74
+ hx = self.pool3(hx3)
75
+
76
+ hx4 = self.rebnconv4(hx)
77
+ hx = self.pool4(hx4)
78
+
79
+ hx5 = self.rebnconv5(hx)
80
+ hx = self.pool5(hx5)
81
+
82
+ hx6 = self.rebnconv6(hx)
83
+
84
+ hx7 = self.rebnconv7(hx6)
85
+
86
+ hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
87
+ hx6dup = _upsample_like(hx6d,hx5)
88
+
89
+ hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
90
+ hx5dup = _upsample_like(hx5d,hx4)
91
+
92
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
93
+ hx4dup = _upsample_like(hx4d,hx3)
94
+
95
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
96
+ hx3dup = _upsample_like(hx3d,hx2)
97
+
98
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
99
+ hx2dup = _upsample_like(hx2d,hx1)
100
+
101
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
102
+
103
+ return hx1d + hxin
104
+
105
+ ### RSU-6 ###
106
+ class RSU6(nn.Module):#UNet06DRES(nn.Module):
107
+
108
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
109
+ super(RSU6,self).__init__()
110
+
111
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
112
+
113
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
114
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
115
+
116
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
117
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
118
+
119
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
120
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
121
+
122
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
123
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
124
+
125
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
126
+
127
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)
128
+
129
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
130
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
131
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
132
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
133
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
134
+
135
+ def forward(self,x):
136
+
137
+ hx = x
138
+
139
+ hxin = self.rebnconvin(hx)
140
+
141
+ hx1 = self.rebnconv1(hxin)
142
+ hx = self.pool1(hx1)
143
+
144
+ hx2 = self.rebnconv2(hx)
145
+ hx = self.pool2(hx2)
146
+
147
+ hx3 = self.rebnconv3(hx)
148
+ hx = self.pool3(hx3)
149
+
150
+ hx4 = self.rebnconv4(hx)
151
+ hx = self.pool4(hx4)
152
+
153
+ hx5 = self.rebnconv5(hx)
154
+
155
+ hx6 = self.rebnconv6(hx5)
156
+
157
+
158
+ hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
159
+ hx5dup = _upsample_like(hx5d,hx4)
160
+
161
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
162
+ hx4dup = _upsample_like(hx4d,hx3)
163
+
164
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
165
+ hx3dup = _upsample_like(hx3d,hx2)
166
+
167
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
168
+ hx2dup = _upsample_like(hx2d,hx1)
169
+
170
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
171
+
172
+ return hx1d + hxin
173
+
174
+ ### RSU-5 ###
175
+ class RSU5(nn.Module):#UNet05DRES(nn.Module):
176
+
177
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
178
+ super(RSU5,self).__init__()
179
+
180
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
181
+
182
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
183
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
184
+
185
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
186
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
187
+
188
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
189
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
190
+
191
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
192
+
193
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)
194
+
195
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
196
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
197
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
198
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
199
+
200
+ def forward(self,x):
201
+
202
+ hx = x
203
+
204
+ hxin = self.rebnconvin(hx)
205
+
206
+ hx1 = self.rebnconv1(hxin)
207
+ hx = self.pool1(hx1)
208
+
209
+ hx2 = self.rebnconv2(hx)
210
+ hx = self.pool2(hx2)
211
+
212
+ hx3 = self.rebnconv3(hx)
213
+ hx = self.pool3(hx3)
214
+
215
+ hx4 = self.rebnconv4(hx)
216
+
217
+ hx5 = self.rebnconv5(hx4)
218
+
219
+ hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
220
+ hx4dup = _upsample_like(hx4d,hx3)
221
+
222
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
223
+ hx3dup = _upsample_like(hx3d,hx2)
224
+
225
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
226
+ hx2dup = _upsample_like(hx2d,hx1)
227
+
228
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
229
+
230
+ return hx1d + hxin
231
+
232
+ ### RSU-4 ###
233
+ class RSU4(nn.Module):#UNet04DRES(nn.Module):
234
+
235
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
236
+ super(RSU4,self).__init__()
237
+
238
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
239
+
240
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
241
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
242
+
243
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
244
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
245
+
246
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
247
+
248
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)
249
+
250
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
251
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
252
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
253
+
254
+ def forward(self,x):
255
+
256
+ hx = x
257
+
258
+ hxin = self.rebnconvin(hx)
259
+
260
+ hx1 = self.rebnconv1(hxin)
261
+ hx = self.pool1(hx1)
262
+
263
+ hx2 = self.rebnconv2(hx)
264
+ hx = self.pool2(hx2)
265
+
266
+ hx3 = self.rebnconv3(hx)
267
+
268
+ hx4 = self.rebnconv4(hx3)
269
+
270
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
271
+ hx3dup = _upsample_like(hx3d,hx2)
272
+
273
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
274
+ hx2dup = _upsample_like(hx2d,hx1)
275
+
276
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
277
+
278
+ return hx1d + hxin
279
+
280
+ ### RSU-4F ###
281
+ class RSU4F(nn.Module):#UNet04FRES(nn.Module):
282
+
283
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
284
+ super(RSU4F,self).__init__()
285
+
286
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
287
+
288
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
289
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
290
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
291
+
292
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)
293
+
294
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
295
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
296
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
297
+
298
+ def forward(self,x):
299
+
300
+ hx = x
301
+
302
+ hxin = self.rebnconvin(hx)
303
+
304
+ hx1 = self.rebnconv1(hxin)
305
+ hx2 = self.rebnconv2(hx1)
306
+ hx3 = self.rebnconv3(hx2)
307
+
308
+ hx4 = self.rebnconv4(hx3)
309
+
310
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
311
+ hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
312
+ hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
313
+
314
+ return hx1d + hxin
315
+
316
+
317
+ ##### U^2-Net ####
318
+ class U2NET(nn.Module):
319
+
320
+ def __init__(self,in_ch=3,out_ch=1):
321
+ super(U2NET,self).__init__()
322
+
323
+ self.stage1 = RSU7(in_ch,32,64)
324
+ self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
325
+
326
+ self.stage2 = RSU6(64,32,128)
327
+ self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
328
+
329
+ self.stage3 = RSU5(128,64,256)
330
+ self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
331
+
332
+ self.stage4 = RSU4(256,128,512)
333
+ self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
334
+
335
+ self.stage5 = RSU4F(512,256,512)
336
+ self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
337
+
338
+ self.stage6 = RSU4F(512,256,512)
339
+
340
+ # decoder
341
+ self.stage5d = RSU4F(1024,256,512)
342
+ self.stage4d = RSU4(1024,128,256)
343
+ self.stage3d = RSU5(512,64,128)
344
+ self.stage2d = RSU6(256,32,64)
345
+ self.stage1d = RSU7(128,16,64)
346
+
347
+ self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
348
+ self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
349
+ self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
350
+ self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
351
+ self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
352
+ self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
353
+
354
+ self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
355
+
356
+ def forward(self,x):
357
+
358
+ hx = x
359
+
360
+ #stage 1
361
+ hx1 = self.stage1(hx)
362
+ hx = self.pool12(hx1)
363
+
364
+ #stage 2
365
+ hx2 = self.stage2(hx)
366
+ hx = self.pool23(hx2)
367
+
368
+ #stage 3
369
+ hx3 = self.stage3(hx)
370
+ hx = self.pool34(hx3)
371
+
372
+ #stage 4
373
+ hx4 = self.stage4(hx)
374
+ hx = self.pool45(hx4)
375
+
376
+ #stage 5
377
+ hx5 = self.stage5(hx)
378
+ hx = self.pool56(hx5)
379
+
380
+ #stage 6
381
+ hx6 = self.stage6(hx)
382
+ hx6up = _upsample_like(hx6,hx5)
383
+
384
+ #-------------------- decoder --------------------
385
+ hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
386
+ hx5dup = _upsample_like(hx5d,hx4)
387
+
388
+ hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
389
+ hx4dup = _upsample_like(hx4d,hx3)
390
+
391
+ hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
392
+ hx3dup = _upsample_like(hx3d,hx2)
393
+
394
+ hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
395
+ hx2dup = _upsample_like(hx2d,hx1)
396
+
397
+ hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
398
+
399
+
400
+ #side output
401
+ d1 = self.side1(hx1d)
402
+
403
+ d2 = self.side2(hx2d)
404
+ d2 = _upsample_like(d2,d1)
405
+
406
+ d3 = self.side3(hx3d)
407
+ d3 = _upsample_like(d3,d1)
408
+
409
+ d4 = self.side4(hx4d)
410
+ d4 = _upsample_like(d4,d1)
411
+
412
+ d5 = self.side5(hx5d)
413
+ d5 = _upsample_like(d5,d1)
414
+
415
+ d6 = self.side6(hx6)
416
+ d6 = _upsample_like(d6,d1)
417
+
418
+ d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
419
+
420
+ return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)
421
+
422
+ ### U^2-Net small ###
423
+ class U2NETP(nn.Module):
424
+
425
+ def __init__(self,in_ch=3,out_ch=1):
426
+ super(U2NETP,self).__init__()
427
+
428
+ self.stage1 = RSU7(in_ch,16,64)
429
+ self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
430
+
431
+ self.stage2 = RSU6(64,16,64)
432
+ self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
433
+
434
+ self.stage3 = RSU5(64,16,64)
435
+ self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
436
+
437
+ self.stage4 = RSU4(64,16,64)
438
+ self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
439
+
440
+ self.stage5 = RSU4F(64,16,64)
441
+ self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
442
+
443
+ self.stage6 = RSU4F(64,16,64)
444
+
445
+ # decoder
446
+ self.stage5d = RSU4F(128,16,64)
447
+ self.stage4d = RSU4(128,16,64)
448
+ self.stage3d = RSU5(128,16,64)
449
+ self.stage2d = RSU6(128,16,64)
450
+ self.stage1d = RSU7(128,16,64)
451
+
452
+ self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
453
+ self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
454
+ self.side3 = nn.Conv2d(64,out_ch,3,padding=1)
455
+ self.side4 = nn.Conv2d(64,out_ch,3,padding=1)
456
+ self.side5 = nn.Conv2d(64,out_ch,3,padding=1)
457
+ self.side6 = nn.Conv2d(64,out_ch,3,padding=1)
458
+
459
+ self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
460
+
461
+ def forward(self,x):
462
+
463
+ hx = x
464
+
465
+ #stage 1
466
+ hx1 = self.stage1(hx)
467
+ hx = self.pool12(hx1)
468
+
469
+ #stage 2
470
+ hx2 = self.stage2(hx)
471
+ hx = self.pool23(hx2)
472
+
473
+ #stage 3
474
+ hx3 = self.stage3(hx)
475
+ hx = self.pool34(hx3)
476
+
477
+ #stage 4
478
+ hx4 = self.stage4(hx)
479
+ hx = self.pool45(hx4)
480
+
481
+ #stage 5
482
+ hx5 = self.stage5(hx)
483
+ hx = self.pool56(hx5)
484
+
485
+ #stage 6
486
+ hx6 = self.stage6(hx)
487
+ hx6up = _upsample_like(hx6,hx5)
488
+
489
+ #decoder
490
+ hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
491
+ hx5dup = _upsample_like(hx5d,hx4)
492
+
493
+ hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
494
+ hx4dup = _upsample_like(hx4d,hx3)
495
+
496
+ hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
497
+ hx3dup = _upsample_like(hx3d,hx2)
498
+
499
+ hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
500
+ hx2dup = _upsample_like(hx2d,hx1)
501
+
502
+ hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
503
+
504
+
505
+ #side output
506
+ d1 = self.side1(hx1d)
507
+
508
+ d2 = self.side2(hx2d)
509
+ d2 = _upsample_like(d2,d1)
510
+
511
+ d3 = self.side3(hx3d)
512
+ d3 = _upsample_like(d3,d1)
513
+
514
+ d4 = self.side4(hx4d)
515
+ d4 = _upsample_like(d4,d1)
516
+
517
+ d5 = self.side5(hx5d)
518
+ d5 = _upsample_like(d5,d1)
519
+
520
+ d6 = self.side6(hx6)
521
+ d6 = _upsample_like(d6,d1)
522
+
523
+ d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
524
+
525
+ return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)
U-2-Net/model/u2net_refactor.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ import math
5
+
6
+ __all__ = ['U2NET_full', 'U2NET_lite']
7
+
8
+
9
+ def _upsample_like(x, size):
10
+ return nn.Upsample(size=size, mode='bilinear', align_corners=False)(x)
11
+
12
+
13
+ def _size_map(x, height):
14
+ # {height: size} for Upsample
15
+ size = list(x.shape[-2:])
16
+ sizes = {}
17
+ for h in range(1, height):
18
+ sizes[h] = size
19
+ size = [math.ceil(w / 2) for w in size]
20
+ return sizes
21
+
22
+
23
+ class REBNCONV(nn.Module):
24
+ def __init__(self, in_ch=3, out_ch=3, dilate=1):
25
+ super(REBNCONV, self).__init__()
26
+
27
+ self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1 * dilate, dilation=1 * dilate)
28
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
29
+ self.relu_s1 = nn.ReLU(inplace=True)
30
+
31
+ def forward(self, x):
32
+ return self.relu_s1(self.bn_s1(self.conv_s1(x)))
33
+
34
+
35
+ class RSU(nn.Module):
36
+ def __init__(self, name, height, in_ch, mid_ch, out_ch, dilated=False):
37
+ super(RSU, self).__init__()
38
+ self.name = name
39
+ self.height = height
40
+ self.dilated = dilated
41
+ self._make_layers(height, in_ch, mid_ch, out_ch, dilated)
42
+
43
+ def forward(self, x):
44
+ sizes = _size_map(x, self.height)
45
+ x = self.rebnconvin(x)
46
+
47
+ # U-Net like symmetric encoder-decoder structure
48
+ def unet(x, height=1):
49
+ if height < self.height:
50
+ x1 = getattr(self, f'rebnconv{height}')(x)
51
+ if not self.dilated and height < self.height - 1:
52
+ x2 = unet(getattr(self, 'downsample')(x1), height + 1)
53
+ else:
54
+ x2 = unet(x1, height + 1)
55
+
56
+ x = getattr(self, f'rebnconv{height}d')(torch.cat((x2, x1), 1))
57
+ return _upsample_like(x, sizes[height - 1]) if not self.dilated and height > 1 else x
58
+ else:
59
+ return getattr(self, f'rebnconv{height}')(x)
60
+
61
+ return x + unet(x)
62
+
63
+ def _make_layers(self, height, in_ch, mid_ch, out_ch, dilated=False):
64
+ self.add_module('rebnconvin', REBNCONV(in_ch, out_ch))
65
+ self.add_module('downsample', nn.MaxPool2d(2, stride=2, ceil_mode=True))
66
+
67
+ self.add_module(f'rebnconv1', REBNCONV(out_ch, mid_ch))
68
+ self.add_module(f'rebnconv1d', REBNCONV(mid_ch * 2, out_ch))
69
+
70
+ for i in range(2, height):
71
+ dilate = 1 if not dilated else 2 ** (i - 1)
72
+ self.add_module(f'rebnconv{i}', REBNCONV(mid_ch, mid_ch, dilate=dilate))
73
+ self.add_module(f'rebnconv{i}d', REBNCONV(mid_ch * 2, mid_ch, dilate=dilate))
74
+
75
+ dilate = 2 if not dilated else 2 ** (height - 1)
76
+ self.add_module(f'rebnconv{height}', REBNCONV(mid_ch, mid_ch, dilate=dilate))
77
+
78
+
79
+ class U2NET(nn.Module):
80
+ def __init__(self, cfgs, out_ch):
81
+ super(U2NET, self).__init__()
82
+ self.out_ch = out_ch
83
+ self._make_layers(cfgs)
84
+
85
+ def forward(self, x):
86
+ sizes = _size_map(x, self.height)
87
+ maps = [] # storage for maps
88
+
89
+ # side saliency map
90
+ def unet(x, height=1):
91
+ if height < 6:
92
+ x1 = getattr(self, f'stage{height}')(x)
93
+ x2 = unet(getattr(self, 'downsample')(x1), height + 1)
94
+ x = getattr(self, f'stage{height}d')(torch.cat((x2, x1), 1))
95
+ side(x, height)
96
+ return _upsample_like(x, sizes[height - 1]) if height > 1 else x
97
+ else:
98
+ x = getattr(self, f'stage{height}')(x)
99
+ side(x, height)
100
+ return _upsample_like(x, sizes[height - 1])
101
+
102
+ def side(x, h):
103
+ # side output saliency map (before sigmoid)
104
+ x = getattr(self, f'side{h}')(x)
105
+ x = _upsample_like(x, sizes[1])
106
+ maps.append(x)
107
+
108
+ def fuse():
109
+ # fuse saliency probability maps
110
+ maps.reverse()
111
+ x = torch.cat(maps, 1)
112
+ x = getattr(self, 'outconv')(x)
113
+ maps.insert(0, x)
114
+ return [torch.sigmoid(x) for x in maps]
115
+
116
+ unet(x)
117
+ maps = fuse()
118
+ return maps
119
+
120
+ def _make_layers(self, cfgs):
121
+ self.height = int((len(cfgs) + 1) / 2)
122
+ self.add_module('downsample', nn.MaxPool2d(2, stride=2, ceil_mode=True))
123
+ for k, v in cfgs.items():
124
+ # build rsu block
125
+ self.add_module(k, RSU(v[0], *v[1]))
126
+ if v[2] > 0:
127
+ # build side layer
128
+ self.add_module(f'side{v[0][-1]}', nn.Conv2d(v[2], self.out_ch, 3, padding=1))
129
+ # build fuse layer
130
+ self.add_module('outconv', nn.Conv2d(int(self.height * self.out_ch), self.out_ch, 1))
131
+
132
+
133
+ def U2NET_full():
134
+ full = {
135
+ # cfgs for building RSUs and sides
136
+ # {stage : [name, (height(L), in_ch, mid_ch, out_ch, dilated), side]}
137
+ 'stage1': ['En_1', (7, 3, 32, 64), -1],
138
+ 'stage2': ['En_2', (6, 64, 32, 128), -1],
139
+ 'stage3': ['En_3', (5, 128, 64, 256), -1],
140
+ 'stage4': ['En_4', (4, 256, 128, 512), -1],
141
+ 'stage5': ['En_5', (4, 512, 256, 512, True), -1],
142
+ 'stage6': ['En_6', (4, 512, 256, 512, True), 512],
143
+ 'stage5d': ['De_5', (4, 1024, 256, 512, True), 512],
144
+ 'stage4d': ['De_4', (4, 1024, 128, 256), 256],
145
+ 'stage3d': ['De_3', (5, 512, 64, 128), 128],
146
+ 'stage2d': ['De_2', (6, 256, 32, 64), 64],
147
+ 'stage1d': ['De_1', (7, 128, 16, 64), 64],
148
+ }
149
+ return U2NET(cfgs=full, out_ch=1)
150
+
151
+
152
+ def U2NET_lite():
153
+ lite = {
154
+ # cfgs for building RSUs and sides
155
+ # {stage : [name, (height(L), in_ch, mid_ch, out_ch, dilated), side]}
156
+ 'stage1': ['En_1', (7, 3, 16, 64), -1],
157
+ 'stage2': ['En_2', (6, 64, 16, 64), -1],
158
+ 'stage3': ['En_3', (5, 64, 16, 64), -1],
159
+ 'stage4': ['En_4', (4, 64, 16, 64), -1],
160
+ 'stage5': ['En_5', (4, 64, 16, 64, True), -1],
161
+ 'stage6': ['En_6', (4, 64, 16, 64, True), 64],
162
+ 'stage5d': ['De_5', (4, 128, 16, 64, True), 64],
163
+ 'stage4d': ['De_4', (4, 128, 16, 64), 64],
164
+ 'stage3d': ['De_3', (5, 128, 16, 64), 64],
165
+ 'stage2d': ['De_2', (6, 128, 16, 64), 64],
166
+ 'stage1d': ['De_1', (7, 128, 16, 64), 64],
167
+ }
168
+ return U2NET(cfgs=lite, out_ch=1)
U-2-Net/requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ numpy==1.15.2
2
+ scikit-image==0.14.0
3
+ torch
4
+ torchvision
5
+ pillow==8.1.1
6
+ opencv-python
7
+ paddlepaddle
8
+ paddlehub
9
+ gradio
U-2-Net/saved_models/face_detection_cv2/haarcascade_frontalface_default.xml ADDED
The diff for this file is too large to render. See raw diff
 
U-2-Net/saved_models/u2net_portrait/u2net_portrait.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fb9f0378a16868d08e2325c8b36eae2b174b040b91bf64781fbb5dd4d31712b4
3
+ size 176315791
U-2-Net/setup_model_weights.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gdown
3
+
4
+ os.makedirs('./saved_models/u2net', exist_ok=True)
5
+ os.makedirs('./saved_models/u2net_portrait', exist_ok=True)
6
+
7
+ gdown.download('https://drive.google.com/uc?id=1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ',
8
+ './saved_models/u2net/u2net.pth',
9
+ quiet=False)
10
+
11
+ gdown.download('https://drive.google.com/uc?id=1IG3HdpcRiDoWNookbncQjeaPN28t90yW',
12
+ './saved_models/u2net_portrait/u2net_portrait.pth',
13
+ quiet=False)
U-2-Net/u2net_human_seg_test.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from skimage import io, transform
3
+ import torch
4
+ import torchvision
5
+ from torch.autograd import Variable
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch.utils.data import Dataset, DataLoader
9
+ from torchvision import transforms#, utils
10
+ # import torch.optim as optim
11
+
12
+ import numpy as np
13
+ from PIL import Image
14
+ import glob
15
+
16
+ from data_loader import RescaleT
17
+ from data_loader import ToTensor
18
+ from data_loader import ToTensorLab
19
+ from data_loader import SalObjDataset
20
+
21
+ from model import U2NET # full size version 173.6 MB
22
+
23
+ # normalize the predicted SOD probability map
24
+ def normPRED(d):
25
+ ma = torch.max(d)
26
+ mi = torch.min(d)
27
+
28
+ dn = (d-mi)/(ma-mi)
29
+
30
+ return dn
31
+
32
+ def save_output(image_name,pred,d_dir):
33
+
34
+ predict = pred
35
+ predict = predict.squeeze()
36
+ predict_np = predict.cpu().data.numpy()
37
+
38
+ im = Image.fromarray(predict_np*255).convert('RGB')
39
+ img_name = image_name.split(os.sep)[-1]
40
+ image = io.imread(image_name)
41
+ imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR)
42
+
43
+ pb_np = np.array(imo)
44
+
45
+ aaa = img_name.split(".")
46
+ bbb = aaa[0:-1]
47
+ imidx = bbb[0]
48
+ for i in range(1,len(bbb)):
49
+ imidx = imidx + "." + bbb[i]
50
+
51
+ imo.save(d_dir+imidx+'.png')
52
+
53
+ def main():
54
+
55
+ # --------- 1. get image path and name ---------
56
+ model_name='u2net'
57
+
58
+
59
+ image_dir = os.path.join(os.getcwd(), 'test_data', 'test_human_images')
60
+ prediction_dir = os.path.join(os.getcwd(), 'test_data', 'test_human_images' + '_results' + os.sep)
61
+ model_dir = os.path.join(os.getcwd(), 'saved_models', model_name+'_human_seg', model_name + '_human_seg.pth')
62
+
63
+ img_name_list = glob.glob(image_dir + os.sep + '*')
64
+ print(img_name_list)
65
+
66
+ # --------- 2. dataloader ---------
67
+ #1. dataloader
68
+ test_salobj_dataset = SalObjDataset(img_name_list = img_name_list,
69
+ lbl_name_list = [],
70
+ transform=transforms.Compose([RescaleT(320),
71
+ ToTensorLab(flag=0)])
72
+ )
73
+ test_salobj_dataloader = DataLoader(test_salobj_dataset,
74
+ batch_size=1,
75
+ shuffle=False,
76
+ num_workers=1)
77
+
78
+ # --------- 3. model define ---------
79
+ if(model_name=='u2net'):
80
+ print("...load U2NET---173.6 MB")
81
+ net = U2NET(3,1)
82
+
83
+ if torch.cuda.is_available():
84
+ net.load_state_dict(torch.load(model_dir))
85
+ net.cuda()
86
+ else:
87
+ net.load_state_dict(torch.load(model_dir, map_location='cpu'))
88
+ net.eval()
89
+
90
+ # --------- 4. inference for each image ---------
91
+ for i_test, data_test in enumerate(test_salobj_dataloader):
92
+
93
+ print("inferencing:",img_name_list[i_test].split(os.sep)[-1])
94
+
95
+ inputs_test = data_test['image']
96
+ inputs_test = inputs_test.type(torch.FloatTensor)
97
+
98
+ if torch.cuda.is_available():
99
+ inputs_test = Variable(inputs_test.cuda())
100
+ else:
101
+ inputs_test = Variable(inputs_test)
102
+
103
+ d1,d2,d3,d4,d5,d6,d7= net(inputs_test)
104
+
105
+ # normalization
106
+ pred = d1[:,0,:,:]
107
+ pred = normPRED(pred)
108
+
109
+ # save results to test_results folder
110
+ if not os.path.exists(prediction_dir):
111
+ os.makedirs(prediction_dir, exist_ok=True)
112
+ save_output(img_name_list[i_test],pred,prediction_dir)
113
+
114
+ del d1,d2,d3,d4,d5,d6,d7
115
+
116
+ if __name__ == "__main__":
117
+ main()
U-2-Net/u2net_portrait_composite.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from skimage import io, transform
3
+ from skimage.filters import gaussian
4
+ import torch
5
+ import torchvision
6
+ from torch.autograd import Variable
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.utils.data import Dataset, DataLoader
10
+ from torchvision import transforms#, utils
11
+ # import torch.optim as optim
12
+
13
+ import numpy as np
14
+ from PIL import Image
15
+ import glob
16
+
17
+ from data_loader import RescaleT
18
+ from data_loader import ToTensor
19
+ from data_loader import ToTensorLab
20
+ from data_loader import SalObjDataset
21
+
22
+ from model import U2NET # full size version 173.6 MB
23
+ from model import U2NETP # small version u2net 4.7 MB
24
+
25
+ import argparse
26
+
27
+ # normalize the predicted SOD probability map
28
+ def normPRED(d):
29
+ ma = torch.max(d)
30
+ mi = torch.min(d)
31
+
32
+ dn = (d-mi)/(ma-mi)
33
+
34
+ return dn
35
+
36
+ def save_output(image_name,pred,d_dir,sigma=2,alpha=0.5):
37
+
38
+ predict = pred
39
+ predict = predict.squeeze()
40
+ predict_np = predict.cpu().data.numpy()
41
+
42
+ image = io.imread(image_name)
43
+ pd = transform.resize(predict_np,image.shape[0:2],order=2)
44
+ pd = pd/(np.amax(pd)+1e-8)*255
45
+ pd = pd[:,:,np.newaxis]
46
+
47
+ print(image.shape)
48
+ print(pd.shape)
49
+
50
+ ## fuse the orignal portrait image and the portraits into one composite image
51
+ ## 1. use gaussian filter to blur the orginal image
52
+ sigma=sigma
53
+ image = gaussian(image, sigma=sigma, preserve_range=True)
54
+
55
+ ## 2. fuse these orignal image and the portrait with certain weight: alpha
56
+ alpha = alpha
57
+ im_comp = image*alpha+pd*(1-alpha)
58
+
59
+ print(im_comp.shape)
60
+
61
+
62
+ img_name = image_name.split(os.sep)[-1]
63
+ aaa = img_name.split(".")
64
+ bbb = aaa[0:-1]
65
+ imidx = bbb[0]
66
+ for i in range(1,len(bbb)):
67
+ imidx = imidx + "." + bbb[i]
68
+ io.imsave(d_dir+'/'+imidx+'_sigma_' + str(sigma) + '_alpha_' + str(alpha) + '_composite.png',im_comp)
69
+
70
+ def main():
71
+
72
+ parser = argparse.ArgumentParser(description="image and portrait composite")
73
+ parser.add_argument('-s',action='store',dest='sigma')
74
+ parser.add_argument('-a',action='store',dest='alpha')
75
+ args = parser.parse_args()
76
+ print(args.sigma)
77
+ print(args.alpha)
78
+ print("--------------------")
79
+
80
+ # --------- 1. get image path and name ---------
81
+ model_name='u2net_portrait'#u2netp
82
+
83
+
84
+ image_dir = './test_data/test_portrait_images/your_portrait_im'
85
+ prediction_dir = './test_data/test_portrait_images/your_portrait_results'
86
+ if(not os.path.exists(prediction_dir)):
87
+ os.mkdir(prediction_dir)
88
+
89
+ model_dir = './saved_models/u2net_portrait/u2net_portrait.pth'
90
+
91
+ img_name_list = glob.glob(image_dir+'/*')
92
+ print("Number of images: ", len(img_name_list))
93
+
94
+ # --------- 2. dataloader ---------
95
+ #1. dataloader
96
+ test_salobj_dataset = SalObjDataset(img_name_list = img_name_list,
97
+ lbl_name_list = [],
98
+ transform=transforms.Compose([RescaleT(512),
99
+ ToTensorLab(flag=0)])
100
+ )
101
+ test_salobj_dataloader = DataLoader(test_salobj_dataset,
102
+ batch_size=1,
103
+ shuffle=False,
104
+ num_workers=1)
105
+
106
+ # --------- 3. model define ---------
107
+
108
+ print("...load U2NET---173.6 MB")
109
+ net = U2NET(3,1)
110
+
111
+ net.load_state_dict(torch.load(model_dir))
112
+ if torch.cuda.is_available():
113
+ net.cuda()
114
+ net.eval()
115
+
116
+ # --------- 4. inference for each image ---------
117
+ for i_test, data_test in enumerate(test_salobj_dataloader):
118
+
119
+ print("inferencing:",img_name_list[i_test].split(os.sep)[-1])
120
+
121
+ inputs_test = data_test['image']
122
+ inputs_test = inputs_test.type(torch.FloatTensor)
123
+
124
+ if torch.cuda.is_available():
125
+ inputs_test = Variable(inputs_test.cuda())
126
+ else:
127
+ inputs_test = Variable(inputs_test)
128
+
129
+ d1,d2,d3,d4,d5,d6,d7= net(inputs_test)
130
+
131
+ # normalization
132
+ pred = 1.0 - d1[:,0,:,:]
133
+ pred = normPRED(pred)
134
+
135
+ # save results to test_results folder
136
+ save_output(img_name_list[i_test],pred,prediction_dir,sigma=float(args.sigma),alpha=float(args.alpha))
137
+
138
+ del d1,d2,d3,d4,d5,d6,d7
139
+
140
+ if __name__ == "__main__":
141
+ main()
U-2-Net/u2net_portrait_demo.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ from model import U2NET
4
+ from torch.autograd import Variable
5
+ import numpy as np
6
+ from glob import glob
7
+ import os
8
+
9
+ def detect_single_face(face_cascade,img):
10
+ # Convert into grayscale
11
+ gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
12
+
13
+ # Detect faces
14
+ faces = face_cascade.detectMultiScale(gray, 1.1, 4)
15
+ if(len(faces)==0):
16
+ print("Warming: no face detection, the portrait u2net will run on the whole image!")
17
+ return None
18
+
19
+ # filter to keep the largest face
20
+ wh = 0
21
+ idx = 0
22
+ for i in range(0,len(faces)):
23
+ (x,y,w,h) = faces[i]
24
+ if(wh<w*h):
25
+ idx = i
26
+ wh = w*h
27
+
28
+ return faces[idx]
29
+
30
+ # crop, pad and resize face region to 512x512 resolution
31
+ def crop_face(img, face):
32
+
33
+ # no face detected, return the whole image and the inference will run on the whole image
34
+ if(face is None):
35
+ return img
36
+ (x, y, w, h) = face
37
+
38
+ height,width = img.shape[0:2]
39
+
40
+ # crop the face with a bigger bbox
41
+ hmw = h - w
42
+ # hpad = int(h/2)+1
43
+ # wpad = int(w/2)+1
44
+
45
+ l,r,t,b = 0,0,0,0
46
+ lpad = int(float(w)*0.4)
47
+ left = x-lpad
48
+ if(left<0):
49
+ l = lpad-x
50
+ left = 0
51
+
52
+ rpad = int(float(w)*0.4)
53
+ right = x+w+rpad
54
+ if(right>width):
55
+ r = right-width
56
+ right = width
57
+
58
+ tpad = int(float(h)*0.6)
59
+ top = y - tpad
60
+ if(top<0):
61
+ t = tpad-y
62
+ top = 0
63
+
64
+ bpad = int(float(h)*0.2)
65
+ bottom = y+h+bpad
66
+ if(bottom>height):
67
+ b = bottom-height
68
+ bottom = height
69
+
70
+
71
+ im_face = img[top:bottom,left:right]
72
+ if(len(im_face.shape)==2):
73
+ im_face = np.repeat(im_face[:,:,np.newaxis],(1,1,3))
74
+
75
+ im_face = np.pad(im_face,((t,b),(l,r),(0,0)),mode='constant',constant_values=((255,255),(255,255),(255,255)))
76
+
77
+ # pad to achieve image with square shape for avoding face deformation after resizing
78
+ hf,wf = im_face.shape[0:2]
79
+ if(hf-2>wf):
80
+ wfp = int((hf-wf)/2)
81
+ im_face = np.pad(im_face,((0,0),(wfp,wfp),(0,0)),mode='constant',constant_values=((255,255),(255,255),(255,255)))
82
+ elif(wf-2>hf):
83
+ hfp = int((wf-hf)/2)
84
+ im_face = np.pad(im_face,((hfp,hfp),(0,0),(0,0)),mode='constant',constant_values=((255,255),(255,255),(255,255)))
85
+
86
+ # resize to have 512x512 resolution
87
+ im_face = cv2.resize(im_face, (512,512), interpolation = cv2.INTER_AREA)
88
+
89
+ return im_face
90
+
91
+ def normPRED(d):
92
+ ma = torch.max(d)
93
+ mi = torch.min(d)
94
+
95
+ dn = (d-mi)/(ma-mi)
96
+
97
+ return dn
98
+
99
+ def inference(net,input):
100
+
101
+ # normalize the input
102
+ tmpImg = np.zeros((input.shape[0],input.shape[1],3))
103
+ input = input/np.max(input)
104
+
105
+ tmpImg[:,:,0] = (input[:,:,2]-0.406)/0.225
106
+ tmpImg[:,:,1] = (input[:,:,1]-0.456)/0.224
107
+ tmpImg[:,:,2] = (input[:,:,0]-0.485)/0.229
108
+
109
+ # convert BGR to RGB
110
+ tmpImg = tmpImg.transpose((2, 0, 1))
111
+ tmpImg = tmpImg[np.newaxis,:,:,:]
112
+ tmpImg = torch.from_numpy(tmpImg)
113
+
114
+ # convert numpy array to torch tensor
115
+ tmpImg = tmpImg.type(torch.FloatTensor)
116
+
117
+ if torch.cuda.is_available():
118
+ tmpImg = Variable(tmpImg.cuda())
119
+ else:
120
+ tmpImg = Variable(tmpImg)
121
+
122
+ # inference
123
+ d1,d2,d3,d4,d5,d6,d7= net(tmpImg)
124
+
125
+ # normalization
126
+ pred = 1.0 - d1[:,0,:,:]
127
+ pred = normPRED(pred)
128
+
129
+ # convert torch tensor to numpy array
130
+ pred = pred.squeeze()
131
+ pred = pred.cpu().data.numpy()
132
+
133
+ del d1,d2,d3,d4,d5,d6,d7
134
+
135
+ return pred
136
+
137
+ def main():
138
+
139
+ # get the image path list for inference
140
+ im_list = glob('./test_data/test_portrait_images/your_portrait_im/*')
141
+ print("Number of images: ",len(im_list))
142
+ # indicate the output directory
143
+ out_dir = './test_data/test_portrait_images/your_portrait_results'
144
+ if(not os.path.exists(out_dir)):
145
+ os.mkdir(out_dir)
146
+
147
+ # Load the cascade face detection model
148
+ face_cascade = cv2.CascadeClassifier('./saved_models/face_detection_cv2/haarcascade_frontalface_default.xml')
149
+ # u2net_portrait path
150
+ model_dir = './saved_models/u2net_portrait/u2net_portrait.pth'
151
+
152
+ # load u2net_portrait model
153
+ net = U2NET(3,1)
154
+ net.load_state_dict(torch.load(model_dir))
155
+ if torch.cuda.is_available():
156
+ net.cuda()
157
+ net.eval()
158
+
159
+ # do the inference one-by-one
160
+ for i in range(0,len(im_list)):
161
+ print("--------------------------")
162
+ print("inferencing ", i, "/", len(im_list), im_list[i])
163
+
164
+ # load each image
165
+ img = cv2.imread(im_list[i])
166
+ height,width = img.shape[0:2]
167
+ face = detect_single_face(face_cascade,img)
168
+ im_face = crop_face(img, face)
169
+ im_portrait = inference(net,im_face)
170
+
171
+ # save the output
172
+ cv2.imwrite(out_dir+"/"+im_list[i].split('/')[-1][0:-4]+'.png',(im_portrait*255).astype(np.uint8))
173
+
174
+ if __name__ == '__main__':
175
+ main()
U-2-Net/u2net_portrait_test.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from skimage import io, transform
3
+ import torch
4
+ import torchvision
5
+ from torch.autograd import Variable
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch.utils.data import Dataset, DataLoader
9
+ from torchvision import transforms#, utils
10
+ # import torch.optim as optim
11
+
12
+ import numpy as np
13
+ from PIL import Image
14
+ import glob
15
+
16
+ from data_loader import RescaleT
17
+ from data_loader import ToTensor
18
+ from data_loader import ToTensorLab
19
+ from data_loader import SalObjDataset
20
+
21
+ from model import U2NET # full size version 173.6 MB
22
+ from model import U2NETP # small version u2net 4.7 MB
23
+
24
+ # normalize the predicted SOD probability map
25
+ def normPRED(d):
26
+ ma = torch.max(d)
27
+ mi = torch.min(d)
28
+
29
+ dn = (d-mi)/(ma-mi)
30
+
31
+ return dn
32
+
33
+ def save_output(image_name,pred,d_dir):
34
+
35
+ predict = pred
36
+ predict = predict.squeeze()
37
+ predict_np = predict.cpu().data.numpy()
38
+
39
+ im = Image.fromarray(predict_np*255).convert('RGB')
40
+ img_name = image_name.split(os.sep)[-1]
41
+ image = io.imread(image_name)
42
+ imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR)
43
+
44
+ pb_np = np.array(imo)
45
+
46
+ aaa = img_name.split(".")
47
+ bbb = aaa[0:-1]
48
+ imidx = bbb[0]
49
+ for i in range(1,len(bbb)):
50
+ imidx = imidx + "." + bbb[i]
51
+
52
+ imo.save(d_dir+'/'+imidx+'.png')
53
+
54
+ def main():
55
+
56
+ # --------- 1. get image path and name ---------
57
+ model_name='u2net_portrait'#u2netp
58
+
59
+
60
+ image_dir = './test_data/test_portrait_images/portrait_im'
61
+ prediction_dir = './test_data/test_portrait_images/portrait_results'
62
+ if(not os.path.exists(prediction_dir)):
63
+ os.mkdir(prediction_dir)
64
+
65
+ model_dir = './saved_models/u2net_portrait/u2net_portrait.pth'
66
+
67
+ img_name_list = glob.glob(image_dir+'/*')
68
+ print("Number of images: ", len(img_name_list))
69
+
70
+ # --------- 2. dataloader ---------
71
+ #1. dataloader
72
+ test_salobj_dataset = SalObjDataset(img_name_list = img_name_list,
73
+ lbl_name_list = [],
74
+ transform=transforms.Compose([RescaleT(512),
75
+ ToTensorLab(flag=0)])
76
+ )
77
+ test_salobj_dataloader = DataLoader(test_salobj_dataset,
78
+ batch_size=1,
79
+ shuffle=False,
80
+ num_workers=1)
81
+
82
+ # --------- 3. model define ---------
83
+
84
+ print("...load U2NET---173.6 MB")
85
+ net = U2NET(3,1)
86
+
87
+ net.load_state_dict(torch.load(model_dir))
88
+ if torch.cuda.is_available():
89
+ net.cuda()
90
+ net.eval()
91
+
92
+ # --------- 4. inference for each image ---------
93
+ for i_test, data_test in enumerate(test_salobj_dataloader):
94
+
95
+ print("inferencing:",img_name_list[i_test].split(os.sep)[-1])
96
+
97
+ inputs_test = data_test['image']
98
+ inputs_test = inputs_test.type(torch.FloatTensor)
99
+
100
+ if torch.cuda.is_available():
101
+ inputs_test = Variable(inputs_test.cuda())
102
+ else:
103
+ inputs_test = Variable(inputs_test)
104
+
105
+ d1,d2,d3,d4,d5,d6,d7= net(inputs_test)
106
+
107
+ # normalization
108
+ pred = 1.0 - d1[:,0,:,:]
109
+ pred = normPRED(pred)
110
+
111
+ # save results to test_results folder
112
+ save_output(img_name_list[i_test],pred,prediction_dir)
113
+
114
+ del d1,d2,d3,d4,d5,d6,d7
115
+
116
+ if __name__ == "__main__":
117
+ main()
U-2-Net/u2net_test.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from skimage import io, transform
3
+ import torch
4
+ import torchvision
5
+ from torch.autograd import Variable
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch.utils.data import Dataset, DataLoader
9
+ from torchvision import transforms#, utils
10
+ # import torch.optim as optim
11
+
12
+ import numpy as np
13
+ from PIL import Image
14
+ import glob
15
+
16
+ from data_loader import RescaleT
17
+ from data_loader import ToTensor
18
+ from data_loader import ToTensorLab
19
+ from data_loader import SalObjDataset
20
+
21
+ from model import U2NET # full size version 173.6 MB
22
+ from model import U2NETP # small version u2net 4.7 MB
23
+
24
+ # normalize the predicted SOD probability map
25
+ def normPRED(d):
26
+ ma = torch.max(d)
27
+ mi = torch.min(d)
28
+
29
+ dn = (d-mi)/(ma-mi)
30
+
31
+ return dn
32
+
33
+ def save_output(image_name,pred,d_dir):
34
+
35
+ predict = pred
36
+ predict = predict.squeeze()
37
+ predict_np = predict.cpu().data.numpy()
38
+
39
+ im = Image.fromarray(predict_np*255).convert('RGB')
40
+ img_name = image_name.split(os.sep)[-1]
41
+ image = io.imread(image_name)
42
+ imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR)
43
+
44
+ pb_np = np.array(imo)
45
+
46
+ aaa = img_name.split(".")
47
+ bbb = aaa[0:-1]
48
+ imidx = bbb[0]
49
+ for i in range(1,len(bbb)):
50
+ imidx = imidx + "." + bbb[i]
51
+
52
+ imo.save(d_dir+imidx+'.png')
53
+
54
+ def main():
55
+
56
+ # --------- 1. get image path and name ---------
57
+ model_name='u2net'#u2netp
58
+
59
+
60
+
61
+ image_dir = os.path.join(os.getcwd(), 'test_data', 'test_images')
62
+ prediction_dir = os.path.join(os.getcwd(), 'test_data', model_name + '_results' + os.sep)
63
+ model_dir = os.path.join(os.getcwd(), 'saved_models', model_name, model_name + '.pth')
64
+
65
+ img_name_list = glob.glob(image_dir + os.sep + '*')
66
+ print(img_name_list)
67
+
68
+ # --------- 2. dataloader ---------
69
+ #1. dataloader
70
+ test_salobj_dataset = SalObjDataset(img_name_list = img_name_list,
71
+ lbl_name_list = [],
72
+ transform=transforms.Compose([RescaleT(320),
73
+ ToTensorLab(flag=0)])
74
+ )
75
+ test_salobj_dataloader = DataLoader(test_salobj_dataset,
76
+ batch_size=1,
77
+ shuffle=False,
78
+ num_workers=1)
79
+
80
+ # --------- 3. model define ---------
81
+ if(model_name=='u2net'):
82
+ print("...load U2NET---173.6 MB")
83
+ net = U2NET(3,1)
84
+ elif(model_name=='u2netp'):
85
+ print("...load U2NEP---4.7 MB")
86
+ net = U2NETP(3,1)
87
+
88
+ if torch.cuda.is_available():
89
+ net.load_state_dict(torch.load(model_dir))
90
+ net.cuda()
91
+ else:
92
+ net.load_state_dict(torch.load(model_dir, map_location='cpu'))
93
+ net.eval()
94
+
95
+ # --------- 4. inference for each image ---------
96
+ for i_test, data_test in enumerate(test_salobj_dataloader):
97
+
98
+ print("inferencing:",img_name_list[i_test].split(os.sep)[-1])
99
+
100
+ inputs_test = data_test['image']
101
+ inputs_test = inputs_test.type(torch.FloatTensor)
102
+
103
+ if torch.cuda.is_available():
104
+ inputs_test = Variable(inputs_test.cuda())
105
+ else:
106
+ inputs_test = Variable(inputs_test)
107
+
108
+ d1,d2,d3,d4,d5,d6,d7= net(inputs_test)
109
+
110
+ # normalization
111
+ pred = d1[:,0,:,:]
112
+ pred = normPRED(pred)
113
+
114
+ # save results to test_results folder
115
+ if not os.path.exists(prediction_dir):
116
+ os.makedirs(prediction_dir, exist_ok=True)
117
+ save_output(img_name_list[i_test],pred,prediction_dir)
118
+
119
+ del d1,d2,d3,d4,d5,d6,d7
120
+
121
+ if __name__ == "__main__":
122
+ main()
U-2-Net/u2net_train.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchvision
4
+ from torch.autograd import Variable
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from torch.utils.data import Dataset, DataLoader
9
+ from torchvision import transforms, utils
10
+ import torch.optim as optim
11
+ import torchvision.transforms as standard_transforms
12
+
13
+ import numpy as np
14
+ import glob
15
+ import os
16
+
17
+ from data_loader import Rescale
18
+ from data_loader import RescaleT
19
+ from data_loader import RandomCrop
20
+ from data_loader import ToTensor
21
+ from data_loader import ToTensorLab
22
+ from data_loader import SalObjDataset
23
+
24
+ from model import U2NET
25
+ from model import U2NETP
26
+
27
+ # ------- 1. define loss function --------
28
+
29
+ bce_loss = nn.BCELoss(size_average=True)
30
+
31
+ def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):
32
+
33
+ loss0 = bce_loss(d0,labels_v)
34
+ loss1 = bce_loss(d1,labels_v)
35
+ loss2 = bce_loss(d2,labels_v)
36
+ loss3 = bce_loss(d3,labels_v)
37
+ loss4 = bce_loss(d4,labels_v)
38
+ loss5 = bce_loss(d5,labels_v)
39
+ loss6 = bce_loss(d6,labels_v)
40
+
41
+ loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
42
+ print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n"%(loss0.data.item(),loss1.data.item(),loss2.data.item(),loss3.data.item(),loss4.data.item(),loss5.data.item(),loss6.data.item()))
43
+
44
+ return loss0, loss
45
+
46
+
47
+ # ------- 2. set the directory of training dataset --------
48
+
49
+ model_name = 'u2net' #'u2netp'
50
+
51
+ data_dir = os.path.join(os.getcwd(), 'train_data' + os.sep)
52
+ tra_image_dir = os.path.join('DUTS', 'DUTS-TR', 'DUTS-TR', 'im_aug' + os.sep)
53
+ tra_label_dir = os.path.join('DUTS', 'DUTS-TR', 'DUTS-TR', 'gt_aug' + os.sep)
54
+
55
+ image_ext = '.jpg'
56
+ label_ext = '.png'
57
+
58
+ model_dir = os.path.join(os.getcwd(), 'saved_models', model_name + os.sep)
59
+
60
+ epoch_num = 100000
61
+ batch_size_train = 12
62
+ batch_size_val = 1
63
+ train_num = 0
64
+ val_num = 0
65
+
66
+ tra_img_name_list = glob.glob(data_dir + tra_image_dir + '*' + image_ext)
67
+
68
+ tra_lbl_name_list = []
69
+ for img_path in tra_img_name_list:
70
+ img_name = img_path.split(os.sep)[-1]
71
+
72
+ aaa = img_name.split(".")
73
+ bbb = aaa[0:-1]
74
+ imidx = bbb[0]
75
+ for i in range(1,len(bbb)):
76
+ imidx = imidx + "." + bbb[i]
77
+
78
+ tra_lbl_name_list.append(data_dir + tra_label_dir + imidx + label_ext)
79
+
80
+ print("---")
81
+ print("train images: ", len(tra_img_name_list))
82
+ print("train labels: ", len(tra_lbl_name_list))
83
+ print("---")
84
+
85
+ train_num = len(tra_img_name_list)
86
+
87
+ salobj_dataset = SalObjDataset(
88
+ img_name_list=tra_img_name_list,
89
+ lbl_name_list=tra_lbl_name_list,
90
+ transform=transforms.Compose([
91
+ RescaleT(320),
92
+ RandomCrop(288),
93
+ ToTensorLab(flag=0)]))
94
+ salobj_dataloader = DataLoader(salobj_dataset, batch_size=batch_size_train, shuffle=True, num_workers=1)
95
+
96
+ # ------- 3. define model --------
97
+ # define the net
98
+ if(model_name=='u2net'):
99
+ net = U2NET(3, 1)
100
+ elif(model_name=='u2netp'):
101
+ net = U2NETP(3,1)
102
+
103
+ if torch.cuda.is_available():
104
+ net.cuda()
105
+
106
+ # ------- 4. define optimizer --------
107
+ print("---define optimizer...")
108
+ optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
109
+
110
+ # ------- 5. training process --------
111
+ print("---start training...")
112
+ ite_num = 0
113
+ running_loss = 0.0
114
+ running_tar_loss = 0.0
115
+ ite_num4val = 0
116
+ save_frq = 2000 # save the model every 2000 iterations
117
+
118
+ for epoch in range(0, epoch_num):
119
+ net.train()
120
+
121
+ for i, data in enumerate(salobj_dataloader):
122
+ ite_num = ite_num + 1
123
+ ite_num4val = ite_num4val + 1
124
+
125
+ inputs, labels = data['image'], data['label']
126
+
127
+ inputs = inputs.type(torch.FloatTensor)
128
+ labels = labels.type(torch.FloatTensor)
129
+
130
+ # wrap them in Variable
131
+ if torch.cuda.is_available():
132
+ inputs_v, labels_v = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(),
133
+ requires_grad=False)
134
+ else:
135
+ inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False)
136
+
137
+ # y zero the parameter gradients
138
+ optimizer.zero_grad()
139
+
140
+ # forward + backward + optimize
141
+ d0, d1, d2, d3, d4, d5, d6 = net(inputs_v)
142
+ loss2, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v)
143
+
144
+ loss.backward()
145
+ optimizer.step()
146
+
147
+ # # print statistics
148
+ running_loss += loss.data.item()
149
+ running_tar_loss += loss2.data.item()
150
+
151
+ # del temporary outputs and loss
152
+ del d0, d1, d2, d3, d4, d5, d6, loss2, loss
153
+
154
+ print("[epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f " % (
155
+ epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val))
156
+
157
+ if ite_num % save_frq == 0:
158
+
159
+ torch.save(net.state_dict(), model_dir + model_name+"_bce_itr_%d_train_%3f_tar_%3f.pth" % (ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val))
160
+ running_loss = 0.0
161
+ running_tar_loss = 0.0
162
+ net.train() # resume train
163
+ ite_num4val = 0
164
+
app.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gradio as gr
4
+
5
+
6
+ def process(im):
7
+
8
+ return im
9
+
10
+ title = "U-2-Net"
11
+ description = "Gradio demo for U-2-Net, https://github.com/xuebinqin/U-2-Net"
12
+ article = ""
13
+
14
+ gr.Interface(
15
+ process,
16
+ [gr.inputs.Image(type="pil", label="Input")
17
+ ],
18
+ gr.outputs.Image(type="pil", label="Output"),
19
+ title=title,
20
+ description=description,
21
+ article=article,
22
+ examples=[],
23
+ allow_flagging=False,
24
+ allow_screenshot=False
25
+ ).launch(enable_queue=True,cache_examples=True)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ numpy==1.15.2
2
+ scikit-image==0.14.0
3
+ torch
4
+ torchvision
5
+ pillow==8.1.1
6
+ opencv-python-headless