import os from skimage import io, transform import torch import torchvision from torch.autograd import Variable import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from torchvision import transforms#, utils from u2net_test import normPRED # import torch.optim as optim import numpy as np from PIL import Image import glob import warnings from data_loader import RescaleT from data_loader import ToTensor from data_loader import ToTensorLab from data_loader import SalObjDataset warnings.filterwarnings("ignore") def save_images(image_name,pred,d_dir): predict = pred predict = predict.squeeze() predict_np = predict.cpu().data.numpy() im = Image.fromarray(predict_np*255).convert('RGB') img_name = image_name.split(os.sep)[-1] image = io.imread(image_name) imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BICUBIC) pb_np = np.array(imo) aaa = img_name.split(".") bbb = aaa[0:-1] imidx = bbb[0] for i in range(1,len(bbb)): imidx = imidx + "." + bbb[i] print('Saving output at {}'.format(os.path.join(d_dir, imidx+'.png'))) imo.save(os.path.join(d_dir, imidx+'.png')) def infer( net, image_dir = os.path.join(os.getcwd(), 'test_data', 'test_images'), prediction_dir = os.path.join(os.getcwd(), 'test_data', 'u2net' + '_results') ): img_name_list = glob.glob(image_dir + os.sep + '*') prediction_dir = prediction_dir + os.sep # --------- 2. dataloader --------- #1. dataloader test_salobj_dataset = SalObjDataset(img_name_list = img_name_list, lbl_name_list = [], transform=transforms.Compose([RescaleT(320), ToTensorLab(flag=0)]) ) test_salobj_dataloader = DataLoader(test_salobj_dataset, batch_size=1, shuffle=False, num_workers=1) # --------- 4. inference for each image --------- for i_test, data_test in enumerate(test_salobj_dataloader): print("Generating mask for:",img_name_list[i_test].split(os.sep)[-1]) inputs_test = data_test['image'] inputs_test = inputs_test.type(torch.FloatTensor) if torch.cuda.is_available(): inputs_test = Variable(inputs_test.cuda()) else: inputs_test = Variable(inputs_test) d1,d2,d3,d4,d5,d6,d7= net(inputs_test) # normalization pred = d1[:,0,:,:] pred = normPRED(pred) # save results to test_results folder if not os.path.exists(prediction_dir): os.makedirs(prediction_dir, exist_ok=True) save_images(img_name_list[i_test],pred,prediction_dir) del d1,d2,d3,d4,d5,d6,d7