interfacegan_pp / app.py
younesbelkada
update app
09ef29c
raw history blame
No virus
2.92 kB
import os
import torch
import PIL.Image
import numpy as np
import gradio as gr
from yarg import get
from models.stylegan_generator import StyleGANGenerator
from models.stylegan2_generator import StyleGAN2Generator
VALID_CHOICES = [
"Bald",
"Young",
"Mustache",
"Eyeglasses",
"Hat",
"Smiling"
]
ENABLE_GPU = False
MODEL_NAMES = [
'stylegan_ffhq',
'stylegan2_ffhq'
]
NB_IMG = 4
OUTPUT_LIST = [gr.outputs.Image(type="pil", label="Generated Image") for _ in range(NB_IMG)] + [gr.outputs.Image(type="pil", label="Modified Image") for _ in range(NB_IMG)]
def tensor_to_pil(input_object):
"""Shows images in one figure."""
if isinstance(input_object, dict):
im_array = []
images = input_object['image']
else:
images = input_object
for _, image in enumerate(images):
im_array.append(PIL.Image.fromarray(image))
return im_array
def get_generator(model_name):
if model_name == 'stylegan_ffhq':
generator = StyleGANGenerator(model_name)
elif model_name == 'stylegan2_ffhq':
generator = StyleGAN2Generator(model_name)
else:
raise ValueError('Model name not recognized')
if ENABLE_GPU:
generator = generator.cuda()
return generator
@torch.no_grad()
def inference(seed, choice, model_name, coef, nb_images=NB_IMG):
np.random.seed(seed)
boundary = np.squeeze(np.load(open(os.path.join('boundaries', model_name, 'boundary_%s.npy' % choice), 'rb')))
generator = get_generator(model_name)
latent_codes = generator.easy_sample(nb_images)
if ENABLE_GPU:
latent_codes = latent_codes.cuda()
generator = generator.cuda()
generated_images = generator.easy_synthesize(latent_codes)
generated_images = tensor_to_pil(generated_images)
new_latent_codes = latent_codes.copy()
for i, _ in enumerate(generated_images):
new_latent_codes[i, :] += boundary*coef
modified_generated_images = generator.easy_synthesize(new_latent_codes)
modified_generated_images = tensor_to_pil(modified_generated_images)
return generated_images + modified_generated_images
iface = gr.Interface(
fn=inference,
inputs=[
gr.inputs.Slider(
minimum=0,
maximum=1000,
step=1,
default=264,
label="Random seed to use for the generation"
),
gr.inputs.Dropdown(
choices=VALID_CHOICES,
type="value",
label="Attribute to modify",
),
gr.inputs.Dropdown(
choices=MODEL_NAMES,
type="value",
label="Model to use",
),
gr.inputs.Slider(
minimum=-3,
maximum=3,
step=0.1,
default=0,
label="Modification coefficient",
),
],
outputs=OUTPUT_LIST,
layout="horizontal",
theme="peach"
)
iface.launch()