Spaces:
Runtime error
Runtime error
"""SLIC dataset | |
- Returns an image together with its SLIC segmentation map. | |
""" | |
import torch | |
import torch.utils.data as data | |
import torchvision.transforms as transforms | |
import numpy as np | |
from glob import glob | |
from PIL import Image | |
import torch.nn.functional as F | |
import torchvision.transforms.functional as TF | |
from .custom_transform import * | |
class Dataset(data.Dataset): | |
def __init__(self, data_dir, img_size=256, crop_size=128, test=False, | |
sp_num=256, slic = True, lab = False): | |
super(Dataset, self).__init__() | |
#self.data_list = glob(os.path.join(data_dir, "*.jpg")) | |
ext = ["*.jpg"] | |
dl = [] | |
[dl.extend(glob(data_dir + '/**/' + e, recursive=True)) for e in ext] | |
self.data_list = dl | |
self.sp_num = sp_num | |
self.slic = slic | |
self.lab = lab | |
if test: | |
self.transform = transforms.Compose([ | |
transforms.Resize(img_size), | |
transforms.CenterCrop(crop_size)]) | |
else: | |
self.transform = transforms.Compose([ | |
transforms.Resize(int(img_size)), | |
transforms.RandomCrop(crop_size)]) | |
N = len(self.data_list) | |
# eqv transform | |
self.random_horizontal_flip = RandomHorizontalTensorFlip(N=N) | |
self.random_vertical_flip = RandomVerticalFlip(N=N) | |
self.random_resized_crop = RandomResizedCrop(N=N, res=256) | |
# photometric transform | |
self.random_color_brightness = [RandomColorBrightness(x=0.3, p=0.8, N=N) for _ in range(2)] # Control this later (NOTE)] | |
self.random_color_contrast = [RandomColorContrast(x=0.3, p=0.8, N=N) for _ in range(2)] # Control this later (NOTE) | |
self.random_color_saturation = [RandomColorSaturation(x=0.3, p=0.8, N=N) for _ in range(2)] # Control this later (NOTE) | |
self.random_color_hue = [RandomColorHue(x=0.1, p=0.8, N=N) for _ in range(2)] # Control this later (NOTE) | |
self.random_gray_scale = [RandomGrayScale(p=0.2, N=N) for _ in range(2)] | |
self.random_gaussian_blur = [RandomGaussianBlur(sigma=[.1, 2.], p=0.5, N=N) for _ in range(2)] | |
self.eqv_list = ['random_crop', 'h_flip'] | |
self.inv_list = ['brightness', 'contrast', 'saturation', 'hue', 'gray', 'blur'] | |
self.transform_tensor = TensorTransform() | |
def transform_eqv(self, indice, image): | |
if 'random_crop' in self.eqv_list: | |
image = self.random_resized_crop(indice, image) | |
if 'h_flip' in self.eqv_list: | |
image = self.random_horizontal_flip(indice, image) | |
if 'v_flip' in self.eqv_list: | |
image = self.random_vertical_flip(indice, image) | |
return image | |
def transform_inv(self, index, image, ver): | |
""" | |
Hyperparameters same as MoCo v2. | |
(https://github.com/facebookresearch/moco/blob/master/main_moco.py) | |
""" | |
if 'brightness' in self.inv_list: | |
image = self.random_color_brightness[ver](index, image) | |
if 'contrast' in self.inv_list: | |
image = self.random_color_contrast[ver](index, image) | |
if 'saturation' in self.inv_list: | |
image = self.random_color_saturation[ver](index, image) | |
if 'hue' in self.inv_list: | |
image = self.random_color_hue[ver](index, image) | |
if 'gray' in self.inv_list: | |
image = self.random_gray_scale[ver](index, image) | |
if 'blur' in self.inv_list: | |
image = self.random_gaussian_blur[ver](index, image) | |
return image | |
def transform_image(self, index, image): | |
image1 = self.transform_inv(index, image, 0) | |
image1 = self.transform_tensor(image) | |
image2 = self.transform_inv(index, image, 1) | |
#image2 = TF.resize(image2, self.crop_size, Image.BILINEAR) | |
image2 = self.transform_tensor(image2) | |
return image1, image2 | |
def __getitem__(self, index): | |
data_path = self.data_list[index] | |
ori_img = Image.open(data_path) | |
ori_img = self.transform(ori_img) | |
image1, image2 = self.transform_image(index, ori_img) | |
rets = [] | |
rets.append(image1) | |
rets.append(image2) | |
rets.append(index) | |
return rets | |
def __len__(self): | |
return len(self.data_list) | |
if __name__ == '__main__': | |
import torchvision.utils as vutils | |
dataset = Dataset('/home/xtli/DATA/texture_data/', | |
sampled_num=3000) | |
loader_ = torch.utils.data.DataLoader(dataset = dataset, | |
batch_size = 1, | |
shuffle = True, | |
num_workers = 1, | |
drop_last = True) | |
loader = iter(loader_) | |
img, points, pixs = loader.next() | |
crop_size = 128 | |
canvas = torch.zeros((1, 3, crop_size, crop_size)) | |
for i in range(points.shape[-2]): | |
p = (points[0, i] + 1) / 2.0 * (crop_size - 1) | |
canvas[0, :, int(p[0]), int(p[1])] = pixs[0, :, i] | |
vutils.save_image(canvas, 'canvas.png') | |
vutils.save_image(img, 'img.png') | |