misterbrainley commited on
Commit
45df722
1 Parent(s): a3e10f4

initial commit

Browse files
Files changed (1) hide show
  1. app.py +48 -0
app.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+ from torchvision.utils import make_grid
5
+ from utils import (
6
+ ConditionalVariationalAutoEncoder,
7
+ classes,
8
+ species,
9
+ genders
10
+ )
11
+
12
+ N = 16
13
+
14
+ cvae = ConditionalVariationalAutoEncoder()
15
+ cvae.load_state_dict(torch.load('cvae.pth'))
16
+
17
+ def greet(class_, species_, gender_):
18
+ c_i = classes.index(class_)
19
+ s_i = species.index(species_)
20
+ g_i = genders.index(gender_)
21
+
22
+ c_ = torch.LongTensor([c_i] * N)
23
+ s_ = torch.LongTensor([s_i] * N)
24
+ g_ = torch.LongTensor([g_i] * N)
25
+
26
+ latent = torch.randn(N, 1024)
27
+ imgs = cvae.dec(latent, s_, c_, g_)
28
+
29
+ Z = make_grid(imgs, nrow=4)
30
+
31
+ if Z.min() < 1:
32
+ Z = (Z + 1) / 2
33
+
34
+ Z = np.transpose(Z.detach().cpu().numpy(), (1, 2, 0))
35
+ return Z
36
+
37
+ demo = gr.Interface(
38
+ fn=greet,
39
+ inputs=[
40
+ gr.Dropdown(classes),
41
+ gr.Dropdown(species),
42
+ gr.Dropdown(genders)
43
+ ],
44
+ outputs="image",
45
+ live=True
46
+ )
47
+
48
+ demo.launch(share=True)