|
import math |
|
import typing as tp |
|
from dataclasses import dataclass, field |
|
import typing as tp |
|
import torch |
|
from torch import nn |
|
from einops import rearrange |
|
import torch.nn.functional as F |
|
|
|
@dataclass |
|
class QuantizedResult: |
|
x: torch.Tensor |
|
codes: torch.Tensor |
|
bandwidth: torch.Tensor |
|
penalty: tp.Optional[torch.Tensor] = None |
|
metrics: dict = field(default_factory=dict) |
|
|
|
|
|
|
|
|
|
|
|
|
|
class EuclideanCodebook(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
codebook_size, |
|
kmeans_init=False, |
|
kmeans_iters=10, |
|
decay=0.8, |
|
epsilon=1e-5, |
|
): |
|
super().__init__() |
|
self.decay=decay |
|
init_fn=uniform_init if not kmeans_init else torch.zeros |
|
embed = init_fn(codebook_size, dim) |
|
|
|
self.codebook_size = codebook_size |
|
|
|
self.kmeans_iters = kmeans_iters |
|
self.epsilon = epsilon |
|
|
|
self.register_buffer("inited", torch.Tensor([not kmeans_init])) |
|
self.register_buffer("cluster_size", torch.zeros(codebook_size)) |
|
self.register_buffer("embed", embed) |
|
self.register_buffer("embed_avg", embed.clone()) |
|
|
|
@torch.jit.ignore |
|
def init_embed_(self, data): |
|
if self.inited: |
|
return |
|
|
|
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) |
|
self.embed.data.copy_(embed) |
|
self.embed_avg.data.copy_(embed.clone()) |
|
self.cluster_size.data.copy_(cluster_size) |
|
self.inited.data.copy_(torch.Tensor([True])) |
|
|
|
|
|
|
|
|
|
|
|
def postprocess_emb(self, embed_ind, shape): |
|
return embed_ind.view(*shape[:-1]) |
|
|
|
def dequantize(self, embed_ind): |
|
|
|
|
|
quantize = F.embedding(embed_ind, self.embed) |
|
|
|
return quantize |
|
|
|
def decode(self, embed_ind): |
|
quantize = self.dequantize(embed_ind) |
|
return quantize |
|
|
|
|
|
|
|
class VectorQuantization(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
dim, |
|
codebook_size, |
|
codebook_dim=None, |
|
decay=0.8, |
|
epsilon=1e-5, |
|
kmeans_init=False, |
|
kmeans_iters=10, |
|
channels_last=False, |
|
): |
|
super().__init__() |
|
|
|
_codebook_dim = codebook_dim if codebook_dim is not None else dim |
|
|
|
requires_projection = _codebook_dim != dim |
|
self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()) |
|
self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()) |
|
self._codebook = EuclideanCodebook(dim=_codebook_dim, |
|
codebook_size=codebook_size, |
|
kmeans_init=kmeans_init, |
|
kmeans_iters=kmeans_iters, |
|
decay=decay, |
|
epsilon=epsilon) |
|
self.codebook_size = codebook_size |
|
|
|
self.channels_last = channels_last |
|
|
|
@property |
|
def codebook(self): |
|
return self._codebook.embed |
|
|
|
@property |
|
def inited(self): |
|
return self._codebook.inited |
|
|
|
def _postprocess(self, quantize): |
|
if not self.channels_last: |
|
quantize = rearrange(quantize, "b n d -> b d n") |
|
return quantize |
|
|
|
def decode(self, embed_ind): |
|
quantize = self._codebook.decode(embed_ind) |
|
quantize = self.project_out(quantize) |
|
quantize = self._postprocess(quantize) |
|
return quantize |
|
|
|
|
|
|
|
|
|
class ResidualVectorQuantization(nn.Module): |
|
"""Residual vector quantization implementation. |
|
|
|
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf |
|
""" |
|
def __init__(self, *, num_quantizers, **kwargs): |
|
super().__init__() |
|
self.layers = nn.ModuleList( |
|
[VectorQuantization(**kwargs) for _ in range(num_quantizers)] |
|
) |
|
|
|
def decode(self, q_indices: torch.Tensor) -> torch.Tensor: |
|
quantized_out = torch.tensor(0.0, device=q_indices.device) |
|
for i, indices in enumerate(q_indices): |
|
layer = self.layers[i] |
|
quantized = layer.decode(indices) |
|
quantized_out = quantized_out + quantized |
|
return quantized_out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ResidualVectorQuantizer(nn.Module): |
|
"""Residual Vector Quantizer. |
|
|
|
Args: |
|
dimension (int): Dimension of the codebooks. |
|
n_q (int): Number of residual vector quantizers used. |
|
q_dropout (bool): Random quantizer drop out at train time. |
|
bins (int): Codebook size. |
|
decay (float): Decay for exponential moving average over the codebooks. |
|
kmeans_init (bool): Whether to use kmeans to initialize the codebooks. |
|
kmeans_iters (int): Number of iterations used for kmeans initialization. |
|
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes |
|
that have an exponential moving average cluster size less than the specified threshold with |
|
randomly selected vector from the current batch. |
|
orthogonal_reg_weight (float): Orthogonal regularization weights. |
|
orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes. |
|
orthogonal_reg_max_codes (optional int): Maximum number of codes to consider. |
|
for orthogonal regularization. |
|
""" |
|
def __init__( |
|
self, |
|
dimension: int = 256, |
|
n_q: int = 8, |
|
q_dropout: bool = False, |
|
bins: int = 1024, |
|
decay: float = 0.99, |
|
kmeans_init: bool = True, |
|
kmeans_iters: int = 10, |
|
threshold_ema_dead_code: int = 2, |
|
orthogonal_reg_weight: float = 0.0, |
|
orthogonal_reg_active_codes_only: bool = False, |
|
orthogonal_reg_max_codes: tp.Optional[int] = None, |
|
): |
|
super().__init__() |
|
self.max_n_q = n_q |
|
self.n_q = n_q |
|
self.q_dropout = q_dropout |
|
self.dimension = dimension |
|
self.bins = bins |
|
self.decay = decay |
|
self.kmeans_init = kmeans_init |
|
self.kmeans_iters = kmeans_iters |
|
self.threshold_ema_dead_code = threshold_ema_dead_code |
|
self.orthogonal_reg_weight = orthogonal_reg_weight |
|
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only |
|
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes |
|
print(f' {kmeans_init=}\n\n\n\n') |
|
self.vq = ResidualVectorQuantization( |
|
dim=self.dimension, |
|
codebook_size=self.bins, |
|
num_quantizers=self.n_q, |
|
decay=self.decay, |
|
kmeans_init=self.kmeans_init, |
|
kmeans_iters=self.kmeans_iters, |
|
channels_last=False |
|
) |
|
|
|
def forward(self, x: torch.Tensor, frame_rate: int): |
|
n_q = self.n_q |
|
if self.training and self.q_dropout: |
|
n_q = int(torch.randint(1, self.n_q + 1, (1,)).item()) |
|
bw_per_q = math.log2(self.bins) * frame_rate / 1000 |
|
quantized, codes, commit_loss = self.vq(x, n_q=n_q) |
|
codes = codes.transpose(0, 1) |
|
|
|
bw = torch.tensor(n_q * bw_per_q).to(x) |
|
return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss)) |
|
|
|
def encode(self, x: torch.Tensor) -> torch.Tensor: |
|
"""Encode a given input tensor with the specified frame rate at the given bandwidth. |
|
The RVQ encode method sets the appropriate number of quantizer to use |
|
and returns indices for each quantizer. |
|
""" |
|
n_q = self.n_q |
|
codes = self.vq.encode(x, n_q=n_q) |
|
codes = codes.transpose(0, 1) |
|
|
|
return codes |
|
|
|
def decode(self, codes: torch.Tensor) -> torch.Tensor: |
|
"""Decode the given codes to the quantized representation.""" |
|
|
|
codes = codes.transpose(0, 1) |
|
quantized = self.vq.decode(codes) |
|
return quantized |
|
|
|
@property |
|
def total_codebooks(self): |
|
return self.max_n_q |
|
|
|
@property |
|
def num_codebooks(self): |
|
return self.n_q |
|
|
|
def set_num_codebooks(self, n: int): |
|
assert n > 0 and n <= self.max_n_q |
|
self.n_q = n |
|
|