File size: 1,057 Bytes
45df722
 
 
 
 
 
 
 
 
 
 
 
 
 
2496d9f
 
 
45df722
7ea3fea
45df722
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ea3fea
45df722
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import gradio as gr
import numpy as np
import torch
from torchvision.utils import make_grid
from utils import (
    ConditionalVariationalAutoEncoder,
    classes,
    species,
    genders
)

N = 16

cvae = ConditionalVariationalAutoEncoder()

# read state dict and map onto cpu
cvae.load_state_dict(torch.load('cvae.pth', map_location=torch.device('cpu')))

def generate(class_, species_, gender_):
    c_i = classes.index(class_)
    s_i = species.index(species_)
    g_i = genders.index(gender_)

    c_ = torch.LongTensor([c_i] * N)
    s_ = torch.LongTensor([s_i] * N)
    g_ = torch.LongTensor([g_i] * N)

    latent = torch.randn(N, 1024)
    imgs = cvae.dec(latent, s_, c_, g_)

    Z = make_grid(imgs, nrow=4)
    
    if Z.min() < 1:
        Z = (Z + 1) / 2
    
    Z = np.transpose(Z.detach().cpu().numpy(), (1, 2, 0))    
    return Z

demo = gr.Interface(
    fn=generate,
    inputs=[
        gr.Dropdown(classes),
        gr.Dropdown(species),
        gr.Dropdown(genders)
    ],
    outputs="image",
    live=True
)

demo.launch(share=True)