|
from font_dataset.fontlabel import FontLabel |
|
from font_dataset.font import DSFont, load_font_with_exclusion |
|
from . import config |
|
|
|
|
|
import math |
|
import os |
|
import pickle |
|
import torch |
|
import torchvision.transforms as transforms |
|
from typing import List, Dict, Tuple |
|
from torch.utils.data import Dataset, DataLoader |
|
from pytorch_lightning import LightningDataModule |
|
from PIL import Image |
|
|
|
|
|
class FontDataset(Dataset): |
|
def __init__(self, path: str, config_path: str = "configs/font.yml"): |
|
self.path = path |
|
self.fonts = load_font_with_exclusion(config_path) |
|
|
|
self.images = [ |
|
os.path.join(path, f) for f in os.listdir(path) if f.endswith(".jpg") |
|
] |
|
self.images.sort() |
|
|
|
def __len__(self): |
|
return len(self.images) |
|
|
|
def fontlabel2tensor(self, label: FontLabel, label_path) -> torch.Tensor: |
|
out = torch.zeros(12, dtype=torch.float) |
|
try: |
|
out[0] = self.fonts[label.font.path] |
|
except KeyError: |
|
print(f"Unqualified font: {label.font.path}") |
|
print(f"Label path: {label_path}") |
|
raise KeyError |
|
out[1] = 0 if label.text_direction == "ltr" else 1 |
|
|
|
out[2] = label.text_color[0] / 255.0 |
|
out[3] = label.text_color[1] / 255.0 |
|
out[4] = label.text_color[2] / 255.0 |
|
out[5] = label.text_size / label.image_width |
|
out[6] = label.stroke_width / label.image_width |
|
if label.stroke_color: |
|
out[7] = label.stroke_color[0] / 255.0 |
|
out[8] = label.stroke_color[1] / 255.0 |
|
out[9] = label.stroke_color[2] / 255.0 |
|
else: |
|
out[7:10] = 0.5 |
|
out[10] = label.line_spacing / label.image_width |
|
out[11] = label.angle / 180.0 + 0.5 |
|
|
|
return out |
|
|
|
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
image_path = self.images[index] |
|
image = Image.open(image_path).convert("RGB") |
|
|
|
transform = transforms.Compose( |
|
[ |
|
transforms.Resize((config.INPUT_SIZE, config.INPUT_SIZE)), |
|
transforms.ToTensor(), |
|
] |
|
) |
|
image = transform(image) |
|
|
|
|
|
label_path = image_path.replace(".jpg", ".bin") |
|
with open(label_path, "rb") as f: |
|
label: FontLabel = pickle.load(f) |
|
|
|
|
|
label = self.fontlabel2tensor(label, label_path) |
|
|
|
return image, label |
|
|
|
|
|
class FontDataModule(LightningDataModule): |
|
def __init__( |
|
self, |
|
config_path: str = "configs/font.yml", |
|
train_path: str = "./dataset/font_img/train", |
|
val_path: str = "./dataset/font_img/val", |
|
test_path: str = "./dataset/font_img/test", |
|
train_shuffle: bool = True, |
|
val_shuffle: bool = False, |
|
test_shuffle: bool = False, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
self.dataloader_args = kwargs |
|
self.train_shuffle = train_shuffle |
|
self.val_shuffle = val_shuffle |
|
self.test_shuffle = test_shuffle |
|
self.train_dataset = FontDataset(train_path, config_path) |
|
self.val_dataset = FontDataset(val_path, config_path) |
|
self.test_dataset = FontDataset(test_path, config_path) |
|
|
|
def get_train_num_iter(self, num_device: int) -> int: |
|
return math.ceil( |
|
len(self.train_dataset) / (self.dataloader_args["batch_size"] * num_device) |
|
) |
|
|
|
def train_dataloader(self): |
|
return DataLoader( |
|
self.train_dataset, |
|
shuffle=self.train_shuffle, |
|
**self.dataloader_args, |
|
) |
|
|
|
def val_dataloader(self): |
|
return DataLoader( |
|
self.val_dataset, |
|
shuffle=self.val_shuffle, |
|
**self.dataloader_args, |
|
) |
|
|
|
def test_dataloader(self): |
|
return DataLoader( |
|
self.test_dataset, |
|
shuffle=self.test_shuffle, |
|
**self.dataloader_args, |
|
) |
|
|