MoMask / models /vq /model.py
MeYourHint's picture
first demo version
c0eac48
import random
import torch.nn as nn
from models.vq.encdec import Encoder, Decoder
from models.vq.residual_vq import ResidualVQ
class RVQVAE(nn.Module):
def __init__(self,
args,
input_width=263,
nb_code=1024,
code_dim=512,
output_emb_width=512,
down_t=3,
stride_t=2,
width=512,
depth=3,
dilation_growth_rate=3,
activation='relu',
norm=None):
super().__init__()
assert output_emb_width == code_dim
self.code_dim = code_dim
self.num_code = nb_code
# self.quant = args.quantizer
self.encoder = Encoder(input_width, output_emb_width, down_t, stride_t, width, depth,
dilation_growth_rate, activation=activation, norm=norm)
self.decoder = Decoder(input_width, output_emb_width, down_t, stride_t, width, depth,
dilation_growth_rate, activation=activation, norm=norm)
rvqvae_config = {
'num_quantizers': args.num_quantizers,
'shared_codebook': args.shared_codebook,
'quantize_dropout_prob': args.quantize_dropout_prob,
'quantize_dropout_cutoff_index': 0,
'nb_code': nb_code,
'code_dim':code_dim,
'args': args,
}
self.quantizer = ResidualVQ(**rvqvae_config)
def preprocess(self, x):
# (bs, T, Jx3) -> (bs, Jx3, T)
x = x.permute(0, 2, 1).float()
return x
def postprocess(self, x):
# (bs, Jx3, T) -> (bs, T, Jx3)
x = x.permute(0, 2, 1)
return x
def encode(self, x):
N, T, _ = x.shape
x_in = self.preprocess(x)
x_encoder = self.encoder(x_in)
# print(x_encoder.shape)
code_idx, all_codes = self.quantizer.quantize(x_encoder, return_latent=True)
# print(code_idx.shape)
# code_idx = code_idx.view(N, -1)
# (N, T, Q)
# print()
return code_idx, all_codes
def forward(self, x):
x_in = self.preprocess(x)
# Encode
x_encoder = self.encoder(x_in)
## quantization
# x_quantized, code_idx, commit_loss, perplexity = self.quantizer(x_encoder, sample_codebook_temp=0.5,
# force_dropout_index=0) #TODO hardcode
x_quantized, code_idx, commit_loss, perplexity = self.quantizer(x_encoder, sample_codebook_temp=0.5)
# print(code_idx[0, :, 1])
## decoder
x_out = self.decoder(x_quantized)
# x_out = self.postprocess(x_decoder)
return x_out, commit_loss, perplexity
def forward_decoder(self, x):
x_d = self.quantizer.get_codes_from_indices(x)
# x_d = x_d.view(1, -1, self.code_dim).permute(0, 2, 1).contiguous()
x = x_d.sum(dim=0).permute(0, 2, 1)
# decoder
x_out = self.decoder(x)
# x_out = self.postprocess(x_decoder)
return x_out
class LengthEstimator(nn.Module):
def __init__(self, input_size, output_size):
super(LengthEstimator, self).__init__()
nd = 512
self.output = nn.Sequential(
nn.Linear(input_size, nd),
nn.LayerNorm(nd),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout(0.2),
nn.Linear(nd, nd // 2),
nn.LayerNorm(nd // 2),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout(0.2),
nn.Linear(nd // 2, nd // 4),
nn.LayerNorm(nd // 4),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(nd // 4, output_size)
)
self.output.apply(self.__init_weights)
def __init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(self, text_emb):
return self.output(text_emb)