Spaces:
Runtime error
Runtime error
import torch | |
from models.cdvae import ConditionalDiscreteVAE | |
vae = ConditionalDiscreteVAE( | |
input_shape = (7,7), | |
num_layers = 3, # number of downsamples - ex. 256 / (2 ** 3) = (32 x 32 feature map) | |
num_tokens = 8192, # number of visual tokens. in the paper, they used 8192, but could be smaller for downsized projects | |
codebook_dim = 512, # codebook dimension | |
cond_dim = 100, | |
hidden_dim = 64, # hidden dimension | |
num_resnet_blocks = 1, # number of resnet blocks | |
temperature = 0.9, # gumbel softmax temperature, the lower this is, the harder the discretization | |
straight_through = False, # straight-through for gumbel softmax. unclear if it is better one way or the other | |
) | |
images = torch.randn(4, 3, *vae.input_shape) | |
cond = torch.randn(4, 100, *vae.codebook_layer_shape) | |
logits = vae(images, cond=cond, return_logits = True) | |
logits.shape | |
import numpy as np | |
torch.randint(0,10,(1,)) | |
image_seq = torch.randint(0,8192, (4,np.prod(vae.codebook_layer_shape))) | |
image = vae.decode(image_seq, cond=cond) | |
image.shape | |
# loss = vae(images, return_loss = True) | |
# loss.backward() | |
# loss | |
# train with a lot of data to learn a good codebook | |