Spaces:
Runtime error
Runtime error
import os | |
import pickle | |
import random | |
import string | |
import json | |
import logging | |
from pathlib import Path | |
from omegaconf import OmegaConf | |
import numpy as np | |
import PIL.Image as Image | |
import torch | |
from torch.utils.data import Dataset | |
from tqdm import tqdm | |
REPEATE_NUM = 10000 | |
WHITE = 255 | |
MAX_TRIAL = 10 | |
_upper_case = set(map(lambda s: f"{ord(s):04X}", string.ascii_uppercase)) | |
_digits = set(map(lambda s: f"{ord(s):04X}", string.digits)) | |
english_set = list(_upper_case.union(_digits)) | |
NOTO_FONT_DIRNAME = "Noto" | |
class GoogleFontDataset(Dataset): | |
def __init__(self, args, mode='train', | |
metadata_path="./lang_set.json"): | |
super(GoogleFontDataset, self).__init__() | |
self.args = args | |
self.font_dir = Path(args.font_dir) | |
self.mode = mode | |
self.lang_list = sorted([x.stem for x in self.font_dir.iterdir() if x.is_dir()]) | |
self.min_tight_bound = 10000 | |
self.min_font_name = None | |
if self.mode == 'train': | |
self.lang_list = self.lang_list[:-2] | |
else: | |
self.lang_list = self.lang_list[-2:] | |
with open(metadata_path, "r") as json_f: | |
self.data = json.load(json_f) | |
self.num_lang = None | |
self.num_font = None | |
self.num_char = None | |
self.content_meta, self.style_meta, self.num_lang, self.num_font, self.num_char = self.get_meta() | |
logging.info(f"min_tight_bound: {self.min_tight_bound}") # 20 | |
def center_align(bg_img, item_img, fit=False): | |
bg_img = bg_img.copy() | |
item_img = item_img.copy() | |
item_w, item_h = item_img.size | |
W, H = bg_img.size | |
if fit: | |
item_ratio = item_w / item_h | |
bg_ratio = W / H | |
if bg_ratio > item_ratio: | |
# height fitting | |
resize_ratio = H / item_h | |
else: | |
# width fitting | |
resize_ratio = W / item_w | |
item_img = item_img.resize((int(item_w * resize_ratio), int(item_h * resize_ratio))) | |
item_w, item_h = item_img.size | |
bg_img.paste(item_img, ((W - item_w) // 2, (H - item_h) // 2)) | |
return bg_img | |
def _get_content_image(self, png_path): | |
im = Image.open(png_path) | |
bg_img = Image.new('RGB', (self.args.imsize, self.args.imsize), color='white') | |
blend_img = self.center_align(bg_img, im, fit=True) | |
return blend_img | |
def _get_style_image(self, png_path): | |
im = Image.open(png_path) | |
w, h = im.size | |
# tight_bound_check & update | |
tight_bound = self.get_tight_bound_size(np.array(im)) | |
if self.min_tight_bound > tight_bound: | |
self.min_tight_bound = tight_bound | |
self.min_font_name = png_path | |
logging.debug(f"min_tight_bound: {self.min_tight_bound}, min_font_name: {self.min_font_name}") | |
bg_img = Image.new('RGB', (max([w, h, self.args.imsize]), max([w, h, self.args.imsize])), color='white') | |
blend_img = self.center_align(bg_img, im) | |
return blend_img | |
def get_meta(self): | |
content_meta = dict() | |
style_meta = dict() | |
num_lang = 0 | |
num_font = 0 | |
num_char = 0 | |
for lang_dir in tqdm(self.lang_list, total=len(self.lang_list)): | |
font_list = sorted([x for x in (self.font_dir / lang_dir).iterdir() if x.is_dir()]) | |
font_content_dict = dict() | |
font_style_dict = dict() | |
for font_dir in font_list: | |
image_content_dict = dict() | |
image_style_dict = dict() | |
png_list = [x for x in font_dir.glob("*.png")] | |
for png_path in png_list: | |
# image_content_dict[png_path.stem] = self._get_content_image(png_path) | |
# image_style_dict[png_path.stem] = self._get_style_image(png_path) | |
image_content_dict[png_path.stem] = png_path | |
image_style_dict[png_path.stem] = png_path | |
num_char += 1 | |
font_content_dict[font_dir.stem] = image_content_dict | |
font_style_dict[font_dir.stem] = image_style_dict | |
num_font += 1 | |
content_meta[lang_dir] = font_content_dict | |
style_meta[lang_dir] = font_style_dict | |
num_lang += 1 | |
return content_meta, style_meta, num_lang, num_font, num_char | |
def get_tight_bound_size(img): | |
contents_cell = np.where(img < WHITE) | |
if len(contents_cell[0]) == 0: | |
return 0 | |
size = { | |
'xmin': np.min(contents_cell[1]), | |
'ymin': np.min(contents_cell[0]), | |
'xmax': np.max(contents_cell[1]) + 1, | |
'ymax': np.max(contents_cell[0]) + 1, | |
} | |
return max(size['xmax'] - size['xmin'], size['ymax'] - size['ymin']) | |
def get_patch_from_style_image(self, image, patch_per_image=1): | |
w, h = image.size | |
image_list = [] | |
relative_patch_size = int(self.args.imsize * 2) | |
for _ in range(patch_per_image): | |
offset = w - relative_patch_size | |
if offset < relative_patch_size // 2: | |
# if image is too small, just resize | |
crop_candidate = np.array(image.resize((self.args.imsize, self.args.imsize))) | |
else: | |
# if image is sufficent to be cropped, randomly crop | |
x = np.random.randint(0, offset) | |
y = np.random.randint(0, offset) | |
crop_candidate = image.crop((x, y, x + relative_patch_size, y + relative_patch_size)) | |
_trial = 0 | |
while self.get_tight_bound_size(np.array(crop_candidate)) < relative_patch_size // 16 and _trial < MAX_TRIAL: | |
x = np.random.randint(0, offset) | |
y = np.random.randint(0, offset) | |
crop_candidate = image.crop((x, y, x + relative_patch_size, y + relative_patch_size)) | |
_trial += 1 | |
crop_candidate = np.array(crop_candidate.resize((self.args.imsize, self.args.imsize))) | |
image_list.append(crop_candidate) | |
return image_list | |
def get_pairs(self, content_english=False, style_english=False): | |
lang_content = random.choice(self.lang_list) | |
content_unicode_list = english_set if content_english else self.data[lang_content] | |
style_unicode_list = english_set if style_english else self.data[lang_content] | |
if content_english == style_english: | |
# content_unicode_list == style_unicode_list | |
chars = random.sample(content_unicode_list, | |
k=self.args.reference_imgs.style + 1) | |
content_char = chars[-1] | |
style_chars = chars[:self.args.reference_imgs.style] | |
else: | |
content_char = random.choice(content_unicode_list) | |
style_chars = random.sample(style_unicode_list, k=self.args.reference_imgs.style) | |
# fonts = random.sample(self.content_meta[lang_content].keys(), | |
# k=self.args.reference_imgs.char + 1) | |
# content_fonts = fonts[:self.args.reference_imgs.char] | |
# style_font = fonts[-1] | |
style_font_list = list(self.content_meta[lang_content].keys()) | |
style_font_list.remove(NOTO_FONT_DIRNAME) | |
style_font = random.choice(style_font_list) | |
content_fonts = [NOTO_FONT_DIRNAME] | |
content_fonts_image = [self.content_meta[lang_content][x][content_char] for x in content_fonts] | |
style_chars_image = [self.content_meta[lang_content][style_font][x] for x in style_chars] | |
# style_chars_image = [self.content_meta[lang_content][style_font][x] for x in style_chars] | |
# style_chars_cropped = [] | |
# for style_char_image in style_chars_image: | |
# style_chars_cropped.extend(self.get_patch_from_style_image(style_char_image, | |
# patch_per_image=self.args.reference_imgs.style // self.args.reference_imgs.char)) | |
target_image = self.content_meta[lang_content][style_font][content_char] | |
content_fonts_image = [self._get_content_image(image_path) for image_path in content_fonts_image] | |
style_chars_image = [self._get_content_image(image_path) for image_path in style_chars_image] | |
target_image = self._get_content_image(target_image) | |
return content_char, content_fonts, content_fonts_image, style_font, style_chars, style_chars_image, target_image | |
def __getitem__(self, idx): | |
"""GoogleFontDataset의 __getitem__ | |
Args: | |
idx (int): torch dataset index | |
Returns: | |
dict: return dict with following keys | |
gt_images: target_image, | |
content_images: same_chars_image, | |
style_images: same_fonts_image, | |
style_idx: font_idx, | |
char_idx: char_idx, | |
content_image_idxs: same_chars, | |
style_image_idxs: same_fonts, | |
image_paths: '' | |
""" | |
use_eng_content, use_eng_style = random.choice([(True, False), (False, True), (False, False)]) | |
if self.mode != 'train': | |
use_eng_content = False | |
use_eng_style = True | |
content_char, content_fonts, content_fonts_image, style_font, style_chars, style_chars_image, target_image = \ | |
self.get_pairs(content_english=use_eng_content, style_english=use_eng_style) | |
content_fonts_image = np.array([np.mean(np.array(x), axis=-1) / WHITE | |
for x in content_fonts_image], dtype=np.float32) | |
style_chars_image = np.array([np.mean(np.array(x), axis=-1) / WHITE | |
for x in style_chars_image], dtype=np.float32) | |
target_image = np.mean(np.array(target_image, dtype=np.float32), axis=-1)[np.newaxis, ...] / WHITE | |
dict_return = { | |
# data for training | |
'gt_images': target_image, | |
'content_images': content_fonts_image, | |
'style_images': style_chars_image, # TODO: crop style image with fixed size | |
# data for logging | |
'style_idx': style_font, | |
'char_idx': content_char, | |
'content_image_idxs': content_fonts, | |
'style_image_idxs': style_chars, | |
'image_paths': '', | |
} | |
return dict_return | |
def __len__(self): | |
return len(self.lang_list) * REPEATE_NUM | |
if __name__ == '__main__': | |
hp = OmegaConf.load('config/datasets/googlefont.yaml').datasets.train | |
metadata_path = "./lang_set.json" | |
FONT_DIR = "/data2/hksong/DATA/fonts-image" | |
_dataset = GoogleFontDataset(hp, metadata_path=metadata_path, font_dir=FONT_DIR) | |
TEST_ITER_NUM = 4 | |
for i in range(TEST_ITER_NUM): | |
data = _dataset[i] | |
print(data.keys()) | |
print(data['gt_image'].size, | |
data['content_images'][0].size, | |
data['style_images'][0].size, | |
data['lang'], | |
data['style_idx'], | |
data['char_idx'], | |
data['content_image_idxs'], | |
data['style_image_idxs']) | |