|
from typing import Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from einops import rearrange
|
|
from torch.nn.utils import weight_norm
|
|
|
|
from dac.nn.layers import WNConv1d
|
|
|
|
class VectorQuantizeLegacy(nn.Module):
|
|
"""
|
|
Implementation of VQ similar to Karpathy's repo:
|
|
https://github.com/karpathy/deep-vector-quantization
|
|
removed in-out projection
|
|
"""
|
|
|
|
def __init__(self, input_dim: int, codebook_size: int):
|
|
super().__init__()
|
|
self.codebook_size = codebook_size
|
|
self.codebook = nn.Embedding(codebook_size, input_dim)
|
|
|
|
def forward(self, z, z_mask=None):
|
|
"""Quantized the input tensor using a fixed codebook and returns
|
|
the corresponding codebook vectors
|
|
|
|
Parameters
|
|
----------
|
|
z : Tensor[B x D x T]
|
|
|
|
Returns
|
|
-------
|
|
Tensor[B x D x T]
|
|
Quantized continuous representation of input
|
|
Tensor[1]
|
|
Commitment loss to train encoder to predict vectors closer to codebook
|
|
entries
|
|
Tensor[1]
|
|
Codebook loss to update the codebook
|
|
Tensor[B x T]
|
|
Codebook indices (quantized discrete representation of input)
|
|
Tensor[B x D x T]
|
|
Projected latents (continuous representation of input before quantization)
|
|
"""
|
|
|
|
z_e = z
|
|
z_q, indices = self.decode_latents(z)
|
|
|
|
if z_mask is not None:
|
|
commitment_loss = (F.mse_loss(z_e, z_q.detach(), reduction="none").mean(1) * z_mask).sum() / z_mask.sum()
|
|
codebook_loss = (F.mse_loss(z_q, z_e.detach(), reduction="none").mean(1) * z_mask).sum() / z_mask.sum()
|
|
else:
|
|
commitment_loss = F.mse_loss(z_e, z_q.detach())
|
|
codebook_loss = F.mse_loss(z_q, z_e.detach())
|
|
z_q = (
|
|
z_e + (z_q - z_e).detach()
|
|
)
|
|
|
|
return z_q, indices, z_e, commitment_loss, codebook_loss
|
|
|
|
def embed_code(self, embed_id):
|
|
return F.embedding(embed_id, self.codebook.weight)
|
|
|
|
def decode_code(self, embed_id):
|
|
return self.embed_code(embed_id).transpose(1, 2)
|
|
|
|
def decode_latents(self, latents):
|
|
encodings = rearrange(latents, "b d t -> (b t) d")
|
|
codebook = self.codebook.weight
|
|
|
|
|
|
encodings = F.normalize(encodings)
|
|
codebook = F.normalize(codebook)
|
|
|
|
|
|
dist = (
|
|
encodings.pow(2).sum(1, keepdim=True)
|
|
- 2 * encodings @ codebook.t()
|
|
+ codebook.pow(2).sum(1, keepdim=True).t()
|
|
)
|
|
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
|
z_q = self.decode_code(indices)
|
|
return z_q, indices
|
|
|
|
class VectorQuantize(nn.Module):
|
|
"""
|
|
Implementation of VQ similar to Karpathy's repo:
|
|
https://github.com/karpathy/deep-vector-quantization
|
|
Additionally uses following tricks from Improved VQGAN
|
|
(https://arxiv.org/pdf/2110.04627.pdf):
|
|
1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
|
|
for improved codebook usage
|
|
2. l2-normalized codes: Converts euclidean distance to cosine similarity which
|
|
improves training stability
|
|
"""
|
|
|
|
def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
|
|
super().__init__()
|
|
self.codebook_size = codebook_size
|
|
self.codebook_dim = codebook_dim
|
|
|
|
self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
|
|
self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
|
|
self.codebook = nn.Embedding(codebook_size, codebook_dim)
|
|
|
|
def forward(self, z, z_mask=None):
|
|
"""Quantized the input tensor using a fixed codebook and returns
|
|
the corresponding codebook vectors
|
|
|
|
Parameters
|
|
----------
|
|
z : Tensor[B x D x T]
|
|
|
|
Returns
|
|
-------
|
|
Tensor[B x D x T]
|
|
Quantized continuous representation of input
|
|
Tensor[1]
|
|
Commitment loss to train encoder to predict vectors closer to codebook
|
|
entries
|
|
Tensor[1]
|
|
Codebook loss to update the codebook
|
|
Tensor[B x T]
|
|
Codebook indices (quantized discrete representation of input)
|
|
Tensor[B x D x T]
|
|
Projected latents (continuous representation of input before quantization)
|
|
"""
|
|
|
|
|
|
z_e = self.in_proj(z)
|
|
z_q, indices = self.decode_latents(z_e)
|
|
|
|
if z_mask is not None:
|
|
commitment_loss = (F.mse_loss(z_e, z_q.detach(), reduction="none").mean(1) * z_mask).sum() / z_mask.sum()
|
|
codebook_loss = (F.mse_loss(z_q, z_e.detach(), reduction="none").mean(1) * z_mask).sum() / z_mask.sum()
|
|
else:
|
|
commitment_loss = F.mse_loss(z_e, z_q.detach())
|
|
codebook_loss = F.mse_loss(z_q, z_e.detach())
|
|
|
|
z_q = (
|
|
z_e + (z_q - z_e).detach()
|
|
)
|
|
|
|
z_q = self.out_proj(z_q)
|
|
|
|
return z_q, commitment_loss, codebook_loss, indices, z_e
|
|
|
|
def embed_code(self, embed_id):
|
|
return F.embedding(embed_id, self.codebook.weight)
|
|
|
|
def decode_code(self, embed_id):
|
|
return self.embed_code(embed_id).transpose(1, 2)
|
|
|
|
def decode_latents(self, latents):
|
|
encodings = rearrange(latents, "b d t -> (b t) d")
|
|
codebook = self.codebook.weight
|
|
|
|
|
|
encodings = F.normalize(encodings)
|
|
codebook = F.normalize(codebook)
|
|
|
|
|
|
dist = (
|
|
encodings.pow(2).sum(1, keepdim=True)
|
|
- 2 * encodings @ codebook.t()
|
|
+ codebook.pow(2).sum(1, keepdim=True).t()
|
|
)
|
|
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
|
z_q = self.decode_code(indices)
|
|
return z_q, indices
|
|
|
|
|
|
class ResidualVectorQuantize(nn.Module):
|
|
"""
|
|
Introduced in SoundStream: An end2end neural audio codec
|
|
https://arxiv.org/abs/2107.03312
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
input_dim: int = 512,
|
|
n_codebooks: int = 9,
|
|
codebook_size: int = 1024,
|
|
codebook_dim: Union[int, list] = 8,
|
|
quantizer_dropout: float = 0.0,
|
|
):
|
|
super().__init__()
|
|
if isinstance(codebook_dim, int):
|
|
codebook_dim = [codebook_dim for _ in range(n_codebooks)]
|
|
|
|
self.n_codebooks = n_codebooks
|
|
self.codebook_dim = codebook_dim
|
|
self.codebook_size = codebook_size
|
|
|
|
self.quantizers = nn.ModuleList(
|
|
[
|
|
VectorQuantize(input_dim, codebook_size, codebook_dim[i])
|
|
for i in range(n_codebooks)
|
|
]
|
|
)
|
|
self.quantizer_dropout = quantizer_dropout
|
|
|
|
def forward(self, z, n_quantizers: int = None):
|
|
"""Quantized the input tensor using a fixed set of `n` codebooks and returns
|
|
the corresponding codebook vectors
|
|
Parameters
|
|
----------
|
|
z : Tensor[B x D x T]
|
|
n_quantizers : int, optional
|
|
No. of quantizers to use
|
|
(n_quantizers < self.n_codebooks ex: for quantizer dropout)
|
|
Note: if `self.quantizer_dropout` is True, this argument is ignored
|
|
when in training mode, and a random number of quantizers is used.
|
|
Returns
|
|
-------
|
|
dict
|
|
A dictionary with the following keys:
|
|
|
|
"z" : Tensor[B x D x T]
|
|
Quantized continuous representation of input
|
|
"codes" : Tensor[B x N x T]
|
|
Codebook indices for each codebook
|
|
(quantized discrete representation of input)
|
|
"latents" : Tensor[B x N*D x T]
|
|
Projected latents (continuous representation of input before quantization)
|
|
"vq/commitment_loss" : Tensor[1]
|
|
Commitment loss to train encoder to predict vectors closer to codebook
|
|
entries
|
|
"vq/codebook_loss" : Tensor[1]
|
|
Codebook loss to update the codebook
|
|
"""
|
|
z_q = 0
|
|
residual = z
|
|
commitment_loss = 0
|
|
codebook_loss = 0
|
|
|
|
codebook_indices = []
|
|
latents = []
|
|
|
|
if n_quantizers is None:
|
|
n_quantizers = self.n_codebooks
|
|
if self.training:
|
|
n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
|
|
dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
|
|
n_dropout = int(z.shape[0] * self.quantizer_dropout)
|
|
n_quantizers[:n_dropout] = dropout[:n_dropout]
|
|
n_quantizers = n_quantizers.to(z.device)
|
|
|
|
for i, quantizer in enumerate(self.quantizers):
|
|
if self.training is False and i >= n_quantizers:
|
|
break
|
|
|
|
z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
|
|
residual
|
|
)
|
|
|
|
|
|
mask = (
|
|
torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
|
|
)
|
|
z_q = z_q + z_q_i * mask[:, None, None]
|
|
residual = residual - z_q_i
|
|
|
|
|
|
commitment_loss += (commitment_loss_i * mask).mean()
|
|
codebook_loss += (codebook_loss_i * mask).mean()
|
|
|
|
codebook_indices.append(indices_i)
|
|
latents.append(z_e_i)
|
|
|
|
codes = torch.stack(codebook_indices, dim=1)
|
|
latents = torch.cat(latents, dim=1)
|
|
|
|
return z_q, codes, latents, commitment_loss, codebook_loss
|
|
|
|
def from_codes(self, codes: torch.Tensor):
|
|
"""Given the quantized codes, reconstruct the continuous representation
|
|
Parameters
|
|
----------
|
|
codes : Tensor[B x N x T]
|
|
Quantized discrete representation of input
|
|
Returns
|
|
-------
|
|
Tensor[B x D x T]
|
|
Quantized continuous representation of input
|
|
"""
|
|
z_q = 0.0
|
|
z_p = []
|
|
n_codebooks = codes.shape[1]
|
|
for i in range(n_codebooks):
|
|
z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
|
|
z_p.append(z_p_i)
|
|
|
|
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
|
z_q = z_q + z_q_i
|
|
return z_q, torch.cat(z_p, dim=1), codes
|
|
|
|
def from_latents(self, latents: torch.Tensor):
|
|
"""Given the unquantized latents, reconstruct the
|
|
continuous representation after quantization.
|
|
|
|
Parameters
|
|
----------
|
|
latents : Tensor[B x N x T]
|
|
Continuous representation of input after projection
|
|
|
|
Returns
|
|
-------
|
|
Tensor[B x D x T]
|
|
Quantized representation of full-projected space
|
|
Tensor[B x D x T]
|
|
Quantized representation of latent space
|
|
"""
|
|
z_q = 0
|
|
z_p = []
|
|
codes = []
|
|
dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
|
|
|
|
n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
|
|
0
|
|
]
|
|
for i in range(n_codebooks):
|
|
j, k = dims[i], dims[i + 1]
|
|
z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
|
|
z_p.append(z_p_i)
|
|
codes.append(codes_i)
|
|
|
|
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
|
z_q = z_q + z_q_i
|
|
|
|
return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
rvq = ResidualVectorQuantize(quantizer_dropout=True)
|
|
x = torch.randn(16, 512, 80)
|
|
y = rvq(x)
|
|
print(y["latents"].shape)
|
|
|