Spaces:
Running
Running
import random | |
import torch.nn as nn | |
from models.vq.encdec import Encoder, Decoder | |
from models.vq.residual_vq import ResidualVQ | |
class RVQVAE(nn.Module): | |
def __init__(self, | |
args, | |
input_width=263, | |
nb_code=1024, | |
code_dim=512, | |
output_emb_width=512, | |
down_t=3, | |
stride_t=2, | |
width=512, | |
depth=3, | |
dilation_growth_rate=3, | |
activation='relu', | |
norm=None): | |
super().__init__() | |
assert output_emb_width == code_dim | |
self.code_dim = code_dim | |
self.num_code = nb_code | |
# self.quant = args.quantizer | |
self.encoder = Encoder(input_width, output_emb_width, down_t, stride_t, width, depth, | |
dilation_growth_rate, activation=activation, norm=norm) | |
self.decoder = Decoder(input_width, output_emb_width, down_t, stride_t, width, depth, | |
dilation_growth_rate, activation=activation, norm=norm) | |
rvqvae_config = { | |
'num_quantizers': args.num_quantizers, | |
'shared_codebook': args.shared_codebook, | |
'quantize_dropout_prob': args.quantize_dropout_prob, | |
'quantize_dropout_cutoff_index': 0, | |
'nb_code': nb_code, | |
'code_dim':code_dim, | |
'args': args, | |
} | |
self.quantizer = ResidualVQ(**rvqvae_config) | |
def preprocess(self, x): | |
# (bs, T, Jx3) -> (bs, Jx3, T) | |
x = x.permute(0, 2, 1).float() | |
return x | |
def postprocess(self, x): | |
# (bs, Jx3, T) -> (bs, T, Jx3) | |
x = x.permute(0, 2, 1) | |
return x | |
def encode(self, x): | |
N, T, _ = x.shape | |
x_in = self.preprocess(x) | |
x_encoder = self.encoder(x_in) | |
# print(x_encoder.shape) | |
code_idx, all_codes = self.quantizer.quantize(x_encoder, return_latent=True) | |
# print(code_idx.shape) | |
# code_idx = code_idx.view(N, -1) | |
# (N, T, Q) | |
# print() | |
return code_idx, all_codes | |
def forward(self, x): | |
x_in = self.preprocess(x) | |
# Encode | |
x_encoder = self.encoder(x_in) | |
## quantization | |
# x_quantized, code_idx, commit_loss, perplexity = self.quantizer(x_encoder, sample_codebook_temp=0.5, | |
# force_dropout_index=0) #TODO hardcode | |
x_quantized, code_idx, commit_loss, perplexity = self.quantizer(x_encoder, sample_codebook_temp=0.5) | |
# print(code_idx[0, :, 1]) | |
## decoder | |
x_out = self.decoder(x_quantized) | |
# x_out = self.postprocess(x_decoder) | |
return x_out, commit_loss, perplexity | |
def forward_decoder(self, x): | |
x_d = self.quantizer.get_codes_from_indices(x) | |
# x_d = x_d.view(1, -1, self.code_dim).permute(0, 2, 1).contiguous() | |
x = x_d.sum(dim=0).permute(0, 2, 1) | |
# decoder | |
x_out = self.decoder(x) | |
# x_out = self.postprocess(x_decoder) | |
return x_out | |
class LengthEstimator(nn.Module): | |
def __init__(self, input_size, output_size): | |
super(LengthEstimator, self).__init__() | |
nd = 512 | |
self.output = nn.Sequential( | |
nn.Linear(input_size, nd), | |
nn.LayerNorm(nd), | |
nn.LeakyReLU(0.2, inplace=True), | |
nn.Dropout(0.2), | |
nn.Linear(nd, nd // 2), | |
nn.LayerNorm(nd // 2), | |
nn.LeakyReLU(0.2, inplace=True), | |
nn.Dropout(0.2), | |
nn.Linear(nd // 2, nd // 4), | |
nn.LayerNorm(nd // 4), | |
nn.LeakyReLU(0.2, inplace=True), | |
nn.Linear(nd // 4, output_size) | |
) | |
self.output.apply(self.__init_weights) | |
def __init_weights(self, module): | |
if isinstance(module, (nn.Linear, nn.Embedding)): | |
module.weight.data.normal_(mean=0.0, std=0.02) | |
if isinstance(module, nn.Linear) and module.bias is not None: | |
module.bias.data.zero_() | |
elif isinstance(module, nn.LayerNorm): | |
module.bias.data.zero_() | |
module.weight.data.fill_(1.0) | |
def forward(self, text_emb): | |
return self.output(text_emb) |