Junyin's picture
Add files using upload-large-folder tool
05744dc verified
import torch
import torch.nn as nn
from .vq import VectorQuantizer
class ResidualVectorQuantizer(nn.Module):
""" References:
SoundStream: An End-to-End Neural Audio Codec
https://arxiv.org/pdf/2107.03312.pdf
"""
def __init__(self, n_e_list, e_dim, sk_epsilons,
kmeans_init = False, kmeans_iters = 100, sk_iters=100,):
super().__init__()
self.n_e_list = n_e_list
self.e_dim = e_dim
self.num_quantizers = len(n_e_list)
self.kmeans_init = kmeans_init
self.kmeans_iters = kmeans_iters
self.sk_epsilons = sk_epsilons
self.sk_iters = sk_iters
self.vq_layers = nn.ModuleList([VectorQuantizer(n_e, e_dim,
kmeans_init = self.kmeans_init,
kmeans_iters = self.kmeans_iters,
sk_epsilon=sk_epsilon,
sk_iters=sk_iters)
for n_e, sk_epsilon in zip(n_e_list,sk_epsilons) ])
def get_codebook(self):
all_codebook = []
for quantizer in self.vq_layers:
codebook = quantizer.get_codebook()
all_codebook.append(codebook)
return torch.stack(all_codebook)
def forward(self, x, use_sk=True):
all_losses = []
all_indices = []
x_q = 0
residual = x
for quantizer in self.vq_layers:
x_res, loss, indices = quantizer(residual, use_sk=use_sk)
residual = residual - x_res
x_q = x_q + x_res
all_losses.append(loss)
all_indices.append(indices)
mean_losses = torch.stack(all_losses).mean()
all_indices = torch.stack(all_indices, dim=-1)
return x_q, mean_losses, all_indices