Spaces:
Runtime error
Runtime error
"""SLIC dataset | |
- Returns an image together with its SLIC segmentation map. | |
""" | |
import torch | |
import torch.utils.data as data | |
import torchvision.transforms as transforms | |
import numpy as np | |
from glob import glob | |
from PIL import Image | |
from skimage.segmentation import slic | |
from skimage.color import rgb2lab | |
import torch.nn.functional as F | |
from .utils import label2one_hot_torch | |
class RandomResizedCrop(object): | |
def __init__(self, N, res, scale=(0.5, 1.0)): | |
self.res = res | |
self.scale = scale | |
self.rscale = [np.random.uniform(*scale) for _ in range(N)] | |
self.rcrop = [(np.random.uniform(0, 1), np.random.uniform(0, 1)) for _ in range(N)] | |
def random_crop(self, idx, img): | |
ws, hs = self.rcrop[idx] | |
res1 = int(img.size(-1)) | |
res2 = int(self.rscale[idx]*res1) | |
i1 = int(round((res1-res2)*ws)) | |
j1 = int(round((res1-res2)*hs)) | |
return img[:, :, i1:i1+res2, j1:j1+res2] | |
def __call__(self, indice, image): | |
new_image = [] | |
res_tar = self.res // 8 if image.size(1) > 5 else self.res # View 1 or View 2? | |
for i, idx in enumerate(indice): | |
img = image[[i]] | |
img = self.random_crop(idx, img) | |
img = F.interpolate(img, res_tar, mode='bilinear', align_corners=False) | |
new_image.append(img) | |
new_image = torch.cat(new_image) | |
return new_image | |
class RandomVerticalFlip(object): | |
def __init__(self, N, p=0.5): | |
self.p_ref = p | |
self.plist = np.random.random_sample(N) | |
def __call__(self, indice, image): | |
I = np.nonzero(self.plist[indice] < self.p_ref)[0] | |
if len(image.size()) == 3: | |
image_t = image[I].flip([1]) | |
else: | |
image_t = image[I].flip([2]) | |
return torch.stack([image_t[np.where(I==i)[0][0]] if i in I else image[i] for i in range(image.size(0))]) | |
class RandomHorizontalTensorFlip(object): | |
def __init__(self, N, p=0.5): | |
self.p_ref = p | |
self.plist = np.random.random_sample(N) | |
def __call__(self, indice, image, is_label=False): | |
I = np.nonzero(self.plist[indice.cpu()] < self.p_ref)[0] | |
if len(image.size()) == 3: | |
image_t = image[I].flip([2]) | |
else: | |
image_t = image[I].flip([3]) | |
return torch.stack([image_t[np.where(I==i)[0][0]] if i in I else image[i] for i in range(image.size(0))]) | |
class Dataset(data.Dataset): | |
def __init__(self, data_dir, img_size=256, crop_size=128, test=False, | |
sp_num=256, slic = True, lab = False): | |
super(Dataset, self).__init__() | |
#self.data_list = glob(os.path.join(data_dir, "*.jpg")) | |
ext = ["*.jpg"] | |
dl = [] | |
[dl.extend(glob(data_dir + '/**/' + e, recursive=True)) for e in ext] | |
self.data_list = dl | |
self.sp_num = sp_num | |
self.slic = slic | |
self.lab = lab | |
if test: | |
self.transform = transforms.Compose([ | |
transforms.Resize(img_size), | |
transforms.CenterCrop(crop_size)]) | |
else: | |
self.transform = transforms.Compose([ | |
transforms.RandomChoice([ | |
transforms.ColorJitter(brightness=0.05), | |
transforms.ColorJitter(contrast=0.05), | |
transforms.ColorJitter(saturation=0.01), | |
transforms.ColorJitter(hue=0.01)]), | |
transforms.RandomHorizontalFlip(), | |
transforms.RandomVerticalFlip(), | |
transforms.Resize(int(img_size)), | |
transforms.RandomCrop(crop_size)]) | |
N = len(self.data_list) | |
self.random_horizontal_flip = RandomHorizontalTensorFlip(N=N) | |
self.random_vertical_flip = RandomVerticalFlip(N=N) | |
self.random_resized_crop = RandomResizedCrop(N=N, res=224) | |
self.eqv_list = ['random_crop', 'h_flip'] | |
def transform_eqv(self, indice, image): | |
if 'random_crop' in self.eqv_list: | |
image = self.random_resized_crop(indice, image) | |
if 'h_flip' in self.eqv_list: | |
image = self.random_horizontal_flip(indice, image) | |
if 'v_flip' in self.eqv_list: | |
image = self.random_vertical_flip(indice, image) | |
return image | |
def __getitem__(self, index): | |
data_path = self.data_list[index] | |
ori_img = Image.open(data_path) | |
ori_img = self.transform(ori_img) | |
ori_img = np.array(ori_img) | |
# compute slic | |
if self.slic: | |
slic_i = slic(ori_img, n_segments=self.sp_num, compactness=10, start_label=0, min_size_factor=0.3) | |
slic_i = torch.from_numpy(slic_i) | |
slic_i[slic_i >= self.sp_num] = self.sp_num - 1 | |
oh = label2one_hot_torch(slic_i.unsqueeze(0).unsqueeze(0), C = self.sp_num).squeeze() | |
if ori_img.ndim < 3: | |
ori_img = np.expand_dims(ori_img, axis=2).repeat(3, axis = 2) | |
ori_img = ori_img[:, :, :3] | |
rets = [] | |
if self.lab: | |
lab_img = rgb2lab(ori_img) | |
rets.append(torch.from_numpy(lab_img).float().permute(2, 0, 1)) | |
ori_img = torch.from_numpy(ori_img).float().permute(2, 0, 1) | |
rets.append(ori_img/255.0) | |
if self.slic: | |
rets.append(oh) | |
rets.append(index) | |
return rets | |
def __len__(self): | |
return len(self.data_list) | |
if __name__ == '__main__': | |
import torchvision.utils as vutils | |
dataset = Dataset('/home/xtli/DATA/texture_data/', | |
sampled_num=3000) | |
loader_ = torch.utils.data.DataLoader(dataset = dataset, | |
batch_size = 1, | |
shuffle = True, | |
num_workers = 1, | |
drop_last = True) | |
loader = iter(loader_) | |
img, points, pixs = loader.next() | |
crop_size = 128 | |
canvas = torch.zeros((1, 3, crop_size, crop_size)) | |
for i in range(points.shape[-2]): | |
p = (points[0, i] + 1) / 2.0 * (crop_size - 1) | |
canvas[0, :, int(p[0]), int(p[1])] = pixs[0, :, i] | |
vutils.save_image(canvas, 'canvas.png') | |
vutils.save_image(img, 'img.png') | |