|
import os |
|
import glob |
|
import torch |
|
import random |
|
from PIL import Image |
|
from torch.utils import data |
|
from torchvision import transforms as T |
|
|
|
class data_prefetcher(): |
|
def __init__(self, loader): |
|
self.loader = loader |
|
self.dataiter = iter(loader) |
|
self.stream = torch.cuda.Stream() |
|
self.mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(1,3,1,1) |
|
self.std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(1,3,1,1) |
|
|
|
|
|
|
|
|
|
self.num_images = len(loader) |
|
self.preload() |
|
|
|
def preload(self): |
|
try: |
|
self.src_image1, self.src_image2 = next(self.dataiter) |
|
except StopIteration: |
|
self.dataiter = iter(self.loader) |
|
self.src_image1, self.src_image2 = next(self.dataiter) |
|
|
|
with torch.cuda.stream(self.stream): |
|
self.src_image1 = self.src_image1.cuda(non_blocking=True) |
|
self.src_image1 = self.src_image1.sub_(self.mean).div_(self.std) |
|
self.src_image2 = self.src_image2.cuda(non_blocking=True) |
|
self.src_image2 = self.src_image2.sub_(self.mean).div_(self.std) |
|
|
|
def next(self): |
|
torch.cuda.current_stream().wait_stream(self.stream) |
|
src_image1 = self.src_image1 |
|
src_image2 = self.src_image2 |
|
self.preload() |
|
return src_image1, src_image2 |
|
|
|
def __len__(self): |
|
"""Return the number of images.""" |
|
return self.num_images |
|
|
|
class SwappingDataset(data.Dataset): |
|
"""Dataset class for the Artworks dataset and content dataset.""" |
|
|
|
def __init__(self, |
|
image_dir, |
|
img_transform, |
|
subffix='jpg', |
|
random_seed=1234): |
|
"""Initialize and preprocess the Swapping dataset.""" |
|
self.image_dir = image_dir |
|
self.img_transform = img_transform |
|
self.subffix = subffix |
|
self.dataset = [] |
|
self.random_seed = random_seed |
|
self.preprocess() |
|
self.num_images = len(self.dataset) |
|
|
|
def preprocess(self): |
|
"""Preprocess the Swapping dataset.""" |
|
print("processing Swapping dataset images...") |
|
|
|
temp_path = os.path.join(self.image_dir,'*/') |
|
pathes = glob.glob(temp_path) |
|
self.dataset = [] |
|
for dir_item in pathes: |
|
join_path = glob.glob(os.path.join(dir_item,'*.jpg')) |
|
print("processing %s"%dir_item,end='\r') |
|
temp_list = [] |
|
for item in join_path: |
|
temp_list.append(item) |
|
self.dataset.append(temp_list) |
|
random.seed(self.random_seed) |
|
random.shuffle(self.dataset) |
|
print('Finished preprocessing the Swapping dataset, total dirs number: %d...'%len(self.dataset)) |
|
|
|
def __getitem__(self, index): |
|
"""Return two src domain images and two dst domain images.""" |
|
dir_tmp1 = self.dataset[index] |
|
dir_tmp1_len = len(dir_tmp1) |
|
|
|
filename1 = dir_tmp1[random.randint(0,dir_tmp1_len-1)] |
|
filename2 = dir_tmp1[random.randint(0,dir_tmp1_len-1)] |
|
image1 = self.img_transform(Image.open(filename1)) |
|
image2 = self.img_transform(Image.open(filename2)) |
|
return image1, image2 |
|
|
|
def __len__(self): |
|
"""Return the number of images.""" |
|
return self.num_images |
|
|
|
def GetLoader( dataset_roots, |
|
batch_size=16, |
|
dataloader_workers=8, |
|
random_seed = 1234 |
|
): |
|
"""Build and return a data loader.""" |
|
|
|
num_workers = dataloader_workers |
|
data_root = dataset_roots |
|
random_seed = random_seed |
|
|
|
c_transforms = [] |
|
|
|
c_transforms.append(T.ToTensor()) |
|
c_transforms = T.Compose(c_transforms) |
|
|
|
content_dataset = SwappingDataset( |
|
data_root, |
|
c_transforms, |
|
"jpg", |
|
random_seed) |
|
content_data_loader = data.DataLoader(dataset=content_dataset,batch_size=batch_size, |
|
drop_last=True,shuffle=True,num_workers=num_workers,pin_memory=True) |
|
prefetcher = data_prefetcher(content_data_loader) |
|
return prefetcher |
|
|
|
def denorm(x): |
|
out = (x + 1) / 2 |
|
return out.clamp_(0, 1) |