nijigenka / app.py
senior-sigan's picture
remove progress bar because of queue error
284b74f
import argparse
import functools
import pathlib
import os
from typing import Dict, List, Optional, Tuple
import gradio as gr
import PIL.Image
from encoder import Encoder
from face_detector import FaceAligner
from generator import Generator
from huggingface_hub import hf_hub_download
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument('--models_repo_id', type=str,
default='senior-sigan/nijigenka')
return parser.parse_args()
def load_examples():
image_dir = pathlib.Path('examples')
images = sorted(image_dir.glob('*.jpg'))
return [[path.as_posix(), 'art'] for path in images]
def join_image_h(im1: PIL.Image.Image, im2: PIL.Image.Image) -> PIL.Image.Image:
im1 = im1.resize(im2.size)
dst = PIL.Image.new('RGB', (im1.width + im2.width, im1.height))
dst.paste(im1, (0, 0))
dst.paste(im2, (im1.width, 0))
return dst
def predict(
image: PIL.Image.Image,
style: str,
*,
face_aligner: FaceAligner,
encoder: Encoder,
generator: Dict[str, Generator],
) -> Tuple[List[PIL.Image.Image], Optional[str]]:
images = face_aligner.align(image)
if len(images) == 0:
error_msg = "Cannot find any face in photo"
# gradio doesn't support empty list for images carusel, so we create dummy img
return [PIL.Image.new('RGB', (1, 1))], error_msg
results = []
for img in images:
x = encoder.predict(img)
gen_img = generator[style].predict(x)
result = join_image_h(img, gen_img)
results.append(result)
return results, None
def get_model_path(repo_id: str, filename: str):
maybe_path = os.path.join(repo_id, filename)
if os.path.exists(maybe_path):
print('Using local models')
return os.path.abspath(maybe_path)
else:
return hf_hub_download(
repo_id,
filename,
)
def load_models(repo_id: str):
encoder_path = get_model_path(
repo_id,
'encoder.onnx',
)
generator_art_path = get_model_path(
repo_id,
'face2art.onnx',
)
generator_anime_path = get_model_path(
repo_id,
'face2kuvshinov2.onnx',
)
shape_predictor_path = get_model_path(
repo_id,
'shape_predictor_68_face_landmarks.bin',
)
face_aligner = FaceAligner(
image_size=512,
shape_predictor_path=shape_predictor_path,
)
encoder = Encoder(model_path=encoder_path)
generator_art = Generator(model_path=generator_art_path)
generator_anime = Generator(model_path=generator_anime_path)
return face_aligner, encoder, {'art': generator_art, 'anime': generator_anime}
def main():
args = parse_args()
gr.close_all()
face_aligner, encoder, generator = load_models(args.models_repo_id)
generator_types = list(generator.keys())
func = functools.partial(
predict,
face_aligner=face_aligner,
encoder=encoder,
generator=generator,
)
func = functools.update_wrapper(func, predict)
iface = gr.Interface(
fn=func,
inputs=[
gr.Image(
type='pil',
label='Real photo with a face',
),
gr.Radio(
choices=generator_types,
type='value',
value=generator_types[0],
label='Style',
),
],
outputs=[
gr.Gallery(label='Result'),
gr.Textbox(label='Error'),
],
examples=load_examples(),
title='Nijigenka: Portrait to Art',
allow_flagging='never',
)
iface.queue().launch()
if __name__ == '__main__':
main()