Understanding Vector Quantization in VQ-VAE
The Vector Quantized Variational Autoencoder (VQ-VAE) leverages a unique mechanism called vector quantization to map continuous latent representations into discrete embeddings. In this article, I will try explaining the mechanism in a more hands on way.
Initialize the Layer
class VQEmbedding(nn.Module):
def __init__(self, num_embeddings, embedding_dim):
super().__init__()
self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
self.embedding.weight.data.uniform_(-1/self.num_embeddings, 1/self.num_embeddings)
The VQEmbedding
class is designed to create and manage the embedding matrix (codebook embedding), where each row represents a possible discrete embedding that the model can choose from. This matrix has a shape defined by num_embeddings
(the number of embeddings) and embedding_dim
(the size of each embedding vector).
A crucial part of the initialization process is setting the embedding weights using a uniform distribution. Specifically, each weight is assigned a value between -1/self.num_embeddings
and 1/self.num_embeddings
, ensuring that the initial values are spread evenly across this range. This uniform initialization is important because it prevents any bias at the start of training. By avoiding overly large or small initial values, the model starts in a neutral state, which promotes balanced learning.
Flattening for Flexibility
def forward(self, z):
b, c, h, w = z.shape
z_channel_last = z.permute(0, 2, 3, 1)
z_flattened = z_channel_last.reshape(b*h*w, self.embedding_dim)
The first step in vector quantization involves flattening the encoded inputs. Typically, encoded inputs from an image have a shape of [Batch, embedding_dim, h, w]
. By flattening this tensor, we convert it into [Batch * h * h, embedding_dim]
. This transformation not only simplifies the subsequent operations but also makes the module versatile, compatible with various input shapes.
The Distance Computation
At the heart of vector quantization lies the distance computation between the encoded vectors and the codebook embeddings. To compute distance we use the Mean Squared Error (MSE) loss. The MSE between two vectors (the original vector) and (the quantized vector) can be expressed as:
Where:
- is the number of elements in the vectors.
- and are the corresponding elements of the vectors and .
This MSE loss can be rewritten using the formula for the square of a difference:
Substituting this back into the MSE formula, we get:
# Calculate distances between z and the codebook embeddings |a-b|²
distances = (
torch.sum(z_flattened ** 2, dim=-1, keepdim=True) # a²
+ torch.sum(self.embedding.weight.t() ** 2, dim=0, keepdim=True) # b²
- 2 * torch.matmul(z_flattened, self.embedding.weight.t()) # -2ab
)
Here, understanding the shape of matrices is crucial:
- The flattened encoded input has a shape of
[b*h*w, embedding_dim]
. - The embedding matrix (the codebook) has weights of shape
[num_embeddings, embedding_dim]
.
Through careful transposition, we ensure that the operations align correctly, resulting in a distance matrix of shape [b*h*w, num_embeddings]
. This matrix contains the distances from each encoded input vector to all codebook embeddings.
Selecting the Closest Codebook Embedding
Once we have the distance matrix, the next step is to identify the index of the minimum distance for each vector. This selection process, while reminiscent of the attention mechanism (with a key difference being that attention focuses on maximum values), allows us to map each input vector to its closest codebook entry.
# Get the index with the smallest distance
encoding_indices = torch.argmin(distances, dim=-1)
Quantization and Reshaping
With the indices of the nearest codebook embeddings in hand, we use PyTorch's nn.Embedding
module to retrieve the quantized vectors. These vectors, now of shape [b*h*w, embedding_dim]
, are reshaped back to the original spatial dimensions and passed on to the decoder.
# Get the quantized vector
z_q = self.embedding(encoding_indices)
z_q = z_q.reshape(b, h, w, self.embedding_dim)
z_q = z_q.permute(0, 3, 1, 2)
Loss and Gradient Flow
While reading about VQ-VAE I understood that the idea of codebook was not what stands out, but how the authors managed to propagate the gradients in order to make the model end to end trainable.
In VQ-VAE, commitment loss plays a crucial role in ensuring that the encoder network commits to a specific codebook entry that accurately represents the input. Without this commitment, the encoder might produce outputs that are not well-aligned with the available codebook entries, leading to poor reconstruction quality. The commitment loss is typically a Mean Squared Error (MSE) between the continuous encoded vector and its corresponding quantized version. The idea is to penalize the encoder when its output drifts too far from the chosen codebook entry, encouraging the encoder to produce representations that are closer to the discrete embeddings in the codebook. This loss term helps to stabilize training and ensures that the encoder and codebook work in harmony, improving the overall quality of the learned representations.
# Calculate the commitment loss
loss = F.mse_loss(z_q, z.detach()) + commitment_cost * F.mse_loss(z_q.detach(), z)
# Straight-through estimator trick for gradient backpropagation
z_q = z + (z_q - z).detach()
return z_q, loss, encoding_indices
The straight-through estimator is a clever technique used. The challenge arises because the quantization process, where we map continuous vectors to discrete codebook entries, is non-differentiable. This non-differentiability prevents the gradients from flowing backward through the network, making it difficult to train the model using standard backpropagation. The straight-through estimator addresses this by allowing the gradients to bypass the non-differentiable quantization step. Specifically, it treats the discrete quantized output as if it were continuous during the backward pass, effectively copying the gradient from the quantized vector to the original continuous vector. This trick enables the model to be trained end-to-end despite the presence of discrete variables, maintaining the benefits of gradient-based optimization.
By combining the straight-through estimator with commitment loss, VQ-VAE successfully balances the need for discrete representations with the benefits of gradient-based optimization, enabling the model to learn rich, quantized embeddings that are both useful for downstream tasks and easy to optimize during training.
Bringing it together
class VQEmbedding(nn.Module):
def __init__(self, num_embeddings, embedding_dim):
super().__init__()
self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
self.embedding.weight.data.uniform_(-1/self.num_embeddings, 1/self.num_embeddings)
def forward(self, z):
b, c, h, w = z.shape
z_channel_last = z.permute(0, 2, 3, 1)
z_flattened = z_channel_last.reshape(b*h*w, self.embedding_dim)
# Calculate distances between z and the codebook embeddings |a-b|²
distances = (
torch.sum(z_flattened ** 2, dim=-1, keepdim=True) # a²
+ torch.sum(self.embedding.weight.t() ** 2, dim=0, keepdim=True) # b²
- 2 * torch.matmul(z_flattened, self.embedding.weight.t()) # -2ab
)
# Get the index with the smallest distance
encoding_indices = torch.argmin(distances, dim=-1)
# Get the quantized vector
z_q = self.embedding(encoding_indices)
z_q = z_q.reshape(b, h, w, self.embedding_dim)
z_q = z_q.permute(0, 3, 1, 2)
# Calculate the commitment loss
loss = F.mse_loss(z_q, z.detach()) + commitment_cost * F.mse_loss(z_q.detach(), z)
# Straight-through estimator trick for gradient backpropagation
z_q = z + (z_q - z).detach()
return z_q, loss, encoding_indices
One can also visit this repository to see VQ-VAE training on the CIFAR10 dataset.