import torch import os from skimage import io, transform import torch import torchvision from torch.autograd import Variable import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from torchvision import transforms import numpy as np from PIL import Image import glob def normPRED(d): ma = torch.max(d) mi = torch.min(d) dn = (d - mi)/(ma - mi) return dn def save_output(image_name, pred, d_dir): predict = pred predict = predict.squeeze() predict_np = predict.cpu().data.numpy() im = Image.fromarray(predict_np * 255).convert('RGB') img_name = image_name.split(os.sep)[-1] image = io.imread(image_name) imo = im.resize((image.shape[1], image.shape[0]), resample = Image.BILINEAR) pb_np = np.array(imo) aaa = img_name.split(".") bbb = aaa[0:-1] imidx = bbb[0] for i in range(1, len(bbb)): imidx = imidx + "." + bbb[i] imo.save(d_dir + "/" + imidx + '.png') #image_dir = "./test_data/" #prediction_dir = './outputs_pred/' #model_dir = 'quant_model_u2net.pth'#'u2net.pth' #img_name_list = glob.glob(image_dir + "/*") #print("Number of images : ", len(img_name_list)) ### Make test dataset class RescaleT(object): def __init__(self,output_size): assert isinstance(output_size,(int,tuple)) self.output_size = output_size def __call__(self,sample): imidx, image, label = sample['imidx'], sample['image'],sample['label'] h, w = image.shape[:2] if isinstance(self.output_size,int): if h > w: new_h, new_w = self.output_size*h/w,self.output_size else: new_h, new_w = self.output_size,self.output_size*w/h else: new_h, new_w = self.output_size new_h, new_w = int(new_h), int(new_w) # #resize the image to new_h x new_w and convert image from range [0,255] to [0,1] # img = transform.resize(image,(new_h,new_w),mode='constant') # lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True) img = transform.resize(image,(self.output_size,self.output_size),mode='constant') lbl = transform.resize(label,(self.output_size,self.output_size),mode='constant', order=0, preserve_range=True) return {'imidx':imidx, 'image':img,'label':lbl} class ToTensorLab(object): """Convert ndarrays in sample to Tensors.""" def __init__(self,flag=0): self.flag = flag def __call__(self, sample): imidx, image, label =sample['imidx'], sample['image'], sample['label'] tmpLbl = np.zeros(label.shape) if(np.max(label)<1e-6): label = label else: label = label/np.max(label) # change the color space if self.flag == 2: # with rgb and Lab colors tmpImg = np.zeros((image.shape[0],image.shape[1],6)) tmpImgt = np.zeros((image.shape[0],image.shape[1],3)) if image.shape[2]==1: tmpImgt[:,:,0] = image[:,:,0] tmpImgt[:,:,1] = image[:,:,0] tmpImgt[:,:,2] = image[:,:,0] else: tmpImgt = image tmpImgtl = color.rgb2lab(tmpImgt) # nomalize image to range [0,1] tmpImg[:,:,0] = (tmpImgt[:,:,0]-np.min(tmpImgt[:,:,0]))/(np.max(tmpImgt[:,:,0])-np.min(tmpImgt[:,:,0])) tmpImg[:,:,1] = (tmpImgt[:,:,1]-np.min(tmpImgt[:,:,1]))/(np.max(tmpImgt[:,:,1])-np.min(tmpImgt[:,:,1])) tmpImg[:,:,2] = (tmpImgt[:,:,2]-np.min(tmpImgt[:,:,2]))/(np.max(tmpImgt[:,:,2])-np.min(tmpImgt[:,:,2])) tmpImg[:,:,3] = (tmpImgtl[:,:,0]-np.min(tmpImgtl[:,:,0]))/(np.max(tmpImgtl[:,:,0])-np.min(tmpImgtl[:,:,0])) tmpImg[:,:,4] = (tmpImgtl[:,:,1]-np.min(tmpImgtl[:,:,1]))/(np.max(tmpImgtl[:,:,1])-np.min(tmpImgtl[:,:,1])) tmpImg[:,:,5] = (tmpImgtl[:,:,2]-np.min(tmpImgtl[:,:,2]))/(np.max(tmpImgtl[:,:,2])-np.min(tmpImgtl[:,:,2])) # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg)) tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0]) tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1]) tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2]) tmpImg[:,:,3] = (tmpImg[:,:,3]-np.mean(tmpImg[:,:,3]))/np.std(tmpImg[:,:,3]) tmpImg[:,:,4] = (tmpImg[:,:,4]-np.mean(tmpImg[:,:,4]))/np.std(tmpImg[:,:,4]) tmpImg[:,:,5] = (tmpImg[:,:,5]-np.mean(tmpImg[:,:,5]))/np.std(tmpImg[:,:,5]) elif self.flag == 1: #with Lab color tmpImg = np.zeros((image.shape[0],image.shape[1],3)) if image.shape[2]==1: tmpImg[:,:,0] = image[:,:,0] tmpImg[:,:,1] = image[:,:,0] tmpImg[:,:,2] = image[:,:,0] else: tmpImg = image tmpImg = color.rgb2lab(tmpImg) # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg)) tmpImg[:,:,0] = (tmpImg[:,:,0]-np.min(tmpImg[:,:,0]))/(np.max(tmpImg[:,:,0])-np.min(tmpImg[:,:,0])) tmpImg[:,:,1] = (tmpImg[:,:,1]-np.min(tmpImg[:,:,1]))/(np.max(tmpImg[:,:,1])-np.min(tmpImg[:,:,1])) tmpImg[:,:,2] = (tmpImg[:,:,2]-np.min(tmpImg[:,:,2]))/(np.max(tmpImg[:,:,2])-np.min(tmpImg[:,:,2])) tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0]) tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1]) tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2]) else: # with rgb color tmpImg = np.zeros((image.shape[0],image.shape[1],3)) image = image/np.max(image) if image.shape[2]==1: tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229 tmpImg[:,:,1] = (image[:,:,0]-0.485)/0.229 tmpImg[:,:,2] = (image[:,:,0]-0.485)/0.229 else: tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229 tmpImg[:,:,1] = (image[:,:,1]-0.456)/0.224 tmpImg[:,:,2] = (image[:,:,2]-0.406)/0.225 tmpLbl[:,:,0] = label[:,:,0] tmpImg = tmpImg.transpose((2, 0, 1)) tmpLbl = label.transpose((2, 0, 1)) return {'imidx':torch.from_numpy(imidx), 'image': torch.from_numpy(tmpImg), 'label': torch.from_numpy(tmpLbl)} class TestData(Dataset): def __init__(self, img_name_list, lbl_name_list, transform = None): self.img_list = img_name_list self.label_name_list = lbl_name_list self.transform = transform def __len__(self): return len(self.img_list) def __getitem__(self, idx): #image = io.imread(self.img_list[idx]) image = self.img_list[idx] imname = self.img_list[idx] imidx = np.array([idx]) if (0 == len(self.label_name_list)): label_3 = np.zeros(image.shape) else: label_3 = io.imread(self.label_name_list[idx]) label = np.zeros(label_3.shape[0:2]) if(3==len(label_3.shape)): label = label_3[:,:,0] elif(2==len(label_3.shape)): label = label_3 if(3==len(image.shape) and 2==len(label.shape)): label = label[:,:,np.newaxis] elif(2==len(image.shape) and 2==len(label.shape)): image = image[:,:,np.newaxis] label = label[:,:,np.newaxis] sample = {'imidx':imidx, 'image':image, 'label':label} if self.transform: sample = self.transform(sample) return sample #test_dataset = TestData(img_name_list = img_name_list, lbl_name_list = [], # transform = transforms.Compose([RescaleT(512), ToTensorLab(flag = 0)])) ### Make test dataloader #test_dataloader = DataLoader(test_dataset, batch_size = 1, shuffle = False, num_workers = 1) #net = U2Net(3,1) #net = torch.jit.load('quant_model_u2net.pth') #net.load_state_dict(torch.load(model_dir)) """net.eval() for i_test, data_test in enumerate(test_dataloader): print("Inferencing : ", img_name_list[i_test].split(os.sep)[-1]) inputs_test = data_test['image'] inputs_test = inputs_test.type(torch.FloatTensor) inputs_test = Variable(inputs_test) d1, d2, d3, d4, d5, d6, d7 = net(inputs_test) pred = 1.0 - d1[:,0,:,:] pred = normPRED(pred) #save_output(img_name_list[i_test], pred, prediction_dir) #del d1, d2, d3, d4, d5, d6, d7"""