# data loader from __future__ import print_function, division import glob import torch from skimage import io, transform, color import numpy as np import random import math import matplotlib.pyplot as plt from torch.utils.data import Dataset, DataLoader from torchvision import transforms, utils from PIL import Image #==========================dataset load========================== class RescaleT(object): def __init__(self,output_size): assert isinstance(output_size,(int,tuple)) self.output_size = output_size def __call__(self,sample): imidx, image, label = sample['imidx'], sample['image'],sample['label'] h, w = image.shape[:2] if isinstance(self.output_size,int): if h > w: new_h, new_w = self.output_size*h/w,self.output_size else: new_h, new_w = self.output_size,self.output_size*w/h else: new_h, new_w = self.output_size new_h, new_w = int(new_h), int(new_w) # #resize the image to new_h x new_w and convert image from range [0,255] to [0,1] # img = transform.resize(image,(new_h,new_w),mode='constant') # lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True) img = transform.resize(image,(self.output_size,self.output_size),mode='constant') lbl = transform.resize(label,(self.output_size,self.output_size),mode='constant', order=0, preserve_range=True) return {'imidx':imidx, 'image':img,'label':lbl} class Rescale(object): def __init__(self,output_size): assert isinstance(output_size,(int,tuple)) self.output_size = output_size def __call__(self,sample): imidx, image, label = sample['imidx'], sample['image'],sample['label'] if random.random() >= 0.5: image = image[::-1] label = label[::-1] h, w = image.shape[:2] if isinstance(self.output_size,int): if h > w: new_h, new_w = self.output_size*h/w,self.output_size else: new_h, new_w = self.output_size,self.output_size*w/h else: new_h, new_w = self.output_size new_h, new_w = int(new_h), int(new_w) # #resize the image to new_h x new_w and convert image from range [0,255] to [0,1] img = transform.resize(image,(new_h,new_w),mode='constant') lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True) return {'imidx':imidx, 'image':img,'label':lbl} class RandomCrop(object): def __init__(self,output_size): assert isinstance(output_size, (int, tuple)) if isinstance(output_size, int): self.output_size = (output_size, output_size) else: assert len(output_size) == 2 self.output_size = output_size def __call__(self,sample): imidx, image, label = sample['imidx'], sample['image'], sample['label'] if random.random() >= 0.5: image = image[::-1] label = label[::-1] h, w = image.shape[:2] new_h, new_w = self.output_size top = np.random.randint(0, h - new_h) left = np.random.randint(0, w - new_w) image = image[top: top + new_h, left: left + new_w] label = label[top: top + new_h, left: left + new_w] return {'imidx':imidx,'image':image, 'label':label} class ToTensor(object): """Convert ndarrays in sample to Tensors.""" def __call__(self, sample): imidx, image, label = sample['imidx'], sample['image'], sample['label'] tmpImg = np.zeros((image.shape[0],image.shape[1],3)) tmpLbl = np.zeros(label.shape) image = image/np.max(image) if(np.max(label)<1e-6): label = label else: label = label/np.max(label) if image.shape[2]==1: tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229 tmpImg[:,:,1] = (image[:,:,0]-0.485)/0.229 tmpImg[:,:,2] = (image[:,:,0]-0.485)/0.229 else: tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229 tmpImg[:,:,1] = (image[:,:,1]-0.456)/0.224 tmpImg[:,:,2] = (image[:,:,2]-0.406)/0.225 tmpLbl[:,:,0] = label[:,:,0] # change the r,g,b to b,r,g from [0,255] to [0,1] #transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225)) tmpImg = tmpImg.transpose((2, 0, 1)) tmpLbl = label.transpose((2, 0, 1)) return {'imidx':torch.from_numpy(imidx), 'image': torch.from_numpy(tmpImg), 'label': torch.from_numpy(tmpLbl)} class ToTensorLab(object): """Convert ndarrays in sample to Tensors.""" def __init__(self,flag=0): self.flag = flag def __call__(self, sample): imidx, image, label =sample['imidx'], sample['image'], sample['label'] tmpLbl = np.zeros(label.shape) if(np.max(label)<1e-6): label = label else: label = label/np.max(label) # change the color space if self.flag == 2: # with rgb and Lab colors tmpImg = np.zeros((image.shape[0],image.shape[1],6)) tmpImgt = np.zeros((image.shape[0],image.shape[1],3)) if image.shape[2]==1: tmpImgt[:,:,0] = image[:,:,0] tmpImgt[:,:,1] = image[:,:,0] tmpImgt[:,:,2] = image[:,:,0] else: tmpImgt = image tmpImgtl = color.rgb2lab(tmpImgt) # nomalize image to range [0,1] tmpImg[:,:,0] = (tmpImgt[:,:,0]-np.min(tmpImgt[:,:,0]))/(np.max(tmpImgt[:,:,0])-np.min(tmpImgt[:,:,0])) tmpImg[:,:,1] = (tmpImgt[:,:,1]-np.min(tmpImgt[:,:,1]))/(np.max(tmpImgt[:,:,1])-np.min(tmpImgt[:,:,1])) tmpImg[:,:,2] = (tmpImgt[:,:,2]-np.min(tmpImgt[:,:,2]))/(np.max(tmpImgt[:,:,2])-np.min(tmpImgt[:,:,2])) tmpImg[:,:,3] = (tmpImgtl[:,:,0]-np.min(tmpImgtl[:,:,0]))/(np.max(tmpImgtl[:,:,0])-np.min(tmpImgtl[:,:,0])) tmpImg[:,:,4] = (tmpImgtl[:,:,1]-np.min(tmpImgtl[:,:,1]))/(np.max(tmpImgtl[:,:,1])-np.min(tmpImgtl[:,:,1])) tmpImg[:,:,5] = (tmpImgtl[:,:,2]-np.min(tmpImgtl[:,:,2]))/(np.max(tmpImgtl[:,:,2])-np.min(tmpImgtl[:,:,2])) # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg)) tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0]) tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1]) tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2]) tmpImg[:,:,3] = (tmpImg[:,:,3]-np.mean(tmpImg[:,:,3]))/np.std(tmpImg[:,:,3]) tmpImg[:,:,4] = (tmpImg[:,:,4]-np.mean(tmpImg[:,:,4]))/np.std(tmpImg[:,:,4]) tmpImg[:,:,5] = (tmpImg[:,:,5]-np.mean(tmpImg[:,:,5]))/np.std(tmpImg[:,:,5]) elif self.flag == 1: #with Lab color tmpImg = np.zeros((image.shape[0],image.shape[1],3)) if image.shape[2]==1: tmpImg[:,:,0] = image[:,:,0] tmpImg[:,:,1] = image[:,:,0] tmpImg[:,:,2] = image[:,:,0] else: tmpImg = image tmpImg = color.rgb2lab(tmpImg) # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg)) tmpImg[:,:,0] = (tmpImg[:,:,0]-np.min(tmpImg[:,:,0]))/(np.max(tmpImg[:,:,0])-np.min(tmpImg[:,:,0])) tmpImg[:,:,1] = (tmpImg[:,:,1]-np.min(tmpImg[:,:,1]))/(np.max(tmpImg[:,:,1])-np.min(tmpImg[:,:,1])) tmpImg[:,:,2] = (tmpImg[:,:,2]-np.min(tmpImg[:,:,2]))/(np.max(tmpImg[:,:,2])-np.min(tmpImg[:,:,2])) tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0]) tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1]) tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2]) else: # with rgb color tmpImg = np.zeros((image.shape[0],image.shape[1],3)) image = image/np.max(image) if image.shape[2]==1: tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229 tmpImg[:,:,1] = (image[:,:,0]-0.485)/0.229 tmpImg[:,:,2] = (image[:,:,0]-0.485)/0.229 else: tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229 tmpImg[:,:,1] = (image[:,:,1]-0.456)/0.224 tmpImg[:,:,2] = (image[:,:,2]-0.406)/0.225 tmpLbl[:,:,0] = label[:,:,0] # change the r,g,b to b,r,g from [0,255] to [0,1] #transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225)) tmpImg = tmpImg.transpose((2, 0, 1)) tmpLbl = label.transpose((2, 0, 1)) return {'imidx':torch.from_numpy(imidx), 'image': torch.from_numpy(tmpImg), 'label': torch.from_numpy(tmpLbl)} class SalObjDataset(Dataset): def __init__(self,img_name_list,lbl_name_list,transform=None): # self.root_dir = root_dir # self.image_name_list = glob.glob(image_dir+'*.png') # self.label_name_list = glob.glob(label_dir+'*.png') self.image_name_list = img_name_list self.label_name_list = lbl_name_list self.transform = transform def __len__(self): return len(self.image_name_list) def __getitem__(self,idx): # image = Image.open(self.image_name_list[idx])#io.imread(self.image_name_list[idx]) # label = Image.open(self.label_name_list[idx])#io.imread(self.label_name_list[idx]) image = io.imread(self.image_name_list[idx]) imname = self.image_name_list[idx] imidx = np.array([idx]) if(0==len(self.label_name_list)): label_3 = np.zeros(image.shape) else: label_3 = io.imread(self.label_name_list[idx]) label = np.zeros(label_3.shape[0:2]) if(3==len(label_3.shape)): label = label_3[:,:,0] elif(2==len(label_3.shape)): label = label_3 if(3==len(image.shape) and 2==len(label.shape)): label = label[:,:,np.newaxis] elif(2==len(image.shape) and 2==len(label.shape)): image = image[:,:,np.newaxis] label = label[:,:,np.newaxis] sample = {'imidx':imidx, 'image':image, 'label':label} if self.transform: sample = self.transform(sample) return sample