import torch import torch.utils.data as data import torchvision.transforms as transforms import os import numpy as np from PIL import Image from glob import glob from skimage.segmentation import slic import torchvision.transforms.functional as TF from scipy.io import loadmat import random import cv2 import sys sys.path.append("../..") def label2one_hot_torch(labels, C=14): """ Converts an integer label torch.autograd.Variable to a one-hot Variable. Args: labels(tensor) : segmentation label C (integer) : number of classes in labels Returns: target (tensor) : one-hot vector of the input label Shape: labels: (B, 1, H, W) target: (B, N, H, W) """ b,_, h, w = labels.shape one_hot = torch.zeros(b, C, h, w, dtype=torch.long).to(labels) target = one_hot.scatter_(1, labels.type(torch.long).data, 1) #require long type return target.type(torch.float32) class Dataset(data.Dataset): def __init__(self, data_dir, crop_size = 128, test=False, sp_num = 256, slic = True, preprocess_name = False, gt_label = False, label_path = None, test_time = False, img_path = None): super(Dataset, self).__init__() ext = ["*.jpg"] dl = [] self.test = test self.test_time = test_time [dl.extend(glob(data_dir + '/**/' + e, recursive=True)) for e in ext] data_list = sorted(dl) self.data_list = data_list self.sp_num = sp_num self.slic = slic self.crop = transforms.CenterCrop(size = (crop_size, crop_size)) self.crop_size = crop_size self.test = test self.gt_label = gt_label if gt_label: self.label_path = label_path self.img_path = img_path def preprocess_label(self, seg): segs = label2one_hot_torch(seg.unsqueeze(0), C = seg.max() + 1) new_seg = [] for cnt in range(seg.max() + 1): if segs[0, cnt].sum() > 0: new_seg.append(segs[0, cnt]) new_seg = torch.stack(new_seg) return torch.argmax(new_seg, dim = 0) def __getitem__(self, index): if self.img_path is None: data_path = self.data_list[index] else: data_path = self.img_path rgb_img = Image.open(data_path) imgH, imgW = rgb_img.size if self.gt_label: img_name = data_path.split("/")[-1].split("_")[0] mat_path = os.path.join(self.label_path, data_path.split('/')[-2], img_name.replace('.jpg', '.mat')) mat = loadmat(mat_path) max_label_num = 0 final_seg = None for i in range(len(mat['groundTruth'][0])): seg = mat['groundTruth'][0][i][0][0][0] if len(np.unique(seg)) > max_label_num: max_label_num = len(np.unique(seg)) final_seg = seg seg = torch.from_numpy(final_seg.astype(np.float32)) segs = seg.long().unsqueeze(0) if self.img_path is None: i, j, h, w = transforms.RandomCrop.get_params(rgb_img, output_size=(self.crop_size, self.crop_size)) else: i = 40; j = 40; h = self.crop_size; w = self.crop_size rgb_img = TF.crop(rgb_img, i, j, h, w) if self.gt_label: segs = TF.crop(segs, i, j, h, w) segs = self.preprocess_label(segs) if self.slic: sp_num = self.sp_num # compute superpixel slic_i = slic(np.array(rgb_img), n_segments=sp_num, compactness=10, start_label=0, min_size_factor=0.3) slic_i = torch.from_numpy(slic_i) slic_i[slic_i >= sp_num] = sp_num - 1 oh = label2one_hot_torch(slic_i.unsqueeze(0).unsqueeze(0), C = sp_num).squeeze() rgb_img = TF.to_tensor(rgb_img) if rgb_img.shape[0] == 1: rgb_img = rgb_img.repeat(3, 1, 1) rgb_img = rgb_img[:3, :, :] rets = [] rets.append(rgb_img) if self.slic: rets.append(oh) rets.append(data_path.split("/")[-1]) rets.append(index) if self.gt_label: rets.append(segs.view(1, segs.shape[-2], segs.shape[-1])) return rets def __len__(self): return len(self.data_list) if __name__ == '__main__': import torchvision.utils as vutils dataset = Dataset('/home/xtli/DATA/texture_data/', sampled_num=3000) loader_ = torch.utils.data.DataLoader(dataset = dataset, batch_size = 1, shuffle = True, num_workers = 1, drop_last = True) loader = iter(loader_) img, points, pixs = loader.next() crop_size = 128 canvas = torch.zeros((1, 3, crop_size, crop_size)) for i in range(points.shape[-2]): p = (points[0, i] + 1) / 2.0 * (crop_size - 1) canvas[0, :, int(p[0]), int(p[1])] = pixs[0, :, i] vutils.save_image(canvas, 'canvas.png') vutils.save_image(img, 'img.png')