myTest01 / analysis /sandbox_dalle.py
meng2003's picture
Upload 357 files
2d5fdd1
import torch
from models.cdvae import ConditionalDiscreteVAE
vae = ConditionalDiscreteVAE(
input_shape = (7,7),
num_layers = 3, # number of downsamples - ex. 256 / (2 ** 3) = (32 x 32 feature map)
num_tokens = 8192, # number of visual tokens. in the paper, they used 8192, but could be smaller for downsized projects
codebook_dim = 512, # codebook dimension
cond_dim = 100,
hidden_dim = 64, # hidden dimension
num_resnet_blocks = 1, # number of resnet blocks
temperature = 0.9, # gumbel softmax temperature, the lower this is, the harder the discretization
straight_through = False, # straight-through for gumbel softmax. unclear if it is better one way or the other
)
images = torch.randn(4, 3, *vae.input_shape)
cond = torch.randn(4, 100, *vae.codebook_layer_shape)
logits = vae(images, cond=cond, return_logits = True)
logits.shape
import numpy as np
torch.randint(0,10,(1,))
image_seq = torch.randint(0,8192, (4,np.prod(vae.codebook_layer_shape)))
image = vae.decode(image_seq, cond=cond)
image.shape
# loss = vae(images, return_loss = True)
# loss.backward()
# loss
# train with a lot of data to learn a good codebook