MoMask-test / models /vq /residual_vq.py
andrewatef's picture
Upload folder using huggingface_hub
823807d verified
raw
history blame
No virus
7.22 kB
import random
from math import ceil
from functools import partial
from itertools import zip_longest
from random import randrange
import torch
from torch import nn
import torch.nn.functional as F
# from vector_quantize_pytorch.vector_quantize_pytorch import VectorQuantize
from models.vq.quantizer import QuantizeEMAReset, QuantizeEMA
from einops import rearrange, repeat, pack, unpack
# helper functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def round_up_multiple(num, mult):
return ceil(num / mult) * mult
# main class
class ResidualVQ(nn.Module):
""" Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """
def __init__(
self,
num_quantizers,
shared_codebook=False,
quantize_dropout_prob=0.5,
quantize_dropout_cutoff_index=0,
**kwargs
):
super().__init__()
self.num_quantizers = num_quantizers
# self.layers = nn.ModuleList([VectorQuantize(accept_image_fmap = accept_image_fmap, **kwargs) for _ in range(num_quantizers)])
if shared_codebook:
layer = QuantizeEMAReset(**kwargs)
self.layers = nn.ModuleList([layer for _ in range(num_quantizers)])
else:
self.layers = nn.ModuleList([QuantizeEMAReset(**kwargs) for _ in range(num_quantizers)])
# self.layers = nn.ModuleList([QuantizeEMA(**kwargs) for _ in range(num_quantizers)])
# self.quantize_dropout = quantize_dropout and num_quantizers > 1
assert quantize_dropout_cutoff_index >= 0 and quantize_dropout_prob >= 0
self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
self.quantize_dropout_prob = quantize_dropout_prob
@property
def codebooks(self):
codebooks = [layer.codebook for layer in self.layers]
codebooks = torch.stack(codebooks, dim = 0)
return codebooks # 'q c d'
def get_codes_from_indices(self, indices): #indices shape 'b n q' # dequantize
batch, quantize_dim = indices.shape[0], indices.shape[-1]
# because of quantize dropout, one can pass in indices that are coarse
# and the network should be able to reconstruct
if quantize_dim < self.num_quantizers:
indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value = -1)
# get ready for gathering
codebooks = repeat(self.codebooks, 'q c d -> q b c d', b = batch)
gather_indices = repeat(indices, 'b n q -> q b n d', d = codebooks.shape[-1])
# take care of quantizer dropout
mask = gather_indices == -1.
gather_indices = gather_indices.masked_fill(mask, 0) # have it fetch a dummy code to be masked out later
# print(gather_indices.max(), gather_indices.min())
all_codes = codebooks.gather(2, gather_indices) # gather all codes
# mask out any codes that were dropout-ed
all_codes = all_codes.masked_fill(mask, 0.)
return all_codes # 'q b n d'
def get_codebook_entry(self, indices): #indices shape 'b n q'
all_codes = self.get_codes_from_indices(indices) #'q b n d'
latent = torch.sum(all_codes, dim=0) #'b n d'
latent = latent.permute(0, 2, 1)
return latent
def forward(self, x, return_all_codes = False, sample_codebook_temp = None, force_dropout_index=-1):
# debug check
# print(self.codebooks[:,0,0].detach().cpu().numpy())
num_quant, quant_dropout_prob, device = self.num_quantizers, self.quantize_dropout_prob, x.device
quantized_out = 0.
residual = x
all_losses = []
all_indices = []
all_perplexity = []
should_quantize_dropout = self.training and random.random() < self.quantize_dropout_prob
start_drop_quantize_index = num_quant
# To ensure the first-k layers learn things as much as possible, we randomly dropout the last q - k layers
if should_quantize_dropout:
start_drop_quantize_index = randrange(self.quantize_dropout_cutoff_index, num_quant) # keep quant layers <= quantize_dropout_cutoff_index, TODO vary in batch
null_indices_shape = [x.shape[0], x.shape[-1]] # 'b*n'
null_indices = torch.full(null_indices_shape, -1., device = device, dtype = torch.long)
# null_loss = 0.
if force_dropout_index >= 0:
should_quantize_dropout = True
start_drop_quantize_index = force_dropout_index
null_indices_shape = [x.shape[0], x.shape[-1]] # 'b*n'
null_indices = torch.full(null_indices_shape, -1., device=device, dtype=torch.long)
# print(force_dropout_index)
# go through the layers
for quantizer_index, layer in enumerate(self.layers):
if should_quantize_dropout and quantizer_index > start_drop_quantize_index:
all_indices.append(null_indices)
# all_losses.append(null_loss)
continue
# layer_indices = None
# if return_loss:
# layer_indices = indices[..., quantizer_index] #gt indices
# quantized, *rest = layer(residual, indices = layer_indices, sample_codebook_temp = sample_codebook_temp) #single quantizer TODO
quantized, *rest = layer(residual, return_idx=True, temperature=sample_codebook_temp) #single quantizer
# print(quantized.shape, residual.shape)
residual -= quantized.detach()
quantized_out += quantized
embed_indices, loss, perplexity = rest
all_indices.append(embed_indices)
all_losses.append(loss)
all_perplexity.append(perplexity)
# stack all losses and indices
all_indices = torch.stack(all_indices, dim=-1)
all_losses = sum(all_losses)/len(all_losses)
all_perplexity = sum(all_perplexity)/len(all_perplexity)
ret = (quantized_out, all_indices, all_losses, all_perplexity)
if return_all_codes:
# whether to return all codes from all codebooks across layers
all_codes = self.get_codes_from_indices(all_indices)
# will return all codes in shape (quantizer, batch, sequence length, codebook dimension)
ret = (*ret, all_codes)
return ret
def quantize(self, x, return_latent=False):
all_indices = []
quantized_out = 0.
residual = x
all_codes = []
for quantizer_index, layer in enumerate(self.layers):
quantized, *rest = layer(residual, return_idx=True) #single quantizer
residual = residual - quantized.detach()
quantized_out = quantized_out + quantized
embed_indices, loss, perplexity = rest
all_indices.append(embed_indices)
# print(quantizer_index, embed_indices[0])
# print(quantizer_index, quantized[0])
# break
all_codes.append(quantized)
code_idx = torch.stack(all_indices, dim=-1)
all_codes = torch.stack(all_codes, dim=0)
if return_latent:
return code_idx, all_codes
return code_idx