fifa-tryon-demo / u2net_run.py
hasibzunair's picture
added files
4a285f6
import os
from skimage import io, transform
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms#, utils
from u2net_test import normPRED
# import torch.optim as optim
import numpy as np
from PIL import Image
import glob
import warnings
from data_loader import RescaleT
from data_loader import ToTensor
from data_loader import ToTensorLab
from data_loader import SalObjDataset
warnings.filterwarnings("ignore")
def save_images(image_name,pred,d_dir):
predict = pred
predict = predict.squeeze()
predict_np = predict.cpu().data.numpy()
im = Image.fromarray(predict_np*255).convert('RGB')
img_name = image_name.split(os.sep)[-1]
image = io.imread(image_name)
imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BICUBIC)
pb_np = np.array(imo)
aaa = img_name.split(".")
bbb = aaa[0:-1]
imidx = bbb[0]
for i in range(1,len(bbb)):
imidx = imidx + "." + bbb[i]
print('Saving output at {}'.format(os.path.join(d_dir, imidx+'.png')))
imo.save(os.path.join(d_dir, imidx+'.png'))
def infer(
net,
image_dir = os.path.join(os.getcwd(), 'test_data', 'test_images'),
prediction_dir = os.path.join(os.getcwd(), 'test_data', 'u2net' + '_results')
):
img_name_list = glob.glob(image_dir + os.sep + '*')
prediction_dir = prediction_dir + os.sep
# --------- 2. dataloader ---------
#1. dataloader
test_salobj_dataset = SalObjDataset(img_name_list = img_name_list,
lbl_name_list = [],
transform=transforms.Compose([RescaleT(320),
ToTensorLab(flag=0)])
)
test_salobj_dataloader = DataLoader(test_salobj_dataset,
batch_size=1,
shuffle=False,
num_workers=1)
# --------- 4. inference for each image ---------
for i_test, data_test in enumerate(test_salobj_dataloader):
print("Generating mask for:",img_name_list[i_test].split(os.sep)[-1])
inputs_test = data_test['image']
inputs_test = inputs_test.type(torch.FloatTensor)
if torch.cuda.is_available():
inputs_test = Variable(inputs_test.cuda())
else:
inputs_test = Variable(inputs_test)
d1,d2,d3,d4,d5,d6,d7= net(inputs_test)
# normalization
pred = d1[:,0,:,:]
pred = normPRED(pred)
# save results to test_results folder
if not os.path.exists(prediction_dir):
os.makedirs(prediction_dir, exist_ok=True)
save_images(img_name_list[i_test],pred,prediction_dir)
del d1,d2,d3,d4,d5,d6,d7