File size: 2,381 Bytes
e8feed3
eddf472
 
e8feed3
 
 
eddf472
8f773c7
 
20ffa2c
8f773c7
0eacafb
8f773c7
cc9c2c0
e8feed3
eddf472
e8feed3
 
 
 
 
eddf472
e8feed3
 
 
eddf472
e8feed3
eddf472
e8feed3
eddf472
e8feed3
 
 
eddf472
e8feed3
eddf472
e8feed3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eddf472
 
e8feed3
eddf472
e8feed3
 
 
eddf472
e8feed3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import gradio as gr
import os
import numpy as np
import torch
import pickle
import types

from huggingface_hub import hf_hub_url, cached_download

TOKEN = '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'

with open(cached_download(hf_hub_url('i72sijia/ganbanales', 'ganbanales.pkl'), use_auth_token=TOKEN), 'rb') as f:
    G = pickle.load(f)['G_ema']# torch.nn.Module
    
device = torch.device("cpu")

if torch.cuda.is_available():
    device = torch.device("cuda")
    G = G.to(device)
else:
    _old_forward = G.forward

    def _new_forward(self, *args, **kwargs):
        kwargs["force_fp32"] = True
        return _old_forward(*args, **kwargs)

    G.forward = types.MethodType(_new_forward, G)

    _old_synthesis_forward = G.synthesis.forward

    def _new_synthesis_forward(self, *args, **kwargs):
        kwargs["force_fp32"] = True
        return _old_synthesis_forward(*args, **kwargs)

    G.synthesis.forward = types.MethodType(_new_synthesis_forward, G.synthesis)
    
####################################################################
# Image generation    

def generate(num_images, interpolate):
    if interpolate:
        z1 = torch.randn([1, G.z_dim])# latent codes
        z2 = torch.randn([1, G.z_dim])# latent codes
        zs = torch.cat([z1 + (z2 - z1) * i / (num_images-1) for i in range(num_images)], 0)
    else:
        zs = torch.randn([num_images, G.z_dim])# latent codes
    with torch.no_grad():
        zs = zs.to(device)
        img = G(zs, None, force_fp32=True, noise_mode='const') 
        img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
    return img.cpu().numpy()

####################################################################
# Graphical User Interface

demo = gr.Blocks()

def infer(num_images, interpolate):
    img = generate(round(num_images), interpolate)
    imgs = list(img)
    return imgs

with demo:
    gr.Markdown(
    """
    # EmojiGAN
    Generate Emojis with AI (StyleGAN2-ADA). Made by [mfrashad](https://github.com/mfrashad)
    """)
    images_num = gr.inputs.Slider(default=1, label="Num Images", minimum=1, maximum=16, step=1)
    interpolate = gr.inputs.Checkbox(default=False, label="Interpolate")
    submit = gr.Button("Generate")
    
    
    out = gr.Gallery()

    submit.click(fn=infer, 
               inputs=[images_num, interpolate], 
               outputs=out)

demo.launch()