sunshineatnoon
Add application file
1b2a9b1
raw
history blame
5.18 kB
import torch
import torch.utils.data as data
import torchvision.transforms as transforms
import os
import numpy as np
from PIL import Image
from glob import glob
from skimage.segmentation import slic
import torchvision.transforms.functional as TF
from scipy.io import loadmat
import random
import cv2
import sys
sys.path.append("../..")
def label2one_hot_torch(labels, C=14):
""" Converts an integer label torch.autograd.Variable to a one-hot Variable.
Args:
labels(tensor) : segmentation label
C (integer) : number of classes in labels
Returns:
target (tensor) : one-hot vector of the input label
Shape:
labels: (B, 1, H, W)
target: (B, N, H, W)
"""
b,_, h, w = labels.shape
one_hot = torch.zeros(b, C, h, w, dtype=torch.long).to(labels)
target = one_hot.scatter_(1, labels.type(torch.long).data, 1) #require long type
return target.type(torch.float32)
class Dataset(data.Dataset):
def __init__(self, data_dir, crop_size = 128, test=False,
sp_num = 256, slic = True, preprocess_name = False,
gt_label = False, label_path = None, test_time = False,
img_path = None):
super(Dataset, self).__init__()
ext = ["*.jpg"]
dl = []
self.test = test
self.test_time = test_time
[dl.extend(glob(data_dir + '/**/' + e, recursive=True)) for e in ext]
data_list = sorted(dl)
self.data_list = data_list
self.sp_num = sp_num
self.slic = slic
self.crop = transforms.CenterCrop(size = (crop_size, crop_size))
self.crop_size = crop_size
self.test = test
self.gt_label = gt_label
if gt_label:
self.label_path = label_path
self.img_path = img_path
def preprocess_label(self, seg):
segs = label2one_hot_torch(seg.unsqueeze(0), C = seg.max() + 1)
new_seg = []
for cnt in range(seg.max() + 1):
if segs[0, cnt].sum() > 0:
new_seg.append(segs[0, cnt])
new_seg = torch.stack(new_seg)
return torch.argmax(new_seg, dim = 0)
def __getitem__(self, index):
if self.img_path is None:
data_path = self.data_list[index]
else:
data_path = self.img_path
rgb_img = Image.open(data_path)
imgH, imgW = rgb_img.size
if self.gt_label:
img_name = data_path.split("/")[-1].split("_")[0]
mat_path = os.path.join(self.label_path, data_path.split('/')[-2], img_name.replace('.jpg', '.mat'))
mat = loadmat(mat_path)
max_label_num = 0
final_seg = None
for i in range(len(mat['groundTruth'][0])):
seg = mat['groundTruth'][0][i][0][0][0]
if len(np.unique(seg)) > max_label_num:
max_label_num = len(np.unique(seg))
final_seg = seg
seg = torch.from_numpy(final_seg.astype(np.float32))
segs = seg.long().unsqueeze(0)
if self.img_path is None:
i, j, h, w = transforms.RandomCrop.get_params(rgb_img, output_size=(self.crop_size, self.crop_size))
else:
i = 40; j = 40; h = self.crop_size; w = self.crop_size
rgb_img = TF.crop(rgb_img, i, j, h, w)
if self.gt_label:
segs = TF.crop(segs, i, j, h, w)
segs = self.preprocess_label(segs)
if self.slic:
sp_num = self.sp_num
# compute superpixel
slic_i = slic(np.array(rgb_img), n_segments=sp_num, compactness=10, start_label=0, min_size_factor=0.3)
slic_i = torch.from_numpy(slic_i)
slic_i[slic_i >= sp_num] = sp_num - 1
oh = label2one_hot_torch(slic_i.unsqueeze(0).unsqueeze(0), C = sp_num).squeeze()
rgb_img = TF.to_tensor(rgb_img)
if rgb_img.shape[0] == 1:
rgb_img = rgb_img.repeat(3, 1, 1)
rgb_img = rgb_img[:3, :, :]
rets = []
rets.append(rgb_img)
if self.slic:
rets.append(oh)
rets.append(data_path.split("/")[-1])
rets.append(index)
if self.gt_label:
rets.append(segs.view(1, segs.shape[-2], segs.shape[-1]))
return rets
def __len__(self):
return len(self.data_list)
if __name__ == '__main__':
import torchvision.utils as vutils
dataset = Dataset('/home/xtli/DATA/texture_data/',
sampled_num=3000)
loader_ = torch.utils.data.DataLoader(dataset = dataset,
batch_size = 1,
shuffle = True,
num_workers = 1,
drop_last = True)
loader = iter(loader_)
img, points, pixs = loader.next()
crop_size = 128
canvas = torch.zeros((1, 3, crop_size, crop_size))
for i in range(points.shape[-2]):
p = (points[0, i] + 1) / 2.0 * (crop_size - 1)
canvas[0, :, int(p[0]), int(p[1])] = pixs[0, :, i]
vutils.save_image(canvas, 'canvas.png')
vutils.save_image(img, 'img.png')