misterbrainley's picture
map saved model onto cpu
2496d9f
import gradio as gr
import numpy as np
import torch
from torchvision.utils import make_grid
from utils import (
ConditionalVariationalAutoEncoder,
classes,
species,
genders
)
N = 16
cvae = ConditionalVariationalAutoEncoder()
# read state dict and map onto cpu
cvae.load_state_dict(torch.load('cvae.pth', map_location=torch.device('cpu')))
def generate(class_, species_, gender_):
c_i = classes.index(class_)
s_i = species.index(species_)
g_i = genders.index(gender_)
c_ = torch.LongTensor([c_i] * N)
s_ = torch.LongTensor([s_i] * N)
g_ = torch.LongTensor([g_i] * N)
latent = torch.randn(N, 1024)
imgs = cvae.dec(latent, s_, c_, g_)
Z = make_grid(imgs, nrow=4)
if Z.min() < 1:
Z = (Z + 1) / 2
Z = np.transpose(Z.detach().cpu().numpy(), (1, 2, 0))
return Z
demo = gr.Interface(
fn=generate,
inputs=[
gr.Dropdown(classes),
gr.Dropdown(species),
gr.Dropdown(genders)
],
outputs="image",
live=True
)
demo.launch(share=True)