|
import gradio as gr |
|
import os |
|
import numpy as np |
|
import torch |
|
import pickle |
|
import types |
|
|
|
from huggingface_hub import hf_hub_url, cached_download |
|
|
|
TOKEN = os.environ['TOKEN'] |
|
|
|
with open(pkl_file, 'rb') as pickle_file: |
|
_G, _D, G = pickle.load(pickle_file) |
|
|
|
device = torch.device("cpu") |
|
|
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda") |
|
G = G.to(device) |
|
else: |
|
_old_forward = G.forward |
|
|
|
def _new_forward(self, *args, **kwargs): |
|
kwargs["force_fp32"] = True |
|
return _old_forward(*args, **kwargs) |
|
|
|
G.forward = types.MethodType(_new_forward, G) |
|
|
|
_old_synthesis_forward = G.synthesis.forward |
|
|
|
def _new_synthesis_forward(self, *args, **kwargs): |
|
kwargs["force_fp32"] = True |
|
return _old_synthesis_forward(*args, **kwargs) |
|
|
|
G.synthesis.forward = types.MethodType(_new_synthesis_forward, G.synthesis) |
|
|
|
|
|
|
|
|
|
def generate(num_images, interpolate): |
|
if interpolate: |
|
z1 = torch.randn([1, G.z_dim]) |
|
z2 = torch.randn([1, G.z_dim]) |
|
zs = torch.cat([z1 + (z2 - z1) * i / (num_images-1) for i in range(num_images)], 0) |
|
else: |
|
zs = torch.randn([num_images, G.z_dim]) |
|
with torch.no_grad(): |
|
zs = zs.to(device) |
|
img = G(zs, None, force_fp32=True, noise_mode='const') |
|
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) |
|
return img.cpu().numpy() |
|
|
|
|
|
|
|
|
|
demo = gr.Blocks() |
|
|
|
def infer(num_images, interpolate): |
|
img = generate(round(num_images), interpolate) |
|
imgs = list(img) |
|
return imgs |
|
|
|
with demo: |
|
gr.Markdown( |
|
""" |
|
# EmojiGAN |
|
Generate Emojis with AI (StyleGAN2-ADA). Made by [mfrashad](https://github.com/mfrashad) |
|
""") |
|
images_num = gr.inputs.Slider(default=1, label="Num Images", minimum=1, maximum=16, step=1) |
|
interpolate = gr.inputs.Checkbox(default=False, label="Interpolate") |
|
submit = gr.Button("Generate") |
|
|
|
|
|
out = gr.Gallery() |
|
|
|
submit.click(fn=infer, |
|
inputs=[images_num, interpolate], |
|
outputs=out) |
|
|
|
demo.launch() |