Spaces:
Runtime error
Runtime error
from pathlib import Path | |
from typing import Dict, List, Union, Tuple | |
from omegaconf import OmegaConf | |
import numpy as np | |
import torch | |
from torch import nn | |
from PIL import Image, ImageDraw, ImageFont | |
import models | |
GENERATOR_PREFIX = "networks.g." | |
WHITE = 255 | |
EXAMPLE_CHARACTERS = ['A', 'B', 'C', 'D', 'E'] | |
class InferenceServicer: | |
def __init__(self, hp, checkpoint_path, content_image_dir, imsize=64, gpu_id='0') -> None: | |
self.hp = hp | |
self.imsize = imsize | |
if gpu_id is None: | |
self.device = torch.device(f'cuda:0') if torch.cuda.is_available() else 'cpu' | |
else: | |
self.device = torch.device(f'cuda:{gpu_id}') | |
model_config = self.hp.models.G | |
self.model: nn.Module = models.Generator(model_config) | |
# Load Generator model weight | |
model_state_dict_pl = torch.load(checkpoint_path, map_location='cpu') | |
generator_state_dict = self.convert_generator_state_dict(model_state_dict_pl) | |
self.model.load_state_dict(generator_state_dict) | |
self.model.to(device=self.device) | |
self.model.eval() | |
# Setting Content font files | |
self.content_character_dict = self.load_content_character_dict(Path(content_image_dir)) | |
def convert_generator_state_dict(model_state_dict_pl): | |
generator_prefix = GENERATOR_PREFIX | |
generator_state_dict = {} | |
for module_name, module_state in model_state_dict_pl['state_dict'].items(): | |
if module_name.startswith(generator_prefix): | |
generator_state_dict[module_name[len(generator_prefix):]] = module_state | |
return generator_state_dict | |
def load_content_character_dict(content_image_dir: Path) -> Dict[str, Path]: | |
content_character_dict = {} | |
for filepath in content_image_dir.glob("**/*.png"): | |
content_character_dict[filepath.stem] = filepath | |
return content_character_dict | |
def center_align(bg_img: Image.Image, item_img: Image.Image, fit=False) -> Image.Image: | |
bg_img = bg_img.copy() | |
item_img = item_img.copy() | |
item_w, item_h = item_img.size | |
W, H = bg_img.size | |
if fit: | |
item_ratio = item_w / item_h | |
bg_ratio = W / H | |
if bg_ratio > item_ratio: | |
# height fitting | |
resize_ratio = H / item_h | |
else: | |
# width fitting | |
resize_ratio = W / item_w | |
item_img = item_img.resize((int(item_w * resize_ratio), int(item_h * resize_ratio))) | |
item_w, item_h = item_img.size | |
bg_img.paste(item_img, ((W - item_w) // 2, (H - item_h) // 2)) | |
return bg_img | |
def set_image(self, image: Union[Path, Image.Image]) -> Image.Image: | |
if isinstance(image, (str, Path)): | |
image = Image.open(image) | |
assert isinstance(image, Image.Image) | |
bg_img = Image.new('RGB', (self.imsize, self.imsize), color='white') | |
blend_img = self.center_align(bg_img, image, fit=True) | |
return blend_img | |
def pil_image_to_array(blend_img: Image.Image) -> np.ndarray: | |
normalized_array = np.mean(np.array(blend_img, dtype=np.float32), axis=-1) / WHITE # L-only image normalized to [0, 1] | |
return normalized_array | |
def get_images_from_fontfile(self, font_file_path: Path, imgmode: str = 'RGB', position: tuple = (0, 0), font_size: int = 128, padding: int = 100) -> List[Image.Image]: | |
imagefont = ImageFont.truetype(str(font_file_path), size=font_size) | |
example_characters = EXAMPLE_CHARACTERS | |
font_images: List[Image.Image] = [] | |
for character in example_characters: | |
x, y, _, _ = imagefont.getbbox(character) | |
img = Image.new(imgmode, (x + padding, y + padding), color='white') | |
draw = ImageDraw.Draw(img) | |
# bbox = draw.textbbox((0,0), character, font=imagefont) | |
# w = bbox[2] - bbox[0] | |
# h = bbox[3] - bbox[1] | |
w, h = draw.textsize(character, font=imagefont) | |
img = Image.new(imgmode, (w + padding, h + padding), color='white') | |
draw = ImageDraw.Draw(img) | |
draw.text(position, text=character, font=imagefont, fill='black') | |
img = img.convert(imgmode) | |
font_images.append(img) | |
return font_images | |
def get_hex_from_char(char: str) -> str: | |
assert len(char) == 1 | |
return f"{ord(char):04X}".upper() # 4-digit hex string | |
def inference(self, content_char: str, style_font: Union[str, Path]) -> Tuple[Image.Image, List[Image.Image], Image.Image]: | |
assert len(content_char) > 0 | |
content_char = content_char[:1] # only get the first character if the length > 1 | |
char_hex = self.get_hex_from_char(content_char) | |
if char_hex not in self.content_character_dict: | |
raise ValueError(f"The character {content_char} (hex: {char_hex}) is not supported in this model!") | |
content_image = self.set_image(self.content_character_dict[char_hex]) | |
style_images: List[Image.Image] = self.get_images_from_fontfile(Path(style_font)) | |
style_images: List[Image.Image] = [self.set_image(image) for image in style_images] | |
content_image_array = self.pil_image_to_array(content_image)[np.newaxis, np.newaxis, ...] # 1 x C(=1) x H x W | |
style_images_array: np.ndarray = np.array([self.pil_image_to_array(image) for image in style_images])[np.newaxis, ...] # 1 x C(=5, # shots) x H x W, k-shots goes to batch | |
content_input_tensor = torch.from_numpy(content_image_array).to(self.device) | |
style_input_tensor = torch.from_numpy(style_images_array).to(self.device) | |
generated_images: torch.Tensor = self.model((content_input_tensor, style_input_tensor)) | |
generated_images = torch.clip(generated_images, 0, 1) | |
assert generated_images.size(0) == 1 | |
generated_image_numpy = (generated_images[0].cpu().numpy() * 255).astype(np.uint8)[0, ...] # H x W | |
return content_image, style_images, Image.fromarray(generated_image_numpy, mode='L') | |
if __name__ == '__main__': | |
hp = OmegaConf.load("config/models/google-font.yaml") | |
checkpoint_path = "epoch=199-step=257400.ckpt" | |
content_image_dir = "../DATA/NotoSans" | |
servicer = InferenceServicer(hp, checkpoint_path, content_image_dir) | |
style_font = "example_fonts/MaShanZheng-Regular.ttf" | |
content_image, style_images, result = servicer.inference("7", style_font) | |
content_image.save("result_content.png") | |
for idx, style_image in enumerate(style_images): | |
style_image.save(f"result_style_{idx:02d}.png") | |
result.save("result_generated.png") |