TextureScraping / libs /data_geo_pho.py
sunshineatnoon
Add application file
1b2a9b1
raw
history blame
5.22 kB
"""SLIC dataset
- Returns an image together with its SLIC segmentation map.
"""
import torch
import torch.utils.data as data
import torchvision.transforms as transforms
import numpy as np
from glob import glob
from PIL import Image
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from .custom_transform import *
class Dataset(data.Dataset):
def __init__(self, data_dir, img_size=256, crop_size=128, test=False,
sp_num=256, slic = True, lab = False):
super(Dataset, self).__init__()
#self.data_list = glob(os.path.join(data_dir, "*.jpg"))
ext = ["*.jpg"]
dl = []
[dl.extend(glob(data_dir + '/**/' + e, recursive=True)) for e in ext]
self.data_list = dl
self.sp_num = sp_num
self.slic = slic
self.lab = lab
if test:
self.transform = transforms.Compose([
transforms.Resize(img_size),
transforms.CenterCrop(crop_size)])
else:
self.transform = transforms.Compose([
transforms.Resize(int(img_size)),
transforms.RandomCrop(crop_size)])
N = len(self.data_list)
# eqv transform
self.random_horizontal_flip = RandomHorizontalTensorFlip(N=N)
self.random_vertical_flip = RandomVerticalFlip(N=N)
self.random_resized_crop = RandomResizedCrop(N=N, res=256)
# photometric transform
self.random_color_brightness = [RandomColorBrightness(x=0.3, p=0.8, N=N) for _ in range(2)] # Control this later (NOTE)]
self.random_color_contrast = [RandomColorContrast(x=0.3, p=0.8, N=N) for _ in range(2)] # Control this later (NOTE)
self.random_color_saturation = [RandomColorSaturation(x=0.3, p=0.8, N=N) for _ in range(2)] # Control this later (NOTE)
self.random_color_hue = [RandomColorHue(x=0.1, p=0.8, N=N) for _ in range(2)] # Control this later (NOTE)
self.random_gray_scale = [RandomGrayScale(p=0.2, N=N) for _ in range(2)]
self.random_gaussian_blur = [RandomGaussianBlur(sigma=[.1, 2.], p=0.5, N=N) for _ in range(2)]
self.eqv_list = ['random_crop', 'h_flip']
self.inv_list = ['brightness', 'contrast', 'saturation', 'hue', 'gray', 'blur']
self.transform_tensor = TensorTransform()
def transform_eqv(self, indice, image):
if 'random_crop' in self.eqv_list:
image = self.random_resized_crop(indice, image)
if 'h_flip' in self.eqv_list:
image = self.random_horizontal_flip(indice, image)
if 'v_flip' in self.eqv_list:
image = self.random_vertical_flip(indice, image)
return image
def transform_inv(self, index, image, ver):
"""
Hyperparameters same as MoCo v2.
(https://github.com/facebookresearch/moco/blob/master/main_moco.py)
"""
if 'brightness' in self.inv_list:
image = self.random_color_brightness[ver](index, image)
if 'contrast' in self.inv_list:
image = self.random_color_contrast[ver](index, image)
if 'saturation' in self.inv_list:
image = self.random_color_saturation[ver](index, image)
if 'hue' in self.inv_list:
image = self.random_color_hue[ver](index, image)
if 'gray' in self.inv_list:
image = self.random_gray_scale[ver](index, image)
if 'blur' in self.inv_list:
image = self.random_gaussian_blur[ver](index, image)
return image
def transform_image(self, index, image):
image1 = self.transform_inv(index, image, 0)
image1 = self.transform_tensor(image)
image2 = self.transform_inv(index, image, 1)
#image2 = TF.resize(image2, self.crop_size, Image.BILINEAR)
image2 = self.transform_tensor(image2)
return image1, image2
def __getitem__(self, index):
data_path = self.data_list[index]
ori_img = Image.open(data_path)
ori_img = self.transform(ori_img)
image1, image2 = self.transform_image(index, ori_img)
rets = []
rets.append(image1)
rets.append(image2)
rets.append(index)
return rets
def __len__(self):
return len(self.data_list)
if __name__ == '__main__':
import torchvision.utils as vutils
dataset = Dataset('/home/xtli/DATA/texture_data/',
sampled_num=3000)
loader_ = torch.utils.data.DataLoader(dataset = dataset,
batch_size = 1,
shuffle = True,
num_workers = 1,
drop_last = True)
loader = iter(loader_)
img, points, pixs = loader.next()
crop_size = 128
canvas = torch.zeros((1, 3, crop_size, crop_size))
for i in range(points.shape[-2]):
p = (points[0, i] + 1) / 2.0 * (crop_size - 1)
canvas[0, :, int(p[0]), int(p[1])] = pixs[0, :, i]
vutils.save_image(canvas, 'canvas.png')
vutils.save_image(img, 'img.png')