import os import random from PIL import Image from torch.utils.data import Dataset import torchvision.transforms as transforms def get_nonorm_transform(resolution): nonorm_transform = transforms.Compose( [transforms.Resize((resolution, resolution), interpolation=transforms.InterpolationMode.BILINEAR), transforms.ToTensor()]) return nonorm_transform class FontDataset(Dataset): """The dataset of font generation """ def __init__(self, args, phase, transforms=None): super().__init__() self.root = args.data_root self.phase = phase # Get Data path self.get_path() self.transforms = transforms self.nonorm_transforms = get_nonorm_transform(args.resolution) def get_path(self): self.target_images = [] # images with related style self.style_to_images = {} target_image_dir = f"{self.root}/{self.phase}/TargetImage" for style in os.listdir(target_image_dir): images_related_style = [] for img in os.listdir(f"{target_image_dir}/{style}"): img_path = f"{target_image_dir}/{style}/{img}" self.target_images.append(img_path) images_related_style.append(img_path) self.style_to_images[style] = images_related_style def __getitem__(self, index): target_image_path = self.target_images[index] target_image_name = target_image_path.split('/')[-1] style, content = target_image_name.split('.')[0].split('+') # Read content image content_image_path = f"{self.root}/{self.phase}/ContentImage/{content}.jpg" content_image = Image.open(content_image_path).convert('RGB') # Random sample used for style image images_related_style = self.style_to_images[style].copy() images_related_style.remove(target_image_path) style_image_path = random.choice(images_related_style) style_image = Image.open(style_image_path).convert("RGB") # Read target image target_image = Image.open(target_image_path).convert("RGB") nonorm_target_image = self.nonorm_transforms(target_image) if self.transforms is not None: content_image = self.transforms[0](content_image) style_image = self.transforms[1](style_image) target_image = self.transforms[2](target_image) return content_image, style_image, target_image, nonorm_target_image, target_image_path def __len__(self): return len(self.target_images)