File size: 2,344 Bytes
e8feed3
eddf472
 
e8feed3
 
 
eddf472
8f773c7
 
20ffa2c
8f773c7
0eacafb
8f773c7
cc9c2c0
e8feed3
eddf472
e8feed3
 
 
 
 
eddf472
e8feed3
 
 
eddf472
e8feed3
eddf472
e8feed3
eddf472
e8feed3
 
 
eddf472
e8feed3
eddf472
e8feed3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dfdf08e
 
e8feed3
2d77599
 
 
 
dfdf08e
2d77599
473089a
 
2d77599
dfdf08e
473089a
61bd713
 
dfdf08e
473089a
2d77599
 
82ee964
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import gradio as gr
import os
import numpy as np
import torch
import pickle
import types

from huggingface_hub import hf_hub_url, cached_download

TOKEN = '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'

with open(cached_download(hf_hub_url('i72sijia/ganbanales', 'ganbanales.pkl'), use_auth_token=TOKEN), 'rb') as f:
    G = pickle.load(f)['G_ema']# torch.nn.Module
    
device = torch.device("cpu")

if torch.cuda.is_available():
    device = torch.device("cuda")
    G = G.to(device)
else:
    _old_forward = G.forward

    def _new_forward(self, *args, **kwargs):
        kwargs["force_fp32"] = True
        return _old_forward(*args, **kwargs)

    G.forward = types.MethodType(_new_forward, G)

    _old_synthesis_forward = G.synthesis.forward

    def _new_synthesis_forward(self, *args, **kwargs):
        kwargs["force_fp32"] = True
        return _old_synthesis_forward(*args, **kwargs)

    G.synthesis.forward = types.MethodType(_new_synthesis_forward, G.synthesis)
    
####################################################################
# Image generation    

def generate(num_images, interpolate):
    if interpolate:
        z1 = torch.randn([1, G.z_dim])# latent codes
        z2 = torch.randn([1, G.z_dim])# latent codes
        zs = torch.cat([z1 + (z2 - z1) * i / (num_images-1) for i in range(num_images)], 0)
    else:
        zs = torch.randn([num_images, G.z_dim])# latent codes
    with torch.no_grad():
        zs = zs.to(device)
        img = G(zs, None, force_fp32=True, noise_mode='const') 
        img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
    return img.cpu().numpy()

####################################################################
# Graphical User Interface
def infer(num_images, interpolate):
    img = generate(round(num_images), interpolate)
    imgs = list(img)
    return imgs
    
demo = gr.Blocks()

with demo:
    gr.Markdown(
    """
    # EmojiGAN
    Generate Emojis with AI
    """)
    images_num = gr.inputs.Slider(default=1, label="Num Images", minimum=1, maximum=16, step=1)
    interpolate = gr.inputs.Checkbox(default=False, label="Interpolate")
    
    submit = gr.Button("Generate")
        
    out = gr.Gallery()

    submit.click(fn=initiate, 
               inputs=[images_num, interpolate], 
               outputs=out)

demo.launch(enable_queue =True)