myTest01 / models /cdvae.py
meng2003's picture
Upload 85 files
bc32eea
from math import log2, sqrt
import torch
from torch import nn, einsum
import torch.nn.functional as F
from models.transformer import BasicTransformerModel, EncDecTransformerModel, EncDecXTransformer
from axial_positional_embedding import AxialPositionalEmbedding
from einops import rearrange
# from dalle_pytorch import distributed_utils
# from dalle_pytorch.vae import OpenAIDiscreteVAE
# from dalle_pytorch.vae import VQGanVAE1024
# from dalle_pytorch.transformer import Transformer
# helpers
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def always(val):
def inner(*args, **kwargs):
return val
return inner
def is_empty(t):
return t.nelement() == 0
def masked_mean(t, mask, dim = 1):
t = t.masked_fill(~mask[:, :, None], 0.)
return t.sum(dim = 1) / mask.sum(dim = 1)[..., None]
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
# sampling helpers
def top_k(logits, thres = 0.5):
num_logits = logits.shape[-1]
k = max(int((1 - thres) * num_logits), 1)
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(1, ind, val)
return probs
# discrete vae class
class ResBlock(nn.Module):
def __init__(self, chan):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(chan, chan, 3, padding = 1),
nn.ReLU(),
nn.Conv2d(chan, chan, 3, padding = 1),
nn.ReLU(),
nn.Conv2d(chan, chan, 1)
)
def forward(self, x):
return self.net(x) + x
class ConditionalDiscreteVAEVision(nn.Module):
def __init__(
self,
image_shape = (256,256),
num_tokens = 512,
codebook_dim = 512,
num_layers = 3,
num_resnet_blocks = 0,
hidden_dim = 64,
conditioning_dim = 64,
channels = 3,
smooth_l1_loss = False,
temperature = 0.9,
straight_through = False,
kl_div_loss_weight = 0.,
normalization = ((0.5,) * 3, (0.5,) * 3)
):
super().__init__()
assert log2(image_shape[0]).is_integer(), 'image size must be a power of 2'
assert log2(image_shape[1]).is_integer(), 'image size must be a power of 2'
assert num_layers >= 1, 'number of layers must be greater than or equal to 1'
has_resblocks = num_resnet_blocks > 0
self.image_shape = image_shape
self.num_tokens = num_tokens
self.num_layers = num_layers
self.temperature = temperature
self.straight_through = straight_through
self.codebook = nn.Embedding(num_tokens, codebook_dim)
hdim = hidden_dim
enc_chans = [hidden_dim] * num_layers
dec_chans = list(reversed(enc_chans))
enc_chans = [channels, *enc_chans]
if not has_resblocks:
dec_init_chan = codebook_dim
else:
dec_init_chan = dec_chans[0]
dec_chans = [dec_init_chan, *dec_chans]
enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans))
enc_layers = []
dec_layers = []
for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io):
enc_layers.append(nn.Sequential(nn.Conv2d(enc_in, enc_out, 4, stride = 2, padding = 1), nn.ReLU()))
dec_layers.append(nn.Sequential(nn.ConvTranspose2d(dec_in, dec_out, 4, stride = 2, padding = 1), nn.ReLU()))
for _ in range(num_resnet_blocks):
dec_layers.insert(0, ResBlock(dec_chans[1]))
enc_layers.append(ResBlock(enc_chans[-1]))
if num_resnet_blocks > 0:
dec_layers.insert(0, nn.Conv2d(codebook_dim, dec_chans[1], 1))
enc_layers.append(nn.Conv2d(enc_chans[-1], num_tokens, 1))
dec_layers.append(nn.Conv2d(dec_chans[-1], channels, 1))
self.encoder = nn.Sequential(*enc_layers)
self.decoder = nn.Sequential(*dec_layers)
self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss
self.kl_div_loss_weight = kl_div_loss_weight
# take care of normalization within class
self.normalization = normalization
# self._register_external_parameters()
# def _register_external_parameters(self):
# """Register external parameters for DeepSpeed partitioning."""
# if (
# not distributed_utils.is_distributed
# or not distributed_utils.using_backend(
# distributed_utils.DeepSpeedBackend)
# ):
# return
#
# deepspeed = distributed_utils.backend.backend_module
# deepspeed.zero.register_external_parameters(self, self.codebook.weight)
def norm(self, images):
if not exists(self.normalization):
return images
means, stds = map(lambda t: torch.as_tensor(t).to(images), self.normalization)
means, stds = map(lambda t: rearrange(t, 'c -> () c () ()'), (means, stds))
images = images.clone()
images.sub_(means).div_(stds)
return images
@torch.no_grad()
@eval_decorator
def get_codebook_indices(self, images):
logits = self(images, return_logits = True)
codebook_indices = logits.argmax(dim = 1).flatten(1)
return codebook_indices
def decode(
self,
img_seq
):
image_embeds = self.codebook(img_seq)
b, n, d = image_embeds.shape
h = w = int(sqrt(n))
image_embeds = rearrange(image_embeds, 'b (h w) d -> b d h w', h = h, w = w)
images = self.decoder(image_embeds)
return images
def forward(
self,
img,
return_loss = False,
return_recons = False,
return_logits = False,
temp = None
):
device, num_tokens, image_shape, kl_div_loss_weight = img.device, self.num_tokens, self.image_shape, self.kl_div_loss_weight
assert img.shape[-1] == image_shape[1] and img.shape[-2] == image_shape[0], f'input must have the correct image size {image_shape[0]}x{image_shape[1]}'
img = self.norm(img)
logits = self.encoder(img)
if return_logits:
return logits # return logits for getting hard image indices for DALL-E training
temp = default(temp, self.temperature)
soft_one_hot = F.gumbel_softmax(logits, tau = temp, dim = 1, hard = self.straight_through)
sampled = einsum('b n h w, n d -> b d h w', soft_one_hot, self.codebook.weight)
out = self.decoder(sampled)
if not return_loss:
return out
# reconstruction loss
recon_loss = self.loss_fn(img, out)
# kl divergence
logits = rearrange(logits, 'b n h w -> b (h w) n')
log_qy = F.log_softmax(logits, dim = -1)
log_uniform = torch.log(torch.tensor([1. / num_tokens], device = device))
kl_div = F.kl_div(log_uniform, log_qy, None, None, 'batchmean', log_target = True)
loss = recon_loss + (kl_div * kl_div_loss_weight)
if not return_recons:
return loss
return loss, out
class ConditionalDiscreteVAE(nn.Module):
def __init__(
self,
input_shape = (256,256),
num_tokens = 512,
codebook_dim = 512,
num_layers = 3,
num_resnet_blocks = 0,
hidden_dim = 64,
cond_dim = 0,
channels = 3,
smooth_l1_loss = False,
temperature = 0.9,
straight_through = False,
kl_div_loss_weight = 0.,
normalization = None,
prior_nhead = 8,
prior_dhid = 512,
prior_nlayers = 8,
prior_dropout = 0,
prior_use_pos_emb = True,
prior_use_x_transformers = False,
opt = None,
cond_vae = False
):
super().__init__()
assert num_layers >= 1, 'number of layers must be greater than or equal to 1'
has_resblocks = num_resnet_blocks > 0
self.input_shape = input_shape
self.num_tokens = num_tokens
self.num_layers = num_layers
self.temperature = temperature
self.straight_through = straight_through
self.codebook = nn.Embedding(num_tokens, codebook_dim)
self.cond_dim = cond_dim
self.cond_vae = cond_vae
hdim = hidden_dim
enc_chans = [hidden_dim] * num_layers
dec_chans = list(reversed(enc_chans))
if cond_vae:
enc_chans = [channels + cond_dim, *enc_chans]
else:
enc_chans = [channels, *enc_chans]
if not has_resblocks:
if cond_vae:
dec_init_chan = codebook_dim + cond_dim
else:
dec_init_chan = codebook_dim
else:
dec_init_chan = dec_chans[0]
dec_chans = [dec_init_chan, *dec_chans]
enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans))
enc_layers = []
dec_layers = []
if input_shape[0] == 1:
kernel_size1 = 1
padding_size1 = 0
codebook_layer_shape1 = 1
elif input_shape[0] in [2,3,4]:
kernel_size1 = 3
padding_size1 = 1
codebook_layer_shape1 = input_shape[0]
else:
#kernel_size1 = 4
kernel_size1 = 3
padding_size1 = 1
#codebook_layer_shape1 = input_shape[0] - num_layers
codebook_layer_shape1 = input_shape[0]
if input_shape[1] == 1:
kernel_size2 = 1
padding_size2 = 0
codebook_layer_shape2 = 1
elif input_shape[1] in [2,3,4]:
kernel_size2 = 3
padding_size2 = 1
codebook_layer_shape2 = input_shape[1]
else:
#kernel_size2 = 4
kernel_size2 = 3
padding_size2 = 1
#codebook_layer_shape2 = input_shape[1] - num_layers
codebook_layer_shape2 = input_shape[1]
self.codebook_layer_shape = (codebook_layer_shape1,codebook_layer_shape2)
kernel_shape = (kernel_size1, kernel_size2)
padding_shape = (padding_size1, padding_size2)
for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io):
enc_layers.append(nn.Sequential(nn.Conv2d(enc_in, enc_out, kernel_shape, stride = 1, padding = padding_shape), nn.ReLU()))
dec_layers.append(nn.Sequential(nn.ConvTranspose2d(dec_in, dec_out, kernel_shape, stride = 1, padding = padding_shape), nn.ReLU()))
for _ in range(num_resnet_blocks):
dec_layers.insert(0, ResBlock(dec_chans[1]))
enc_layers.append(ResBlock(enc_chans[-1]))
if num_resnet_blocks > 0:
if cond_vae:
dec_layers.insert(0, nn.Conv2d(codebook_dim + cond_dim, dec_chans[1], 1))
else:
dec_layers.insert(0, nn.Conv2d(codebook_dim, dec_chans[1], 1))
enc_layers.append(nn.Conv2d(enc_chans[-1], num_tokens, 1))
dec_layers.append(nn.Conv2d(dec_chans[-1], channels, 1))
self.cond_upsampler = torch.nn.Upsample(size=input_shape) #upsampler to feed the conditioning to the input of the encoder
self.encoder = nn.Sequential(*enc_layers)
self.decoder = nn.Sequential(*dec_layers)
self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss
self.kl_div_loss_weight = kl_div_loss_weight
# take care of normalization within class
self.normalization = normalization
latent_size = codebook_layer_shape1*codebook_layer_shape2
self.latent_size = latent_size
if cond_dim > 0:
self.prior_transformer = ContDiscTransformer(cond_dim, num_tokens, codebook_dim, prior_nhead, prior_dhid, prior_nlayers, prior_dropout,
use_pos_emb=prior_use_pos_emb,
src_length=latent_size,
tgt_length=latent_size,
use_x_transformers=prior_use_x_transformers,
opt=opt)
# self._register_external_parameters()
# def _register_external_parameters(self):
# """Register external parameters for DeepSpeed partitioning."""
# if (
# not distributed_utils.is_distributed
# or not distributed_utils.using_backend(
# distributed_utils.DeepSpeedBackend)
# ):
# return
#
# deepspeed = distributed_utils.backend.backend_module
# deepspeed.zero.register_external_parameters(self, self.codebook.weight)
def norm(self, images):
if not exists(self.normalization):
return images
means, stds = map(lambda t: torch.as_tensor(t).to(images), self.normalization)
means, stds = map(lambda t: rearrange(t, 'c -> () c () ()'), (means, stds))
images = images.clone()
images.sub_(means).div_(stds)
return images
@torch.no_grad()
@eval_decorator
def get_codebook_indices(self, inputs, cond=None):
logits = self(inputs, cond, return_logits = True)
codebook_indices = logits.argmax(dim = 1).flatten(1)
return codebook_indices
def decode(
self,
img_seq,
cond = None
):
image_embeds = self.codebook(img_seq)
b, n, d = image_embeds.shape
h = w = int(sqrt(n))
image_embeds = rearrange(image_embeds, 'b (h w) d -> b d h w', h = h, w = w)
if cond is not None:
image_embeds_cond = torch.cat([image_embeds, cond], dim = 1)
images = self.decoder(image_embeds_cond)
else:
images = self.decoder(image_embeds)
return images
def prior_logp(
self,
inputs,
cond = None,
return_accuracy = False,
detach_cond = False
):
# import pdb;pdb.set_trace()
#if cond is None: raise NotImplementedError("Haven't implemented non-conditional DVAEs")
if len(inputs.shape) == 3:
inputs = inputs.reshape(inputs.shape[0], inputs.shape[1],*self.input_shape)
if len(cond.shape) == 3:
cond = cond.reshape(cond.shape[0], cond.shape[1],*self.codebook_layer_shape)
with torch.no_grad():
if self.cond_vae:
labels = self.get_codebook_indices(inputs, cond)
else:
labels = self.get_codebook_indices(inputs)
if detach_cond:
cond = cond.detach()
logits = self.prior_transformer(cond.squeeze(-1).permute(2,0,1), labels.permute(1,0)).permute(1,2,0)
loss = F.cross_entropy(logits, labels)
if not return_accuracy:
return loss
# import pdb;pdb.set_trace()
predicted = logits.argmax(dim = 1).flatten(1)
accuracy = (predicted == labels).sum()/predicted.nelement()
return loss, accuracy
def generate(self, cond, temp=1.0, filter_thres = 0.5):
#if cond is None: raise NotImplementedError("Haven't implemented non-conditional DVAEs")
if len(cond.shape) == 3:
cond = cond.reshape(cond.shape[0], cond.shape[1],*self.codebook_layer_shape)
dummy = torch.zeros(1,1).long().to(cond.device)
tokens = []
for i in range(self.latent_size):
# print(i)
logits = self.prior_transformer(cond.squeeze(-1).permute(2,0,1), torch.cat(tokens+[dummy], 0)).permute(1,2,0)[:,-1,:]
filtered_logits = top_k(logits, thres = filter_thres)
probs = F.softmax(filtered_logits / temp, dim = -1)
sampled = torch.multinomial(probs, 1)
tokens.append(sampled)
print(tokens)
embs = self.codebook(torch.cat(tokens, 0))
# import pdb;pdb.set_trace()
if self.cond_vae:
sampled_cond = torch.cat([embs.permute(2,0,1).unsqueeze(0),cond], dim=1)
else:
sampled_cond = embs.permute(2,0,1).unsqueeze(0)
out = self.decoder(sampled_cond)
return out
def forward(
self,
inp,
cond = None,
return_loss = False,
return_recons = False,
return_logits = False,
temp = None
):
if len(inp.shape) == 3:
inp = inp.reshape(inp.shape[0], inp.shape[1],*self.input_shape)
device, num_tokens, input_shape, kl_div_loss_weight = inp.device, self.num_tokens, self.input_shape, self.kl_div_loss_weight
assert inp.shape[-1] == input_shape[1] and inp.shape[-2] == input_shape[0], f'input must have the correct image size {input_shape[0]}x{input_shape[1]}. Instead got {inp.shape[0]}x{inp.shape[1]}'
inp = self.norm(inp)
if cond is not None:
if len(cond.shape) == 3:
cond = cond.reshape(cond.shape[0], cond.shape[1],*self.codebook_layer_shape)
cond_upsampled = self.cond_upsampler(cond)
inp_cond = torch.cat([inp,cond_upsampled], dim=1)
inp_cond = self.norm(inp_cond)
else:
inp_cond = self.norm(inp)
logits = self.encoder(inp_cond)
# codebook_indices = logits.argmax(dim = 1).flatten(1)
# print(codebook_indices.shape)
# print(codebook_indices)
# print(list(self.encoder.parameters())[1].data)
# for p in self.prior_transformer.parameters():
# print(p.norm())
if return_logits:
return logits # return logits for getting hard image indices for DALL-E training
temp = default(temp, self.temperature)
soft_one_hot = F.gumbel_softmax(logits, tau = temp, dim = 1, hard = self.straight_through)
sampled = einsum('b n h w, n d -> b d h w', soft_one_hot, self.codebook.weight)
if cond is not None:
sampled_cond = torch.cat([sampled,cond], dim=1)
out = self.decoder(sampled_cond)
else:
out = self.decoder(sampled)
if not return_loss:
return out
# reconstruction loss
# import pdb;pdb.set_trace()
recon_loss = self.loss_fn(inp, out)
# kl divergence
logits = rearrange(logits, 'b n h w -> b (h w) n')
log_qy = F.log_softmax(logits, dim = -1)
log_uniform = torch.log(torch.tensor([1. / num_tokens], device = device))
kl_div = F.kl_div(log_uniform, log_qy, None, None, 'batchmean', log_target = True)
loss = recon_loss + (kl_div * kl_div_loss_weight)
if not return_recons:
return loss
return loss, out
class ContDiscTransformer(nn.Module):
def __init__(self, src_d, tgt_num_tokens, tgt_emb_dim, nhead, dhid, nlayers, dropout=0.5,use_pos_emb=False,src_length=0,tgt_length=0,use_x_transformers=False,opt=None):
super(ContDiscTransformer, self).__init__()
self.transformer = EncDecTransformerModel(tgt_num_tokens, src_d, tgt_emb_dim, nhead, dhid, nlayers, dropout=dropout,use_pos_emb=use_pos_emb,src_length=src_length,tgt_length=tgt_length,use_x_transformers=use_x_transformers,opt=opt)
#self.transformer = EncDecTransformerModel(tgt_num_tokens, src_d, tgt_emb_dim, nhead, dhid, nlayers, dropout=dropout,use_pos_emb=False,src_length=src_length,tgt_length=tgt_length,use_x_transformers=use_x_transformers,opt=opt)
# self.transformer = EncDecXTransformer(dim=dhid, dec_dim_out=tgt_num_tokens, enc_dim_in=src_d, enc_dim_out=tgt_emb_dim, dec_din_in=tgt_emb_dim, enc_heads=nhead, dec_heads=nhead, enc_depth=nlayers, dec_depth=nlayers, enc_dropout=dropout, dec_dropout=dropout, enc_max_seq_len=1024, dec_max_seq_len=1024)
self.embedding = nn.Embedding(tgt_num_tokens, tgt_emb_dim)
self.first_input = nn.Parameter((torch.randn(1,1,tgt_emb_dim)))
def forward(self, src, tgt):
tgt = tgt[:-1]
embs = self.embedding(tgt)
embs = torch.cat([torch.tile(self.first_input, (1,embs.shape[1],1)), embs], 0)
output = self.transformer(src,embs)
return output