"""SLIC dataset - Returns an image together with its SLIC segmentation map. """ import torch import torch.utils.data as data import torchvision.transforms as transforms import numpy as np from glob import glob from PIL import Image import torch.nn.functional as F import torchvision.transforms.functional as TF from .custom_transform import * class Dataset(data.Dataset): def __init__(self, data_dir, img_size=256, crop_size=128, test=False, sp_num=256, slic = True, lab = False): super(Dataset, self).__init__() #self.data_list = glob(os.path.join(data_dir, "*.jpg")) ext = ["*.jpg"] dl = [] [dl.extend(glob(data_dir + '/**/' + e, recursive=True)) for e in ext] self.data_list = dl self.sp_num = sp_num self.slic = slic self.lab = lab if test: self.transform = transforms.Compose([ transforms.Resize(img_size), transforms.CenterCrop(crop_size)]) else: self.transform = transforms.Compose([ transforms.Resize(int(img_size)), transforms.RandomCrop(crop_size)]) N = len(self.data_list) # eqv transform self.random_horizontal_flip = RandomHorizontalTensorFlip(N=N) self.random_vertical_flip = RandomVerticalFlip(N=N) self.random_resized_crop = RandomResizedCrop(N=N, res=256) # photometric transform self.random_color_brightness = [RandomColorBrightness(x=0.3, p=0.8, N=N) for _ in range(2)] # Control this later (NOTE)] self.random_color_contrast = [RandomColorContrast(x=0.3, p=0.8, N=N) for _ in range(2)] # Control this later (NOTE) self.random_color_saturation = [RandomColorSaturation(x=0.3, p=0.8, N=N) for _ in range(2)] # Control this later (NOTE) self.random_color_hue = [RandomColorHue(x=0.1, p=0.8, N=N) for _ in range(2)] # Control this later (NOTE) self.random_gray_scale = [RandomGrayScale(p=0.2, N=N) for _ in range(2)] self.random_gaussian_blur = [RandomGaussianBlur(sigma=[.1, 2.], p=0.5, N=N) for _ in range(2)] self.eqv_list = ['random_crop', 'h_flip'] self.inv_list = ['brightness', 'contrast', 'saturation', 'hue', 'gray', 'blur'] self.transform_tensor = TensorTransform() def transform_eqv(self, indice, image): if 'random_crop' in self.eqv_list: image = self.random_resized_crop(indice, image) if 'h_flip' in self.eqv_list: image = self.random_horizontal_flip(indice, image) if 'v_flip' in self.eqv_list: image = self.random_vertical_flip(indice, image) return image def transform_inv(self, index, image, ver): """ Hyperparameters same as MoCo v2. (https://github.com/facebookresearch/moco/blob/master/main_moco.py) """ if 'brightness' in self.inv_list: image = self.random_color_brightness[ver](index, image) if 'contrast' in self.inv_list: image = self.random_color_contrast[ver](index, image) if 'saturation' in self.inv_list: image = self.random_color_saturation[ver](index, image) if 'hue' in self.inv_list: image = self.random_color_hue[ver](index, image) if 'gray' in self.inv_list: image = self.random_gray_scale[ver](index, image) if 'blur' in self.inv_list: image = self.random_gaussian_blur[ver](index, image) return image def transform_image(self, index, image): image1 = self.transform_inv(index, image, 0) image1 = self.transform_tensor(image) image2 = self.transform_inv(index, image, 1) #image2 = TF.resize(image2, self.crop_size, Image.BILINEAR) image2 = self.transform_tensor(image2) return image1, image2 def __getitem__(self, index): data_path = self.data_list[index] ori_img = Image.open(data_path) ori_img = self.transform(ori_img) image1, image2 = self.transform_image(index, ori_img) rets = [] rets.append(image1) rets.append(image2) rets.append(index) return rets def __len__(self): return len(self.data_list) if __name__ == '__main__': import torchvision.utils as vutils dataset = Dataset('/home/xtli/DATA/texture_data/', sampled_num=3000) loader_ = torch.utils.data.DataLoader(dataset = dataset, batch_size = 1, shuffle = True, num_workers = 1, drop_last = True) loader = iter(loader_) img, points, pixs = loader.next() crop_size = 128 canvas = torch.zeros((1, 3, crop_size, crop_size)) for i in range(points.shape[-2]): p = (points[0, i] + 1) / 2.0 * (crop_size - 1) canvas[0, :, int(p[0]), int(p[1])] = pixs[0, :, i] vutils.save_image(canvas, 'canvas.png') vutils.save_image(img, 'img.png')