huaweilin's picture
update
14ce5a9
"""This file contains the definition of the VQ quantizer."""
from typing import Mapping, Text, Tuple
import torch
from einops import rearrange
from .quantizer_utils import entropy_loss_fn
class SimpleVectorizer(torch.nn.Module):
def __init__(
self,
codebook_size: int = 1024,
token_size: int = 256,
commitment_cost: float = 0.25,
entropy_loss_weight: float = 0.0,
entropy_loss_temperature: float = 0.01,
entropy_gamma: float = 1.0,
use_l2_normalisation: bool = False,
):
"""Initializes the quantizer.
Args:
codebook_size -> int: The size of the codebook.
token_size -> int: The feature dimensions of the tokens.
commitment_cost -> float: The commitment cost.
entropy_loss_weight -> float: The weight of the entropy loss.
entropy_loss_temperature -> float: The temperature of the entropy loss.
entropy_gamma -> float: The gamma of the entropy loss.
use_l2_normalisation -> bool: Whether to use L2 normalisation.
"""
super().__init__()
self.commitment_cost = commitment_cost
self.embedding = torch.nn.Embedding(codebook_size, token_size)
self.embedding.weight.data.uniform_(-1.0 / codebook_size, 1.0 / codebook_size)
self.entropy_loss_weight = entropy_loss_weight
self.entropy_loss_temperature = entropy_loss_temperature
self.entropy_gamma = entropy_gamma
self.use_l2_normalisation = use_l2_normalisation
def forward(
self, z: torch.Tensor
) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
"""Computes the quantization loss and returns the quantized latent representation.
Args:
z -> torch.Tensor: The latent representation.
Returns:
z_quantized -> torch.Tensor: The quantized latent representation.
result_dict -> Mapping[Text, torch.Tensor]: A dictionary containing additional results
and losses from the quantizer.
"""
# reshape z -> (batch, height, width, channel) and flatten
z = rearrange(z, "b c h w -> b h w c").contiguous()
if self.use_l2_normalisation:
z = torch.nn.functional.normalize(z, dim=-1)
embedding = torch.nn.functional.normalize(self.embedding.weight, dim=-1)
else:
embedding = self.embedding.weight
z_flattened = rearrange(z, "b h w c -> (b h w) c")
# distances from z to embeddings e_j d = (z - e)^2 = z^2 + e^2 - 2 e * z
d = (
torch.sum(z_flattened**2, dim=1, keepdim=True)
+ torch.sum(embedding**2, dim=1)
- 2 * torch.einsum("bd,dn->bn", z_flattened, embedding.T)
)
min_encoding_indices = torch.argmin(d, dim=1)
z_quantized = self.get_codebook_entry(min_encoding_indices).view(z.shape)
# compute loss for embedding
commitment_loss = self.commitment_cost * torch.mean(
(z_quantized.detach() - z) ** 2
)
codebook_loss = torch.mean((z_quantized - z.detach()) ** 2)
entropy_loss = torch.zeros((), device=z.device)
per_sample_entropy = torch.zeros((), device=z.device)
avg_entropy = torch.zeros((), device=z.device)
# Use entropy loss on the codebook
if self.entropy_loss_weight != 0.0 and self.training:
per_sample_entropy, avg_entropy = entropy_loss_fn(
-1 * d, self.entropy_loss_temperature, self.entropy_gamma
)
entropy_loss = self.entropy_loss_weight * (per_sample_entropy - avg_entropy)
loss = commitment_loss + codebook_loss + entropy_loss
# preserve gradients
z_quantized = z + (z_quantized - z).detach()
# reshape back to match original input shape
z_quantized = rearrange(z_quantized, "b h w c -> b c h w").contiguous()
result_dict = dict(
quantizer_loss=loss,
commitment_loss=commitment_loss,
codebook_loss=codebook_loss,
entropy_loss=entropy_loss,
per_sample_entropy=per_sample_entropy,
avg_entropy=avg_entropy,
min_encoding_indices=min_encoding_indices.view(
z_quantized.shape[0], z_quantized.shape[2], z_quantized.shape[3]
),
)
return z_quantized, result_dict
def get_codebook_entry(self, indices: torch.Tensor) -> torch.Tensor:
"""Returns the codebook entry for the given indices.
Args:
indices -> torch.Tensor: The indices of the codebook entries.
Returns:
z_quantized -> torch.Tensor: The codebook entries.
"""
# get quantized latent vectors
z_quantized = self.embedding(indices.int())
if self.use_l2_normalisation:
z_quantized = torch.nn.functional.normalize(z_quantized, dim=-1)
return z_quantized