Spaces:
Runtime error
Runtime error
import os | |
import argparse | |
from pathlib import Path | |
from typing import Optional, Union, Tuple, List | |
import subprocess | |
import gradio as gr | |
from PIL import Image | |
from omegaconf import OmegaConf, DictConfig | |
from inference import InferenceServicer | |
PATH_DOCS = os.getenv("PATH_DOCS", default="docs/ml-font-style-transfer.md") | |
MODEL_CONFIG = os.getenv("MODEL_CONFIG", default="config/models/google-font.yaml") | |
MODEL_CHECKPOINT_PATH = os.getenv("MODEL_CHECKPOINT_PATH", default=None) | |
NOTO_SANS_ZIP_PATH = os.getenv("NOTO_SANS_ZIP_PATH", default=None) | |
LOCAL_CHECKPOINT_PATH = "checkpoint/checkpoint.ckpt" | |
LOCAL_NOTO_ZIP_PATH = "data/NotoSans.zip" | |
if MODEL_CHECKPOINT_PATH is not None: | |
subprocess.call(f"wget --no-check-certificate -O {LOCAL_CHECKPOINT_PATH} {MODEL_CHECKPOINT_PATH}", shell=True) | |
if NOTO_SANS_ZIP_PATH is not None: | |
subprocess.call(f"wget --no-check-certificate -O {LOCAL_NOTO_ZIP_PATH} {NOTO_SANS_ZIP_PATH}", shell=True) | |
subprocess.call(f"unzip data/NotoSans.zip -d {str(Path(LOCAL_NOTO_ZIP_PATH).parent)}", shell=True) | |
assert Path("checkpoint/checkpoint.ckpt").exists() | |
assert Path("data/NotoSans").exists() | |
EXAMPLE_FONTS = sorted([ | |
"example_fonts/BalooDa2-Bold.ttf", | |
"example_fonts/BalooDa2-Regular.ttf", | |
"example_fonts/Lalezar-Regular.ttf", | |
"example_fonts/MaShanZheng-Regular.ttf", | |
]) | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="Augmentation simulator for NetsPresso Trainer") | |
# -------- User arguments ---------------------------------------- | |
parser.add_argument( | |
'--docs', type=Path, default=PATH_DOCS, | |
help="Docs string file") | |
parser.add_argument( | |
'--config', type=Path, default=MODEL_CONFIG, | |
help="Config for model") | |
parser.add_argument( | |
'--local', action='store_true', | |
help="Whether to run in local environment or not") | |
parser.add_argument( | |
'--port', type=int, default=50003, | |
help="Service port (only applicable when running on local server)") | |
args, _ = parser.parse_known_args() | |
return args | |
class InferenceServiceResolver(InferenceServicer): | |
def __init__(self, hp, checkpoint_path, content_image_dir, imsize=64, gpu_id='0') -> None: | |
super().__init__(hp, checkpoint_path, content_image_dir, imsize, gpu_id) | |
def generate(self, content_char: str, style_font: Union[str, Path]) -> List[Image.Image]: | |
try: | |
content_image, style_images, result = self.inference(content_char=content_char, style_font=style_font) | |
return [content_image, *style_images, result] | |
except Exception as e: | |
raise gr.Error(str(e)) | |
def launch_gradio(docs_path: Path, hp: DictConfig, checkpoint_path: Path, content_image_dir: Path, is_local: bool, port: Optional[int] = None): | |
servicer = InferenceServiceResolver(hp, checkpoint_path, content_image_dir, gpu_id=None) | |
with gr.Blocks(title="Multilingual Font Style Transfer (training with Google Fonts)") as demo: | |
gr.Markdown(docs_path.read_text()) | |
with gr.Row(equal_height=True): | |
character_input = gr.Textbox(max_lines=1, value="7", info="Only single character is acceptable (e.g. '간', '7', or 'ជ')") | |
style_font = gr.Dropdown(label="Select example font: ", choices=EXAMPLE_FONTS, value=EXAMPLE_FONTS[0]) | |
run_button = gr.Button(value="Generate", variant='primary') | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=1): | |
with gr.Group(): | |
gr.Markdown(f"<center><h3>Content character</h3></center>") | |
content_char = gr.Image(label="Content character", show_label=False) | |
with gr.Column(scale=5): | |
with gr.Group(): | |
gr.Markdown(f"<center><h3>Style font images</h3></center>") | |
with gr.Row(equal_height=True): | |
style_char_1 = gr.Image(label="Style #1", show_label=False) | |
style_char_2 = gr.Image(label="Style #2", show_label=False) | |
style_char_3 = gr.Image(label="Style #3", show_label=False) | |
style_char_4 = gr.Image(label="Style #4", show_label=False) | |
style_char_5 = gr.Image(label="Style #5", show_label=False) | |
with gr.Column(scale=1): | |
with gr.Group(): | |
gr.Markdown(f"<center><h3>Generated font image</h3></center>") | |
generated_font = gr.Image(label="Generated font image", show_label=False) | |
outputs = [content_char, style_char_1, style_char_2, style_char_3, style_char_4, style_char_5, generated_font] | |
run_inputs = [character_input, style_font] | |
run_button.click(servicer.generate, inputs=run_inputs, outputs=outputs) | |
if is_local: | |
demo.launch(server_name="0.0.0.0", server_port=port) | |
else: | |
demo.launch() | |
if __name__ == "__main__": | |
args = parse_args() | |
hp = OmegaConf.load(args.config) | |
checkpoint_path = Path(LOCAL_CHECKPOINT_PATH) | |
content_image_dir = Path(LOCAL_NOTO_ZIP_PATH).with_suffix("") | |
launch_gradio(args.docs, hp, checkpoint_path, content_image_dir, args.local, args.port) |