File size: 5,177 Bytes
1ba3df3
 
 
 
 
c6f5e45
1ba3df3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6f5e45
1ba3df3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import argparse
from pathlib import Path
from typing import Optional, Union, Tuple, List
import subprocess
from itertools import chain

import gradio as gr
from PIL import Image
from omegaconf import OmegaConf, DictConfig

from inference import InferenceServicer

PATH_DOCS = os.getenv("PATH_DOCS", default="docs/ml-font-style-transfer.md")
MODEL_CONFIG = os.getenv("MODEL_CONFIG", default="config/models/google-font.yaml")

MODEL_CHECKPOINT_PATH = os.getenv("MODEL_CHECKPOINT_PATH", default=None)
NOTO_SANS_ZIP_PATH = os.getenv("NOTO_SANS_ZIP_PATH", default=None)

LOCAL_CHECKPOINT_PATH = "checkpoint/checkpoint.ckpt"
LOCAL_NOTO_ZIP_PATH = "data/NotoSans.zip"

if MODEL_CHECKPOINT_PATH is not None:
    subprocess.call(f"wget --no-check-certificate -O {LOCAL_CHECKPOINT_PATH} {MODEL_CHECKPOINT_PATH}", shell=True)
if NOTO_SANS_ZIP_PATH is not None:
    subprocess.call(f"wget --no-check-certificate -O {LOCAL_NOTO_ZIP_PATH} {NOTO_SANS_ZIP_PATH}", shell=True)
subprocess.call(f"unzip data/NotoSans.zip -d {str(Path(LOCAL_NOTO_ZIP_PATH).parent)}", shell=True)

assert Path("checkpoint/checkpoint.ckpt").exists()
assert Path("data/NotoSans").exists()

EXAMPLE_FONTS = sorted([str(x) for x in chain(Path("example_fonts").glob("*.ttf"), Path("example_fonts").glob("*.otf"))])

def parse_args():

    parser = argparse.ArgumentParser(description="Augmentation simulator for NetsPresso Trainer")

    # -------- User arguments ----------------------------------------

    parser.add_argument(
        '--docs', type=Path, default=PATH_DOCS,
        help="Docs string file")

    parser.add_argument(
        '--config', type=Path, default=MODEL_CONFIG,
        help="Config for model")

    parser.add_argument(
        '--local', action='store_true',
        help="Whether to run in local environment or not")

    parser.add_argument(
        '--port', type=int, default=50003,
        help="Service port (only applicable when running on local server)")

    args, _ = parser.parse_known_args()

    return args

class InferenceServiceResolver(InferenceServicer):
    def __init__(self, hp, checkpoint_path, content_image_dir, imsize=64, gpu_id='0') -> None:
        super().__init__(hp, checkpoint_path, content_image_dir, imsize, gpu_id)
        
    def generate(self, content_char: str, style_font: Union[str, Path]) -> List[Image.Image]:
        try:
            content_image, style_images, result = self.inference(content_char=content_char, style_font=style_font)
            return [content_image, *style_images, result]
        except Exception as e:
            raise gr.Error(str(e))

def launch_gradio(docs_path: Path, hp: DictConfig, checkpoint_path: Path, content_image_dir: Path, is_local: bool, port: Optional[int] = None):
    
    servicer = InferenceServiceResolver(hp, checkpoint_path, content_image_dir, gpu_id=None)
    with gr.Blocks(title="Multilingual Font Style Transfer (training with Google Fonts)") as demo:
        gr.Markdown(docs_path.read_text())
        with gr.Row(equal_height=True):
            character_input = gr.Textbox(max_lines=1, value="7", info="Only single character is acceptable (e.g. '간', '7', or 'ជ')")
            style_font = gr.Dropdown(label="Select example font: ", choices=EXAMPLE_FONTS, value=EXAMPLE_FONTS[0])
            run_button = gr.Button(value="Generate", variant='primary')
        
        with gr.Row(equal_height=True):
            with gr.Column(scale=1):
                with gr.Group():
                    gr.Markdown(f"<center><h3>Content character</h3></center>")
                    content_char = gr.Image(label="Content character", show_label=False)
            with gr.Column(scale=5):
                with gr.Group():
                    gr.Markdown(f"<center><h3>Style font images</h3></center>")
                    with gr.Row(equal_height=True):
                        style_char_1 = gr.Image(label="Style #1", show_label=False)
                        style_char_2 = gr.Image(label="Style #2", show_label=False)
                        style_char_3 = gr.Image(label="Style #3", show_label=False)
                        style_char_4 = gr.Image(label="Style #4", show_label=False)
                        style_char_5 = gr.Image(label="Style #5", show_label=False)
            with gr.Column(scale=1):
                with gr.Group():
                    gr.Markdown(f"<center><h3>Generated font image</h3></center>")
                    generated_font = gr.Image(label="Generated font image", show_label=False)

        outputs = [content_char, style_char_1, style_char_2, style_char_3, style_char_4, style_char_5, generated_font]
        run_inputs = [character_input, style_font]
        run_button.click(servicer.generate, inputs=run_inputs, outputs=outputs)

    if is_local:
        demo.launch(server_name="0.0.0.0", server_port=port)
    else:
        demo.launch()


if __name__ == "__main__":
    args = parse_args()
    
    hp = OmegaConf.load(args.config)
    checkpoint_path = Path(LOCAL_CHECKPOINT_PATH)
    content_image_dir = Path(LOCAL_NOTO_ZIP_PATH).with_suffix("")
    
    launch_gradio(args.docs, hp, checkpoint_path, content_image_dir, args.local, args.port)