u2net_portrait / U-2-Net /u2net_portrait_composite.py
hylee's picture
init
de51c6d
import os
from skimage import io, transform
from skimage.filters import gaussian
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms#, utils
# import torch.optim as optim
import numpy as np
from PIL import Image
import glob
from data_loader import RescaleT
from data_loader import ToTensor
from data_loader import ToTensorLab
from data_loader import SalObjDataset
from model import U2NET # full size version 173.6 MB
from model import U2NETP # small version u2net 4.7 MB
import argparse
# normalize the predicted SOD probability map
def normPRED(d):
ma = torch.max(d)
mi = torch.min(d)
dn = (d-mi)/(ma-mi)
return dn
def save_output(image_name,pred,d_dir,sigma=2,alpha=0.5):
predict = pred
predict = predict.squeeze()
predict_np = predict.cpu().data.numpy()
image = io.imread(image_name)
pd = transform.resize(predict_np,image.shape[0:2],order=2)
pd = pd/(np.amax(pd)+1e-8)*255
pd = pd[:,:,np.newaxis]
print(image.shape)
print(pd.shape)
## fuse the orignal portrait image and the portraits into one composite image
## 1. use gaussian filter to blur the orginal image
sigma=sigma
image = gaussian(image, sigma=sigma, preserve_range=True)
## 2. fuse these orignal image and the portrait with certain weight: alpha
alpha = alpha
im_comp = image*alpha+pd*(1-alpha)
print(im_comp.shape)
img_name = image_name.split(os.sep)[-1]
aaa = img_name.split(".")
bbb = aaa[0:-1]
imidx = bbb[0]
for i in range(1,len(bbb)):
imidx = imidx + "." + bbb[i]
io.imsave(d_dir+'/'+imidx+'_sigma_' + str(sigma) + '_alpha_' + str(alpha) + '_composite.png',im_comp)
def main():
parser = argparse.ArgumentParser(description="image and portrait composite")
parser.add_argument('-s',action='store',dest='sigma')
parser.add_argument('-a',action='store',dest='alpha')
args = parser.parse_args()
print(args.sigma)
print(args.alpha)
print("--------------------")
# --------- 1. get image path and name ---------
model_name='u2net_portrait'#u2netp
image_dir = './test_data/test_portrait_images/your_portrait_im'
prediction_dir = './test_data/test_portrait_images/your_portrait_results'
if(not os.path.exists(prediction_dir)):
os.mkdir(prediction_dir)
model_dir = './saved_models/u2net_portrait/u2net_portrait.pth'
img_name_list = glob.glob(image_dir+'/*')
print("Number of images: ", len(img_name_list))
# --------- 2. dataloader ---------
#1. dataloader
test_salobj_dataset = SalObjDataset(img_name_list = img_name_list,
lbl_name_list = [],
transform=transforms.Compose([RescaleT(512),
ToTensorLab(flag=0)])
)
test_salobj_dataloader = DataLoader(test_salobj_dataset,
batch_size=1,
shuffle=False,
num_workers=1)
# --------- 3. model define ---------
print("...load U2NET---173.6 MB")
net = U2NET(3,1)
net.load_state_dict(torch.load(model_dir))
if torch.cuda.is_available():
net.cuda()
net.eval()
# --------- 4. inference for each image ---------
for i_test, data_test in enumerate(test_salobj_dataloader):
print("inferencing:",img_name_list[i_test].split(os.sep)[-1])
inputs_test = data_test['image']
inputs_test = inputs_test.type(torch.FloatTensor)
if torch.cuda.is_available():
inputs_test = Variable(inputs_test.cuda())
else:
inputs_test = Variable(inputs_test)
d1,d2,d3,d4,d5,d6,d7= net(inputs_test)
# normalization
pred = 1.0 - d1[:,0,:,:]
pred = normPRED(pred)
# save results to test_results folder
save_output(img_name_list[i_test],pred,prediction_dir,sigma=float(args.sigma),alpha=float(args.alpha))
del d1,d2,d3,d4,d5,d6,d7
if __name__ == "__main__":
main()