|
"""This file contains the definition of the VQ quantizer.""" |
|
|
|
from typing import Mapping, Text, Tuple |
|
|
|
import torch |
|
from einops import rearrange |
|
|
|
from .quantizer_utils import entropy_loss_fn |
|
|
|
|
|
class SimpleVectorizer(torch.nn.Module): |
|
def __init__( |
|
self, |
|
codebook_size: int = 1024, |
|
token_size: int = 256, |
|
commitment_cost: float = 0.25, |
|
entropy_loss_weight: float = 0.0, |
|
entropy_loss_temperature: float = 0.01, |
|
entropy_gamma: float = 1.0, |
|
use_l2_normalisation: bool = False, |
|
): |
|
"""Initializes the quantizer. |
|
|
|
Args: |
|
codebook_size -> int: The size of the codebook. |
|
token_size -> int: The feature dimensions of the tokens. |
|
commitment_cost -> float: The commitment cost. |
|
entropy_loss_weight -> float: The weight of the entropy loss. |
|
entropy_loss_temperature -> float: The temperature of the entropy loss. |
|
entropy_gamma -> float: The gamma of the entropy loss. |
|
use_l2_normalisation -> bool: Whether to use L2 normalisation. |
|
""" |
|
|
|
super().__init__() |
|
self.commitment_cost = commitment_cost |
|
|
|
self.embedding = torch.nn.Embedding(codebook_size, token_size) |
|
self.embedding.weight.data.uniform_(-1.0 / codebook_size, 1.0 / codebook_size) |
|
|
|
self.entropy_loss_weight = entropy_loss_weight |
|
self.entropy_loss_temperature = entropy_loss_temperature |
|
self.entropy_gamma = entropy_gamma |
|
self.use_l2_normalisation = use_l2_normalisation |
|
|
|
def forward( |
|
self, z: torch.Tensor |
|
) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: |
|
"""Computes the quantization loss and returns the quantized latent representation. |
|
|
|
Args: |
|
z -> torch.Tensor: The latent representation. |
|
|
|
Returns: |
|
z_quantized -> torch.Tensor: The quantized latent representation. |
|
result_dict -> Mapping[Text, torch.Tensor]: A dictionary containing additional results |
|
and losses from the quantizer. |
|
""" |
|
|
|
z = rearrange(z, "b c h w -> b h w c").contiguous() |
|
|
|
if self.use_l2_normalisation: |
|
z = torch.nn.functional.normalize(z, dim=-1) |
|
embedding = torch.nn.functional.normalize(self.embedding.weight, dim=-1) |
|
else: |
|
embedding = self.embedding.weight |
|
|
|
z_flattened = rearrange(z, "b h w c -> (b h w) c") |
|
|
|
|
|
d = ( |
|
torch.sum(z_flattened**2, dim=1, keepdim=True) |
|
+ torch.sum(embedding**2, dim=1) |
|
- 2 * torch.einsum("bd,dn->bn", z_flattened, embedding.T) |
|
) |
|
|
|
min_encoding_indices = torch.argmin(d, dim=1) |
|
z_quantized = self.get_codebook_entry(min_encoding_indices).view(z.shape) |
|
|
|
|
|
commitment_loss = self.commitment_cost * torch.mean( |
|
(z_quantized.detach() - z) ** 2 |
|
) |
|
codebook_loss = torch.mean((z_quantized - z.detach()) ** 2) |
|
entropy_loss = torch.zeros((), device=z.device) |
|
per_sample_entropy = torch.zeros((), device=z.device) |
|
avg_entropy = torch.zeros((), device=z.device) |
|
|
|
|
|
if self.entropy_loss_weight != 0.0 and self.training: |
|
per_sample_entropy, avg_entropy = entropy_loss_fn( |
|
-1 * d, self.entropy_loss_temperature, self.entropy_gamma |
|
) |
|
entropy_loss = self.entropy_loss_weight * (per_sample_entropy - avg_entropy) |
|
|
|
loss = commitment_loss + codebook_loss + entropy_loss |
|
|
|
|
|
z_quantized = z + (z_quantized - z).detach() |
|
|
|
|
|
z_quantized = rearrange(z_quantized, "b h w c -> b c h w").contiguous() |
|
|
|
result_dict = dict( |
|
quantizer_loss=loss, |
|
commitment_loss=commitment_loss, |
|
codebook_loss=codebook_loss, |
|
entropy_loss=entropy_loss, |
|
per_sample_entropy=per_sample_entropy, |
|
avg_entropy=avg_entropy, |
|
min_encoding_indices=min_encoding_indices.view( |
|
z_quantized.shape[0], z_quantized.shape[2], z_quantized.shape[3] |
|
), |
|
) |
|
|
|
return z_quantized, result_dict |
|
|
|
def get_codebook_entry(self, indices: torch.Tensor) -> torch.Tensor: |
|
"""Returns the codebook entry for the given indices. |
|
|
|
Args: |
|
indices -> torch.Tensor: The indices of the codebook entries. |
|
|
|
Returns: |
|
z_quantized -> torch.Tensor: The codebook entries. |
|
""" |
|
|
|
z_quantized = self.embedding(indices.int()) |
|
if self.use_l2_normalisation: |
|
z_quantized = torch.nn.functional.normalize(z_quantized, dim=-1) |
|
return z_quantized |
|
|