import numpy as np import streamlit as st import torch import disvae import transforms as trans @st.cache_resource def load_decode_function(): P_MODEL = "models/btcvae_celeba" sorter = trans.LatentSorter(disvae.get_kl_dict(P_MODEL)) vae = disvae.load_model(P_MODEL) _dec = trans.sequential_function( sorter.inv, vae.decoder ) def decode(latent): with torch.no_grad(): return trans.np_sample(_dec)(latent) return decode # GUI ----------------------------------------------------------- decode = load_decode_function() latent_vector = np.array([st.slider(f"L{l}",min_value=-3.0,max_value=3.0,value=0.0) for l in range(3)]) latent_vector = np.concatenate([latent_vector,np.zeros(7)],axis=0) value = decode(latent_vector) value = np.swapaxes(np.swapaxes(value, 0, 2), 0, 1) st.image(value, use_column_width="always")