gyrojeff's picture
fix: typo
daa52ce
raw
history blame
3.98 kB
from font_dataset.fontlabel import FontLabel
from font_dataset.font import DSFont, load_font_with_exclusion
from . import config
import math
import os
import pickle
import torch
import torchvision.transforms as transforms
from typing import List, Dict, Tuple
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning import LightningDataModule
from PIL import Image
class FontDataset(Dataset):
def __init__(self, path: str, config_path: str = "configs/font.yml"):
self.path = path
self.fonts = load_font_with_exclusion(config_path)
self.images = [
os.path.join(path, f) for f in os.listdir(path) if f.endswith(".jpg")
]
self.images.sort()
def __len__(self):
return len(self.images)
def fontlabel2tensor(self, label: FontLabel, label_path) -> torch.Tensor:
out = torch.zeros(12, dtype=torch.float)
try:
out[0] = self.fonts[label.font.path]
except KeyError:
print(f"Unqualified font: {label.font.path}")
print(f"Label path: {label_path}")
raise KeyError
out[1] = 0 if label.text_direction == "ltr" else 1
# [0, 1]
out[2] = label.text_color[0] / 255.0
out[3] = label.text_color[1] / 255.0
out[4] = label.text_color[2] / 255.0
out[5] = label.text_size / label.image_width
out[6] = label.stroke_width / label.image_width
if label.stroke_color:
out[7] = label.stroke_color[0] / 255.0
out[8] = label.stroke_color[1] / 255.0
out[9] = label.stroke_color[2] / 255.0
else:
out[7:10] = 0.5
out[10] = label.line_spacing / label.image_width
out[11] = label.angle / 180.0 + 0.5
return out
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
# Load image
image_path = self.images[index]
image = Image.open(image_path).convert("RGB")
transform = transforms.Compose(
[
transforms.Resize((config.INPUT_SIZE, config.INPUT_SIZE)),
transforms.ToTensor(),
]
)
image = transform(image)
# Load label
label_path = image_path.replace(".jpg", ".bin")
with open(label_path, "rb") as f:
label: FontLabel = pickle.load(f)
# encode label
label = self.fontlabel2tensor(label, label_path)
return image, label
class FontDataModule(LightningDataModule):
def __init__(
self,
config_path: str = "configs/font.yml",
train_path: str = "./dataset/font_img/train",
val_path: str = "./dataset/font_img/val",
test_path: str = "./dataset/font_img/test",
train_shuffle: bool = True,
val_shuffle: bool = False,
test_shuffle: bool = False,
**kwargs,
):
super().__init__()
self.dataloader_args = kwargs
self.train_shuffle = train_shuffle
self.val_shuffle = val_shuffle
self.test_shuffle = test_shuffle
self.train_dataset = FontDataset(train_path, config_path)
self.val_dataset = FontDataset(val_path, config_path)
self.test_dataset = FontDataset(test_path, config_path)
def get_train_num_iter(self, num_device: int) -> int:
return math.ceil(
len(self.train_dataset) / (self.dataloader_args["batch_size"] * num_device)
)
def train_dataloader(self):
return DataLoader(
self.train_dataset,
shuffle=self.train_shuffle,
**self.dataloader_args,
)
def val_dataloader(self):
return DataLoader(
self.val_dataset,
shuffle=self.val_shuffle,
**self.dataloader_args,
)
def test_dataloader(self):
return DataLoader(
self.test_dataset,
shuffle=self.test_shuffle,
**self.dataloader_args,
)