EmojiGAN / app.py
mfrashad's picture
init code
5bde23c
raw
history blame contribute delete
No virus
2.18 kB
import gradio as gr
import os
import numpy as np
import torch
import pickle
import types
from huggingface_hub import hf_hub_url, cached_download
TOKEN = os.environ['TOKEN']
with open(cached_download(hf_hub_url('mfrashad/stylegan2_emoji_512', 'stylegan2_emoji_512.pkl'), use_auth_token=TOKEN), 'rb') as f:
G = pickle.load(f)['G_ema']# torch.nn.Module
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda")
G = G.to(device)
else:
_old_forward = G.forward
def _new_forward(self, *args, **kwargs):
kwargs["force_fp32"] = True
return _old_forward(*args, **kwargs)
G.forward = types.MethodType(_new_forward, G)
_old_synthesis_forward = G.synthesis.forward
def _new_synthesis_forward(self, *args, **kwargs):
kwargs["force_fp32"] = True
return _old_synthesis_forward(*args, **kwargs)
G.synthesis.forward = types.MethodType(_new_synthesis_forward, G.synthesis)
def generate(num_images, interpolate):
if interpolate:
z1 = torch.randn([1, G.z_dim])# latent codes
z2 = torch.randn([1, G.z_dim])# latent codes
zs = torch.cat([z1 + (z2 - z1) * i / (num_images-1) for i in range(num_images)], 0)
else:
zs = torch.randn([num_images, G.z_dim])# latent codes
with torch.no_grad():
zs = zs.to(device)
img = G(zs, None, force_fp32=True, noise_mode='const')
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
return img.cpu().numpy()
demo = gr.Blocks()
def infer(num_images, interpolate):
img = generate(round(num_images), interpolate)
imgs = list(img)
return imgs
with demo:
gr.Markdown(
"""
# EmojiGAN
Generate Emojis with AI (StyleGAN2-ADA). Made by [mfrashad](https://github.com/mfrashad)
""")
images_num = gr.inputs.Slider(default=1, label="Num Images", minimum=1, maximum=16, step=1)
interpolate = gr.inputs.Checkbox(default=False, label="Interpolate")
submit = gr.Button("Generate")
out = gr.Gallery()
submit.click(fn=infer,
inputs=[images_num, interpolate],
outputs=out)
demo.launch()