zhzluke96
update
1df74c6
raw
history blame
696 Bytes
import torch
from einops import rearrange
from vector_quantize_pytorch.residual_fsq import GroupedResidualFSQ
def quantize(
quantizer: GroupedResidualFSQ,
audio_latents: torch.Tensor, # (batch_size, audio_len, audio_dim=1024)
) -> tuple[torch.Tensor, torch.Tensor]:
# feat shape (batch_size, audio_len, audio_dim)
# ind shape (GFSQ.G, batch_size, audio_len, GFSQ.R)
# num_vq=GFSQ.G*GFSQ.R
feat, ind = quantizer(audio_latents)
audio_quantized_latents = feat # (batch_size, audio_len, audio_dim)
audio_input_ids = rearrange( # (batch_size, audio_len, num_vq)
ind,
"g b t r ->b t (g r)",
)
return audio_quantized_latents, audio_input_ids