huaweilin's picture
update
14ce5a9
"""This file contains the definition of the look-free quantizer."""
from typing import Mapping, Text, Tuple
import torch
from einops import rearrange, reduce
from .quantizer_utils import entropy_loss_fn
class LookupFreeQuantizer(torch.nn.Module):
def __init__(
self,
token_bits: int = 10,
commitment_cost: float = 0.25,
entropy_loss_weight: float = 0.1,
entropy_loss_temperature: float = 0.01,
entropy_gamma: float = 1.0,
):
"""Initializes the lookup-free quantizer.
Args:
token_bits -> int: The number of bits per token.
commitment_cost -> float: The commitment cost.
entropy_loss_weight -> float: The weight of the entropy loss.
entropy_loss_temperature -> float: The temperature for the entropy loss.
entropy_gamma -> float: The gamma for the entropy loss.
"""
super().__init__()
self.token_size = token_bits
self.codebook_size = 2**token_bits
self.commitment_cost = commitment_cost
self.entropy_loss_weight = entropy_loss_weight
self.entropy_loss_temperature = entropy_loss_temperature
self.entropy_gamma = entropy_gamma
bits_to_indices = torch.pow(
2.0, torch.arange(0, self.token_size, dtype=torch.float32)
)
self.register_buffer("bits_to_indices", bits_to_indices.int())
all_codes = torch.arange(self.codebook_size)
bits = ((all_codes[..., None].int() & self.bits_to_indices) != 0).float()
self.register_buffer("codebook", bits * 2.0 - 1.0)
def forward(
self, z: torch.Tensor
) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
"""Forward pass of the quantizer.
Args:
z -> torch.Tensor: The input tensor.
Returns:
z_quantized -> torch.Tensor: The quantized latent representation.
result_dict -> Mapping[Text, torch.Tensor]: A dictionary containing additional results
and losses from the quantizer.
"""
z = rearrange(z, "b c h w -> b h w c").contiguous()
ones = torch.ones_like(z)
sign_mask = z > 0.0
z_quantized = torch.where(sign_mask, ones, -ones)
min_encoding_indices = self.convert_bits_to_indices(z_quantized)
# compute loss for embedding
commitment_loss = self.commitment_cost * torch.mean(
(z_quantized.detach() - z) ** 2
)
entropy_loss = torch.zeros((), device=z.device)
per_sample_entropy = torch.zeros((), device=z.device)
avg_entropy = torch.zeros((), device=z.device)
# Use entropy loss on the codebook
if self.entropy_loss_weight != 0.0 and self.training:
d = -2 * torch.einsum("b h w c, n c -> b h w n", z, self.codebook)
per_sample_entropy, avg_entropy = entropy_loss_fn(
-1 * d, self.entropy_loss_temperature, self.entropy_gamma
)
entropy_loss = self.entropy_loss_weight * (per_sample_entropy - avg_entropy)
loss = commitment_loss + entropy_loss
# preserve gradients
z_quantized = z + (z_quantized - z).detach()
# reshape back to match original input shape
z_quantized = rearrange(z_quantized, "b h w c -> b c h w").contiguous()
result_dict = dict(
quantizer_loss=loss,
commitment_loss=commitment_loss,
entropy_loss=entropy_loss,
per_sample_entropy=per_sample_entropy,
avg_entropy=avg_entropy,
min_encoding_indices=min_encoding_indices,
)
return z_quantized, result_dict
def get_codebook_entry(self, indices: torch.Tensor) -> torch.Tensor:
"""Returns the `codebook entry` for the given indices.
As the codebook exists only implicitly, this is mainly an integer conversion to a bit representation.
Note: The bits are represented by {-1, 1}.
Args:
indices -> torch.Tensor: The indices in range 0 to codebook size - 1.
Returns:
tokens -> torch.Tensor: The bit representation.
"""
indices = indices.long()
bits = ((indices[..., None].int() & self.bits_to_indices) != 0).float()
tokens = bits * 2.0 - 1.0 # scale to -1..1
return tokens
def convert_bits_to_indices(self, tokens: torch.Tensor) -> torch.Tensor:
"""Converts the given tokens to index numbers.
As the codebook exists only implicitly, this is mainly an integer conversion from a bit representation.
Note: The bits are represented by {-1, 1}.
Args:
tokens -> torch.Tensor: The tokens.
Returns:
indices -> torch.Tensor: The indices in range 0 to codebook size - 1.
"""
tokens = rearrange(tokens, "b h w c -> b h w c").contiguous()
sign_mask = tokens > 0.0
return reduce(sign_mask.int() * self.bits_to_indices, "b h w c -> b h w", "sum")
def convert_indices_to_bits(self, indices: torch.Tensor) -> torch.Tensor:
"""Converts the given indices to tokens.
As the codebook exists only implicitly, this is mainly an integer conversion to a bit representation.
Note: The bits are represented by {-1, 1}.
Args:
indices -> torch.Tensor: The indices in range 0 to codebook size - 1.
Returns:
tokens -> torch.Tensor: The bit representation.
"""
indices = indices.long()
return self.get_codebook_entry(indices)
if __name__ == "__main__":
quantizer = LookupFreeQuantizer(
token_bits=10,
commitment_cost=0.25,
entropy_loss_weight=0.1,
entropy_loss_temperature=0.01,
entropy_gamma=1.0,
)
all_entries = torch.arange(1024).reshape(1, 1, 1024)
indices = quantizer.convert_bits_to_indices(
quantizer.convert_indices_to_bits(all_entries)
)
assert torch.equal(indices, all_entries)
assert torch.equal(
quantizer.convert_bits_to_indices(quantizer.codebook.reshape(1, 1, 1024, 10)),
all_entries,
)