import os from skimage import io, transform import torch import torchvision from torch.autograd import Variable import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from torchvision import transforms#, utils # import torch.optim as optim import numpy as np from PIL import Image import glob from data_loader import RescaleT from data_loader import ToTensor from data_loader import ToTensorLab from data_loader import SalObjDataset from model import U2NET # full size version 173.6 MB from model import U2NETP # small version u2net 4.7 MB # normalize the predicted SOD probability map def normPRED(d): ma = torch.max(d) mi = torch.min(d) dn = (d-mi)/(ma-mi) return dn def save_output(image_name,pred,d_dir): predict = pred predict = predict.squeeze() predict_np = predict.cpu().data.numpy() im = Image.fromarray(predict_np*255).convert('RGB') img_name = image_name.split(os.sep)[-1] image = io.imread(image_name) imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR) pb_np = np.array(imo) aaa = img_name.split(".") bbb = aaa[0:-1] imidx = bbb[0] for i in range(1,len(bbb)): imidx = imidx + "." + bbb[i] imo.save(d_dir+'/'+imidx+'.png') def main(): # --------- 1. get image path and name --------- model_name='u2net_portrait'#u2netp image_dir = './test_data/test_portrait_images/portrait_im' prediction_dir = './test_data/test_portrait_images/portrait_results' if(not os.path.exists(prediction_dir)): os.mkdir(prediction_dir) model_dir = './saved_models/u2net_portrait/u2net_portrait.pth' img_name_list = glob.glob(image_dir+'/*') print("Number of images: ", len(img_name_list)) # --------- 2. dataloader --------- #1. dataloader test_salobj_dataset = SalObjDataset(img_name_list = img_name_list, lbl_name_list = [], transform=transforms.Compose([RescaleT(512), ToTensorLab(flag=0)]) ) test_salobj_dataloader = DataLoader(test_salobj_dataset, batch_size=1, shuffle=False, num_workers=1) # --------- 3. model define --------- print("...load U2NET---173.6 MB") net = U2NET(3,1) net.load_state_dict(torch.load(model_dir)) if torch.cuda.is_available(): net.cuda() net.eval() # --------- 4. inference for each image --------- for i_test, data_test in enumerate(test_salobj_dataloader): print("inferencing:",img_name_list[i_test].split(os.sep)[-1]) inputs_test = data_test['image'] inputs_test = inputs_test.type(torch.FloatTensor) if torch.cuda.is_available(): inputs_test = Variable(inputs_test.cuda()) else: inputs_test = Variable(inputs_test) d1,d2,d3,d4,d5,d6,d7= net(inputs_test) # normalization pred = 1.0 - d1[:,0,:,:] pred = normPRED(pred) # save results to test_results folder save_output(img_name_list[i_test],pred,prediction_dir) del d1,d2,d3,d4,d5,d6,d7 if __name__ == "__main__": main()