Spaces:
Runtime error
Runtime error
import torch | |
import torch.utils.data as data | |
import torchvision.transforms as transforms | |
import os | |
import numpy as np | |
from PIL import Image | |
from glob import glob | |
from skimage.segmentation import slic | |
import torchvision.transforms.functional as TF | |
from scipy.io import loadmat | |
import random | |
import cv2 | |
import sys | |
sys.path.append("../..") | |
def label2one_hot_torch(labels, C=14): | |
""" Converts an integer label torch.autograd.Variable to a one-hot Variable. | |
Args: | |
labels(tensor) : segmentation label | |
C (integer) : number of classes in labels | |
Returns: | |
target (tensor) : one-hot vector of the input label | |
Shape: | |
labels: (B, 1, H, W) | |
target: (B, N, H, W) | |
""" | |
b,_, h, w = labels.shape | |
one_hot = torch.zeros(b, C, h, w, dtype=torch.long).to(labels) | |
target = one_hot.scatter_(1, labels.type(torch.long).data, 1) #require long type | |
return target.type(torch.float32) | |
class Dataset(data.Dataset): | |
def __init__(self, data_dir, crop_size = 128, test=False, | |
sp_num = 256, slic = True, preprocess_name = False, | |
gt_label = False, label_path = None, test_time = False, | |
img_path = None): | |
super(Dataset, self).__init__() | |
ext = ["*.jpg"] | |
dl = [] | |
self.test = test | |
self.test_time = test_time | |
[dl.extend(glob(data_dir + '/**/' + e, recursive=True)) for e in ext] | |
data_list = sorted(dl) | |
self.data_list = data_list | |
self.sp_num = sp_num | |
self.slic = slic | |
self.crop = transforms.CenterCrop(size = (crop_size, crop_size)) | |
self.crop_size = crop_size | |
self.test = test | |
self.gt_label = gt_label | |
if gt_label: | |
self.label_path = label_path | |
self.img_path = img_path | |
def preprocess_label(self, seg): | |
segs = label2one_hot_torch(seg.unsqueeze(0), C = seg.max() + 1) | |
new_seg = [] | |
for cnt in range(seg.max() + 1): | |
if segs[0, cnt].sum() > 0: | |
new_seg.append(segs[0, cnt]) | |
new_seg = torch.stack(new_seg) | |
return torch.argmax(new_seg, dim = 0) | |
def __getitem__(self, index): | |
if self.img_path is None: | |
data_path = self.data_list[index] | |
else: | |
data_path = self.img_path | |
rgb_img = Image.open(data_path) | |
imgH, imgW = rgb_img.size | |
if self.gt_label: | |
img_name = data_path.split("/")[-1].split("_")[0] | |
mat_path = os.path.join(self.label_path, data_path.split('/')[-2], img_name.replace('.jpg', '.mat')) | |
mat = loadmat(mat_path) | |
max_label_num = 0 | |
final_seg = None | |
for i in range(len(mat['groundTruth'][0])): | |
seg = mat['groundTruth'][0][i][0][0][0] | |
if len(np.unique(seg)) > max_label_num: | |
max_label_num = len(np.unique(seg)) | |
final_seg = seg | |
seg = torch.from_numpy(final_seg.astype(np.float32)) | |
segs = seg.long().unsqueeze(0) | |
if self.img_path is None: | |
i, j, h, w = transforms.RandomCrop.get_params(rgb_img, output_size=(self.crop_size, self.crop_size)) | |
else: | |
i = 40; j = 40; h = self.crop_size; w = self.crop_size | |
rgb_img = TF.crop(rgb_img, i, j, h, w) | |
if self.gt_label: | |
segs = TF.crop(segs, i, j, h, w) | |
segs = self.preprocess_label(segs) | |
if self.slic: | |
sp_num = self.sp_num | |
# compute superpixel | |
slic_i = slic(np.array(rgb_img), n_segments=sp_num, compactness=10, start_label=0, min_size_factor=0.3) | |
slic_i = torch.from_numpy(slic_i) | |
slic_i[slic_i >= sp_num] = sp_num - 1 | |
oh = label2one_hot_torch(slic_i.unsqueeze(0).unsqueeze(0), C = sp_num).squeeze() | |
rgb_img = TF.to_tensor(rgb_img) | |
if rgb_img.shape[0] == 1: | |
rgb_img = rgb_img.repeat(3, 1, 1) | |
rgb_img = rgb_img[:3, :, :] | |
rets = [] | |
rets.append(rgb_img) | |
if self.slic: | |
rets.append(oh) | |
rets.append(data_path.split("/")[-1]) | |
rets.append(index) | |
if self.gt_label: | |
rets.append(segs.view(1, segs.shape[-2], segs.shape[-1])) | |
return rets | |
def __len__(self): | |
return len(self.data_list) | |
if __name__ == '__main__': | |
import torchvision.utils as vutils | |
dataset = Dataset('/home/xtli/DATA/texture_data/', | |
sampled_num=3000) | |
loader_ = torch.utils.data.DataLoader(dataset = dataset, | |
batch_size = 1, | |
shuffle = True, | |
num_workers = 1, | |
drop_last = True) | |
loader = iter(loader_) | |
img, points, pixs = loader.next() | |
crop_size = 128 | |
canvas = torch.zeros((1, 3, crop_size, crop_size)) | |
for i in range(points.shape[-2]): | |
p = (points[0, i] + 1) / 2.0 * (crop_size - 1) | |
canvas[0, :, int(p[0]), int(p[1])] = pixs[0, :, i] | |
vutils.save_image(canvas, 'canvas.png') | |
vutils.save_image(img, 'img.png') | |