import gradio as gr import numpy as np import torch from torchvision.utils import make_grid from utils import ( ConditionalVariationalAutoEncoder, classes, species, genders ) N = 16 cvae = ConditionalVariationalAutoEncoder() # read state dict and map onto cpu cvae.load_state_dict(torch.load('cvae.pth', map_location=torch.device('cpu'))) def generate(class_, species_, gender_): c_i = classes.index(class_) s_i = species.index(species_) g_i = genders.index(gender_) c_ = torch.LongTensor([c_i] * N) s_ = torch.LongTensor([s_i] * N) g_ = torch.LongTensor([g_i] * N) latent = torch.randn(N, 1024) imgs = cvae.dec(latent, s_, c_, g_) Z = make_grid(imgs, nrow=4) if Z.min() < 1: Z = (Z + 1) / 2 Z = np.transpose(Z.detach().cpu().numpy(), (1, 2, 0)) return Z demo = gr.Interface( fn=generate, inputs=[ gr.Dropdown(classes), gr.Dropdown(species), gr.Dropdown(genders) ], outputs="image", live=True ) demo.launch()