hylee commited on
Commit
18f931d
1 Parent(s): 443d6d3
Files changed (1) hide show
  1. app.py +120 -2
app.py CHANGED
@@ -2,10 +2,128 @@ import os
2
 
3
  import gradio as gr
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  def process(im):
7
-
8
- return im
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  title = "U-2-Net"
11
  description = "Gradio demo for U-2-Net, https://github.com/xuebinqin/U-2-Net"
 
2
 
3
  import gradio as gr
4
 
5
+ import sys
6
+ sys.path.insert(0, 'U-2-Net')
7
+
8
+ from skimage import io, transform
9
+ import torch
10
+ import torchvision
11
+ from torch.autograd import Variable
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from torch.utils.data import Dataset, DataLoader
15
+ from torchvision import transforms#, utils
16
+ # import torch.optim as optim
17
+
18
+ import numpy as np
19
+ from PIL import Image
20
+ import glob
21
+
22
+ from data_loader import RescaleT
23
+ from data_loader import ToTensor
24
+ from data_loader import ToTensorLab
25
+ from data_loader import SalObjDataset
26
+
27
+ from model import U2NET # full size version 173.6 MB
28
+ from model import U2NETP # small version u2net 4.7 MB
29
+
30
+
31
+ # normalize the predicted SOD probability map
32
+ def normPRED(d):
33
+ ma = torch.max(d)
34
+ mi = torch.min(d)
35
+
36
+ dn = (d-mi)/(ma-mi)
37
+
38
+ return dn
39
+ def save_output(image_name,pred,d_dir):
40
+ predict = pred
41
+ predict = predict.squeeze()
42
+ predict_np = predict.cpu().data.numpy()
43
+
44
+ im = Image.fromarray(predict_np*255).convert('RGB')
45
+ img_name = image_name.split(os.sep)[-1]
46
+ image = io.imread(image_name)
47
+ imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR)
48
+
49
+ pb_np = np.array(imo)
50
+
51
+ aaa = img_name.split(".")
52
+ bbb = aaa[0:-1]
53
+ imidx = bbb[0]
54
+ for i in range(1,len(bbb)):
55
+ imidx = imidx + "." + bbb[i]
56
+
57
+ imo.save(d_dir+'/'+imidx+'.png')
58
+ return d_dir+'/'+imidx+'.png'
59
+
60
+
61
+ # --------- 1. get image path and name ---------
62
+ model_name='u2net_portrait'#u2netp
63
+
64
+
65
+ image_dir = 'portrait_im'
66
+ prediction_dir = 'portrait_results'
67
+ if(not os.path.exists(prediction_dir)):
68
+ os.mkdir(prediction_dir)
69
+
70
+ model_dir = os.path.jos.path.join(os.path.abspath(os.path.dirname(__file__)), 'U-2-Net/saved_models/u2net_portrait/u2net_portrait.pth')
71
+
72
+
73
+ # --------- 3. model define ---------
74
+
75
+ print("...load U2NET---173.6 MB")
76
+ net = U2NET(3,1)
77
+
78
+ net.load_state_dict(torch.load(model_dir))
79
+ # if torch.cuda.is_available():
80
+ # net.cuda()
81
+ net.eval()
82
+
83
 
84
  def process(im):
85
+ img_name_list = glob.glob(im.name)
86
+ print("Number of images: ", len(img_name_list))
87
+ # --------- 2. dataloader ---------
88
+ # 1. dataloader
89
+ test_salobj_dataset = SalObjDataset(img_name_list=img_name_list,
90
+ lbl_name_list=[],
91
+ transform=transforms.Compose([RescaleT(512),
92
+ ToTensorLab(flag=0)])
93
+ )
94
+ test_salobj_dataloader = DataLoader(test_salobj_dataset,
95
+ batch_size=1,
96
+ shuffle=False,
97
+ num_workers=1)
98
+
99
+ results = []
100
+ # --------- 4. inference for each image ---------
101
+ for i_test, data_test in enumerate(test_salobj_dataloader):
102
+
103
+ print("inferencing:", img_name_list[i_test].split(os.sep)[-1])
104
+
105
+ inputs_test = data_test['image']
106
+ inputs_test = inputs_test.type(torch.FloatTensor)
107
+
108
+ # if torch.cuda.is_available():
109
+ # inputs_test = Variable(inputs_test.cuda())
110
+ # else:
111
+ inputs_test = Variable(inputs_test)
112
+
113
+ d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)
114
+
115
+ # normalization
116
+ pred = 1.0 - d1[:, 0, :, :]
117
+ pred = normPRED(pred)
118
+
119
+ # save results to test_results folder
120
+ results.append(save_output(img_name_list[i_test], pred, prediction_dir))
121
+
122
+ del d1, d2, d3, d4, d5, d6, d7
123
+
124
+ print(results)
125
+
126
+ return Image.open(results[0])
127
 
128
  title = "U-2-Net"
129
  description = "Gradio demo for U-2-Net, https://github.com/xuebinqin/U-2-Net"