Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
import torch | |
from torchvision.utils import make_grid | |
from utils import ( | |
ConditionalVariationalAutoEncoder, | |
classes, | |
species, | |
genders | |
) | |
N = 16 | |
cvae = ConditionalVariationalAutoEncoder() | |
# read state dict and map onto cpu | |
cvae.load_state_dict(torch.load('cvae.pth', map_location=torch.device('cpu'))) | |
def generate(class_, species_, gender_): | |
c_i = classes.index(class_) | |
s_i = species.index(species_) | |
g_i = genders.index(gender_) | |
c_ = torch.LongTensor([c_i] * N) | |
s_ = torch.LongTensor([s_i] * N) | |
g_ = torch.LongTensor([g_i] * N) | |
latent = torch.randn(N, 1024) | |
imgs = cvae.dec(latent, s_, c_, g_) | |
Z = make_grid(imgs, nrow=4) | |
if Z.min() < 1: | |
Z = (Z + 1) / 2 | |
Z = np.transpose(Z.detach().cpu().numpy(), (1, 2, 0)) | |
return Z | |
demo = gr.Interface( | |
fn=generate, | |
inputs=[ | |
gr.Dropdown(classes), | |
gr.Dropdown(species), | |
gr.Dropdown(genders) | |
], | |
outputs="image", | |
live=True | |
) | |
demo.launch() |