Spaces:
Runtime error
Runtime error
File size: 4,759 Bytes
5a67fb4 d872920 5a67fb4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import math
import numpy as np
import streamlit as st
import torch
import torch.nn.functional as F
import src.app.params as params
from src.models import ConditionalGenerator as InfoSCC_GAN
from src.models.big.BigGAN2 import Generator as BigGAN2Generator
from src.models import ConditionalDecoder as cVAE
from src.data import get_labels_train
from src.utils import sample_labels
device = params.device
size = params.size
n_layers = int(math.log2(size) - 2)
bs = 12
lin_space = torch.linspace(0, 1, bs).unsqueeze(1)
captions = [f'label_a * {(1 - x):.02f} + label_b * {x:.02f}' for x in lin_space.squeeze().numpy()]
@st.cache(allow_output_mutation=True)
def load_model(model_type: str):
print(f'Loading model: {model_type}')
if model_type == 'InfoSCC-GAN':
g = InfoSCC_GAN(size=params.size,
y_size=params.shape_label,
z_size=params.noise_dim)
ckpt = torch.load(params.path_infoscc_gan, map_location=torch.device('cpu'))
g.load_state_dict(ckpt['g_ema'])
elif model_type == 'BigGAN':
g = BigGAN2Generator()
ckpt = torch.load(params.path_biggan, map_location=torch.device('cpu'))
g.load_state_dict(ckpt)
elif model_type == 'cVAE':
g = cVAE()
ckpt = torch.load(params.path_cvae, map_location=torch.device('cpu'))
g.load_state_dict(ckpt)
else:
raise ValueError('Unsupported model')
g = g.eval().to(device=params.device)
return g
@st.cache
def get_labels() -> torch.Tensor:
path_labels = params.path_labels
labels_train = get_labels_train(path_labels)
return labels_train
def get_eps(n: int) -> torch.Tensor:
eps = torch.randn((n, params.dim_z), device=device)
return eps
def app():
global lin_space, captions
st.title('Interpolate Labels')
st.markdown('This app allows the generation of the images with the labels that are interpolated between two labels.')
st.markdown('In each row there are images generated with the same interpolated label by one of the models')
biggan = load_model('BigGAN')
infoscc_gan = load_model('InfoSCC-GAN')
cvae = load_model('cVAE')
labels_train = get_labels()
# ==================== Labels ==============================================
label_a = sample_labels(labels_train, n=1).repeat(bs, 1)
label_b = sample_labels(labels_train, n=1).repeat(bs, 1)
label_interpolated = (1 - lin_space) * label_a + lin_space * label_b
sample_label = st.button('Sample label')
if sample_label:
label_a = sample_labels(labels_train, n=1).repeat(bs, 1)
label_b = sample_labels(labels_train, n=1).repeat(bs, 1)
label_interpolated = (1 - lin_space) * label_a + lin_space * label_b
# ==================== Labels ==============================================
# ==================== Noise ==============================================
eps = get_eps(1).repeat(bs, 1)
eps_infoscc = infoscc_gan.sample_eps(1).repeat(bs, 1)
zs = np.array([[0.0] * params.n_basis] * n_layers, dtype=np.float32)
zs_torch = torch.from_numpy(zs).unsqueeze(0).repeat(bs, 1, 1).to(device)
st.subheader('Noise')
st.markdown(r'Click on __Change eps__ button to change input $\varepsilon$ latent space')
change_eps = st.button('Change eps')
if change_eps:
eps = get_eps(1).repeat(bs, 1)
eps_infoscc = infoscc_gan.sample_eps(1).repeat(bs, 1)
# ==================== Noise ==============================================
with torch.no_grad():
imgs_biggan = biggan(eps, label_interpolated).squeeze(0).cpu()
imgs_infoscc = infoscc_gan(label_interpolated, eps_infoscc, zs_torch).squeeze(0).cpu()
imgs_cvae = cvae(eps, label_interpolated).squeeze(0).cpu()
if params.upsample:
imgs_biggan = F.interpolate(imgs_biggan, (size * 4, size * 4), mode='bicubic')
imgs_infoscc = F.interpolate(imgs_infoscc, (size * 4, size * 4), mode='bicubic')
imgs_cvae = F.interpolate(imgs_cvae, (size * 4, size * 4), mode='bicubic')
imgs_biggan = torch.clip(imgs_biggan, 0, 1)
imgs_biggan = [(imgs_biggan[i].permute(1, 2, 0).numpy() * 255).astype(np.uint8) for i in range(bs)]
imgs_infoscc = [(imgs_infoscc[i].permute(1, 2, 0).numpy() * 127.5 + 127.5).astype(np.uint8) for i in range(bs)]
imgs_cvae = [(imgs_cvae[i].permute(1, 2, 0).numpy() * 127.5 + 127.5).astype(np.uint8) for i in range(bs)]
c1, c2, c3 = st.columns(3)
c1.header('BigGAN')
c1.image(imgs_biggan, use_column_width=True, caption=captions)
c2.header('InfoSCC-GAN')
c2.image(imgs_infoscc, use_column_width=True, caption=captions)
c3.header('cVAE')
c3.image(imgs_cvae, use_column_width=True, caption=captions)
|