AntoreepJana commited on
Commit
865f99b
1 Parent(s): 6ff092a

Upload 23 files

Browse files
app.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+
4
+
5
+ import torch
6
+ import os
7
+ from skimage import io, transform
8
+ import torch
9
+ import torchvision
10
+ from torch.autograd import Variable
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torch.utils.data import Dataset, DataLoader
14
+ from torchvision import transforms
15
+
16
+ import numpy as np
17
+ from PIL import Image
18
+ import glob
19
+
20
+
21
+ import cv2
22
+
23
+ import torch
24
+ import numpy as np
25
+ import matplotlib.pyplot as plt
26
+
27
+ from u2net import U2NET
28
+
29
+ from inference import TestData, RescaleT, ToTensorLab, normPRED
30
+
31
+
32
+ def load_model(model_type):
33
+
34
+ model = U2NET(3,1)
35
+ if model_type == "U2Net":
36
+ model_path = "weights/u2net.pth"
37
+ model.load_state_dict(torch.load(model_path))
38
+
39
+ else:
40
+ model_path = "weights/quant_model_u2net.pth"
41
+ model = torch.jit.load(model_path)
42
+
43
+ return model.eval()
44
+
45
+ def normPred(d):
46
+ ma = torch.max(d)
47
+ mi = torch.min(d)
48
+ dn = (d-mi)/(ma-mi)
49
+
50
+ return dn
51
+
52
+ def segment(model_type, img):
53
+
54
+ #img = cv2.imread(img)
55
+ src = img
56
+
57
+ #img = cv2.resize(img, dsize = (512, 512))
58
+ #img = np.moveaxis(img, -1, 0)
59
+
60
+ #img = np.array(img) / 255.0
61
+ #img = np.expand_dims(img, axis = 0)
62
+ #img = img.astype(np.float32)
63
+ model = load_model(model_type)
64
+ #output = model.predict(img).round()
65
+ # with torch.no_grad():
66
+ # d1,d2,d3,d4,d5,d6,d7 = model(torch.from_numpy(img))
67
+ # output = d1[:,0,:,:]
68
+ # output = normPred(output)
69
+
70
+ test_dataset = TestData(img_name_list = [img], lbl_name_list = [],
71
+ transform = transforms.Compose([RescaleT(512), ToTensorLab(flag = 0)]))
72
+
73
+ test_dataloader = DataLoader(test_dataset, batch_size = 1, shuffle = False, num_workers = 1)
74
+
75
+ for i_test, data_test in enumerate(test_dataloader):
76
+
77
+ #print("Inferencing : ", img_name_list[i_test].split(os.sep)[-1])
78
+
79
+ inputs_test = data_test['image']
80
+
81
+
82
+ inputs_test = inputs_test.type(torch.FloatTensor)
83
+
84
+ inputs_test = Variable(inputs_test)
85
+
86
+ d1, d2, d3, d4, d5, d6, d7 = model(inputs_test)
87
+
88
+ pred = d1[:,0,:,:]
89
+ pred = normPRED(pred)
90
+
91
+ #output = output[...,0]#.squeeze() #* 255.0
92
+ # segmented = superimpose
93
+ #output = output.squeeze(axis = 0)
94
+ #output = #torch.argmax(output, dim = 1)
95
+ #print("output -> ", output.shape)
96
+
97
+ #print(output)
98
+ #output = cv2.cvtColor(output, cv2.COLOR_GRAY2RGB)
99
+
100
+ #mask2 = np.stack((output,)*3, axis=-1)
101
+ #segmented = superimpose(src / 255 , mask2)
102
+ from plantcv import plantcv as pcv
103
+
104
+ pcv.params.debug='plot'
105
+
106
+ #segmented = pcv.visualize.overlay_two_imgs(img1=src, img2=output, alpha=0.5)
107
+
108
+ #output = #np.moveaxis(output, -1, 0)
109
+ #print(pred.shape)
110
+ pred = pred.detach().numpy()
111
+ #print(pred)
112
+ pred = np.transpose(pred, (1,2,0))
113
+ pred = np.squeeze(pred, axis = 2)
114
+
115
+ pred = Image.fromarray((pred*255).astype(np.uint8))
116
+
117
+
118
+ #segmented = pcv.visualize.overlay_two_imgs(img1=src, img2=np.expand_dims(pred, axis =2), alpha=0.5)
119
+ #from PIL import ImageChops
120
+ #im2 = Image.fromarray(src.astype(np.uint8))
121
+ #segmented = ImageChops.logical_xor(pred, im2)
122
+ #print(pred.shape)
123
+ #return pred
124
+ segmented = np.dstack((src, pred))
125
+ return segmented
126
+ #return output#segmented
127
+
128
+ iface = gr.Interface(fn=segment, inputs=[gr.inputs.Dropdown(["Lite U2Net", "U2Net"]), gr.Image(shape = (512, 512))], outputs= gr.Image(shape = (512,512)))
129
+ iface.launch()
inference.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from skimage import io, transform
4
+ import torch
5
+ import torchvision
6
+ from torch.autograd import Variable
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.utils.data import Dataset, DataLoader
10
+ from torchvision import transforms
11
+
12
+ import numpy as np
13
+ from PIL import Image
14
+ import glob
15
+
16
+
17
+ def normPRED(d):
18
+ ma = torch.max(d)
19
+ mi = torch.min(d)
20
+
21
+ dn = (d - mi)/(ma - mi)
22
+
23
+ return dn
24
+
25
+ def save_output(image_name, pred, d_dir):
26
+
27
+ predict = pred
28
+ predict = predict.squeeze()
29
+ predict_np = predict.cpu().data.numpy()
30
+
31
+ im = Image.fromarray(predict_np * 255).convert('RGB')
32
+ img_name = image_name.split(os.sep)[-1]
33
+
34
+ image = io.imread(image_name)
35
+ imo = im.resize((image.shape[1], image.shape[0]), resample = Image.BILINEAR)
36
+
37
+ pb_np = np.array(imo)
38
+
39
+ aaa = img_name.split(".")
40
+ bbb = aaa[0:-1]
41
+ imidx = bbb[0]
42
+
43
+ for i in range(1, len(bbb)):
44
+ imidx = imidx + "." + bbb[i]
45
+
46
+ imo.save(d_dir + "/" + imidx + '.png')
47
+
48
+
49
+
50
+ #image_dir = "./test_data/"
51
+ #prediction_dir = './outputs_pred/'
52
+
53
+ #model_dir = 'quant_model_u2net.pth'#'u2net.pth'
54
+
55
+ #img_name_list = glob.glob(image_dir + "/*")
56
+
57
+ #print("Number of images : ", len(img_name_list))
58
+
59
+
60
+ ### Make test dataset
61
+
62
+ class RescaleT(object):
63
+
64
+ def __init__(self,output_size):
65
+ assert isinstance(output_size,(int,tuple))
66
+ self.output_size = output_size
67
+
68
+ def __call__(self,sample):
69
+ imidx, image, label = sample['imidx'], sample['image'],sample['label']
70
+
71
+ h, w = image.shape[:2]
72
+
73
+ if isinstance(self.output_size,int):
74
+ if h > w:
75
+ new_h, new_w = self.output_size*h/w,self.output_size
76
+ else:
77
+ new_h, new_w = self.output_size,self.output_size*w/h
78
+ else:
79
+ new_h, new_w = self.output_size
80
+
81
+ new_h, new_w = int(new_h), int(new_w)
82
+
83
+ # #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
84
+ # img = transform.resize(image,(new_h,new_w),mode='constant')
85
+ # lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)
86
+
87
+ img = transform.resize(image,(self.output_size,self.output_size),mode='constant')
88
+ lbl = transform.resize(label,(self.output_size,self.output_size),mode='constant', order=0, preserve_range=True)
89
+
90
+ return {'imidx':imidx, 'image':img,'label':lbl}
91
+
92
+ class ToTensorLab(object):
93
+ """Convert ndarrays in sample to Tensors."""
94
+ def __init__(self,flag=0):
95
+ self.flag = flag
96
+
97
+ def __call__(self, sample):
98
+
99
+ imidx, image, label =sample['imidx'], sample['image'], sample['label']
100
+
101
+ tmpLbl = np.zeros(label.shape)
102
+
103
+ if(np.max(label)<1e-6):
104
+ label = label
105
+ else:
106
+ label = label/np.max(label)
107
+
108
+ # change the color space
109
+ if self.flag == 2: # with rgb and Lab colors
110
+ tmpImg = np.zeros((image.shape[0],image.shape[1],6))
111
+ tmpImgt = np.zeros((image.shape[0],image.shape[1],3))
112
+ if image.shape[2]==1:
113
+ tmpImgt[:,:,0] = image[:,:,0]
114
+ tmpImgt[:,:,1] = image[:,:,0]
115
+ tmpImgt[:,:,2] = image[:,:,0]
116
+ else:
117
+ tmpImgt = image
118
+ tmpImgtl = color.rgb2lab(tmpImgt)
119
+
120
+ # nomalize image to range [0,1]
121
+ tmpImg[:,:,0] = (tmpImgt[:,:,0]-np.min(tmpImgt[:,:,0]))/(np.max(tmpImgt[:,:,0])-np.min(tmpImgt[:,:,0]))
122
+ tmpImg[:,:,1] = (tmpImgt[:,:,1]-np.min(tmpImgt[:,:,1]))/(np.max(tmpImgt[:,:,1])-np.min(tmpImgt[:,:,1]))
123
+ tmpImg[:,:,2] = (tmpImgt[:,:,2]-np.min(tmpImgt[:,:,2]))/(np.max(tmpImgt[:,:,2])-np.min(tmpImgt[:,:,2]))
124
+ tmpImg[:,:,3] = (tmpImgtl[:,:,0]-np.min(tmpImgtl[:,:,0]))/(np.max(tmpImgtl[:,:,0])-np.min(tmpImgtl[:,:,0]))
125
+ tmpImg[:,:,4] = (tmpImgtl[:,:,1]-np.min(tmpImgtl[:,:,1]))/(np.max(tmpImgtl[:,:,1])-np.min(tmpImgtl[:,:,1]))
126
+ tmpImg[:,:,5] = (tmpImgtl[:,:,2]-np.min(tmpImgtl[:,:,2]))/(np.max(tmpImgtl[:,:,2])-np.min(tmpImgtl[:,:,2]))
127
+
128
+ # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
129
+
130
+ tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0])
131
+ tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1])
132
+ tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2])
133
+ tmpImg[:,:,3] = (tmpImg[:,:,3]-np.mean(tmpImg[:,:,3]))/np.std(tmpImg[:,:,3])
134
+ tmpImg[:,:,4] = (tmpImg[:,:,4]-np.mean(tmpImg[:,:,4]))/np.std(tmpImg[:,:,4])
135
+ tmpImg[:,:,5] = (tmpImg[:,:,5]-np.mean(tmpImg[:,:,5]))/np.std(tmpImg[:,:,5])
136
+
137
+ elif self.flag == 1: #with Lab color
138
+ tmpImg = np.zeros((image.shape[0],image.shape[1],3))
139
+
140
+ if image.shape[2]==1:
141
+ tmpImg[:,:,0] = image[:,:,0]
142
+ tmpImg[:,:,1] = image[:,:,0]
143
+ tmpImg[:,:,2] = image[:,:,0]
144
+ else:
145
+ tmpImg = image
146
+
147
+ tmpImg = color.rgb2lab(tmpImg)
148
+
149
+ # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
150
+
151
+ tmpImg[:,:,0] = (tmpImg[:,:,0]-np.min(tmpImg[:,:,0]))/(np.max(tmpImg[:,:,0])-np.min(tmpImg[:,:,0]))
152
+ tmpImg[:,:,1] = (tmpImg[:,:,1]-np.min(tmpImg[:,:,1]))/(np.max(tmpImg[:,:,1])-np.min(tmpImg[:,:,1]))
153
+ tmpImg[:,:,2] = (tmpImg[:,:,2]-np.min(tmpImg[:,:,2]))/(np.max(tmpImg[:,:,2])-np.min(tmpImg[:,:,2]))
154
+
155
+ tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0])
156
+ tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1])
157
+ tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2])
158
+
159
+ else: # with rgb color
160
+ tmpImg = np.zeros((image.shape[0],image.shape[1],3))
161
+ image = image/np.max(image)
162
+ if image.shape[2]==1:
163
+ tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
164
+ tmpImg[:,:,1] = (image[:,:,0]-0.485)/0.229
165
+ tmpImg[:,:,2] = (image[:,:,0]-0.485)/0.229
166
+ else:
167
+ tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
168
+ tmpImg[:,:,1] = (image[:,:,1]-0.456)/0.224
169
+ tmpImg[:,:,2] = (image[:,:,2]-0.406)/0.225
170
+
171
+ tmpLbl[:,:,0] = label[:,:,0]
172
+
173
+
174
+ tmpImg = tmpImg.transpose((2, 0, 1))
175
+ tmpLbl = label.transpose((2, 0, 1))
176
+
177
+ return {'imidx':torch.from_numpy(imidx), 'image': torch.from_numpy(tmpImg), 'label': torch.from_numpy(tmpLbl)}
178
+
179
+ class TestData(Dataset):
180
+
181
+ def __init__(self, img_name_list, lbl_name_list, transform = None):
182
+
183
+ self.img_list = img_name_list
184
+ self.label_name_list = lbl_name_list
185
+ self.transform = transform
186
+
187
+ def __len__(self):
188
+
189
+ return len(self.img_list)
190
+
191
+ def __getitem__(self, idx):
192
+
193
+ #image = io.imread(self.img_list[idx])
194
+ image = self.img_list[idx]
195
+ imname = self.img_list[idx]
196
+ imidx = np.array([idx])
197
+
198
+ if (0 == len(self.label_name_list)):
199
+ label_3 = np.zeros(image.shape)
200
+
201
+ else:
202
+ label_3 = io.imread(self.label_name_list[idx])
203
+
204
+ label = np.zeros(label_3.shape[0:2])
205
+
206
+ if(3==len(label_3.shape)):
207
+ label = label_3[:,:,0]
208
+ elif(2==len(label_3.shape)):
209
+ label = label_3
210
+
211
+ if(3==len(image.shape) and 2==len(label.shape)):
212
+ label = label[:,:,np.newaxis]
213
+ elif(2==len(image.shape) and 2==len(label.shape)):
214
+ image = image[:,:,np.newaxis]
215
+ label = label[:,:,np.newaxis]
216
+
217
+ sample = {'imidx':imidx, 'image':image, 'label':label}
218
+
219
+ if self.transform:
220
+ sample = self.transform(sample)
221
+
222
+ return sample
223
+
224
+
225
+
226
+ #test_dataset = TestData(img_name_list = img_name_list, lbl_name_list = [],
227
+ # transform = transforms.Compose([RescaleT(512), ToTensorLab(flag = 0)]))
228
+
229
+
230
+ ### Make test dataloader
231
+
232
+
233
+ #test_dataloader = DataLoader(test_dataset, batch_size = 1, shuffle = False, num_workers = 1)
234
+
235
+
236
+ #net = U2Net(3,1)
237
+
238
+ #net = torch.jit.load('quant_model_u2net.pth')
239
+ #net.load_state_dict(torch.load(model_dir))
240
+
241
+
242
+ """net.eval()
243
+
244
+
245
+ for i_test, data_test in enumerate(test_dataloader):
246
+
247
+ print("Inferencing : ", img_name_list[i_test].split(os.sep)[-1])
248
+
249
+ inputs_test = data_test['image']
250
+
251
+
252
+ inputs_test = inputs_test.type(torch.FloatTensor)
253
+
254
+ inputs_test = Variable(inputs_test)
255
+
256
+ d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)
257
+
258
+ pred = 1.0 - d1[:,0,:,:]
259
+ pred = normPRED(pred)
260
+
261
+ #save_output(img_name_list[i_test], pred, prediction_dir)
262
+
263
+ #del d1, d2, d3, d4, d5, d6, d7"""
requirements.txt ADDED
File without changes
test_images/0002-01.jpg ADDED
test_images/0003.jpg ADDED
test_images/bike.jpg ADDED
test_images/boat.jpg ADDED
test_images/girl.png ADDED
test_images/hockey.png ADDED
test_images/horse.jpg ADDED
test_images/im_01.png ADDED
test_images/im_14.png ADDED
test_images/im_21.png ADDED
test_images/im_27.png ADDED
test_images/lamp2_meitu_1.jpg ADDED
test_images/long.jpg ADDED
test_images/rifle1.jpg ADDED
test_images/rifle2.jpeg ADDED
test_images/sailboat3.jpg ADDED
test_images/vangogh.jpeg ADDED
test_images/whisk.png ADDED
u2net.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class REBNCONV(nn.Module):
6
+ def __init__(self,in_ch=3,out_ch=3,dirate=1):
7
+ super(REBNCONV,self).__init__()
8
+
9
+ self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate)
10
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
11
+ self.relu_s1 = nn.ReLU(inplace=True)
12
+
13
+ def forward(self,x):
14
+
15
+ hx = x
16
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
17
+
18
+ return xout
19
+
20
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
21
+ def _upsample_like(src,tar):
22
+
23
+ src = F.upsample(src,size=tar.shape[2:],mode='bilinear')
24
+
25
+ return src
26
+
27
+
28
+ ### RSU-7 ###
29
+ class RSU7(nn.Module):#UNet07DRES(nn.Module):
30
+
31
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
32
+ super(RSU7,self).__init__()
33
+
34
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
35
+
36
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
37
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
38
+
39
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
40
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
41
+
42
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
43
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
44
+
45
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
46
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
47
+
48
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
49
+ self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
50
+
51
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
52
+
53
+ self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
54
+
55
+ self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
56
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
57
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
58
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
59
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
60
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
61
+
62
+ def forward(self,x):
63
+
64
+ hx = x
65
+ hxin = self.rebnconvin(hx)
66
+
67
+ hx1 = self.rebnconv1(hxin)
68
+ hx = self.pool1(hx1)
69
+
70
+ hx2 = self.rebnconv2(hx)
71
+ hx = self.pool2(hx2)
72
+
73
+ hx3 = self.rebnconv3(hx)
74
+ hx = self.pool3(hx3)
75
+
76
+ hx4 = self.rebnconv4(hx)
77
+ hx = self.pool4(hx4)
78
+
79
+ hx5 = self.rebnconv5(hx)
80
+ hx = self.pool5(hx5)
81
+
82
+ hx6 = self.rebnconv6(hx)
83
+
84
+ hx7 = self.rebnconv7(hx6)
85
+
86
+ hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
87
+ hx6dup = _upsample_like(hx6d,hx5)
88
+
89
+ hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
90
+ hx5dup = _upsample_like(hx5d,hx4)
91
+
92
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
93
+ hx4dup = _upsample_like(hx4d,hx3)
94
+
95
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
96
+ hx3dup = _upsample_like(hx3d,hx2)
97
+
98
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
99
+ hx2dup = _upsample_like(hx2d,hx1)
100
+
101
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
102
+
103
+ return hx1d + hxin
104
+
105
+ ### RSU-6 ###
106
+ class RSU6(nn.Module):#UNet06DRES(nn.Module):
107
+
108
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
109
+ super(RSU6,self).__init__()
110
+
111
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
112
+
113
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
114
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
115
+
116
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
117
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
118
+
119
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
120
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
121
+
122
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
123
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
124
+
125
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
126
+
127
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)
128
+
129
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
130
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
131
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
132
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
133
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
134
+
135
+ def forward(self,x):
136
+
137
+ hx = x
138
+
139
+ hxin = self.rebnconvin(hx)
140
+
141
+ hx1 = self.rebnconv1(hxin)
142
+ hx = self.pool1(hx1)
143
+
144
+ hx2 = self.rebnconv2(hx)
145
+ hx = self.pool2(hx2)
146
+
147
+ hx3 = self.rebnconv3(hx)
148
+ hx = self.pool3(hx3)
149
+
150
+ hx4 = self.rebnconv4(hx)
151
+ hx = self.pool4(hx4)
152
+
153
+ hx5 = self.rebnconv5(hx)
154
+
155
+ hx6 = self.rebnconv6(hx5)
156
+
157
+
158
+ hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
159
+ hx5dup = _upsample_like(hx5d,hx4)
160
+
161
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
162
+ hx4dup = _upsample_like(hx4d,hx3)
163
+
164
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
165
+ hx3dup = _upsample_like(hx3d,hx2)
166
+
167
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
168
+ hx2dup = _upsample_like(hx2d,hx1)
169
+
170
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
171
+
172
+ return hx1d + hxin
173
+
174
+ ### RSU-5 ###
175
+ class RSU5(nn.Module):#UNet05DRES(nn.Module):
176
+
177
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
178
+ super(RSU5,self).__init__()
179
+
180
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
181
+
182
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
183
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
184
+
185
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
186
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
187
+
188
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
189
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
190
+
191
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
192
+
193
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)
194
+
195
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
196
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
197
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
198
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
199
+
200
+ def forward(self,x):
201
+
202
+ hx = x
203
+
204
+ hxin = self.rebnconvin(hx)
205
+
206
+ hx1 = self.rebnconv1(hxin)
207
+ hx = self.pool1(hx1)
208
+
209
+ hx2 = self.rebnconv2(hx)
210
+ hx = self.pool2(hx2)
211
+
212
+ hx3 = self.rebnconv3(hx)
213
+ hx = self.pool3(hx3)
214
+
215
+ hx4 = self.rebnconv4(hx)
216
+
217
+ hx5 = self.rebnconv5(hx4)
218
+
219
+ hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
220
+ hx4dup = _upsample_like(hx4d,hx3)
221
+
222
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
223
+ hx3dup = _upsample_like(hx3d,hx2)
224
+
225
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
226
+ hx2dup = _upsample_like(hx2d,hx1)
227
+
228
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
229
+
230
+ return hx1d + hxin
231
+
232
+ ### RSU-4 ###
233
+ class RSU4(nn.Module):#UNet04DRES(nn.Module):
234
+
235
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
236
+ super(RSU4,self).__init__()
237
+
238
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
239
+
240
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
241
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
242
+
243
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
244
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
245
+
246
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
247
+
248
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)
249
+
250
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
251
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
252
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
253
+
254
+ def forward(self,x):
255
+
256
+ hx = x
257
+
258
+ hxin = self.rebnconvin(hx)
259
+
260
+ hx1 = self.rebnconv1(hxin)
261
+ hx = self.pool1(hx1)
262
+
263
+ hx2 = self.rebnconv2(hx)
264
+ hx = self.pool2(hx2)
265
+
266
+ hx3 = self.rebnconv3(hx)
267
+
268
+ hx4 = self.rebnconv4(hx3)
269
+
270
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
271
+ hx3dup = _upsample_like(hx3d,hx2)
272
+
273
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
274
+ hx2dup = _upsample_like(hx2d,hx1)
275
+
276
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
277
+
278
+ return hx1d + hxin
279
+
280
+ ### RSU-4F ###
281
+ class RSU4F(nn.Module):#UNet04FRES(nn.Module):
282
+
283
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
284
+ super(RSU4F,self).__init__()
285
+
286
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
287
+
288
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
289
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
290
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
291
+
292
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)
293
+
294
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
295
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
296
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
297
+
298
+ def forward(self,x):
299
+
300
+ hx = x
301
+
302
+ hxin = self.rebnconvin(hx)
303
+
304
+ hx1 = self.rebnconv1(hxin)
305
+ hx2 = self.rebnconv2(hx1)
306
+ hx3 = self.rebnconv3(hx2)
307
+
308
+ hx4 = self.rebnconv4(hx3)
309
+
310
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
311
+ hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
312
+ hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
313
+
314
+ return hx1d + hxin
315
+
316
+
317
+ ##### U^2-Net ####
318
+ class U2NET(nn.Module):
319
+
320
+ def __init__(self,in_ch=3,out_ch=1):
321
+ super(U2NET,self).__init__()
322
+
323
+ self.stage1 = RSU7(in_ch,32,64)
324
+ self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
325
+
326
+ self.stage2 = RSU6(64,32,128)
327
+ self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
328
+
329
+ self.stage3 = RSU5(128,64,256)
330
+ self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
331
+
332
+ self.stage4 = RSU4(256,128,512)
333
+ self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
334
+
335
+ self.stage5 = RSU4F(512,256,512)
336
+ self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
337
+
338
+ self.stage6 = RSU4F(512,256,512)
339
+
340
+ # decoder
341
+ self.stage5d = RSU4F(1024,256,512)
342
+ self.stage4d = RSU4(1024,128,256)
343
+ self.stage3d = RSU5(512,64,128)
344
+ self.stage2d = RSU6(256,32,64)
345
+ self.stage1d = RSU7(128,16,64)
346
+
347
+ self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
348
+ self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
349
+ self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
350
+ self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
351
+ self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
352
+ self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
353
+
354
+ self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
355
+
356
+ def forward(self,x):
357
+
358
+ hx = x
359
+
360
+ #stage 1
361
+ hx1 = self.stage1(hx)
362
+ hx = self.pool12(hx1)
363
+
364
+ #stage 2
365
+ hx2 = self.stage2(hx)
366
+ hx = self.pool23(hx2)
367
+
368
+ #stage 3
369
+ hx3 = self.stage3(hx)
370
+ hx = self.pool34(hx3)
371
+
372
+ #stage 4
373
+ hx4 = self.stage4(hx)
374
+ hx = self.pool45(hx4)
375
+
376
+ #stage 5
377
+ hx5 = self.stage5(hx)
378
+ hx = self.pool56(hx5)
379
+
380
+ #stage 6
381
+ hx6 = self.stage6(hx)
382
+ hx6up = _upsample_like(hx6,hx5)
383
+
384
+ #-------------------- decoder --------------------
385
+ hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
386
+ hx5dup = _upsample_like(hx5d,hx4)
387
+
388
+ hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
389
+ hx4dup = _upsample_like(hx4d,hx3)
390
+
391
+ hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
392
+ hx3dup = _upsample_like(hx3d,hx2)
393
+
394
+ hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
395
+ hx2dup = _upsample_like(hx2d,hx1)
396
+
397
+ hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
398
+
399
+
400
+ #side output
401
+ d1 = self.side1(hx1d)
402
+
403
+ d2 = self.side2(hx2d)
404
+ d2 = _upsample_like(d2,d1)
405
+
406
+ d3 = self.side3(hx3d)
407
+ d3 = _upsample_like(d3,d1)
408
+
409
+ d4 = self.side4(hx4d)
410
+ d4 = _upsample_like(d4,d1)
411
+
412
+ d5 = self.side5(hx5d)
413
+ d5 = _upsample_like(d5,d1)
414
+
415
+ d6 = self.side6(hx6)
416
+ d6 = _upsample_like(d6,d1)
417
+
418
+ d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
419
+
420
+ return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)
421
+
422
+ ### U^2-Net small ###
423
+ class U2NETP(nn.Module):
424
+
425
+ def __init__(self,in_ch=3,out_ch=1):
426
+ super(U2NETP,self).__init__()
427
+
428
+ self.stage1 = RSU7(in_ch,16,64)
429
+ self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
430
+
431
+ self.stage2 = RSU6(64,16,64)
432
+ self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
433
+
434
+ self.stage3 = RSU5(64,16,64)
435
+ self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
436
+
437
+ self.stage4 = RSU4(64,16,64)
438
+ self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
439
+
440
+ self.stage5 = RSU4F(64,16,64)
441
+ self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
442
+
443
+ self.stage6 = RSU4F(64,16,64)
444
+
445
+ # decoder
446
+ self.stage5d = RSU4F(128,16,64)
447
+ self.stage4d = RSU4(128,16,64)
448
+ self.stage3d = RSU5(128,16,64)
449
+ self.stage2d = RSU6(128,16,64)
450
+ self.stage1d = RSU7(128,16,64)
451
+
452
+ self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
453
+ self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
454
+ self.side3 = nn.Conv2d(64,out_ch,3,padding=1)
455
+ self.side4 = nn.Conv2d(64,out_ch,3,padding=1)
456
+ self.side5 = nn.Conv2d(64,out_ch,3,padding=1)
457
+ self.side6 = nn.Conv2d(64,out_ch,3,padding=1)
458
+
459
+ self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
460
+
461
+ def forward(self,x):
462
+
463
+ hx = x
464
+
465
+ #stage 1
466
+ hx1 = self.stage1(hx)
467
+ hx = self.pool12(hx1)
468
+
469
+ #stage 2
470
+ hx2 = self.stage2(hx)
471
+ hx = self.pool23(hx2)
472
+
473
+ #stage 3
474
+ hx3 = self.stage3(hx)
475
+ hx = self.pool34(hx3)
476
+
477
+ #stage 4
478
+ hx4 = self.stage4(hx)
479
+ hx = self.pool45(hx4)
480
+
481
+ #stage 5
482
+ hx5 = self.stage5(hx)
483
+ hx = self.pool56(hx5)
484
+
485
+ #stage 6
486
+ hx6 = self.stage6(hx)
487
+ hx6up = _upsample_like(hx6,hx5)
488
+
489
+ #decoder
490
+ hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
491
+ hx5dup = _upsample_like(hx5d,hx4)
492
+
493
+ hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
494
+ hx4dup = _upsample_like(hx4d,hx3)
495
+
496
+ hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
497
+ hx3dup = _upsample_like(hx3d,hx2)
498
+
499
+ hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
500
+ hx2dup = _upsample_like(hx2d,hx1)
501
+
502
+ hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
503
+
504
+
505
+ #side output
506
+ d1 = self.side1(hx1d)
507
+
508
+ d2 = self.side2(hx2d)
509
+ d2 = _upsample_like(d2,d1)
510
+
511
+ d3 = self.side3(hx3d)
512
+ d3 = _upsample_like(d3,d1)
513
+
514
+ d4 = self.side4(hx4d)
515
+ d4 = _upsample_like(d4,d1)
516
+
517
+ d5 = self.side5(hx5d)
518
+ d5 = _upsample_like(d5,d1)
519
+
520
+ d6 = self.side6(hx6)
521
+ d6 = _upsample_like(d6,d1)
522
+
523
+ d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
524
+
525
+ return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)