FontDiffuser-Gradio / dataset /font_dataset.py
yeungchenwa's picture
[Update] Add files and checkpoint
508b842
raw
history blame contribute delete
No virus
2.66 kB
import os
import random
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
def get_nonorm_transform(resolution):
nonorm_transform = transforms.Compose(
[transforms.Resize((resolution, resolution),
interpolation=transforms.InterpolationMode.BILINEAR),
transforms.ToTensor()])
return nonorm_transform
class FontDataset(Dataset):
"""The dataset of font generation
"""
def __init__(self, args, phase, transforms=None):
super().__init__()
self.root = args.data_root
self.phase = phase
# Get Data path
self.get_path()
self.transforms = transforms
self.nonorm_transforms = get_nonorm_transform(args.resolution)
def get_path(self):
self.target_images = []
# images with related style
self.style_to_images = {}
target_image_dir = f"{self.root}/{self.phase}/TargetImage"
for style in os.listdir(target_image_dir):
images_related_style = []
for img in os.listdir(f"{target_image_dir}/{style}"):
img_path = f"{target_image_dir}/{style}/{img}"
self.target_images.append(img_path)
images_related_style.append(img_path)
self.style_to_images[style] = images_related_style
def __getitem__(self, index):
target_image_path = self.target_images[index]
target_image_name = target_image_path.split('/')[-1]
style, content = target_image_name.split('.')[0].split('+')
# Read content image
content_image_path = f"{self.root}/{self.phase}/ContentImage/{content}.jpg"
content_image = Image.open(content_image_path).convert('RGB')
# Random sample used for style image
images_related_style = self.style_to_images[style].copy()
images_related_style.remove(target_image_path)
style_image_path = random.choice(images_related_style)
style_image = Image.open(style_image_path).convert("RGB")
# Read target image
target_image = Image.open(target_image_path).convert("RGB")
nonorm_target_image = self.nonorm_transforms(target_image)
if self.transforms is not None:
content_image = self.transforms[0](content_image)
style_image = self.transforms[1](style_image)
target_image = self.transforms[2](target_image)
return content_image, style_image, target_image, nonorm_target_image, target_image_path
def __len__(self):
return len(self.target_images)