import torch | |
from einops import rearrange | |
from vector_quantize_pytorch.residual_fsq import GroupedResidualFSQ | |
def quantize( | |
quantizer: GroupedResidualFSQ, | |
audio_latents: torch.Tensor, # (batch_size, audio_len, audio_dim=1024) | |
) -> tuple[torch.Tensor, torch.Tensor]: | |
# feat shape (batch_size, audio_len, audio_dim) | |
# ind shape (GFSQ.G, batch_size, audio_len, GFSQ.R) | |
# num_vq=GFSQ.G*GFSQ.R | |
feat, ind = quantizer(audio_latents) | |
audio_quantized_latents = feat # (batch_size, audio_len, audio_dim) | |
audio_input_ids = rearrange( # (batch_size, audio_len, num_vq) | |
ind, | |
"g b t r ->b t (g r)", | |
) | |
return audio_quantized_latents, audio_input_ids | |