Spaces:
Runtime error
Runtime error
File size: 5,178 Bytes
1b2a9b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import torch
import torch.utils.data as data
import torchvision.transforms as transforms
import os
import numpy as np
from PIL import Image
from glob import glob
from skimage.segmentation import slic
import torchvision.transforms.functional as TF
from scipy.io import loadmat
import random
import cv2
import sys
sys.path.append("../..")
def label2one_hot_torch(labels, C=14):
""" Converts an integer label torch.autograd.Variable to a one-hot Variable.
Args:
labels(tensor) : segmentation label
C (integer) : number of classes in labels
Returns:
target (tensor) : one-hot vector of the input label
Shape:
labels: (B, 1, H, W)
target: (B, N, H, W)
"""
b,_, h, w = labels.shape
one_hot = torch.zeros(b, C, h, w, dtype=torch.long).to(labels)
target = one_hot.scatter_(1, labels.type(torch.long).data, 1) #require long type
return target.type(torch.float32)
class Dataset(data.Dataset):
def __init__(self, data_dir, crop_size = 128, test=False,
sp_num = 256, slic = True, preprocess_name = False,
gt_label = False, label_path = None, test_time = False,
img_path = None):
super(Dataset, self).__init__()
ext = ["*.jpg"]
dl = []
self.test = test
self.test_time = test_time
[dl.extend(glob(data_dir + '/**/' + e, recursive=True)) for e in ext]
data_list = sorted(dl)
self.data_list = data_list
self.sp_num = sp_num
self.slic = slic
self.crop = transforms.CenterCrop(size = (crop_size, crop_size))
self.crop_size = crop_size
self.test = test
self.gt_label = gt_label
if gt_label:
self.label_path = label_path
self.img_path = img_path
def preprocess_label(self, seg):
segs = label2one_hot_torch(seg.unsqueeze(0), C = seg.max() + 1)
new_seg = []
for cnt in range(seg.max() + 1):
if segs[0, cnt].sum() > 0:
new_seg.append(segs[0, cnt])
new_seg = torch.stack(new_seg)
return torch.argmax(new_seg, dim = 0)
def __getitem__(self, index):
if self.img_path is None:
data_path = self.data_list[index]
else:
data_path = self.img_path
rgb_img = Image.open(data_path)
imgH, imgW = rgb_img.size
if self.gt_label:
img_name = data_path.split("/")[-1].split("_")[0]
mat_path = os.path.join(self.label_path, data_path.split('/')[-2], img_name.replace('.jpg', '.mat'))
mat = loadmat(mat_path)
max_label_num = 0
final_seg = None
for i in range(len(mat['groundTruth'][0])):
seg = mat['groundTruth'][0][i][0][0][0]
if len(np.unique(seg)) > max_label_num:
max_label_num = len(np.unique(seg))
final_seg = seg
seg = torch.from_numpy(final_seg.astype(np.float32))
segs = seg.long().unsqueeze(0)
if self.img_path is None:
i, j, h, w = transforms.RandomCrop.get_params(rgb_img, output_size=(self.crop_size, self.crop_size))
else:
i = 40; j = 40; h = self.crop_size; w = self.crop_size
rgb_img = TF.crop(rgb_img, i, j, h, w)
if self.gt_label:
segs = TF.crop(segs, i, j, h, w)
segs = self.preprocess_label(segs)
if self.slic:
sp_num = self.sp_num
# compute superpixel
slic_i = slic(np.array(rgb_img), n_segments=sp_num, compactness=10, start_label=0, min_size_factor=0.3)
slic_i = torch.from_numpy(slic_i)
slic_i[slic_i >= sp_num] = sp_num - 1
oh = label2one_hot_torch(slic_i.unsqueeze(0).unsqueeze(0), C = sp_num).squeeze()
rgb_img = TF.to_tensor(rgb_img)
if rgb_img.shape[0] == 1:
rgb_img = rgb_img.repeat(3, 1, 1)
rgb_img = rgb_img[:3, :, :]
rets = []
rets.append(rgb_img)
if self.slic:
rets.append(oh)
rets.append(data_path.split("/")[-1])
rets.append(index)
if self.gt_label:
rets.append(segs.view(1, segs.shape[-2], segs.shape[-1]))
return rets
def __len__(self):
return len(self.data_list)
if __name__ == '__main__':
import torchvision.utils as vutils
dataset = Dataset('/home/xtli/DATA/texture_data/',
sampled_num=3000)
loader_ = torch.utils.data.DataLoader(dataset = dataset,
batch_size = 1,
shuffle = True,
num_workers = 1,
drop_last = True)
loader = iter(loader_)
img, points, pixs = loader.next()
crop_size = 128
canvas = torch.zeros((1, 3, crop_size, crop_size))
for i in range(points.shape[-2]):
p = (points[0, i] + 1) / 2.0 * (crop_size - 1)
canvas[0, :, int(p[0]), int(p[1])] = pixs[0, :, i]
vutils.save_image(canvas, 'canvas.png')
vutils.save_image(img, 'img.png')
|