import os import torch import PIL.Image import numpy as np import gradio as gr from yarg import get from models.stylegan_generator import StyleGANGenerator from models.stylegan2_generator import StyleGAN2Generator from utils.constants import VALID_CHOICES, ENABLE_GPU, MODEL_NAME, OUTPUT_LIST, description, title, css, article from utils.image_manip import tensor_to_pil, concat_images def get_generator(model_name): if model_name == 'stylegan_ffhq': generator = StyleGANGenerator(model_name) elif model_name == 'stylegan2_ffhq': generator = StyleGAN2Generator(model_name) else: raise ValueError('Model name not recognized') if ENABLE_GPU: generator = generator.cuda() return generator generator = get_generator(MODEL_NAME) boundaries = { boundary:np.squeeze(np.load(open(os.path.join('boundaries', MODEL_NAME, 'boundary_%s.npy' % boundary), 'rb'))) for boundary in VALID_CHOICES } @torch.no_grad() def inference(seed, coef, nb_images, list_choices): global generator, boundaries np.random.seed(seed) latent_codes = generator.easy_sample(nb_images) if ENABLE_GPU: latent_codes = latent_codes.cuda() generator = generator.cuda() generated_images = generator.easy_synthesize(latent_codes) generated_images = tensor_to_pil(generated_images) new_latent_codes = latent_codes.copy() for i, _ in enumerate(generated_images): for choice in list_choices: new_latent_codes[i, :] += boundaries[choice]*coef modified_generated_images = generator.easy_synthesize(new_latent_codes) modified_generated_images = tensor_to_pil(modified_generated_images) concatenated_output = concat_images(generated_images, modified_generated_images) return concatenated_output # https://huggingface.co/spaces/osanseviero/6DRepNet/blob/main/app.py iface = gr.Interface( fn=inference, inputs=[ gr.inputs.Slider( minimum=0, maximum=1000, step=1, default=644, label="Random seed to use for the generation" ), gr.inputs.Slider( minimum=-3, maximum=3, step=0.1, default=1, label="Modification scale", ), gr.inputs.Slider( minimum=1, maximum=8, step=1, default=2, label="Number of images to generate", ), gr.inputs.CheckboxGroup( VALID_CHOICES, default=[], type="value", label="Select attributes to modify", optional=False ) ], outputs=OUTPUT_LIST, layout="horizontal", theme="peach", description=description, title=title, css=css, article=article ) iface.launch()