Spaces:
Runtime error
Runtime error
import argparse | |
import functools | |
import pathlib | |
import os | |
from typing import Dict, List, Optional, Tuple | |
import gradio as gr | |
import PIL.Image | |
from encoder import Encoder | |
from face_detector import FaceAligner | |
from generator import Generator | |
from huggingface_hub import hf_hub_download | |
def parse_args() -> argparse.Namespace: | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--models_repo_id', type=str, | |
default='senior-sigan/nijigenka') | |
return parser.parse_args() | |
def load_examples(): | |
image_dir = pathlib.Path('examples') | |
images = sorted(image_dir.glob('*.jpg')) | |
return [[path.as_posix(), 'art'] for path in images] | |
def join_image_h(im1: PIL.Image.Image, im2: PIL.Image.Image) -> PIL.Image.Image: | |
im1 = im1.resize(im2.size) | |
dst = PIL.Image.new('RGB', (im1.width + im2.width, im1.height)) | |
dst.paste(im1, (0, 0)) | |
dst.paste(im2, (im1.width, 0)) | |
return dst | |
def predict( | |
image: PIL.Image.Image, | |
style: str, | |
*, | |
face_aligner: FaceAligner, | |
encoder: Encoder, | |
generator: Dict[str, Generator], | |
) -> Tuple[List[PIL.Image.Image], Optional[str]]: | |
images = face_aligner.align(image) | |
if len(images) == 0: | |
error_msg = "Cannot find any face in photo" | |
# gradio doesn't support empty list for images carusel, so we create dummy img | |
return [PIL.Image.new('RGB', (1, 1))], error_msg | |
results = [] | |
for img in images: | |
x = encoder.predict(img) | |
gen_img = generator[style].predict(x) | |
result = join_image_h(img, gen_img) | |
results.append(result) | |
return results, None | |
def get_model_path(repo_id: str, filename: str): | |
maybe_path = os.path.join(repo_id, filename) | |
if os.path.exists(maybe_path): | |
print('Using local models') | |
return os.path.abspath(maybe_path) | |
else: | |
return hf_hub_download( | |
repo_id, | |
filename, | |
) | |
def load_models(repo_id: str): | |
encoder_path = get_model_path( | |
repo_id, | |
'encoder.onnx', | |
) | |
generator_art_path = get_model_path( | |
repo_id, | |
'face2art.onnx', | |
) | |
generator_anime_path = get_model_path( | |
repo_id, | |
'face2kuvshinov2.onnx', | |
) | |
shape_predictor_path = get_model_path( | |
repo_id, | |
'shape_predictor_68_face_landmarks.bin', | |
) | |
face_aligner = FaceAligner( | |
image_size=512, | |
shape_predictor_path=shape_predictor_path, | |
) | |
encoder = Encoder(model_path=encoder_path) | |
generator_art = Generator(model_path=generator_art_path) | |
generator_anime = Generator(model_path=generator_anime_path) | |
return face_aligner, encoder, {'art': generator_art, 'anime': generator_anime} | |
def main(): | |
args = parse_args() | |
gr.close_all() | |
face_aligner, encoder, generator = load_models(args.models_repo_id) | |
generator_types = list(generator.keys()) | |
func = functools.partial( | |
predict, | |
face_aligner=face_aligner, | |
encoder=encoder, | |
generator=generator, | |
) | |
func = functools.update_wrapper(func, predict) | |
iface = gr.Interface( | |
fn=func, | |
inputs=[ | |
gr.Image( | |
type='pil', | |
label='Real photo with a face', | |
), | |
gr.Radio( | |
choices=generator_types, | |
type='value', | |
value=generator_types[0], | |
label='Style', | |
), | |
], | |
outputs=[ | |
gr.Gallery(label='Result'), | |
gr.Textbox(label='Error'), | |
], | |
examples=load_examples(), | |
title='Nijigenka: Portrait to Art', | |
allow_flagging='never', | |
) | |
iface.queue().launch() | |
if __name__ == '__main__': | |
main() | |