deepkyu's picture
initial commit
1ba3df3
raw
history blame
5.22 kB
import os
import argparse
from pathlib import Path
from typing import Optional, Union, Tuple, List
import subprocess
import gradio as gr
from PIL import Image
from omegaconf import OmegaConf, DictConfig
from inference import InferenceServicer
PATH_DOCS = os.getenv("PATH_DOCS", default="docs/ml-font-style-transfer.md")
MODEL_CONFIG = os.getenv("MODEL_CONFIG", default="config/models/google-font.yaml")
MODEL_CHECKPOINT_PATH = os.getenv("MODEL_CHECKPOINT_PATH", default=None)
NOTO_SANS_ZIP_PATH = os.getenv("NOTO_SANS_ZIP_PATH", default=None)
LOCAL_CHECKPOINT_PATH = "checkpoint/checkpoint.ckpt"
LOCAL_NOTO_ZIP_PATH = "data/NotoSans.zip"
if MODEL_CHECKPOINT_PATH is not None:
subprocess.call(f"wget --no-check-certificate -O {LOCAL_CHECKPOINT_PATH} {MODEL_CHECKPOINT_PATH}", shell=True)
if NOTO_SANS_ZIP_PATH is not None:
subprocess.call(f"wget --no-check-certificate -O {LOCAL_NOTO_ZIP_PATH} {NOTO_SANS_ZIP_PATH}", shell=True)
subprocess.call(f"unzip data/NotoSans.zip -d {str(Path(LOCAL_NOTO_ZIP_PATH).parent)}", shell=True)
assert Path("checkpoint/checkpoint.ckpt").exists()
assert Path("data/NotoSans").exists()
EXAMPLE_FONTS = sorted([
"example_fonts/BalooDa2-Bold.ttf",
"example_fonts/BalooDa2-Regular.ttf",
"example_fonts/Lalezar-Regular.ttf",
"example_fonts/MaShanZheng-Regular.ttf",
])
def parse_args():
parser = argparse.ArgumentParser(description="Augmentation simulator for NetsPresso Trainer")
# -------- User arguments ----------------------------------------
parser.add_argument(
'--docs', type=Path, default=PATH_DOCS,
help="Docs string file")
parser.add_argument(
'--config', type=Path, default=MODEL_CONFIG,
help="Config for model")
parser.add_argument(
'--local', action='store_true',
help="Whether to run in local environment or not")
parser.add_argument(
'--port', type=int, default=50003,
help="Service port (only applicable when running on local server)")
args, _ = parser.parse_known_args()
return args
class InferenceServiceResolver(InferenceServicer):
def __init__(self, hp, checkpoint_path, content_image_dir, imsize=64, gpu_id='0') -> None:
super().__init__(hp, checkpoint_path, content_image_dir, imsize, gpu_id)
def generate(self, content_char: str, style_font: Union[str, Path]) -> List[Image.Image]:
try:
content_image, style_images, result = self.inference(content_char=content_char, style_font=style_font)
return [content_image, *style_images, result]
except Exception as e:
raise gr.Error(str(e))
def launch_gradio(docs_path: Path, hp: DictConfig, checkpoint_path: Path, content_image_dir: Path, is_local: bool, port: Optional[int] = None):
servicer = InferenceServiceResolver(hp, checkpoint_path, content_image_dir, gpu_id=None)
with gr.Blocks(title="Multilingual Font Style Transfer (training with Google Fonts)") as demo:
gr.Markdown(docs_path.read_text())
with gr.Row(equal_height=True):
character_input = gr.Textbox(max_lines=1, value="7", info="Only single character is acceptable (e.g. '간', '7', or 'ជ')")
style_font = gr.Dropdown(label="Select example font: ", choices=EXAMPLE_FONTS, value=EXAMPLE_FONTS[0])
run_button = gr.Button(value="Generate", variant='primary')
with gr.Row(equal_height=True):
with gr.Column(scale=1):
with gr.Group():
gr.Markdown(f"<center><h3>Content character</h3></center>")
content_char = gr.Image(label="Content character", show_label=False)
with gr.Column(scale=5):
with gr.Group():
gr.Markdown(f"<center><h3>Style font images</h3></center>")
with gr.Row(equal_height=True):
style_char_1 = gr.Image(label="Style #1", show_label=False)
style_char_2 = gr.Image(label="Style #2", show_label=False)
style_char_3 = gr.Image(label="Style #3", show_label=False)
style_char_4 = gr.Image(label="Style #4", show_label=False)
style_char_5 = gr.Image(label="Style #5", show_label=False)
with gr.Column(scale=1):
with gr.Group():
gr.Markdown(f"<center><h3>Generated font image</h3></center>")
generated_font = gr.Image(label="Generated font image", show_label=False)
outputs = [content_char, style_char_1, style_char_2, style_char_3, style_char_4, style_char_5, generated_font]
run_inputs = [character_input, style_font]
run_button.click(servicer.generate, inputs=run_inputs, outputs=outputs)
if is_local:
demo.launch(server_name="0.0.0.0", server_port=port)
else:
demo.launch()
if __name__ == "__main__":
args = parse_args()
hp = OmegaConf.load(args.config)
checkpoint_path = Path(LOCAL_CHECKPOINT_PATH)
content_image_dir = Path(LOCAL_NOTO_ZIP_PATH).with_suffix("")
launch_gradio(args.docs, hp, checkpoint_path, content_image_dir, args.local, args.port)