Spaces:
Sleeping
Sleeping
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
"""Residual vector quantizer implementation.""" | |
from dataclasses import dataclass, field | |
import math | |
import typing as tp | |
import torch | |
from torch import nn | |
from module.core_vq import ResidualVectorQuantization | |
class QuantizedResult: | |
quantized: torch.Tensor | |
codes: torch.Tensor | |
bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item. | |
penalty: tp.Optional[torch.Tensor] = None | |
metrics: dict = field(default_factory=dict) | |
class ResidualVectorQuantizer(nn.Module): | |
"""Residual Vector Quantizer. | |
Args: | |
dimension (int): Dimension of the codebooks. | |
n_q (int): Number of residual vector quantizers used. | |
bins (int): Codebook size. | |
decay (float): Decay for exponential moving average over the codebooks. | |
kmeans_init (bool): Whether to use kmeans to initialize the codebooks. | |
kmeans_iters (int): Number of iterations used for kmeans initialization. | |
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes | |
that have an exponential moving average cluster size less than the specified threshold with | |
randomly selected vector from the current batch. | |
""" | |
def __init__( | |
self, | |
dimension: int = 256, | |
n_q: int = 8, | |
bins: int = 1024, | |
decay: float = 0.99, | |
kmeans_init: bool = True, | |
kmeans_iters: int = 50, | |
threshold_ema_dead_code: int = 2, | |
): | |
super().__init__() | |
self.n_q = n_q | |
self.dimension = dimension | |
self.bins = bins | |
self.decay = decay | |
self.kmeans_init = kmeans_init | |
self.kmeans_iters = kmeans_iters | |
self.threshold_ema_dead_code = threshold_ema_dead_code | |
self.vq = ResidualVectorQuantization( | |
dim=self.dimension, | |
codebook_size=self.bins, | |
num_quantizers=self.n_q, | |
decay=self.decay, | |
kmeans_init=self.kmeans_init, | |
kmeans_iters=self.kmeans_iters, | |
threshold_ema_dead_code=self.threshold_ema_dead_code, | |
) | |
def forward( | |
self, | |
x: torch.Tensor, | |
n_q: tp.Optional[int] = None, | |
layers: tp.Optional[list] = None, | |
) -> QuantizedResult: | |
"""Residual vector quantization on the given input tensor. | |
Args: | |
x (torch.Tensor): Input tensor. | |
n_q (int): Number of quantizer used to quantize. Default: All quantizers. | |
layers (list): Layer that need to return quantized. Defalt: None. | |
Returns: | |
QuantizedResult: | |
The quantized (or approximately quantized) representation with | |
the associated numbert quantizers and layer quantized required to return. | |
""" | |
n_q = n_q if n_q else self.n_q | |
if layers and max(layers) >= n_q: | |
raise ValueError( | |
f"Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B." | |
) | |
quantized, codes, commit_loss, quantized_list = self.vq( | |
x, n_q=n_q, layers=layers | |
) | |
return quantized, codes, torch.mean(commit_loss), quantized_list | |
def encode( | |
self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None | |
) -> torch.Tensor: | |
"""Encode a given input tensor with the specified sample rate at the given bandwidth. | |
The RVQ encode method sets the appropriate number of quantizer to use | |
and returns indices for each quantizer. | |
Args: | |
x (torch.Tensor): Input tensor. | |
n_q (int): Number of quantizer used to quantize. Default: All quantizers. | |
st (int): Start to encode input from which layers. Default: 0. | |
""" | |
n_q = n_q if n_q else self.n_q | |
st = st or 0 | |
codes = self.vq.encode(x, n_q=n_q, st=st) | |
return codes | |
def decode(self, codes: torch.Tensor, st: int = 0) -> torch.Tensor: | |
"""Decode the given codes to the quantized representation. | |
Args: | |
codes (torch.Tensor): Input indices for each quantizer. | |
st (int): Start to decode input codes from which layers. Default: 0. | |
""" | |
quantized = self.vq.decode(codes, st=st) | |
return quantized | |