|
import os |
|
import random |
|
from PIL import Image |
|
|
|
from torch.utils.data import Dataset |
|
import torchvision.transforms as transforms |
|
|
|
def get_nonorm_transform(resolution): |
|
nonorm_transform = transforms.Compose( |
|
[transforms.Resize((resolution, resolution), |
|
interpolation=transforms.InterpolationMode.BILINEAR), |
|
transforms.ToTensor()]) |
|
return nonorm_transform |
|
|
|
|
|
class FontDataset(Dataset): |
|
"""The dataset of font generation |
|
""" |
|
def __init__(self, args, phase, transforms=None): |
|
super().__init__() |
|
self.root = args.data_root |
|
self.phase = phase |
|
|
|
|
|
self.get_path() |
|
self.transforms = transforms |
|
self.nonorm_transforms = get_nonorm_transform(args.resolution) |
|
|
|
def get_path(self): |
|
self.target_images = [] |
|
|
|
self.style_to_images = {} |
|
target_image_dir = f"{self.root}/{self.phase}/TargetImage" |
|
for style in os.listdir(target_image_dir): |
|
images_related_style = [] |
|
for img in os.listdir(f"{target_image_dir}/{style}"): |
|
img_path = f"{target_image_dir}/{style}/{img}" |
|
self.target_images.append(img_path) |
|
images_related_style.append(img_path) |
|
self.style_to_images[style] = images_related_style |
|
|
|
def __getitem__(self, index): |
|
target_image_path = self.target_images[index] |
|
target_image_name = target_image_path.split('/')[-1] |
|
style, content = target_image_name.split('.')[0].split('+') |
|
|
|
|
|
content_image_path = f"{self.root}/{self.phase}/ContentImage/{content}.jpg" |
|
content_image = Image.open(content_image_path).convert('RGB') |
|
|
|
|
|
images_related_style = self.style_to_images[style].copy() |
|
images_related_style.remove(target_image_path) |
|
style_image_path = random.choice(images_related_style) |
|
style_image = Image.open(style_image_path).convert("RGB") |
|
|
|
|
|
target_image = Image.open(target_image_path).convert("RGB") |
|
nonorm_target_image = self.nonorm_transforms(target_image) |
|
|
|
if self.transforms is not None: |
|
content_image = self.transforms[0](content_image) |
|
style_image = self.transforms[1](style_image) |
|
target_image = self.transforms[2](target_image) |
|
|
|
return content_image, style_image, target_image, nonorm_target_image, target_image_path |
|
|
|
def __len__(self): |
|
return len(self.target_images) |
|
|