"""SLIC dataset - Returns an image together with its SLIC segmentation map. """ import torch import torch.utils.data as data import torchvision.transforms as transforms import numpy as np from glob import glob from PIL import Image from skimage.segmentation import slic from skimage.color import rgb2lab import torch.nn.functional as F from .utils import label2one_hot_torch class RandomResizedCrop(object): def __init__(self, N, res, scale=(0.5, 1.0)): self.res = res self.scale = scale self.rscale = [np.random.uniform(*scale) for _ in range(N)] self.rcrop = [(np.random.uniform(0, 1), np.random.uniform(0, 1)) for _ in range(N)] def random_crop(self, idx, img): ws, hs = self.rcrop[idx] res1 = int(img.size(-1)) res2 = int(self.rscale[idx]*res1) i1 = int(round((res1-res2)*ws)) j1 = int(round((res1-res2)*hs)) return img[:, :, i1:i1+res2, j1:j1+res2] def __call__(self, indice, image): new_image = [] res_tar = self.res // 8 if image.size(1) > 5 else self.res # View 1 or View 2? for i, idx in enumerate(indice): img = image[[i]] img = self.random_crop(idx, img) img = F.interpolate(img, res_tar, mode='bilinear', align_corners=False) new_image.append(img) new_image = torch.cat(new_image) return new_image class RandomVerticalFlip(object): def __init__(self, N, p=0.5): self.p_ref = p self.plist = np.random.random_sample(N) def __call__(self, indice, image): I = np.nonzero(self.plist[indice] < self.p_ref)[0] if len(image.size()) == 3: image_t = image[I].flip([1]) else: image_t = image[I].flip([2]) return torch.stack([image_t[np.where(I==i)[0][0]] if i in I else image[i] for i in range(image.size(0))]) class RandomHorizontalTensorFlip(object): def __init__(self, N, p=0.5): self.p_ref = p self.plist = np.random.random_sample(N) def __call__(self, indice, image, is_label=False): I = np.nonzero(self.plist[indice.cpu()] < self.p_ref)[0] if len(image.size()) == 3: image_t = image[I].flip([2]) else: image_t = image[I].flip([3]) return torch.stack([image_t[np.where(I==i)[0][0]] if i in I else image[i] for i in range(image.size(0))]) class Dataset(data.Dataset): def __init__(self, data_dir, img_size=256, crop_size=128, test=False, sp_num=256, slic = True, lab = False): super(Dataset, self).__init__() #self.data_list = glob(os.path.join(data_dir, "*.jpg")) ext = ["*.jpg"] dl = [] [dl.extend(glob(data_dir + '/**/' + e, recursive=True)) for e in ext] self.data_list = dl self.sp_num = sp_num self.slic = slic self.lab = lab if test: self.transform = transforms.Compose([ transforms.Resize(img_size), transforms.CenterCrop(crop_size)]) else: self.transform = transforms.Compose([ transforms.RandomChoice([ transforms.ColorJitter(brightness=0.05), transforms.ColorJitter(contrast=0.05), transforms.ColorJitter(saturation=0.01), transforms.ColorJitter(hue=0.01)]), transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), transforms.Resize(int(img_size)), transforms.RandomCrop(crop_size)]) N = len(self.data_list) self.random_horizontal_flip = RandomHorizontalTensorFlip(N=N) self.random_vertical_flip = RandomVerticalFlip(N=N) self.random_resized_crop = RandomResizedCrop(N=N, res=224) self.eqv_list = ['random_crop', 'h_flip'] def transform_eqv(self, indice, image): if 'random_crop' in self.eqv_list: image = self.random_resized_crop(indice, image) if 'h_flip' in self.eqv_list: image = self.random_horizontal_flip(indice, image) if 'v_flip' in self.eqv_list: image = self.random_vertical_flip(indice, image) return image def __getitem__(self, index): data_path = self.data_list[index] ori_img = Image.open(data_path) ori_img = self.transform(ori_img) ori_img = np.array(ori_img) # compute slic if self.slic: slic_i = slic(ori_img, n_segments=self.sp_num, compactness=10, start_label=0, min_size_factor=0.3) slic_i = torch.from_numpy(slic_i) slic_i[slic_i >= self.sp_num] = self.sp_num - 1 oh = label2one_hot_torch(slic_i.unsqueeze(0).unsqueeze(0), C = self.sp_num).squeeze() if ori_img.ndim < 3: ori_img = np.expand_dims(ori_img, axis=2).repeat(3, axis = 2) ori_img = ori_img[:, :, :3] rets = [] if self.lab: lab_img = rgb2lab(ori_img) rets.append(torch.from_numpy(lab_img).float().permute(2, 0, 1)) ori_img = torch.from_numpy(ori_img).float().permute(2, 0, 1) rets.append(ori_img/255.0) if self.slic: rets.append(oh) rets.append(index) return rets def __len__(self): return len(self.data_list) if __name__ == '__main__': import torchvision.utils as vutils dataset = Dataset('/home/xtli/DATA/texture_data/', sampled_num=3000) loader_ = torch.utils.data.DataLoader(dataset = dataset, batch_size = 1, shuffle = True, num_workers = 1, drop_last = True) loader = iter(loader_) img, points, pixs = loader.next() crop_size = 128 canvas = torch.zeros((1, 3, crop_size, crop_size)) for i in range(points.shape[-2]): p = (points[0, i] + 1) / 2.0 * (crop_size - 1) canvas[0, :, int(p[0]), int(p[1])] = pixs[0, :, i] vutils.save_image(canvas, 'canvas.png') vutils.save_image(img, 'img.png')