Spaces:
Runtime error
Runtime error
File size: 1,057 Bytes
45df722 2496d9f 45df722 7ea3fea 45df722 7ea3fea 45df722 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import gradio as gr
import numpy as np
import torch
from torchvision.utils import make_grid
from utils import (
ConditionalVariationalAutoEncoder,
classes,
species,
genders
)
N = 16
cvae = ConditionalVariationalAutoEncoder()
# read state dict and map onto cpu
cvae.load_state_dict(torch.load('cvae.pth', map_location=torch.device('cpu')))
def generate(class_, species_, gender_):
c_i = classes.index(class_)
s_i = species.index(species_)
g_i = genders.index(gender_)
c_ = torch.LongTensor([c_i] * N)
s_ = torch.LongTensor([s_i] * N)
g_ = torch.LongTensor([g_i] * N)
latent = torch.randn(N, 1024)
imgs = cvae.dec(latent, s_, c_, g_)
Z = make_grid(imgs, nrow=4)
if Z.min() < 1:
Z = (Z + 1) / 2
Z = np.transpose(Z.detach().cpu().numpy(), (1, 2, 0))
return Z
demo = gr.Interface(
fn=generate,
inputs=[
gr.Dropdown(classes),
gr.Dropdown(species),
gr.Dropdown(genders)
],
outputs="image",
live=True
)
demo.launch(share=True) |