deepkyu commited on
Commit
1ba3df3
1 Parent(s): eba4868

initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # block temp directory
2
+ .idea/
3
+ __pycache__/
4
+ .ipynb_checkpoints/
5
+ .vscode/
6
+ .temp/
7
+ lightning_logs/
8
+
9
+ # block extension
10
+ *.pkl
11
+ *.png
12
+ *.pth
13
+ *.json
14
+ *.ckpt
15
+
16
+ # block logging directory
17
+ logs/
18
+ wandb/
19
+
20
+ # custom
21
+ font-image
22
+
23
+ !font_list.json
24
+ !font_list_noto_sans.json
app.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ from pathlib import Path
4
+ from typing import Optional, Union, Tuple, List
5
+ import subprocess
6
+
7
+ import gradio as gr
8
+ from PIL import Image
9
+ from omegaconf import OmegaConf, DictConfig
10
+
11
+ from inference import InferenceServicer
12
+
13
+ PATH_DOCS = os.getenv("PATH_DOCS", default="docs/ml-font-style-transfer.md")
14
+ MODEL_CONFIG = os.getenv("MODEL_CONFIG", default="config/models/google-font.yaml")
15
+
16
+ MODEL_CHECKPOINT_PATH = os.getenv("MODEL_CHECKPOINT_PATH", default=None)
17
+ NOTO_SANS_ZIP_PATH = os.getenv("NOTO_SANS_ZIP_PATH", default=None)
18
+
19
+ LOCAL_CHECKPOINT_PATH = "checkpoint/checkpoint.ckpt"
20
+ LOCAL_NOTO_ZIP_PATH = "data/NotoSans.zip"
21
+
22
+ if MODEL_CHECKPOINT_PATH is not None:
23
+ subprocess.call(f"wget --no-check-certificate -O {LOCAL_CHECKPOINT_PATH} {MODEL_CHECKPOINT_PATH}", shell=True)
24
+ if NOTO_SANS_ZIP_PATH is not None:
25
+ subprocess.call(f"wget --no-check-certificate -O {LOCAL_NOTO_ZIP_PATH} {NOTO_SANS_ZIP_PATH}", shell=True)
26
+ subprocess.call(f"unzip data/NotoSans.zip -d {str(Path(LOCAL_NOTO_ZIP_PATH).parent)}", shell=True)
27
+
28
+ assert Path("checkpoint/checkpoint.ckpt").exists()
29
+ assert Path("data/NotoSans").exists()
30
+
31
+ EXAMPLE_FONTS = sorted([
32
+ "example_fonts/BalooDa2-Bold.ttf",
33
+ "example_fonts/BalooDa2-Regular.ttf",
34
+ "example_fonts/Lalezar-Regular.ttf",
35
+ "example_fonts/MaShanZheng-Regular.ttf",
36
+ ])
37
+
38
+ def parse_args():
39
+
40
+ parser = argparse.ArgumentParser(description="Augmentation simulator for NetsPresso Trainer")
41
+
42
+ # -------- User arguments ----------------------------------------
43
+
44
+ parser.add_argument(
45
+ '--docs', type=Path, default=PATH_DOCS,
46
+ help="Docs string file")
47
+
48
+ parser.add_argument(
49
+ '--config', type=Path, default=MODEL_CONFIG,
50
+ help="Config for model")
51
+
52
+ parser.add_argument(
53
+ '--local', action='store_true',
54
+ help="Whether to run in local environment or not")
55
+
56
+ parser.add_argument(
57
+ '--port', type=int, default=50003,
58
+ help="Service port (only applicable when running on local server)")
59
+
60
+ args, _ = parser.parse_known_args()
61
+
62
+ return args
63
+
64
+ class InferenceServiceResolver(InferenceServicer):
65
+ def __init__(self, hp, checkpoint_path, content_image_dir, imsize=64, gpu_id='0') -> None:
66
+ super().__init__(hp, checkpoint_path, content_image_dir, imsize, gpu_id)
67
+
68
+ def generate(self, content_char: str, style_font: Union[str, Path]) -> List[Image.Image]:
69
+ try:
70
+ content_image, style_images, result = self.inference(content_char=content_char, style_font=style_font)
71
+ return [content_image, *style_images, result]
72
+ except Exception as e:
73
+ raise gr.Error(str(e))
74
+
75
+ def launch_gradio(docs_path: Path, hp: DictConfig, checkpoint_path: Path, content_image_dir: Path, is_local: bool, port: Optional[int] = None):
76
+
77
+ servicer = InferenceServiceResolver(hp, checkpoint_path, content_image_dir, gpu_id=None)
78
+ with gr.Blocks(title="Multilingual Font Style Transfer (training with Google Fonts)") as demo:
79
+ gr.Markdown(docs_path.read_text())
80
+ with gr.Row(equal_height=True):
81
+ character_input = gr.Textbox(max_lines=1, value="7", info="Only single character is acceptable (e.g. '간', '7', or 'ជ')")
82
+ style_font = gr.Dropdown(label="Select example font: ", choices=EXAMPLE_FONTS, value=EXAMPLE_FONTS[0])
83
+ run_button = gr.Button(value="Generate", variant='primary')
84
+
85
+ with gr.Row(equal_height=True):
86
+ with gr.Column(scale=1):
87
+ with gr.Group():
88
+ gr.Markdown(f"<center><h3>Content character</h3></center>")
89
+ content_char = gr.Image(label="Content character", show_label=False)
90
+ with gr.Column(scale=5):
91
+ with gr.Group():
92
+ gr.Markdown(f"<center><h3>Style font images</h3></center>")
93
+ with gr.Row(equal_height=True):
94
+ style_char_1 = gr.Image(label="Style #1", show_label=False)
95
+ style_char_2 = gr.Image(label="Style #2", show_label=False)
96
+ style_char_3 = gr.Image(label="Style #3", show_label=False)
97
+ style_char_4 = gr.Image(label="Style #4", show_label=False)
98
+ style_char_5 = gr.Image(label="Style #5", show_label=False)
99
+ with gr.Column(scale=1):
100
+ with gr.Group():
101
+ gr.Markdown(f"<center><h3>Generated font image</h3></center>")
102
+ generated_font = gr.Image(label="Generated font image", show_label=False)
103
+
104
+ outputs = [content_char, style_char_1, style_char_2, style_char_3, style_char_4, style_char_5, generated_font]
105
+ run_inputs = [character_input, style_font]
106
+ run_button.click(servicer.generate, inputs=run_inputs, outputs=outputs)
107
+
108
+ if is_local:
109
+ demo.launch(server_name="0.0.0.0", server_port=port)
110
+ else:
111
+ demo.launch()
112
+
113
+
114
+ if __name__ == "__main__":
115
+ args = parse_args()
116
+
117
+ hp = OmegaConf.load(args.config)
118
+ checkpoint_path = Path(LOCAL_CHECKPOINT_PATH)
119
+ content_image_dir = Path(LOCAL_NOTO_ZIP_PATH).with_suffix("")
120
+
121
+ launch_gradio(args.docs, hp, checkpoint_path, content_image_dir, args.local, args.port)
checkpoint/.gitkeep ADDED
File without changes
config/datasets/googlefont.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets:
2
+ type: GoogleFontDataset
3
+ train:
4
+ split: auto
5
+ font_dir: &font_dir ../DATA/fonts-image-20230929
6
+ imsize: 64
7
+ reference_imgs:
8
+ replace: False
9
+ char: &reference_char 1
10
+ style: &reference_style 5
11
+
12
+ squeeze_gray: &squeeze_gray True
13
+ transform:
14
+ # TODO
15
+
16
+ # loader configs
17
+ shuffle: True
18
+ batch_size: 64
19
+ num_workers: 12
20
+
21
+ eval:
22
+ split: auto
23
+ font_dir: *font_dir
24
+ imsize: 64
25
+ reference_imgs:
26
+ replace: False
27
+ char: *reference_char
28
+ style: *reference_style
29
+
30
+ squeeze_gray: *squeeze_gray
31
+ transform:
32
+ # TODO
33
+
34
+ # loader configs
35
+ shuffle: True
36
+ batch_size: 1
37
+ num_workers: 4
config/lightning.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pl_config:
2
+ checkpoint:
3
+ callback:
4
+ save_top_k: -1
5
+ verbose: True
6
+ every_n_epochs: 5 #epochs
7
+
8
+ trainer:
9
+ gradient_clip_val: 0
10
+ max_epochs: 2000
11
+ num_sanity_val_steps: 1
12
+ fast_dev_run: False
13
+ check_val_every_n_epoch: 5
14
+ # distributed_backend: 'ddp'
15
+ accelerator: 'cuda'
16
+ benchmark: True
config/logging.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ logging:
2
+ dry_run: False
3
+ device: cuda
4
+ log_dir: /ssd1/hksong/LOG/font
5
+ seed: ftgan-patch-full
6
+ freq:
7
+ train: 100 # step
8
+
9
+ nepochs_decay: 100
10
+
11
+ gan_loss: lsgan
12
+ lambda_L1: 100
13
+ lambda_classifier: ~
14
+
15
+ trainer: base
16
+
17
+ savefiles: [
18
+ '*.py',
19
+ 'data/*.*',
20
+ 'datasets/*.*',
21
+ 'models/*.*',
22
+ 'configs/*.*',
23
+ 'utils/*.*',
24
+ 'trainer/*.*',
25
+ ]
config/models/google-font.yaml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ models:
2
+ G:
3
+ encoder:
4
+ content:
5
+ type: ContentVanillaEncoder
6
+ depth: 2
7
+ style:
8
+ type: StyleVanillaEncoder
9
+ depth: 2
10
+ decoder:
11
+ type: VanillaDecoder
12
+ residual_blocks: 6
13
+ depth: 2
14
+
15
+ optim:
16
+ class: torch.optim.Adam
17
+ betas: [ 0.5, 0.999 ]
18
+ lr: 0.0002
19
+ lr_policy: step
20
+ lr_decay_iters: 1000
21
+
22
+ init_type: normal
23
+ init_gain: 0.02
24
+
25
+ D_content:
26
+ in_channels: 2 # char + 1
27
+ class: models.discriminator.PatchGANDiscriminator
28
+ optim:
29
+ class: torch.optim.Adam
30
+ betas: [ 0.5, 0.999 ]
31
+ lr: 2e-4
32
+ lr_policy: step
33
+ lr_decay_iters: 1000
34
+
35
+ D_style:
36
+ in_channels: 6 # style + 1
37
+ class: models.discriminator.PatchGANDiscriminator
38
+ optim:
39
+ class: torch.optim.Adam
40
+ betas: [ 0.5, 0.999 ]
41
+ lr: 2e-4
42
+ lr_policy: step
43
+ lr_decay_iters: 1000
config/setting-google-font.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ config:
2
+ dataset: 'config/datasets/googlefont.yaml'
3
+ model: 'config/models/google-font.yaml'
4
+ logging: 'config/logging.yaml'
5
+ lightning: 'config/lightning.yaml'
data/.gitkeep ADDED
File without changes
datasets/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .googlefont import GoogleFontDataset
2
+ from .ftgan import FTGANDataset
datasets/googlefont.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import random
4
+ import string
5
+ import json
6
+ import logging
7
+ from pathlib import Path
8
+
9
+ from omegaconf import OmegaConf
10
+ import numpy as np
11
+ import PIL.Image as Image
12
+ import torch
13
+ from torch.utils.data import Dataset
14
+ from tqdm import tqdm
15
+
16
+ REPEATE_NUM = 10000
17
+
18
+ WHITE = 255
19
+
20
+ MAX_TRIAL = 10
21
+
22
+ _upper_case = set(map(lambda s: f"{ord(s):04X}", string.ascii_uppercase))
23
+ _digits = set(map(lambda s: f"{ord(s):04X}", string.digits))
24
+ english_set = list(_upper_case.union(_digits))
25
+
26
+ NOTO_FONT_DIRNAME = "Noto"
27
+
28
+
29
+ class GoogleFontDataset(Dataset):
30
+ def __init__(self, args, mode='train',
31
+ metadata_path="./lang_set.json"):
32
+ super(GoogleFontDataset, self).__init__()
33
+ self.args = args
34
+ self.font_dir = Path(args.font_dir)
35
+ self.mode = mode
36
+ self.lang_list = sorted([x.stem for x in self.font_dir.iterdir() if x.is_dir()])
37
+ self.min_tight_bound = 10000
38
+ self.min_font_name = None
39
+
40
+ if self.mode == 'train':
41
+ self.lang_list = self.lang_list[:-2]
42
+ else:
43
+ self.lang_list = self.lang_list[-2:]
44
+ with open(metadata_path, "r") as json_f:
45
+ self.data = json.load(json_f)
46
+
47
+ self.num_lang = None
48
+ self.num_font = None
49
+ self.num_char = None
50
+ self.content_meta, self.style_meta, self.num_lang, self.num_font, self.num_char = self.get_meta()
51
+ logging.info(f"min_tight_bound: {self.min_tight_bound}") # 20
52
+
53
+ @staticmethod
54
+ def center_align(bg_img, item_img, fit=False):
55
+ bg_img = bg_img.copy()
56
+ item_img = item_img.copy()
57
+ item_w, item_h = item_img.size
58
+ W, H = bg_img.size
59
+ if fit:
60
+ item_ratio = item_w / item_h
61
+ bg_ratio = W / H
62
+
63
+ if bg_ratio > item_ratio:
64
+ # height fitting
65
+ resize_ratio = H / item_h
66
+ else:
67
+ # width fitting
68
+ resize_ratio = W / item_w
69
+ item_img = item_img.resize((int(item_w * resize_ratio), int(item_h * resize_ratio)))
70
+ item_w, item_h = item_img.size
71
+
72
+ bg_img.paste(item_img, ((W - item_w) // 2, (H - item_h) // 2))
73
+ return bg_img
74
+
75
+ def _get_content_image(self, png_path):
76
+ im = Image.open(png_path)
77
+ bg_img = Image.new('RGB', (self.args.imsize, self.args.imsize), color='white')
78
+ blend_img = self.center_align(bg_img, im, fit=True)
79
+ return blend_img
80
+
81
+ def _get_style_image(self, png_path):
82
+ im = Image.open(png_path)
83
+ w, h = im.size
84
+
85
+ # tight_bound_check & update
86
+ tight_bound = self.get_tight_bound_size(np.array(im))
87
+ if self.min_tight_bound > tight_bound:
88
+ self.min_tight_bound = tight_bound
89
+ self.min_font_name = png_path
90
+ logging.debug(f"min_tight_bound: {self.min_tight_bound}, min_font_name: {self.min_font_name}")
91
+
92
+ bg_img = Image.new('RGB', (max([w, h, self.args.imsize]), max([w, h, self.args.imsize])), color='white')
93
+ blend_img = self.center_align(bg_img, im)
94
+ return blend_img
95
+
96
+ def get_meta(self):
97
+ content_meta = dict()
98
+ style_meta = dict()
99
+
100
+ num_lang = 0
101
+ num_font = 0
102
+ num_char = 0
103
+ for lang_dir in tqdm(self.lang_list, total=len(self.lang_list)):
104
+ font_list = sorted([x for x in (self.font_dir / lang_dir).iterdir() if x.is_dir()])
105
+
106
+ font_content_dict = dict()
107
+ font_style_dict = dict()
108
+
109
+ for font_dir in font_list:
110
+ image_content_dict = dict()
111
+ image_style_dict = dict()
112
+
113
+ png_list = [x for x in font_dir.glob("*.png")]
114
+
115
+ for png_path in png_list:
116
+
117
+ # image_content_dict[png_path.stem] = self._get_content_image(png_path)
118
+ # image_style_dict[png_path.stem] = self._get_style_image(png_path)
119
+ image_content_dict[png_path.stem] = png_path
120
+ image_style_dict[png_path.stem] = png_path
121
+ num_char += 1
122
+
123
+ font_content_dict[font_dir.stem] = image_content_dict
124
+ font_style_dict[font_dir.stem] = image_style_dict
125
+ num_font += 1
126
+
127
+ content_meta[lang_dir] = font_content_dict
128
+ style_meta[lang_dir] = font_style_dict
129
+ num_lang += 1
130
+
131
+ return content_meta, style_meta, num_lang, num_font, num_char
132
+
133
+ @staticmethod
134
+ def get_tight_bound_size(img):
135
+ contents_cell = np.where(img < WHITE)
136
+
137
+ if len(contents_cell[0]) == 0:
138
+ return 0
139
+
140
+ size = {
141
+ 'xmin': np.min(contents_cell[1]),
142
+ 'ymin': np.min(contents_cell[0]),
143
+ 'xmax': np.max(contents_cell[1]) + 1,
144
+ 'ymax': np.max(contents_cell[0]) + 1,
145
+ }
146
+ return max(size['xmax'] - size['xmin'], size['ymax'] - size['ymin'])
147
+
148
+ def get_patch_from_style_image(self, image, patch_per_image=1):
149
+ w, h = image.size
150
+ image_list = []
151
+ relative_patch_size = int(self.args.imsize * 2)
152
+ for _ in range(patch_per_image):
153
+ offset = w - relative_patch_size
154
+ if offset < relative_patch_size // 2:
155
+ # if image is too small, just resize
156
+ crop_candidate = np.array(image.resize((self.args.imsize, self.args.imsize)))
157
+ else:
158
+ # if image is sufficent to be cropped, randomly crop
159
+ x = np.random.randint(0, offset)
160
+ y = np.random.randint(0, offset)
161
+ crop_candidate = image.crop((x, y, x + relative_patch_size, y + relative_patch_size))
162
+
163
+ _trial = 0
164
+ while self.get_tight_bound_size(np.array(crop_candidate)) < relative_patch_size // 16 and _trial < MAX_TRIAL:
165
+ x = np.random.randint(0, offset)
166
+ y = np.random.randint(0, offset)
167
+ crop_candidate = image.crop((x, y, x + relative_patch_size, y + relative_patch_size))
168
+ _trial += 1
169
+
170
+ crop_candidate = np.array(crop_candidate.resize((self.args.imsize, self.args.imsize)))
171
+ image_list.append(crop_candidate)
172
+ return image_list
173
+
174
+ def get_pairs(self, content_english=False, style_english=False):
175
+ lang_content = random.choice(self.lang_list)
176
+
177
+ content_unicode_list = english_set if content_english else self.data[lang_content]
178
+ style_unicode_list = english_set if style_english else self.data[lang_content]
179
+
180
+ if content_english == style_english:
181
+ # content_unicode_list == style_unicode_list
182
+ chars = random.sample(content_unicode_list,
183
+ k=self.args.reference_imgs.style + 1)
184
+ content_char = chars[-1]
185
+ style_chars = chars[:self.args.reference_imgs.style]
186
+ else:
187
+ content_char = random.choice(content_unicode_list)
188
+ style_chars = random.sample(style_unicode_list, k=self.args.reference_imgs.style)
189
+
190
+ # fonts = random.sample(self.content_meta[lang_content].keys(),
191
+ # k=self.args.reference_imgs.char + 1)
192
+ # content_fonts = fonts[:self.args.reference_imgs.char]
193
+ # style_font = fonts[-1]
194
+
195
+ style_font_list = list(self.content_meta[lang_content].keys())
196
+ style_font_list.remove(NOTO_FONT_DIRNAME)
197
+ style_font = random.choice(style_font_list)
198
+ content_fonts = [NOTO_FONT_DIRNAME]
199
+
200
+ content_fonts_image = [self.content_meta[lang_content][x][content_char] for x in content_fonts]
201
+ style_chars_image = [self.content_meta[lang_content][style_font][x] for x in style_chars]
202
+
203
+ # style_chars_image = [self.content_meta[lang_content][style_font][x] for x in style_chars]
204
+
205
+ # style_chars_cropped = []
206
+ # for style_char_image in style_chars_image:
207
+ # style_chars_cropped.extend(self.get_patch_from_style_image(style_char_image,
208
+ # patch_per_image=self.args.reference_imgs.style // self.args.reference_imgs.char))
209
+
210
+ target_image = self.content_meta[lang_content][style_font][content_char]
211
+
212
+ content_fonts_image = [self._get_content_image(image_path) for image_path in content_fonts_image]
213
+ style_chars_image = [self._get_content_image(image_path) for image_path in style_chars_image]
214
+ target_image = self._get_content_image(target_image)
215
+
216
+ return content_char, content_fonts, content_fonts_image, style_font, style_chars, style_chars_image, target_image
217
+
218
+ def __getitem__(self, idx):
219
+ """GoogleFontDataset의 __getitem__
220
+
221
+ Args:
222
+ idx (int): torch dataset index
223
+
224
+ Returns:
225
+ dict: return dict with following keys
226
+
227
+ gt_images: target_image,
228
+ content_images: same_chars_image,
229
+ style_images: same_fonts_image,
230
+ style_idx: font_idx,
231
+ char_idx: char_idx,
232
+ content_image_idxs: same_chars,
233
+ style_image_idxs: same_fonts,
234
+ image_paths: ''
235
+ """
236
+ use_eng_content, use_eng_style = random.choice([(True, False), (False, True), (False, False)])
237
+
238
+ if self.mode != 'train':
239
+ use_eng_content = False
240
+ use_eng_style = True
241
+
242
+ content_char, content_fonts, content_fonts_image, style_font, style_chars, style_chars_image, target_image = \
243
+ self.get_pairs(content_english=use_eng_content, style_english=use_eng_style)
244
+
245
+ content_fonts_image = np.array([np.mean(np.array(x), axis=-1) / WHITE
246
+ for x in content_fonts_image], dtype=np.float32)
247
+ style_chars_image = np.array([np.mean(np.array(x), axis=-1) / WHITE
248
+ for x in style_chars_image], dtype=np.float32)
249
+ target_image = np.mean(np.array(target_image, dtype=np.float32), axis=-1)[np.newaxis, ...] / WHITE
250
+
251
+ dict_return = {
252
+ # data for training
253
+ 'gt_images': target_image,
254
+ 'content_images': content_fonts_image,
255
+ 'style_images': style_chars_image, # TODO: crop style image with fixed size
256
+ # data for logging
257
+ 'style_idx': style_font,
258
+ 'char_idx': content_char,
259
+ 'content_image_idxs': content_fonts,
260
+ 'style_image_idxs': style_chars,
261
+ 'image_paths': '',
262
+ }
263
+ return dict_return
264
+
265
+ def __len__(self):
266
+ return len(self.lang_list) * REPEATE_NUM
267
+
268
+
269
+ if __name__ == '__main__':
270
+ hp = OmegaConf.load('config/datasets/googlefont.yaml').datasets.train
271
+ metadata_path = "./lang_set.json"
272
+ FONT_DIR = "/data2/hksong/DATA/fonts-image"
273
+
274
+ _dataset = GoogleFontDataset(hp, metadata_path=metadata_path, font_dir=FONT_DIR)
275
+ TEST_ITER_NUM = 4
276
+ for i in range(TEST_ITER_NUM):
277
+ data = _dataset[i]
278
+ print(data.keys())
279
+ print(data['gt_image'].size,
280
+ data['content_images'][0].size,
281
+ data['style_images'][0].size,
282
+ data['lang'],
283
+ data['style_idx'],
284
+ data['char_idx'],
285
+ data['content_image_idxs'],
286
+ data['style_image_idxs'])
docs/ml-font-style-transfer.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ <center><h1>Multilingual Font Style Transfer</h1></center>
2
+
3
+ - Compostion-free font style transfer across 13 different languages
4
+ - Trained with [Google Fonts](https://github.com/google/fonts) (ofl fonts and Nota Sans)
5
+
6
+ This is personal concept proofing demo, so it does not guarantee that the quality of output.
7
+ I hope that in someday there will be an established model for the better mulitlingual society.
8
+
9
+ I only used personal RTX 30 series GPU(s) for training the model. The model is heavily inspired from a model from the previous study, [FTransGAN](https://github.com/ligoudaner377/font_translator_gan) (Li et al.).
font_list.json ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "arabic": [
3
+ "ofl/baloobhaijaan2/BalooBhaijaan2[wght].ttf",
4
+ "ofl/bonanova/BonaNova-Italic.ttf",
5
+ "ofl/bonanova/BonaNova-Regular.ttf",
6
+ "ofl/cairo/Cairo[wght].ttf",
7
+ "ofl/changa/Changa[wght].ttf",
8
+ "ofl/elmessiri/ElMessiri[wght].ttf",
9
+ "ofl/handjet/Handjet[EGRD,ESHP,wght].ttf",
10
+ "ofl/harmattan/Harmattan-Regular.ttf",
11
+ "ofl/ibmplexsansarabic/IBMPlexSansArabic-Regular.ttf",
12
+ "ofl/katibeh/Katibeh-Regular.ttf",
13
+ "ofl/kufam/Kufam-Italic[wght].ttf",
14
+ "ofl/kufam/Kufam[wght].ttf",
15
+ "ofl/lalezar/Lalezar-Regular.ttf",
16
+ "ofl/lemonada/Lemonada[wght].ttf",
17
+ "ofl/lemonadavfbeta/LemonadaVFBeta.ttf",
18
+ "ofl/markazitext/MarkaziText[wght].ttf",
19
+ "ofl/mirza/Mirza-Regular.ttf",
20
+ "ofl/qahiri/Qahiri-Regular.ttf",
21
+ "ofl/rakkas/Rakkas-Regular.ttf",
22
+ "ofl/readexpro/ReadexPro[wght].ttf",
23
+ "ofl/reemkufi/ReemKufi[wght].ttf",
24
+ "ofl/scheherazadenew/ScheherazadeNew-Regular.ttf",
25
+ "ofl/scheherazade/Scheherazade-Regular.ttf",
26
+ "ofl/tajawal/Tajawal-Regular.ttf",
27
+ "ofl/vibes/Vibes-Regular.ttf"
28
+ ],
29
+ "bengali": [
30
+ "ofl/atma/Atma-Regular.ttf",
31
+ "ofl/balooda2/BalooDa2-Regular.ttf",
32
+ "ofl/galada/Galada-Regular.ttf",
33
+ "ofl/hindsiliguri/HindSiliguri-Regular.ttf",
34
+ "ofl/mina/Mina-Regular.ttf"
35
+ ],
36
+ "gujarati": [
37
+ "ofl/baloobhai2/BalooBhai2[wght].ttf",
38
+ "ofl/farsan/Farsan-Regular.ttf",
39
+ "ofl/hindvadodara/HindVadodara-Regular.ttf",
40
+ "ofl/mogra/Mogra-Regular.ttf",
41
+ "ofl/muktavaani/MuktaVaani-Regular.ttf",
42
+ "ofl/rasa/Rasa[wght].ttf",
43
+ "ofl/shrikhand/Shrikhand-Regular.ttf"
44
+ ],
45
+ "hebrew": [
46
+ "ofl/adobeblank/AdobeBlank-Regular.ttf",
47
+ "ofl/alef/Alef-Regular.ttf",
48
+ "ofl/amaticsc/AmaticSC-Regular.ttf",
49
+ "ofl/bellefair/Bellefair-Regular.ttf",
50
+ "ofl/bonanova/BonaNova-Italic.ttf",
51
+ "ofl/bonanova/BonaNova-Regular.ttf",
52
+ "ofl/cardo/Cardo-Italic.ttf",
53
+ "ofl/cardo/Cardo-Regular.ttf",
54
+ "ofl/davidlibre/DavidLibre-Regular.ttf",
55
+ "ofl/frankruhllibre/FrankRuhlLibre-Regular.ttf",
56
+ "ofl/handjet/Handjet[EGRD,ESHP,wght].ttf",
57
+ "ofl/heebo/Heebo[wght].ttf",
58
+ "ofl/ibmplexsanshebrew/IBMPlexSansHebrew-Regular.ttf",
59
+ "ofl/karantina/Karantina-Regular.ttf",
60
+ "ofl/miriamlibre/MiriamLibre-Regular.ttf",
61
+ "ofl/mplus1p/Mplus1p-Regular.ttf",
62
+ "ofl/roundedmplus1c/RoundedMplus1c-Regular.ttf",
63
+ "ofl/rubikbeastly/RubikBeastly-Regular.ttf",
64
+ "ofl/rubikmonoone/RubikMonoOne-Regular.ttf",
65
+ "ofl/rubikone/RubikOne-Regular.ttf",
66
+ "ofl/rubik/Rubik-Italic[wght].ttf",
67
+ "ofl/secularone/SecularOne-Regular.ttf",
68
+ "ofl/suezone/SuezOne-Regular.ttf"
69
+ ],
70
+ "japanese": [
71
+ "ofl/delagothicone/DelaGothicOne-Regular.ttf",
72
+ "ofl/dotgothic16/DotGothic16-Regular.ttf",
73
+ "ofl/hachimarupop/HachiMaruPop-Regular.ttf",
74
+ "ofl/jejugothic/JejuGothic-Regular.ttf",
75
+ "ofl/jejuhallasan/JejuHallasan-Regular.ttf",
76
+ "ofl/jejumyeongjo/JejuMyeongjo-Regular.ttf",
77
+ "ofl/kaiseidecol/KaiseiDecol-Regular.ttf",
78
+ "ofl/kaiseiharunoumi/KaiseiHarunoUmi-Regular.ttf",
79
+ "ofl/kaiseiopti/KaiseiOpti-Regular.ttf",
80
+ "ofl/kaiseitokumin/KaiseiTokumin-Regular.ttf"
81
+ ],
82
+ "khmer": [
83
+ "ofl/angkor/Angkor-Regular.ttf",
84
+ "ofl/battambang/Battambang-Regular.ttf",
85
+ "ofl/bayon/Bayon-Regular.ttf",
86
+ "ofl/bokor/Bokor-Regular.ttf",
87
+ "ofl/dangrek/Dangrek-Regular.ttf",
88
+ "ofl/fasthand/Fasthand-Regular.ttf",
89
+ "ofl/freehand/Freehand-Regular.ttf",
90
+ "ofl/hanuman/Hanuman-Regular.ttf",
91
+ "ofl/kohsantepheap/KohSantepheap-Regular.ttf",
92
+ "ofl/koulen/Koulen-Regular.ttf",
93
+ "ofl/metal/Metal-Regular.ttf",
94
+ "ofl/moul/Moul-Regular.ttf",
95
+ "ofl/moulpali/Moulpali-Regular.ttf",
96
+ "ofl/nokora/Nokora-Regular.ttf",
97
+ "ofl/odormeanchey/OdorMeanChey-Regular.ttf",
98
+ "ofl/preahvihear/Preahvihear-Regular.ttf",
99
+ "ofl/suwannaphum/Suwannaphum-Regular.ttf",
100
+ "ofl/taprom/Taprom-Regular.ttf"
101
+ ],
102
+ "korean": [
103
+ "ofl/blackandwhitepicture/BlackAndWhitePicture-Regular.ttf",
104
+ "ofl/dongle/Dongle-Regular.ttf",
105
+ "ofl/gamjaflower/GamjaFlower-Regular.ttf",
106
+ "ofl/gothica1/GothicA1-Regular.ttf",
107
+ "ofl/gowunbatang/GowunBatang-Regular.ttf",
108
+ "ofl/gowundodum/GowunDodum-Regular.ttf",
109
+ "ofl/himelody/HiMelody-Regular.ttf",
110
+ "ofl/poorstory/PoorStory-Regular.ttf"
111
+ ],
112
+ "malayalam": [
113
+ "ofl/baloochettan2/BalooChettan2-Regular.ttf",
114
+ "ofl/chilanka/Chilanka-Regular.ttf",
115
+ "ofl/gayathri/Gayathri-Regular.ttf",
116
+ "ofl/hindkochi/HindKochi-Regular.ttf",
117
+ "ofl/manjari/Manjari-Regular.ttf"
118
+ ],
119
+ "cyrillic": [
120
+ "ofl/adobeblank/AdobeBlank-Regular.ttf",
121
+ "ofl/alegreya/Alegreya-Italic[wght].ttf",
122
+ "ofl/alegreya/Alegreya[wght].ttf",
123
+ "ofl/alegreyasans/AlegreyaSans-Italic.ttf",
124
+ "ofl/alegreyasans/AlegreyaSans-Regular.ttf",
125
+ "ofl/alegreyasanssc/AlegreyaSansSC-Italic.ttf",
126
+ "ofl/alegreyasanssc/AlegreyaSansSC-Regular.ttf",
127
+ "ofl/alegreyasc/AlegreyaSC-Italic.ttf",
128
+ "ofl/alegreyasc/AlegreyaSC-Regular.ttf",
129
+ "ofl/alice/Alice-Regular.ttf",
130
+ "ofl/alumnisans/AlumniSans-Italic[wght].ttf",
131
+ "ofl/amaticsc/AmaticSC-Regular.ttf",
132
+ "ofl/andika/Andika-Regular.ttf",
133
+ "ofl/anonymouspro/AnonymousPro-Italic.ttf",
134
+ "ofl/anonymouspro/AnonymousPro-Regular.ttf",
135
+ "ofl/arsenal/Arsenal-Italic.ttf",
136
+ "ofl/arsenal/Arsenal-Regular.ttf",
137
+ "ofl/badscript/BadScript-Regular.ttf",
138
+ "ofl/balsamiqsans/BalsamiqSans-Italic.ttf",
139
+ "ofl/balsamiqsans/BalsamiqSans-Regular.ttf",
140
+ "ofl/bellota/Bellota-Italic.ttf",
141
+ "ofl/bellota/Bellota-Regular.ttf",
142
+ "ofl/bellotatext/BellotaText-Italic.ttf",
143
+ "ofl/bellotatext/BellotaText-Regular.ttf",
144
+ "ofl/bitter/Bitter-Italic[wght].ttf",
145
+ "ofl/bonanova/BonaNova-Italic.ttf",
146
+ "ofl/bonanova/BonaNova-Regular.ttf",
147
+ "ofl/brygada1918/Brygada1918-Italic[wght].ttf",
148
+ "ofl/brygada1918/Brygada1918[wght].ttf",
149
+ "ofl/caveat/Caveat[wght].ttf",
150
+ "ofl/comfortaa/Comfortaa[wght].ttf",
151
+ "ofl/comforterbrush/ComforterBrush-Regular.ttf",
152
+ "ofl/comforter/Comforter-Regular.ttf",
153
+ "ofl/cormorant/Cormorant-Italic.ttf",
154
+ "ofl/cormorant/Cormorant-Regular.ttf",
155
+ "ofl/cormorantgaramond/CormorantGaramond-Italic.ttf",
156
+ "ofl/cormorantgaramond/CormorantGaramond-Regular.ttf",
157
+ "ofl/cormorantinfant/CormorantInfant-Italic.ttf",
158
+ "ofl/cormorantinfant/CormorantInfant-Regular.ttf",
159
+ "ofl/cormorantsc/CormorantSC-Regular.ttf",
160
+ "ofl/cormorantunicase/CormorantUnicase-Regular.ttf",
161
+ "ofl/crimsontext/CrimsonText-Regular.ttf",
162
+ "ofl/cuprum/Cuprum-Italic[wght].ttf",
163
+ "ofl/cuprum/Cuprum[wght].ttf",
164
+ "ofl/daysone/DaysOne-Regular.ttf",
165
+ "ofl/delagothicone/DelaGothicOne-Regular.ttf",
166
+ "ofl/didactgothic/DidactGothic-Regular.ttf",
167
+ "ofl/dotgothic16/DotGothic16-Regular.ttf",
168
+ "ofl/ebgaramond/EBGaramond-Italic[wght].ttf",
169
+ "ofl/ebgaramond/EBGaramond[wght].ttf",
170
+ "ofl/elmessiri/ElMessiri[wght].ttf",
171
+ "ofl/exo2/Exo2-Italic[wght].ttf",
172
+ "ofl/exo2/Exo2[wght].ttf",
173
+ "ofl/firasanscondensed/FiraSansCondensed-Italic.ttf",
174
+ "ofl/firasanscondensed/FiraSansCondensed-Regular.ttf",
175
+ "ofl/firasansextracondensed/FiraSansExtraCondensed-Italic.ttf",
176
+ "ofl/firasansextracondensed/FiraSansExtraCondensed-Regular.ttf",
177
+ "ofl/firasans/FiraSans-Italic.ttf",
178
+ "ofl/firasans/FiraSans-Regular.ttf",
179
+ "ofl/flowblock/FlowBlock-Regular.ttf",
180
+ "ofl/flowcircular/FlowCircular-Regular.ttf",
181
+ "ofl/flowrounded/FlowRounded-Regular.ttf",
182
+ "ofl/forum/Forum-Regular.ttf",
183
+ "ofl/gabriela/Gabriela-Regular.ttf",
184
+ "ofl/gothica1/GothicA1-Regular.ttf",
185
+ "ofl/hachimarupop/HachiMaruPop-Regular.ttf",
186
+ "ofl/handjet/Handjet[EGRD,ESHP,wght].ttf",
187
+ "ofl/hinamincho/HinaMincho-Regular.ttf",
188
+ "ofl/ibmplexmono/IBMPlexMono-Italic.ttf",
189
+ "ofl/ibmplexmono/IBMPlexMono-Regular.ttf",
190
+ "ofl/ibmplexsans/IBMPlexSans-Italic.ttf",
191
+ "ofl/ibmplexsans/IBMPlexSans-Regular.ttf",
192
+ "ofl/ibmplexserif/IBMPlexSerif-Italic.ttf",
193
+ "ofl/ibmplexserif/IBMPlexSerif-Regular.ttf",
194
+ "ofl/inter/Inter[slnt,wght].ttf",
195
+ "ofl/istokweb/IstokWeb-Italic.ttf",
196
+ "ofl/istokweb/IstokWeb-Regular.ttf",
197
+ "ofl/jejugothic/JejuGothic-Regular.ttf",
198
+ "ofl/jejuhallasan/JejuHallasan-Regular.ttf",
199
+ "ofl/jejumyeongjo/JejuMyeongjo-Regular.ttf",
200
+ "ofl/jetbrainsmono/JetBrainsMono-Italic[wght].ttf",
201
+ "ofl/jetbrainsmono/JetBrainsMono[wght].ttf",
202
+ "ofl/jost/Jost-Italic[wght].ttf",
203
+ "ofl/jost/Jost[wght].ttf",
204
+ "ofl/kaiseidecol/KaiseiDecol-Regular.ttf",
205
+ "ofl/kaiseiharunoumi/KaiseiHarunoUmi-Regular.ttf",
206
+ "ofl/kaiseiopti/KaiseiOpti-Regular.ttf",
207
+ "ofl/kaiseitokumin/KaiseiTokumin-Regular.ttf",
208
+ "ofl/kellyslab/KellySlab-Regular.ttf",
209
+ "ofl/kiwimaru/KiwiMaru-Regular.ttf",
210
+ "ofl/kleeone/KleeOne-Regular.ttf",
211
+ "ofl/kopubbatang/KoPubBatang-Regular.ttf",
212
+ "ofl/kurale/Kurale-Regular.ttf",
213
+ "ofl/lato/Lato-Italic.ttf",
214
+ "ofl/lato/Lato-Regular.ttf",
215
+ "ofl/ledger/Ledger-Regular.ttf",
216
+ "ofl/literata/Literata-Italic[opsz,wght].ttf",
217
+ "ofl/literata/Literata[opsz,wght].ttf",
218
+ "ofl/lobster/Lobster-Regular.ttf",
219
+ "ofl/lora/Lora-Italic[wght].ttf",
220
+ "ofl/lora/Lora[wght].ttf",
221
+ "ofl/marckscript/MarckScript-Regular.ttf",
222
+ "ofl/marmelad/Marmelad-Regular.ttf",
223
+ "ofl/merriweather/Merriweather-Italic.ttf",
224
+ "ofl/merriweather/Merriweather-Regular.ttf",
225
+ "ofl/montserratalternates/MontserratAlternates-Italic.ttf",
226
+ "ofl/montserratalternates/MontserratAlternates-Regular.ttf",
227
+ "ofl/montserrat/Montserrat-Italic.ttf",
228
+ "ofl/montserrat/Montserrat-Regular.ttf",
229
+ "ofl/mplus1p/Mplus1p-Regular.ttf",
230
+ "ofl/mulish/Mulish-Italic[wght].ttf",
231
+ "ofl/nanumgothiccoding/NanumGothicCoding-Regular.ttf",
232
+ "ofl/neucha/Neucha.ttf",
233
+ "ofl/newscycle/NewsCycle-Regular.ttf",
234
+ "ofl/nobile/Nobile-Italic.ttf",
235
+ "ofl/nobile/Nobile-Regular.ttf",
236
+ "ofl/nunito/Nunito-Italic[wght].ttf",
237
+ "ofl/nunitosans/NunitoSans-Italic.ttf",
238
+ "ofl/nunitosans/NunitoSans-Regular.ttf",
239
+ "ofl/oi/Oi-Regular.ttf",
240
+ "ofl/oranienbaum/Oranienbaum-Regular.ttf",
241
+ "ofl/orelegaone/OrelegaOne-Regular.ttf",
242
+ "ofl/oswald/Oswald[wght].ttf",
243
+ "ofl/overpass/Overpass-Italic[wght].ttf",
244
+ "ofl/overpass/Overpass[wght].ttf",
245
+ "ofl/pacifico/Pacifico-Regular.ttf",
246
+ "ofl/pangolin/Pangolin-Regular.ttf",
247
+ "ofl/pattaya/Pattaya-Regular.ttf",
248
+ "ofl/philosopher/Philosopher-Italic.ttf",
249
+ "ofl/philosopher/Philosopher-Regular.ttf",
250
+ "ofl/piazzolla/Piazzolla-Italic[opsz,wght].ttf",
251
+ "ofl/playfairdisplay/PlayfairDisplay-Italic[wght].ttf",
252
+ "ofl/playfairdisplay/PlayfairDisplay[wght].ttf",
253
+ "ofl/playfairdisplaysc/PlayfairDisplaySC-Italic.ttf",
254
+ "ofl/playfairdisplaysc/PlayfairDisplaySC-Regular.ttf",
255
+ "ofl/play/Play-Regular.ttf",
256
+ "ofl/podkova/Podkova[wght].ttf",
257
+ "ofl/podkovavfbeta/PodkovaVFBeta.ttf",
258
+ "ofl/poiretone/PoiretOne-Regular.ttf",
259
+ "ofl/prata/Prata-Regular.ttf",
260
+ "ofl/pressstart2p/PressStart2P-Regular.ttf",
261
+ "ofl/prostoone/ProstoOne-Regular.ttf",
262
+ "ofl/pushster/Pushster-Regular.ttf",
263
+ "ofl/raleway/Raleway-Italic[wght].ttf",
264
+ "ofl/rampartone/RampartOne-Regular.ttf",
265
+ "ofl/reggaeone/ReggaeOne-Regular.ttf",
266
+ "ofl/robotoflex/RobotoFlex[GRAD,XOPQ,XTRA,YOPQ,YTAS,YTDE,YTFI,YTLC,YTUC,opsz,slnt,wdth,wght].ttf",
267
+ "ofl/rocknrollone/RocknRollOne-Regular.ttf",
268
+ "ofl/roundedmplus1c/RoundedMplus1c-Regular.ttf",
269
+ "ofl/rubikbeastly/RubikBeastly-Regular.ttf",
270
+ "ofl/rubikmonoone/RubikMonoOne-Regular.ttf",
271
+ "ofl/rubikone/RubikOne-Regular.ttf",
272
+ "ofl/rubik/Rubik-Italic[wght].ttf",
273
+ "ofl/ruda/Ruda[wght].ttf",
274
+ "ofl/ruslandisplay/RuslanDisplay.ttf",
275
+ "ofl/russoone/RussoOne-Regular.ttf",
276
+ "ofl/sawarabigothic/SawarabiGothic-Regular.ttf",
277
+ "ofl/scada/Scada-Italic.ttf",
278
+ "ofl/scada/Scada-Regular.ttf",
279
+ "ofl/seoulhangangcondensed/SeoulHangangCondensed-BoldL.ttf",
280
+ "ofl/seoulhangangcondensed/SeoulHangangCondensed-Bold.ttf",
281
+ "ofl/seoulhangangcondensed/SeoulHangangCondensed-ExtraBold.ttf",
282
+ "ofl/seoulhangangcondensed/SeoulHangangCondensed-Medium.ttf",
283
+ "ofl/seoulhangang/SeoulHangang-Bold.ttf",
284
+ "ofl/seoulhangang/SeoulHangang-ExtraBold.ttf",
285
+ "ofl/seoulhangang/SeoulHangang-Medium.ttf",
286
+ "ofl/seoulnamsancondensed/SeoulNamsanCondensed-Black.ttf",
287
+ "ofl/seoulnamsancondensed/SeoulNamsanCondensed-Bold.ttf",
288
+ "ofl/seoulnamsancondensed/SeoulNamsanCondensed-ExtraBold.ttf",
289
+ "ofl/seoulnamsancondensed/SeoulNamsanCondensed-Medium.ttf",
290
+ "ofl/seoulnamsan/SeoulNamsan-Bold.ttf",
291
+ "ofl/seoulnamsan/SeoulNamsan-ExtraBold.ttf",
292
+ "ofl/seoulnamsan/SeoulNamsan-Medium.ttf",
293
+ "ofl/seoulnamsanvertical/SeoulNamsanVertical-Regular.ttf",
294
+ "ofl/seymourone/SeymourOne-Regular.ttf",
295
+ "ofl/sofiasans/SofiaSans-Italic[wdth,wght].ttf",
296
+ "ofl/sofiasans/SofiaSans[wdth,wght].ttf",
297
+ "ofl/sourcesans3/SourceSans3-Italic[wght].ttf",
298
+ "ofl/sourcesans3/SourceSans3[wght].ttf",
299
+ "ofl/sourcesanspro/SourceSansPro-Regular.ttf",
300
+ "ofl/sourceserifpro/SourceSerifPro-Italic.ttf",
301
+ "ofl/sourceserifpro/SourceSerifPro-Regular.ttf",
302
+ "ofl/spectralsc/SpectralSC-Italic.ttf",
303
+ "ofl/spectralsc/SpectralSC-Regular.ttf",
304
+ "ofl/spectral/Spectral-Italic.ttf",
305
+ "ofl/spectral/Spectral-Regular.ttf",
306
+ "ofl/stalinistone/StalinistOne-Regular.ttf",
307
+ "ofl/stick/Stick-Regular.ttf",
308
+ "ofl/stixtwomath/STIXTwoMath-Regular.ttf",
309
+ "ofl/stixtwotext/STIXTwoText-Italic[wght].ttf",
310
+ "ofl/stixtwotext/STIXTwoText[wght].ttf",
311
+ "ofl/strong/Strong-Regular.ttf",
312
+ "ofl/tenorsans/TenorSans-Regular.ttf",
313
+ "ofl/trainone/TrainOne-Regular.ttf",
314
+ "ofl/tuffy/Tuffy-Italic.ttf",
315
+ "ofl/tuffy/Tuffy-Regular.ttf",
316
+ "ofl/underdog/Underdog-Regular.ttf",
317
+ "ofl/viaodalibre/ViaodaLibre-Regular.ttf",
318
+ "ofl/vollkornsc/VollkornSC-Regular.ttf",
319
+ "ofl/vollkorn/Vollkorn-Italic[wght].ttf",
320
+ "ofl/vollkorn/Vollkorn[wght].ttf",
321
+ "ofl/yesevaone/YesevaOne-Regular.ttf",
322
+ "ofl/yomogi/Yomogi-Regular.ttf",
323
+ "ofl/yujiboku/YujiBoku-Regular.ttf",
324
+ "ofl/yujimai/YujiMai-Regular.ttf",
325
+ "ofl/yujisyuku/YujiSyuku-Regular.ttf",
326
+ "ofl/zenantiquesoft/ZenAntiqueSoft-Regular.ttf",
327
+ "ofl/zenantique/ZenAntique-Regular.ttf",
328
+ "ofl/zenkakugothicantique/ZenKakuGothicAntique-Regular.ttf",
329
+ "ofl/zenkakugothicnew/ZenKakuGothicNew-Regular.ttf",
330
+ "ofl/zenkurenaido/ZenKurenaido-Regular.ttf",
331
+ "ofl/zenmarugothic/ZenMaruGothic-Regular.ttf",
332
+ "ofl/zenoldmincho/ZenOldMincho-Regular.ttf"
333
+ ],
334
+ "tamil": [
335
+ "ofl/arimamadurai/ArimaMadurai-Regular.ttf",
336
+ "ofl/baloothambi2/BalooThambi2-Regular.ttf",
337
+ "ofl/coiny/Coiny-Regular.ttf",
338
+ "ofl/hindmadurai/HindMadurai-Regular.ttf",
339
+ "ofl/kavivanar/Kavivanar-Regular.ttf",
340
+ "ofl/meerainimai/MeeraInimai-Regular.ttf",
341
+ "ofl/muktamalar/MuktaMalar-Regular.ttf",
342
+ "ofl/oi/Oi-Regular.ttf",
343
+ "ofl/pavanam/Pavanam-Regular.ttf",
344
+ "ofl/postnobillsjaffna/PostNoBillsJaffna-Regular.ttf"
345
+ ],
346
+ "telugu": [
347
+ "ofl/akayatelivigala/AkayaTelivigala-Regular.ttf",
348
+ "ofl/balootammudu2/BalooTammudu2[wght].ttf",
349
+ "ofl/chathura/Chathura-Regular.ttf",
350
+ "ofl/dhurjati/Dhurjati-Regular.ttf",
351
+ "ofl/gidugu/Gidugu-Regular.ttf",
352
+ "ofl/gurajada/Gurajada-Regular.ttf",
353
+ "ofl/hindguntur/HindGuntur-Regular.ttf",
354
+ "ofl/lakkireddy/LakkiReddy-Regular.ttf",
355
+ "ofl/mallanna/Mallanna-Regular.ttf",
356
+ "ofl/mandali/Mandali-Regular.ttf",
357
+ "ofl/nats/NATS-Regular.ttf",
358
+ "ofl/ntr/NTR-Regular.ttf",
359
+ "ofl/peddana/Peddana-Regular.ttf",
360
+ "ofl/ramabhadra/Ramabhadra-Regular.ttf",
361
+ "ofl/ramaraja/Ramaraja-Regular.ttf",
362
+ "ofl/raviprakash/RaviPrakash-Regular.ttf",
363
+ "ofl/sreekrushnadevaraya/SreeKrushnadevaraya-Regular.ttf",
364
+ "ofl/suranna/Suranna-Regular.ttf",
365
+ "ofl/suravaram/Suravaram-Regular.ttf",
366
+ "ofl/tenaliramakrishna/TenaliRamakrishna-Regular.ttf",
367
+ "ofl/timmana/Timmana-Regular.ttf"
368
+ ],
369
+ "thai": [
370
+ "ofl/athiti/Athiti-Regular.ttf",
371
+ "ofl/baijamjuree/BaiJamjuree-Italic.ttf",
372
+ "ofl/baijamjuree/BaiJamjuree-Regular.ttf",
373
+ "ofl/chakrapetch/ChakraPetch-Italic.ttf",
374
+ "ofl/chakrapetch/ChakraPetch-Regular.ttf",
375
+ "ofl/charm/Charm-Regular.ttf",
376
+ "ofl/charmonman/Charmonman-Regular.ttf",
377
+ "ofl/chonburi/Chonburi-Regular.ttf",
378
+ "ofl/fahkwang/Fahkwang-Italic.ttf",
379
+ "ofl/fahkwang/Fahkwang-Regular.ttf",
380
+ "ofl/ibmplexsansthai/IBMPlexSansThai-Regular.ttf",
381
+ "ofl/ibmplexsansthailooped/IBMPlexSansThaiLooped-Regular.ttf",
382
+ "ofl/itim/Itim-Regular.ttf",
383
+ "ofl/k2d/K2D-Italic.ttf",
384
+ "ofl/k2d/K2D-Regular.ttf",
385
+ "ofl/kanit/Kanit-Italic.ttf",
386
+ "ofl/kanit/Kanit-Regular.ttf",
387
+ "ofl/kodchasan/Kodchasan-Italic.ttf",
388
+ "ofl/kodchasan/Kodchasan-Regular.ttf",
389
+ "ofl/koho/KoHo-Italic.ttf",
390
+ "ofl/koho/KoHo-Regular.ttf",
391
+ "ofl/krub/Krub-Italic.ttf",
392
+ "ofl/krub/Krub-Regular.ttf",
393
+ "ofl/maitree/Maitree-Regular.ttf",
394
+ "ofl/mali/Mali-Italic.ttf",
395
+ "ofl/mali/Mali-Regular.ttf",
396
+ "ofl/mitr/Mitr-Regular.ttf",
397
+ "ofl/niramit/Niramit-Italic.ttf",
398
+ "ofl/niramit/Niramit-Regular.ttf",
399
+ "ofl/pattaya/Pattaya-Regular.ttf",
400
+ "ofl/pridi/Pridi-Regular.ttf",
401
+ "ofl/prompt/Prompt-Italic.ttf",
402
+ "ofl/prompt/Prompt-Regular.ttf",
403
+ "ofl/sarabun/Sarabun-Italic.ttf",
404
+ "ofl/sarabun/Sarabun-Regular.ttf",
405
+ "ofl/sriracha/Sriracha-Regular.ttf",
406
+ "ofl/srisakdi/Srisakdi-Regular.ttf",
407
+ "ofl/taviraj/Taviraj-Italic.ttf",
408
+ "ofl/taviraj/Taviraj-Regular.ttf",
409
+ "ofl/thasadith/Thasadith-Italic.ttf",
410
+ "ofl/thasadith/Thasadith-Regular.ttf",
411
+ "ofl/trirong/Trirong-Italic.ttf",
412
+ "ofl/trirong/Trirong-Regular.ttf"
413
+ ],
414
+ "chinese": [
415
+ "ofl/liujianmaocao/LiuJianMaoCao-Regular.ttf",
416
+ "ofl/longcang/LongCang-Regular.ttf",
417
+ "ofl/mashanzheng/MaShanZheng-Regular.ttf",
418
+ "ofl/mochiypopone/MochiyPopOne-Regular.ttf",
419
+ "ofl/mochiypoppone/MochiyPopPOne-Regular.ttf",
420
+ "ofl/mplus1code/MPLUS1Code[wght].ttf",
421
+ "ofl/mplus1p/Mplus1p-Regular.ttf",
422
+ "ofl/newtegomin/NewTegomin-Regular.ttf",
423
+ "ofl/pottaone/PottaOne-Regular.ttf",
424
+ "ofl/rampartone/RampartOne-Regular.ttf",
425
+ "ofl/reggaeone/ReggaeOne-Regular.ttf"
426
+ ]
427
+ }
font_list_noto_sans.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "arabic": [
3
+ "notosans/notosansarabic/NotoSansArabic[wdth,wght].ttf"
4
+ ],
5
+ "bengali": [
6
+ "notosans/notosansbengali/NotoSansBengali[wdth,wght].ttf"
7
+ ],
8
+ "gujarati": [
9
+ "notosans/notosansgujarati/NotoSansGujarati-Regular.ttf"
10
+ ],
11
+ "hebrew": [
12
+ "notosans/notosanshebrew/NotoSansHebrew[wdth,wght].ttf"
13
+ ],
14
+ "japanese": [
15
+ "notosans/notosansjp/NotoSansJP-Regular.otf"
16
+ ],
17
+ "khmer": [
18
+ "notosans/notosanskhmer/NotoSansKhmer[wdth,wght].ttf"
19
+ ],
20
+ "korean": [
21
+ "notosans/notosanskr/NotoSansKR-Regular.otf"
22
+ ],
23
+ "malayalam": [
24
+ "notosans/notosansmalayalam/NotoSansMalayalam[wdth,wght].ttf"
25
+ ],
26
+ "cyrillic": [
27
+ "notosans/notosans/NotoSans-Regular.ttf"
28
+ ],
29
+ "tamil": [
30
+ "notosans/notosanstamil/NotoSansTamil[wdth,wght].ttf"
31
+ ],
32
+ "telugu": [
33
+ "notosans/notosanstelugu/NotoSansTelugu[wdth,wght].ttf"
34
+ ],
35
+ "thai": [
36
+ "notosans/notosansthai/NotoSansThai[wdth,wght].ttf"
37
+ ],
38
+ "chinese": [
39
+ "notosans/notosanssc/NotoSansSC-Regular.otf"
40
+ ]
41
+ }
inference.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Dict, List, Union, Tuple
3
+
4
+ from omegaconf import OmegaConf
5
+ import numpy as np
6
+ import torch
7
+ from torch import nn
8
+ from PIL import Image, ImageDraw, ImageFont
9
+
10
+ import models
11
+
12
+ GENERATOR_PREFIX = "networks.g."
13
+ WHITE = 255
14
+ EXAMPLE_CHARACTERS = ['A', 'B', 'C', 'D', 'E']
15
+
16
+ class InferenceServicer:
17
+ def __init__(self, hp, checkpoint_path, content_image_dir, imsize=64, gpu_id='0') -> None:
18
+ self.hp = hp
19
+ self.imsize = imsize
20
+
21
+ if gpu_id is None:
22
+ self.device = torch.device(f'cuda:0') if torch.cuda.is_available() else 'cpu'
23
+ else:
24
+ self.device = torch.device(f'cuda:{gpu_id}')
25
+
26
+ model_config = self.hp.models.G
27
+ self.model: nn.Module = models.Generator(model_config)
28
+
29
+ # Load Generator model weight
30
+ model_state_dict_pl = torch.load(checkpoint_path, map_location='cpu')
31
+ generator_state_dict = self.convert_generator_state_dict(model_state_dict_pl)
32
+ self.model.load_state_dict(generator_state_dict)
33
+ self.model.to(device=self.device)
34
+ self.model.eval()
35
+
36
+ # Setting Content font files
37
+ self.content_character_dict = self.load_content_character_dict(Path(content_image_dir))
38
+
39
+ @staticmethod
40
+ def convert_generator_state_dict(model_state_dict_pl):
41
+ generator_prefix = GENERATOR_PREFIX
42
+ generator_state_dict = {}
43
+ for module_name, module_state in model_state_dict_pl['state_dict'].items():
44
+ if module_name.startswith(generator_prefix):
45
+ generator_state_dict[module_name[len(generator_prefix):]] = module_state
46
+
47
+ return generator_state_dict
48
+
49
+ @staticmethod
50
+ def load_content_character_dict(content_image_dir: Path) -> Dict[str, Path]:
51
+ content_character_dict = {}
52
+ for filepath in content_image_dir.glob("**/*.png"):
53
+ content_character_dict[filepath.stem] = filepath
54
+ return content_character_dict
55
+
56
+ @staticmethod
57
+ def center_align(bg_img: Image.Image, item_img: Image.Image, fit=False) -> Image.Image:
58
+ bg_img = bg_img.copy()
59
+ item_img = item_img.copy()
60
+ item_w, item_h = item_img.size
61
+ W, H = bg_img.size
62
+ if fit:
63
+ item_ratio = item_w / item_h
64
+ bg_ratio = W / H
65
+
66
+ if bg_ratio > item_ratio:
67
+ # height fitting
68
+ resize_ratio = H / item_h
69
+ else:
70
+ # width fitting
71
+ resize_ratio = W / item_w
72
+ item_img = item_img.resize((int(item_w * resize_ratio), int(item_h * resize_ratio)))
73
+ item_w, item_h = item_img.size
74
+
75
+ bg_img.paste(item_img, ((W - item_w) // 2, (H - item_h) // 2))
76
+ return bg_img
77
+
78
+ def set_image(self, image: Union[Path, Image.Image]) -> Image.Image:
79
+ if isinstance(image, (str, Path)):
80
+ image = Image.open(image)
81
+ assert isinstance(image, Image.Image)
82
+
83
+ bg_img = Image.new('RGB', (self.imsize, self.imsize), color='white')
84
+ blend_img = self.center_align(bg_img, image, fit=True)
85
+ return blend_img
86
+
87
+ @staticmethod
88
+ def pil_image_to_array(blend_img: Image.Image) -> np.ndarray:
89
+ normalized_array = np.mean(np.array(blend_img, dtype=np.float32), axis=-1) / WHITE # L-only image normalized to [0, 1]
90
+ return normalized_array
91
+
92
+ def get_images_from_fontfile(self, font_file_path: Path, imgmode: str = 'RGB', position: tuple = (0, 0), font_size: int = 128, padding: int = 100) -> List[Image.Image]:
93
+
94
+ imagefont = ImageFont.truetype(str(font_file_path), size=font_size)
95
+ example_characters = EXAMPLE_CHARACTERS
96
+
97
+ font_images: List[Image.Image] = []
98
+
99
+ for character in example_characters:
100
+ x, y, _, _ = imagefont.getbbox(character)
101
+ img = Image.new(imgmode, (x + padding, y + padding), color='white')
102
+ draw = ImageDraw.Draw(img)
103
+
104
+ # bbox = draw.textbbox((0,0), character, font=imagefont)
105
+ # w = bbox[2] - bbox[0]
106
+ # h = bbox[3] - bbox[1]
107
+
108
+ w, h = draw.textsize(character, font=imagefont)
109
+
110
+ img = Image.new(imgmode, (w + padding, h + padding), color='white')
111
+ draw = ImageDraw.Draw(img)
112
+ draw.text(position, text=character, font=imagefont, fill='black')
113
+ img = img.convert(imgmode)
114
+ font_images.append(img)
115
+
116
+ return font_images
117
+
118
+ @staticmethod
119
+ def get_hex_from_char(char: str) -> str:
120
+ assert len(char) == 1
121
+ return f"{ord(char):04X}".upper() # 4-digit hex string
122
+
123
+ @torch.no_grad()
124
+ def inference(self, content_char: str, style_font: Union[str, Path]) -> Tuple[Image.Image, List[Image.Image], Image.Image]:
125
+ assert len(content_char) > 0
126
+ content_char = content_char[:1] # only get the first character if the length > 1
127
+ char_hex = self.get_hex_from_char(content_char)
128
+
129
+ if char_hex not in self.content_character_dict:
130
+ raise ValueError(f"The character {content_char} (hex: {char_hex}) is not supported in this model!")
131
+
132
+ content_image = self.set_image(self.content_character_dict[char_hex])
133
+ style_images: List[Image.Image] = self.get_images_from_fontfile(Path(style_font))
134
+ style_images: List[Image.Image] = [self.set_image(image) for image in style_images]
135
+
136
+ content_image_array = self.pil_image_to_array(content_image)[np.newaxis, np.newaxis, ...] # 1 x C(=1) x H x W
137
+ style_images_array: np.ndarray = np.array([self.pil_image_to_array(image) for image in style_images])[np.newaxis, ...] # 1 x C(=5, # shots) x H x W, k-shots goes to batch
138
+
139
+ content_input_tensor = torch.from_numpy(content_image_array).to(self.device)
140
+ style_input_tensor = torch.from_numpy(style_images_array).to(self.device)
141
+
142
+ generated_images: torch.Tensor = self.model((content_input_tensor, style_input_tensor))
143
+ generated_images = torch.clip(generated_images, 0, 1)
144
+ assert generated_images.size(0) == 1
145
+
146
+ generated_image_numpy = (generated_images[0].cpu().numpy() * 255).astype(np.uint8)[0, ...] # H x W
147
+ return content_image, style_images, Image.fromarray(generated_image_numpy, mode='L')
148
+
149
+
150
+ if __name__ == '__main__':
151
+ hp = OmegaConf.load("config/models/google-font.yaml")
152
+ checkpoint_path = "epoch=199-step=257400.ckpt"
153
+ content_image_dir = "../DATA/NotoSans"
154
+
155
+ servicer = InferenceServicer(hp, checkpoint_path, content_image_dir)
156
+
157
+ style_font = "example_fonts/MaShanZheng-Regular.ttf"
158
+ content_image, style_images, result = servicer.inference("7", style_font)
159
+
160
+ content_image.save("result_content.png")
161
+ for idx, style_image in enumerate(style_images):
162
+ style_image.save(f"result_style_{idx:02d}.png")
163
+ result.save("result_generated.png")
lightning.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch import nn
4
+ import torch.nn.functional as F
5
+ from torch.utils.data import DataLoader
6
+ import pytorch_lightning as pl
7
+ import importlib
8
+ import PIL.Image as Image
9
+
10
+ import models
11
+ import datasets
12
+ from evaluator.ssim import SSIM, MSSSIM
13
+ import lpips
14
+ from models.loss import GANHingeLoss
15
+ from utils import set_logger, magic_image_handler
16
+
17
+ NUM_TEST_SAVE_IMAGE = 10
18
+
19
+
20
+ class FontLightningModule(pl.LightningModule):
21
+ def __init__(self, args):
22
+ super().__init__()
23
+ self.args = args
24
+
25
+ self.losses = {}
26
+ self.metrics = {}
27
+ self.networks = nn.ModuleDict(self.build_models())
28
+ self.module_keys = list(self.networks.keys())
29
+
30
+ self.losses = self.build_losses()
31
+ self.metrics = self.build_metrics()
32
+
33
+ self.opt_tag = {key: None for key in self.networks.keys()}
34
+ self.sched_tag = {key: None for key in self.networks.keys()}
35
+ self.sched_use = False
36
+ # self.automatic_optimization = False
37
+
38
+ self.train_d_content = True
39
+ self.train_d_style = True
40
+
41
+ def build_models(self):
42
+ networks = {}
43
+ for key, hp_model in self.args.models.items():
44
+ key_ = key.lower()
45
+ if 'g' == key_[0]:
46
+ model_ = models.Generator(hp_model)
47
+ elif 'd' == key_[0]:
48
+ model_ = models.PatchGANDiscriminator(hp_model) # TODO: add option for selecting discriminator
49
+ else:
50
+ raise ValueError(f"No key such as {key}")
51
+
52
+ networks[key.lower()] = model_
53
+ return networks
54
+
55
+ def build_losses(self):
56
+ losses_dict = {}
57
+ losses_dict['L1'] = torch.nn.L1Loss()
58
+
59
+ if 'd_content' in self.module_keys:
60
+ losses_dict['GANLoss_content'] = GANHingeLoss()
61
+ if 'd_style' in self.module_keys:
62
+ losses_dict['GANLoss_style'] = GANHingeLoss()
63
+
64
+ return losses_dict
65
+
66
+ def build_metrics(self):
67
+ metrics_dict = nn.ModuleDict()
68
+ metrics_dict['ssim'] = SSIM(val_range=1) # img value is in [0, 1]
69
+ metrics_dict['msssim'] = MSSSIM(weights=[0.45, 0.3, 0.25], val_range=1) # since imsize=64, len(weight)<=3
70
+ metrics_dict['lpips'] = lpips.LPIPS(net='vgg')
71
+ return metrics_dict
72
+
73
+ def configure_optimizers(self):
74
+ optims = {}
75
+ for key, args_model in self.args.models.items():
76
+ key = key.lower()
77
+ if args_model['optim'] is not None:
78
+ args_optim = args_model['optim']
79
+ module, cls = args_optim['class'].rsplit(".", 1)
80
+ O = getattr(importlib.import_module(module, package=None), cls)
81
+ o = O([p for p in self.networks[key].parameters() if p.requires_grad],
82
+ lr=args_optim.lr, betas=args_optim.betas)
83
+
84
+ optims[key] = o
85
+
86
+ optim_module_keys = optims.keys()
87
+
88
+ count = 0
89
+ optim_list = []
90
+
91
+ for _key in self.module_keys:
92
+ if _key in optim_module_keys:
93
+ optim_list.append(optims[_key])
94
+ self.opt_tag[_key] = count
95
+ count += 1
96
+
97
+ return optim_list
98
+
99
+ def forward(self, content_images, style_images):
100
+ return self.networks['g']((content_images, style_images))
101
+
102
+ def common_forward(self, batch, batch_idx):
103
+ loss = {}
104
+ logs = {}
105
+
106
+ content_images = batch['content_images']
107
+ style_images = batch['style_images']
108
+ gt_images = batch['gt_images']
109
+ image_paths = batch['image_paths']
110
+ char_idx = batch['char_idx']
111
+
112
+ generated_images = self(content_images, style_images)
113
+
114
+ # l1 loss
115
+ loss['g_L1'] = self.losses['L1'](generated_images, gt_images)
116
+ loss['g_backward'] = loss['g_L1'] * self.args.logging.lambda_L1
117
+
118
+ # loss for training generator
119
+ if 'd_content' in self.module_keys:
120
+ loss = self.d_content_loss_for_G(content_images, generated_images, loss)
121
+
122
+ if 'd_style' in self.networks.keys():
123
+ loss = self.d_style_loss_for_G(style_images, generated_images, loss)
124
+
125
+ # loss for training discriminator
126
+ generated_images = generated_images.detach()
127
+
128
+ if 'd_content' in self.module_keys:
129
+ if self.train_d_content:
130
+ loss = self.d_content_loss_for_D(content_images, generated_images, gt_images, loss)
131
+
132
+ if 'd_style' in self.module_keys:
133
+ if self.train_d_style:
134
+ loss = self.d_style_loss_for_D(style_images, generated_images, gt_images, loss)
135
+
136
+ logs['content_images'] = content_images
137
+ logs['style_images'] = style_images
138
+ logs['gt_images'] = gt_images
139
+ logs['generated_images'] = generated_images
140
+
141
+ return loss, logs
142
+
143
+ @property
144
+ def automatic_optimization(self):
145
+ return False
146
+
147
+ def training_step(self, batch, batch_idx):
148
+ metrics = {}
149
+ # forward
150
+ loss, logs = self.common_forward(batch, batch_idx)
151
+
152
+ if self.global_step % self.args.logging.freq['train'] == 0:
153
+ with torch.no_grad():
154
+ metrics.update(self.calc_metrics(logs['gt_images'], logs['generated_images']))
155
+
156
+ # backward
157
+ opts = self.optimizers()
158
+
159
+ opts[self.opt_tag['g']].zero_grad()
160
+ self.manual_backward(loss['g_backward'])
161
+
162
+ if 'd_content' in self.module_keys:
163
+ if self.train_d_content:
164
+ opts[self.opt_tag['d_content']].zero_grad()
165
+ self.manual_backward(loss['dcontent_backward'])
166
+
167
+ if 'd_style' in self.module_keys:
168
+ if self.train_d_style:
169
+ opts[self.opt_tag['d_style']].zero_grad()
170
+ self.manual_backward(loss['dstyle_backward'])
171
+
172
+ opts[self.opt_tag['g']].step()
173
+
174
+ if 'd_content' in self.module_keys:
175
+ if self.train_d_content:
176
+ opts[self.opt_tag['d_content']].step()
177
+
178
+ if 'd_style' in self.module_keys:
179
+ if self.train_d_style:
180
+ opts[self.opt_tag['d_style']].step()
181
+
182
+ if self.global_step % self.args.logging.freq['train'] == 0:
183
+ self.custom_log(loss, metrics, logs, mode='train')
184
+
185
+ def validation_step(self, batch, batch_idx):
186
+ metrics = {}
187
+ loss, logs = self.common_forward(batch, batch_idx)
188
+ self.custom_log(loss, metrics, logs, mode='eval')
189
+
190
+ def test_step(self, batch, batch_idx):
191
+ metrics = {}
192
+ loss, logs = self.common_forward(batch, batch_idx)
193
+ metrics.update(self.calc_metrics(logs['gt_images'], logs['generated_images']))
194
+
195
+ if batch_idx < NUM_TEST_SAVE_IMAGE:
196
+ for key, value in logs.items():
197
+ if 'image' in key:
198
+ sample_images = (magic_image_handler(value) * 255)[..., 0].astype(np.uint8)
199
+ Image.fromarray(sample_images).save(f"{batch_idx:02d}_{key}.png")
200
+
201
+ return loss, logs, metrics
202
+
203
+ def test_epoch_end(self, test_step_outputs):
204
+ # do something with the outputs of all test batches
205
+ # all_test_preds = test_step_outputs.metrics
206
+ ssim_list = []
207
+ msssim_list = []
208
+
209
+ for _, test_output in enumerate(test_step_outputs):
210
+
211
+ ssim_list.append(test_output[2]['SSIM'].cpu().numpy())
212
+ msssim_list.append(test_output[2]['MSSSIM'].cpu().numpy())
213
+
214
+ print(f"SSIM: {np.mean(ssim_list)}")
215
+ print(f"MSSSIM: {np.mean(msssim_list)}")
216
+
217
+ def common_dataloader(self, mode='train', batch_size=None):
218
+ dataset_cls = getattr(datasets, self.args.datasets.type)
219
+ dataset_config = getattr(self.args.datasets, mode)
220
+ dataset = dataset_cls(dataset_config, mode=mode)
221
+ _batch_size = batch_size if batch_size is not None else dataset_config.batch_size
222
+ dataloader = DataLoader(dataset,
223
+ shuffle=dataset_config.shuffle,
224
+ batch_size=_batch_size,
225
+ num_workers=dataset_config.num_workers,
226
+ drop_last=True)
227
+
228
+ return dataloader
229
+
230
+ def train_dataloader(self):
231
+ return self.common_dataloader(mode='train')
232
+
233
+ def val_dataloader(self):
234
+ return self.common_dataloader(mode='eval')
235
+
236
+ def test_dataloader(self):
237
+ return self.common_dataloader(mode='eval')
238
+
239
+ def calc_metrics(self, gt_images, generated_images):
240
+ """
241
+
242
+ :param gt_images:
243
+ :param generated_images:
244
+ :return:
245
+ """
246
+ metrics = {}
247
+ _gt = torch.clamp(gt_images.clone(), 0, 1)
248
+ _gen = torch.clamp(generated_images.clone(), 0, 1)
249
+ metrics['SSIM'] = self.metrics['ssim'](_gt, _gen)
250
+ msssim_value = self.metrics['msssim'](_gt, _gen)
251
+ metrics['MSSSIM'] = msssim_value if not torch.isnan(msssim_value) else torch.tensor(0.).type_as(_gt)
252
+ metrics['LPIPS'] = self.metrics['lpips'](_gt * 2 - 1, _gen * 2 - 1).squeeze().mean()
253
+ return metrics
254
+
255
+ # region step
256
+ def d_content_loss_for_G(self, content_images, generated_images, loss):
257
+ pred_generated = self.networks['d_content'](torch.cat([content_images, generated_images], dim=1))
258
+ loss['g_gan_content'] = self.losses['GANLoss_content'](pred_generated, True, for_discriminator=False)
259
+
260
+ loss['g_backward'] += loss['g_gan_content']
261
+ return loss
262
+
263
+ def d_content_loss_for_D(self, content_images, generated_images, gt_images, loss):
264
+ # D
265
+ if 'd_content' in self.module_keys:
266
+ if self.train_d_content:
267
+ pred_gt_images = self.networks['d_content'](torch.cat([content_images, gt_images], dim=1))
268
+ pred_generated_images = self.networks['d_content'](torch.cat([content_images, generated_images], dim=1))
269
+
270
+ loss['dcontent_gt'] = self.losses['GANLoss_content'](pred_gt_images, True, for_discriminator=True)
271
+ loss['dcontent_gen'] = self.losses['GANLoss_content'](pred_generated_images, False, for_discriminator=True)
272
+ loss['dcontent_backward'] = (loss['dcontent_gt'] + loss['dcontent_gen'])
273
+
274
+ return loss
275
+
276
+ def d_style_loss_for_G(self, style_images, generated_images, loss):
277
+ pred_generated = self.networks['d_style'](torch.cat([style_images, generated_images], dim=1))
278
+ loss['g_gan_style'] = self.losses['GANLoss_style'](pred_generated, True, for_discriminator=False)
279
+
280
+ assert self.train_d_style
281
+ loss['g_backward'] += loss['g_gan_style']
282
+ return loss
283
+
284
+ def d_style_loss_for_D(self, style_images, generated_images, gt_images, loss):
285
+ pred_gt_images = self.networks['d_style'](torch.cat([style_images, gt_images], dim=1))
286
+ pred_generated_images = self.networks['d_style'](torch.cat([style_images, generated_images], dim=1))
287
+
288
+ loss['dstyle_gt'] = self.losses['GANLoss_style'](pred_gt_images, True, for_discriminator=True)
289
+ loss['dstyle_gen'] = self.losses['GANLoss_style'](pred_generated_images, False, for_discriminator=True)
290
+ loss['dstyle_backward'] = (loss['dstyle_gt'] + loss['dstyle_gen'])
291
+
292
+ return loss
293
+
294
+ def custom_log(self, loss, metrics, logs, mode):
295
+ # logging values with tensorboard
296
+ for loss_full_key, value in loss.items():
297
+ model_type, loss_type = loss_full_key.split('_')[0], "_".join(loss_full_key.split('_')[1:])
298
+ self.log(f'{model_type}/{mode}_{loss_type}', value)
299
+
300
+ for metric_full_key, value in metrics.items():
301
+ model_type, metric_type = metric_full_key.split('_')[0], "_".join(metric_full_key.split('_')[1:])
302
+ self.log(f'{model_type}/{mode}_{metric_type}', value)
303
+
304
+ # logging images, params, etc.
305
+ tensorboard = self.logger.experiment
306
+ for key, value in logs.items():
307
+ if 'image' in key:
308
+ sample_images = magic_image_handler(value)
309
+ tensorboard.add_image(f"{mode}/" + key, sample_images, self.global_step, dataformats='HWC')
310
+ elif 'param' in key:
311
+ tensorboard.add_histogram(f"{mode}" + key, value, self.global_step)
312
+ else:
313
+ raise RuntimeError(f"Only logging with one of keywords: image, param | current input: {key}")
models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .generator import *
2
+ from .discriminator import *
models/decoder.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from models.module import ResidualBlocks
4
+
5
+ _DECODER_CHANNEL_DEFAULT = 512
6
+
7
+
8
+ class Decoder(nn.Module):
9
+ def __init__(self, hp, in_channels=_DECODER_CHANNEL_DEFAULT, out_channels=1):
10
+ super().__init__()
11
+ self.module = nn.ModuleList()
12
+
13
+ def forward(self, x):
14
+ for block in self.module:
15
+ x = block(x)
16
+ return x
17
+
18
+
19
+ class VanillaDecoder(Decoder):
20
+ def __init__(self, hp, in_channels, out_channels):
21
+ super().__init__(hp, in_channels, out_channels)
22
+ self.depth = hp.decoder.depth
23
+ self.blocks = hp.decoder.residual_blocks
24
+
25
+ self.module = nn.ModuleList()
26
+ if self.blocks > 0:
27
+ self.module.append(ResidualBlocks(in_channels, n_blocks=self.blocks))
28
+
29
+ for layer_idx in range(1, self.depth + 1): # add upsampling layers
30
+ self.module.append(nn.Sequential(
31
+ nn.ConvTranspose2d(in_channels // (2 ** (layer_idx - 1)),
32
+ in_channels // (2 ** layer_idx),
33
+ kernel_size=3, stride=2,
34
+ padding=1, output_padding=1,
35
+ bias=False),
36
+ nn.BatchNorm2d(in_channels // (2 ** layer_idx)),
37
+ nn.ReLU(True)
38
+ ))
39
+
40
+ final = nn.Sequential(
41
+ nn.Conv2d(in_channels // (2 ** self.depth), out_channels, kernel_size=7, padding=3, padding_mode='reflect'),
42
+ nn.Tanh()
43
+ )
44
+
45
+ self.module.append(final)
models/discriminator.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ import omegaconf
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ # FIXME
9
+
10
+
11
+ class PatchGANDiscriminator(nn.Module):
12
+ """Defines a PatchGAN discriminator"""
13
+
14
+ def __init__(self, hp, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
15
+ """Construct a PatchGAN discriminator
16
+
17
+ Parameters:
18
+ ndf (int) -- the number of filters in the last conv layer
19
+ n_layers (int) -- the number of conv layers in the discriminator
20
+ norm_layer -- normalization layer
21
+ """
22
+ super().__init__()
23
+ self.hp = hp
24
+ in_channels = hp.in_channels
25
+
26
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
27
+ use_bias = norm_layer.func == nn.InstanceNorm2d
28
+ else:
29
+ use_bias = norm_layer == nn.InstanceNorm2d
30
+ kw = 4
31
+ padw = 1
32
+ sequence = [nn.Conv2d(in_channels, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
33
+ nf_mult = 1
34
+ nf_mult_prev = 1
35
+ for n in range(1, n_layers): # gradually increase the number of filters
36
+ nf_mult_prev = nf_mult
37
+ nf_mult = min(2 ** n, 8)
38
+ sequence += [
39
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
40
+ norm_layer(ndf * nf_mult),
41
+ nn.LeakyReLU(0.2, True)
42
+ ]
43
+ nf_mult_prev = nf_mult
44
+ nf_mult = min(2 ** n_layers, 8)
45
+ sequence += [
46
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
47
+ norm_layer(ndf * nf_mult),
48
+ nn.LeakyReLU(0.2, True)
49
+ ]
50
+ sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
51
+ self.model = nn.Sequential(*sequence)
52
+
53
+ def forward(self, x):
54
+ return self.model(x)
models/encoder.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from models.module import Conv2d, StyleAttentionBlock
4
+
5
+ _ENCODER_CHANNEL_DEFAULT = 256
6
+
7
+
8
+ class Encoder(nn.Module):
9
+ def __init__(self, hp, in_channels=1, out_channels=_ENCODER_CHANNEL_DEFAULT):
10
+ super().__init__()
11
+ self.hp = hp
12
+ self.module = nn.ModuleList()
13
+
14
+ def forward(self, x):
15
+ for block in self.module:
16
+ x = block(x)
17
+ return x
18
+
19
+
20
+ class ContentVanillaEncoder(Encoder):
21
+ def __init__(self, hp, in_channels, out_channels):
22
+ super().__init__(hp, in_channels, out_channels)
23
+ self.depth = hp.encoder.content.depth
24
+ assert out_channels // (2 ** self.depth) >= in_channels * 2, "Output channel should be increased"
25
+
26
+ self.module = nn.ModuleList()
27
+ self.module.append(
28
+ Conv2d(in_channels, out_channels // (2 ** self.depth),
29
+ kernel_size=7, padding=3, padding_mode='reflect', bias=False)
30
+ )
31
+
32
+ for layer_idx in range(1, self.depth + 1): # downsample
33
+ self.module.append(
34
+ Conv2d(out_channels // (2 ** (self.depth - layer_idx + 1)),
35
+ out_channels // (2 ** (self.depth - layer_idx)),
36
+ kernel_size=3, stride=2, padding=1, bias=False)
37
+ )
38
+
39
+
40
+ class StyleVanillaEncoder(Encoder):
41
+ def __init__(self, hp, in_channels, out_channels):
42
+ super().__init__(hp, in_channels, out_channels)
43
+ self.depth = hp.encoder.style.depth
44
+ assert out_channels // (2 ** self.depth) >= in_channels * 2, "Output channel should be increased"
45
+
46
+ encoder_module = []
47
+ encoder_module.append(
48
+ Conv2d(in_channels, out_channels // (2 ** self.depth),
49
+ kernel_size=7, padding=3, padding_mode='reflect', bias=False)
50
+ )
51
+
52
+ for layer_idx in range(1, self.depth + 1): # downsample
53
+ encoder_module.append(
54
+ Conv2d(out_channels // (2 ** (self.depth - layer_idx + 1)),
55
+ out_channels // (2 ** (self.depth - layer_idx)),
56
+ kernel_size=3, stride=2, padding=1, bias=False)
57
+ )
58
+ self.add_module("encoder_module", nn.Sequential(*encoder_module))
59
+ self.add_module("attention_module", StyleAttentionBlock(out_channels))
60
+
61
+ def forward(self, x):
62
+ B, K, H, W = x.size()
63
+ out = self.encoder_module(x.view(-1, 1, H, W))
64
+ out = self.attention_module(out, B, K)
65
+ return out
models/generator.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from . import encoder, decoder
4
+
5
+ class Generator(nn.Module):
6
+ def __init__(self, hp, in_channels=1):
7
+ super().__init__()
8
+ self.hp = hp
9
+ _ngf = 64
10
+ hidden_dim = _ngf * 4
11
+ self.content_encoder = getattr(encoder, self.hp.encoder.content.type)(self.hp, in_channels, hidden_dim)
12
+ self.style_encoder = getattr(encoder, self.hp.encoder.style.type)(self.hp, in_channels, hidden_dim)
13
+ self.decoder = getattr(decoder, self.hp.decoder.type)(self.hp, hidden_dim * 2, in_channels)
14
+
15
+ def forward(self, images):
16
+ content_images, style_images = images
17
+ content_feature = self.content_encoder(content_images)
18
+ style_images = style_images * 2 - 1 # pixel value range -1 to 1
19
+ style_feature = self.style_encoder(style_images) # K-shot as batch
20
+ _, _, H, W = content_feature.size()
21
+ out = self.decoder(torch.cat([content_feature, style_feature.expand(-1, -1, H, W)], dim=1))
22
+ return out
models/loss.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class GANHingeLoss(nn.Module):
5
+ def __init__(self):
6
+ super(GANHingeLoss, self).__init__()
7
+ self.relu = nn.ReLU()
8
+
9
+ def __call__(self, pred, is_real, for_discriminator):
10
+ if for_discriminator:
11
+ if is_real:
12
+ return self.relu(1 - pred).mean()
13
+ return self.relu(1 + pred).mean()
14
+
15
+ assert is_real, "The generator's hinge loss must be aiming for real"
16
+ return -1.0 * pred.mean()
models/module.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class Conv2d(nn.Module):
6
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
7
+ padding=0, padding_mode='zeros', bias=True, residual=False):
8
+ super(Conv2d, self).__init__()
9
+ self.conv_block = nn.Sequential(
10
+ nn.Conv2d(in_channels, out_channels, kernel_size, stride,
11
+ padding, padding_mode=padding_mode, bias=bias),
12
+ nn.BatchNorm2d(out_channels)
13
+ )
14
+ self.residual = residual
15
+ self.act = nn.ReLU()
16
+
17
+ def forward(self, x):
18
+ out = self.conv_block(x)
19
+ if self.residual:
20
+ out += x
21
+ out = self.act(out)
22
+ return out
23
+
24
+
25
+ class ResnetBlock(nn.Module):
26
+ def __init__(self, channel, padding_mode, norm_layer=nn.BatchNorm2d, bias=False):
27
+ super().__init__()
28
+ if padding_mode not in ['reflect', 'zero']:
29
+ raise NotImplementedError(f"{padding_mode} is not supported!")
30
+
31
+ self.block = nn.Sequential(
32
+ nn.Conv2d(channel, channel, kernel_size=3, padding=1, padding_mode=padding_mode, bias=bias),
33
+ norm_layer(channel)
34
+ )
35
+ self.act = nn.ReLU()
36
+
37
+ def forward(self, x):
38
+ out = self.block(x)
39
+ out = out + x
40
+ out = self.act(out)
41
+ return out
42
+
43
+
44
+ class ResidualBlocks(nn.Module):
45
+ def __init__(self, channel, n_blocks=6):
46
+ super().__init__()
47
+ model = []
48
+ for i in range(n_blocks): # add ResNet blocks
49
+ model += [ResnetBlock(channel, padding_mode='reflect')]
50
+
51
+ self.module = nn.Sequential(*model)
52
+
53
+ def forward(self, x):
54
+ return self.module(x)
55
+
56
+
57
+ class SelfAttentionBlock(nn.Module):
58
+
59
+ def __init__(self, in_dim):
60
+ super().__init__()
61
+ self.feature_dim = in_dim // 8
62
+ self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=self.feature_dim, kernel_size=1)
63
+ self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=self.feature_dim, kernel_size=1)
64
+ self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
65
+ self.gamma = nn.Parameter(torch.zeros(1))
66
+ self.softmax = nn.Softmax(dim=-1)
67
+
68
+ def forward(self, x):
69
+ B, C, H, W = x.size()
70
+ _query = self.query_conv(x).view(B, -1, H * W).permute(0, 2, 1) # B x C x (H'*W')
71
+ _key = self.key_conv(x).view(B, -1, H * W) # B x C x (H'*W')
72
+ attn_matrix = torch.bmm(_query, _key)
73
+ attention = self.softmax(attn_matrix) # B x (H'*W') x (H'*W')
74
+ _value = self.value_conv(x).view(B, -1, H * W) # B X C X (H * W)
75
+
76
+ out = torch.bmm(_value, attention.permute(0, 2, 1))
77
+ out = out.view(B, C, H, W)
78
+
79
+ out = self.gamma * out + x
80
+ return out
81
+
82
+
83
+ class ContextAwareAttentionBlock(nn.Module):
84
+
85
+ def __init__(self, in_channels, hidden_dim=128):
86
+ super().__init__()
87
+ self.self_attn = SelfAttentionBlock(in_channels)
88
+ self.fc = nn.Linear(in_channels, hidden_dim)
89
+ self.context_vector = nn.Linear(hidden_dim, 1, bias=False)
90
+ self.softmax = nn.Softmax(dim=1)
91
+
92
+ def forward(self, style_features):
93
+ B, C, H, W = style_features.size()
94
+ h = self.self_attn(style_features)
95
+ h = h.permute(0, 2, 3, 1).reshape(-1, C)
96
+ h = torch.tanh(self.fc(h)) # (B*H*W) x self.hidden_dim
97
+ h = self.context_vector(h) # (B*H*W) x 1
98
+ attention_score = self.softmax(h.view(B, H * W)).view(B, 1, H, W) # B x 1 x H x W
99
+ return torch.sum(style_features * attention_score, dim=[2, 3]) # B x C
100
+
101
+
102
+ class LayerAttentionBlock(nn.Module):
103
+ """from FTransGAN
104
+ """
105
+
106
+ def __init__(self, in_channels):
107
+ super().__init__()
108
+ self.in_channels = in_channels
109
+ self.width_feat = 4
110
+ self.height_feat = 4
111
+ self.fc = nn.Linear(self.in_channels * self.width_feat * self.height_feat, 3)
112
+ self.softmax = nn.Softmax(dim=1)
113
+
114
+ def forward(self, style_features, style_features_1, style_features_2, style_features_3, B, K):
115
+ style_features = torch.mean(style_features.view(B, K, self.in_channels, self.height_feat, self.width_feat), dim=1)
116
+ style_features = style_features.view(B, -1)
117
+ weight = self.softmax(self.fc(style_features))
118
+
119
+ style_features_1 = torch.mean(style_features_1.view(B, K, self.in_channels), dim=1)
120
+ style_features_2 = torch.mean(style_features_2.view(B, K, self.in_channels), dim=1)
121
+ style_features_3 = torch.mean(style_features_3.view(B, K, self.in_channels), dim=1)
122
+
123
+ style_features = (style_features_1 * weight.narrow(1, 0, 1) +
124
+ style_features_2 * weight.narrow(1, 1, 1) +
125
+ style_features_3 * weight.narrow(1, 2, 1))
126
+ style_features = style_features.view(B, self.in_channels, 1, 1)
127
+ return style_features
128
+
129
+
130
+ class StyleAttentionBlock(nn.Module):
131
+ """from FTransGAN
132
+ """
133
+
134
+ def __init__(self, in_channels):
135
+ super().__init__()
136
+ self.num_local_attention = 3
137
+ for module_idx in range(1, self.num_local_attention + 1):
138
+ self.add_module(f"local_attention_{module_idx}",
139
+ ContextAwareAttentionBlock(in_channels))
140
+
141
+ for module_idx in range(1, self.num_local_attention):
142
+ self.add_module(f"downsample_{module_idx}",
143
+ Conv2d(in_channels, in_channels,
144
+ kernel_size=3, stride=2, padding=1, bias=False))
145
+
146
+ self.add_module(f"layer_attention", LayerAttentionBlock(in_channels))
147
+
148
+ def forward(self, x, B, K):
149
+ feature_1 = self.local_attention_1(x)
150
+
151
+ x = self.downsample_1(x)
152
+ feature_2 = self.local_attention_2(x)
153
+
154
+ x = self.downsample_2(x)
155
+ feature_3 = self.local_attention_3(x)
156
+
157
+ out = self.layer_attention(x, feature_1, feature_2, feature_3, B, K)
158
+
159
+ return out
pretrained/.gitkeep ADDED
File without changes
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ pytorch-lightning==1.6.0
2
+ omegaconf
3
+ fire
4
+ lpips
5
+ tensorboard
6
+ pillow==8.4.0
trainer.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ from pathlib import Path
4
+
5
+ from omegaconf import OmegaConf
6
+ import pytorch_lightning as pl
7
+ from pytorch_lightning.callbacks import ModelCheckpoint
8
+ from pytorch_lightning.loggers import TensorBoardLogger
9
+
10
+ from lightning import FontLightningModule
11
+ from utils import save_files
12
+
13
+
14
+ def load_configuration(path_config):
15
+ setting = OmegaConf.load(path_config)
16
+
17
+ # load hyperparameter
18
+ hp = OmegaConf.load(setting.config.dataset)
19
+ hp = OmegaConf.merge(hp, OmegaConf.load(setting.config.model))
20
+ hp = OmegaConf.merge(hp, OmegaConf.load(setting.config.logging))
21
+
22
+ # with lightning setting
23
+ if hasattr(setting.config, 'lightning'):
24
+ pl_config = OmegaConf.load(setting.config.lightning)
25
+ if hasattr(pl_config, 'pl_config'):
26
+ return hp, pl_config.pl_config
27
+ return hp, pl_config
28
+
29
+ # without lightning setting
30
+ return hp
31
+
32
+
33
+ def parse_args():
34
+ parser = argparse.ArgumentParser(description='Code to train font style transfer')
35
+
36
+ parser.add_argument("--config", type=str, default="./config/setting.yaml",
37
+ help="Config file for training")
38
+ parser.add_argument('-g', '--gpus', type=str, default='0,1',
39
+ help="Number of gpus to use (e.g. '0,1,2,3'). Will use all if not given.")
40
+ parser.add_argument('-p', '--resume_checkpoint_path', type=str, default=None,
41
+ help="path of checkpoint for resuming")
42
+
43
+ args = parser.parse_args()
44
+ return args
45
+
46
+
47
+ def main():
48
+ args = parse_args()
49
+ hp, pl_config = load_configuration(args.config)
50
+
51
+ logging_dir = Path(hp.logging.log_dir)
52
+
53
+ # call lightning module
54
+ font_pl = FontLightningModule(hp)
55
+
56
+ # set logging
57
+ hp.logging['log_dir'] = logging_dir / 'tensorboard'
58
+ savefiles = []
59
+ for reg in hp.logging.savefiles:
60
+ savefiles += glob.glob(reg)
61
+ hp.logging['log_dir'].mkdir(exist_ok=True)
62
+ save_files(str(logging_dir), savefiles)
63
+
64
+ # set tensorboard logger
65
+ logger = TensorBoardLogger(str(logging_dir), name=str(hp.logging.seed))
66
+
67
+ # set checkpoing callback
68
+ weights_save_path = logging_dir / 'checkpoint' / str(hp.logging.seed)
69
+ weights_save_path.mkdir(exist_ok=True)
70
+ checkpoint_callback = ModelCheckpoint(
71
+ dirpath=str(weights_save_path),
72
+ **pl_config.checkpoint.callback
73
+ )
74
+
75
+ # set lightning trainer
76
+ trainer = pl.Trainer(
77
+ logger=logger,
78
+ gpus=-1 if args.gpus is None else args.gpus,
79
+ callbacks=[checkpoint_callback],
80
+ **pl_config.trainer
81
+ )
82
+
83
+ # let's train
84
+ trainer.fit(font_pl)
85
+
86
+
87
+ if __name__ == "__main__":
88
+ main()
trainer.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ python trainer.py --config ./config/setting-google-font.yaml --gpus 0,1,2,3
utils/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .logger import *
2
+ from .tb import *
3
+ from .util import *
utils/logger.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fire
2
+ import logging
3
+ import time
4
+
5
+ def _custom_logger(name):
6
+ fmt = '[{}|%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >>> %(message)s'.format(name)
7
+ fmt_date = '%Y-%m-%d_%T %Z'
8
+
9
+ handler = logging.StreamHandler()
10
+
11
+ formatter = logging.Formatter(fmt, fmt_date)
12
+ handler.setFormatter(formatter)
13
+
14
+ logger = logging.getLogger(name)
15
+ logger.setLevel(logging.DEBUG)
16
+ logger.addHandler(handler)
17
+
18
+ def set_logger(logger_name, level):
19
+ try:
20
+ time.tzset()
21
+ except AttributeError as e:
22
+ print(e)
23
+ print("Skipping timezone setting.")
24
+ _custom_logger(name=logger_name)
25
+ logger = logging.getLogger(logger_name)
26
+ if level == 'DEBUG':
27
+ logger.setLevel(logging.DEBUG)
28
+ elif level == 'INFO':
29
+ logger.setLevel(logging.INFO)
30
+ elif level == 'WARNING':
31
+ logger.setLevel(logging.WARNING)
32
+ elif level == 'ERROR':
33
+ logger.setLevel(logging.ERROR)
34
+ elif level == 'CRITICAL':
35
+ logger.setLevel(logging.CRITICAL)
36
+ return logger
37
+
38
+ if __name__ == '__main__':
39
+ set_logger("test", "DEBUG")
utils/tb.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ def magic_image_handler(img):
5
+ if isinstance(img, torch.Tensor):
6
+ img = img.detach().cpu().numpy()
7
+ if img.ndim == 3:
8
+ img = img.transpose((1, 2, 0))
9
+ elif img.ndim == 2:
10
+ img = np.repeat(img[..., np.newaxis], 3, axis=2)
11
+ elif img.ndim == 4:
12
+ img = img[:4] # first 4 batch
13
+ img = np.concatenate(img, axis=-1)
14
+ img = img.transpose((1, 2, 0))
15
+ elif img.ndim == 5:
16
+ img = img[:4] # first 4 batch
17
+ img = np.concatenate(img, axis=-2)
18
+ img = np.concatenate(img, axis=-1)
19
+ img = img.transpose((1, 2, 0))
20
+ else:
21
+ raise ValueError(f'img ndim is {img.ndim}, should be 2~4')
22
+ if img.shape[-1] != 1 or img.shape[-1] != 3:
23
+ img = np.expand_dims(np.concatenate([img[..., i] for i in range(img.shape[-1])], axis=0), -1)
24
+ img = np.clip(img, a_min=0, a_max=255)
25
+ return img
utils/util.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import shutil
3
+
4
+
5
+ def save_files(path_save_, savefiles):
6
+ path_save = Path(path_save_)
7
+ path_save.mkdir(exist_ok=True)
8
+
9
+ for savefile in savefiles:
10
+ parents_dir = Path(savefile).parents
11
+ if len(parents_dir) >= 1:
12
+ for parent_dir in list(parents_dir)[::-1]:
13
+ target_dir = path_save / parent_dir
14
+ target_dir.mkdir(exist_ok=True)
15
+ try:
16
+ shutil.copy2(savefile, str(path_save / savefile))
17
+ except Exception as e:
18
+ # skip the file
19
+ print(f'{e} occured while saving {savefile}')
20
+
21
+ return # success
22
+
23
+
24
+ if __name__ == "__main__":
25
+ import glob
26
+ savefiles = glob.glob('config/*.yaml')
27
+ savefiles += glob.glob('config/**/*.yaml')
28
+ save_files(".temp", savefiles)