deepkyu's picture
initial commit
1ba3df3
raw
history blame
No virus
6.97 kB
from pathlib import Path
from typing import Dict, List, Union, Tuple
from omegaconf import OmegaConf
import numpy as np
import torch
from torch import nn
from PIL import Image, ImageDraw, ImageFont
import models
GENERATOR_PREFIX = "networks.g."
WHITE = 255
EXAMPLE_CHARACTERS = ['A', 'B', 'C', 'D', 'E']
class InferenceServicer:
def __init__(self, hp, checkpoint_path, content_image_dir, imsize=64, gpu_id='0') -> None:
self.hp = hp
self.imsize = imsize
if gpu_id is None:
self.device = torch.device(f'cuda:0') if torch.cuda.is_available() else 'cpu'
else:
self.device = torch.device(f'cuda:{gpu_id}')
model_config = self.hp.models.G
self.model: nn.Module = models.Generator(model_config)
# Load Generator model weight
model_state_dict_pl = torch.load(checkpoint_path, map_location='cpu')
generator_state_dict = self.convert_generator_state_dict(model_state_dict_pl)
self.model.load_state_dict(generator_state_dict)
self.model.to(device=self.device)
self.model.eval()
# Setting Content font files
self.content_character_dict = self.load_content_character_dict(Path(content_image_dir))
@staticmethod
def convert_generator_state_dict(model_state_dict_pl):
generator_prefix = GENERATOR_PREFIX
generator_state_dict = {}
for module_name, module_state in model_state_dict_pl['state_dict'].items():
if module_name.startswith(generator_prefix):
generator_state_dict[module_name[len(generator_prefix):]] = module_state
return generator_state_dict
@staticmethod
def load_content_character_dict(content_image_dir: Path) -> Dict[str, Path]:
content_character_dict = {}
for filepath in content_image_dir.glob("**/*.png"):
content_character_dict[filepath.stem] = filepath
return content_character_dict
@staticmethod
def center_align(bg_img: Image.Image, item_img: Image.Image, fit=False) -> Image.Image:
bg_img = bg_img.copy()
item_img = item_img.copy()
item_w, item_h = item_img.size
W, H = bg_img.size
if fit:
item_ratio = item_w / item_h
bg_ratio = W / H
if bg_ratio > item_ratio:
# height fitting
resize_ratio = H / item_h
else:
# width fitting
resize_ratio = W / item_w
item_img = item_img.resize((int(item_w * resize_ratio), int(item_h * resize_ratio)))
item_w, item_h = item_img.size
bg_img.paste(item_img, ((W - item_w) // 2, (H - item_h) // 2))
return bg_img
def set_image(self, image: Union[Path, Image.Image]) -> Image.Image:
if isinstance(image, (str, Path)):
image = Image.open(image)
assert isinstance(image, Image.Image)
bg_img = Image.new('RGB', (self.imsize, self.imsize), color='white')
blend_img = self.center_align(bg_img, image, fit=True)
return blend_img
@staticmethod
def pil_image_to_array(blend_img: Image.Image) -> np.ndarray:
normalized_array = np.mean(np.array(blend_img, dtype=np.float32), axis=-1) / WHITE # L-only image normalized to [0, 1]
return normalized_array
def get_images_from_fontfile(self, font_file_path: Path, imgmode: str = 'RGB', position: tuple = (0, 0), font_size: int = 128, padding: int = 100) -> List[Image.Image]:
imagefont = ImageFont.truetype(str(font_file_path), size=font_size)
example_characters = EXAMPLE_CHARACTERS
font_images: List[Image.Image] = []
for character in example_characters:
x, y, _, _ = imagefont.getbbox(character)
img = Image.new(imgmode, (x + padding, y + padding), color='white')
draw = ImageDraw.Draw(img)
# bbox = draw.textbbox((0,0), character, font=imagefont)
# w = bbox[2] - bbox[0]
# h = bbox[3] - bbox[1]
w, h = draw.textsize(character, font=imagefont)
img = Image.new(imgmode, (w + padding, h + padding), color='white')
draw = ImageDraw.Draw(img)
draw.text(position, text=character, font=imagefont, fill='black')
img = img.convert(imgmode)
font_images.append(img)
return font_images
@staticmethod
def get_hex_from_char(char: str) -> str:
assert len(char) == 1
return f"{ord(char):04X}".upper() # 4-digit hex string
@torch.no_grad()
def inference(self, content_char: str, style_font: Union[str, Path]) -> Tuple[Image.Image, List[Image.Image], Image.Image]:
assert len(content_char) > 0
content_char = content_char[:1] # only get the first character if the length > 1
char_hex = self.get_hex_from_char(content_char)
if char_hex not in self.content_character_dict:
raise ValueError(f"The character {content_char} (hex: {char_hex}) is not supported in this model!")
content_image = self.set_image(self.content_character_dict[char_hex])
style_images: List[Image.Image] = self.get_images_from_fontfile(Path(style_font))
style_images: List[Image.Image] = [self.set_image(image) for image in style_images]
content_image_array = self.pil_image_to_array(content_image)[np.newaxis, np.newaxis, ...] # 1 x C(=1) x H x W
style_images_array: np.ndarray = np.array([self.pil_image_to_array(image) for image in style_images])[np.newaxis, ...] # 1 x C(=5, # shots) x H x W, k-shots goes to batch
content_input_tensor = torch.from_numpy(content_image_array).to(self.device)
style_input_tensor = torch.from_numpy(style_images_array).to(self.device)
generated_images: torch.Tensor = self.model((content_input_tensor, style_input_tensor))
generated_images = torch.clip(generated_images, 0, 1)
assert generated_images.size(0) == 1
generated_image_numpy = (generated_images[0].cpu().numpy() * 255).astype(np.uint8)[0, ...] # H x W
return content_image, style_images, Image.fromarray(generated_image_numpy, mode='L')
if __name__ == '__main__':
hp = OmegaConf.load("config/models/google-font.yaml")
checkpoint_path = "epoch=199-step=257400.ckpt"
content_image_dir = "../DATA/NotoSans"
servicer = InferenceServicer(hp, checkpoint_path, content_image_dir)
style_font = "example_fonts/MaShanZheng-Regular.ttf"
content_image, style_images, result = servicer.inference("7", style_font)
content_image.save("result_content.png")
for idx, style_image in enumerate(style_images):
style_image.save(f"result_style_{idx:02d}.png")
result.save("result_generated.png")