Spaces:
Runtime error
Runtime error
File size: 5,217 Bytes
1b2a9b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
"""SLIC dataset
- Returns an image together with its SLIC segmentation map.
"""
import torch
import torch.utils.data as data
import torchvision.transforms as transforms
import numpy as np
from glob import glob
from PIL import Image
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from .custom_transform import *
class Dataset(data.Dataset):
def __init__(self, data_dir, img_size=256, crop_size=128, test=False,
sp_num=256, slic = True, lab = False):
super(Dataset, self).__init__()
#self.data_list = glob(os.path.join(data_dir, "*.jpg"))
ext = ["*.jpg"]
dl = []
[dl.extend(glob(data_dir + '/**/' + e, recursive=True)) for e in ext]
self.data_list = dl
self.sp_num = sp_num
self.slic = slic
self.lab = lab
if test:
self.transform = transforms.Compose([
transforms.Resize(img_size),
transforms.CenterCrop(crop_size)])
else:
self.transform = transforms.Compose([
transforms.Resize(int(img_size)),
transforms.RandomCrop(crop_size)])
N = len(self.data_list)
# eqv transform
self.random_horizontal_flip = RandomHorizontalTensorFlip(N=N)
self.random_vertical_flip = RandomVerticalFlip(N=N)
self.random_resized_crop = RandomResizedCrop(N=N, res=256)
# photometric transform
self.random_color_brightness = [RandomColorBrightness(x=0.3, p=0.8, N=N) for _ in range(2)] # Control this later (NOTE)]
self.random_color_contrast = [RandomColorContrast(x=0.3, p=0.8, N=N) for _ in range(2)] # Control this later (NOTE)
self.random_color_saturation = [RandomColorSaturation(x=0.3, p=0.8, N=N) for _ in range(2)] # Control this later (NOTE)
self.random_color_hue = [RandomColorHue(x=0.1, p=0.8, N=N) for _ in range(2)] # Control this later (NOTE)
self.random_gray_scale = [RandomGrayScale(p=0.2, N=N) for _ in range(2)]
self.random_gaussian_blur = [RandomGaussianBlur(sigma=[.1, 2.], p=0.5, N=N) for _ in range(2)]
self.eqv_list = ['random_crop', 'h_flip']
self.inv_list = ['brightness', 'contrast', 'saturation', 'hue', 'gray', 'blur']
self.transform_tensor = TensorTransform()
def transform_eqv(self, indice, image):
if 'random_crop' in self.eqv_list:
image = self.random_resized_crop(indice, image)
if 'h_flip' in self.eqv_list:
image = self.random_horizontal_flip(indice, image)
if 'v_flip' in self.eqv_list:
image = self.random_vertical_flip(indice, image)
return image
def transform_inv(self, index, image, ver):
"""
Hyperparameters same as MoCo v2.
(https://github.com/facebookresearch/moco/blob/master/main_moco.py)
"""
if 'brightness' in self.inv_list:
image = self.random_color_brightness[ver](index, image)
if 'contrast' in self.inv_list:
image = self.random_color_contrast[ver](index, image)
if 'saturation' in self.inv_list:
image = self.random_color_saturation[ver](index, image)
if 'hue' in self.inv_list:
image = self.random_color_hue[ver](index, image)
if 'gray' in self.inv_list:
image = self.random_gray_scale[ver](index, image)
if 'blur' in self.inv_list:
image = self.random_gaussian_blur[ver](index, image)
return image
def transform_image(self, index, image):
image1 = self.transform_inv(index, image, 0)
image1 = self.transform_tensor(image)
image2 = self.transform_inv(index, image, 1)
#image2 = TF.resize(image2, self.crop_size, Image.BILINEAR)
image2 = self.transform_tensor(image2)
return image1, image2
def __getitem__(self, index):
data_path = self.data_list[index]
ori_img = Image.open(data_path)
ori_img = self.transform(ori_img)
image1, image2 = self.transform_image(index, ori_img)
rets = []
rets.append(image1)
rets.append(image2)
rets.append(index)
return rets
def __len__(self):
return len(self.data_list)
if __name__ == '__main__':
import torchvision.utils as vutils
dataset = Dataset('/home/xtli/DATA/texture_data/',
sampled_num=3000)
loader_ = torch.utils.data.DataLoader(dataset = dataset,
batch_size = 1,
shuffle = True,
num_workers = 1,
drop_last = True)
loader = iter(loader_)
img, points, pixs = loader.next()
crop_size = 128
canvas = torch.zeros((1, 3, crop_size, crop_size))
for i in range(points.shape[-2]):
p = (points[0, i] + 1) / 2.0 * (crop_size - 1)
canvas[0, :, int(p[0]), int(p[1])] = pixs[0, :, i]
vutils.save_image(canvas, 'canvas.png')
vutils.save_image(img, 'img.png')
|