|
"""This file contains the definition of the look-free quantizer.""" |
|
|
|
from typing import Mapping, Text, Tuple |
|
|
|
import torch |
|
from einops import rearrange, reduce |
|
|
|
from .quantizer_utils import entropy_loss_fn |
|
|
|
|
|
class LookupFreeQuantizer(torch.nn.Module): |
|
def __init__( |
|
self, |
|
token_bits: int = 10, |
|
commitment_cost: float = 0.25, |
|
entropy_loss_weight: float = 0.1, |
|
entropy_loss_temperature: float = 0.01, |
|
entropy_gamma: float = 1.0, |
|
): |
|
"""Initializes the lookup-free quantizer. |
|
|
|
Args: |
|
token_bits -> int: The number of bits per token. |
|
commitment_cost -> float: The commitment cost. |
|
entropy_loss_weight -> float: The weight of the entropy loss. |
|
entropy_loss_temperature -> float: The temperature for the entropy loss. |
|
entropy_gamma -> float: The gamma for the entropy loss. |
|
""" |
|
super().__init__() |
|
self.token_size = token_bits |
|
self.codebook_size = 2**token_bits |
|
|
|
self.commitment_cost = commitment_cost |
|
self.entropy_loss_weight = entropy_loss_weight |
|
self.entropy_loss_temperature = entropy_loss_temperature |
|
self.entropy_gamma = entropy_gamma |
|
|
|
bits_to_indices = torch.pow( |
|
2.0, torch.arange(0, self.token_size, dtype=torch.float32) |
|
) |
|
self.register_buffer("bits_to_indices", bits_to_indices.int()) |
|
|
|
all_codes = torch.arange(self.codebook_size) |
|
bits = ((all_codes[..., None].int() & self.bits_to_indices) != 0).float() |
|
self.register_buffer("codebook", bits * 2.0 - 1.0) |
|
|
|
def forward( |
|
self, z: torch.Tensor |
|
) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: |
|
"""Forward pass of the quantizer. |
|
|
|
Args: |
|
z -> torch.Tensor: The input tensor. |
|
|
|
Returns: |
|
z_quantized -> torch.Tensor: The quantized latent representation. |
|
result_dict -> Mapping[Text, torch.Tensor]: A dictionary containing additional results |
|
and losses from the quantizer. |
|
""" |
|
z = rearrange(z, "b c h w -> b h w c").contiguous() |
|
ones = torch.ones_like(z) |
|
sign_mask = z > 0.0 |
|
z_quantized = torch.where(sign_mask, ones, -ones) |
|
|
|
min_encoding_indices = self.convert_bits_to_indices(z_quantized) |
|
|
|
|
|
commitment_loss = self.commitment_cost * torch.mean( |
|
(z_quantized.detach() - z) ** 2 |
|
) |
|
entropy_loss = torch.zeros((), device=z.device) |
|
per_sample_entropy = torch.zeros((), device=z.device) |
|
avg_entropy = torch.zeros((), device=z.device) |
|
|
|
|
|
if self.entropy_loss_weight != 0.0 and self.training: |
|
d = -2 * torch.einsum("b h w c, n c -> b h w n", z, self.codebook) |
|
|
|
per_sample_entropy, avg_entropy = entropy_loss_fn( |
|
-1 * d, self.entropy_loss_temperature, self.entropy_gamma |
|
) |
|
entropy_loss = self.entropy_loss_weight * (per_sample_entropy - avg_entropy) |
|
|
|
loss = commitment_loss + entropy_loss |
|
|
|
|
|
z_quantized = z + (z_quantized - z).detach() |
|
|
|
|
|
z_quantized = rearrange(z_quantized, "b h w c -> b c h w").contiguous() |
|
|
|
result_dict = dict( |
|
quantizer_loss=loss, |
|
commitment_loss=commitment_loss, |
|
entropy_loss=entropy_loss, |
|
per_sample_entropy=per_sample_entropy, |
|
avg_entropy=avg_entropy, |
|
min_encoding_indices=min_encoding_indices, |
|
) |
|
|
|
return z_quantized, result_dict |
|
|
|
def get_codebook_entry(self, indices: torch.Tensor) -> torch.Tensor: |
|
"""Returns the `codebook entry` for the given indices. |
|
|
|
As the codebook exists only implicitly, this is mainly an integer conversion to a bit representation. |
|
Note: The bits are represented by {-1, 1}. |
|
|
|
Args: |
|
indices -> torch.Tensor: The indices in range 0 to codebook size - 1. |
|
|
|
Returns: |
|
tokens -> torch.Tensor: The bit representation. |
|
""" |
|
indices = indices.long() |
|
bits = ((indices[..., None].int() & self.bits_to_indices) != 0).float() |
|
tokens = bits * 2.0 - 1.0 |
|
return tokens |
|
|
|
def convert_bits_to_indices(self, tokens: torch.Tensor) -> torch.Tensor: |
|
"""Converts the given tokens to index numbers. |
|
|
|
As the codebook exists only implicitly, this is mainly an integer conversion from a bit representation. |
|
Note: The bits are represented by {-1, 1}. |
|
|
|
Args: |
|
tokens -> torch.Tensor: The tokens. |
|
|
|
Returns: |
|
indices -> torch.Tensor: The indices in range 0 to codebook size - 1. |
|
""" |
|
tokens = rearrange(tokens, "b h w c -> b h w c").contiguous() |
|
sign_mask = tokens > 0.0 |
|
return reduce(sign_mask.int() * self.bits_to_indices, "b h w c -> b h w", "sum") |
|
|
|
def convert_indices_to_bits(self, indices: torch.Tensor) -> torch.Tensor: |
|
"""Converts the given indices to tokens. |
|
|
|
As the codebook exists only implicitly, this is mainly an integer conversion to a bit representation. |
|
Note: The bits are represented by {-1, 1}. |
|
|
|
Args: |
|
indices -> torch.Tensor: The indices in range 0 to codebook size - 1. |
|
|
|
Returns: |
|
tokens -> torch.Tensor: The bit representation. |
|
""" |
|
indices = indices.long() |
|
return self.get_codebook_entry(indices) |
|
|
|
|
|
if __name__ == "__main__": |
|
quantizer = LookupFreeQuantizer( |
|
token_bits=10, |
|
commitment_cost=0.25, |
|
entropy_loss_weight=0.1, |
|
entropy_loss_temperature=0.01, |
|
entropy_gamma=1.0, |
|
) |
|
all_entries = torch.arange(1024).reshape(1, 1, 1024) |
|
indices = quantizer.convert_bits_to_indices( |
|
quantizer.convert_indices_to_bits(all_entries) |
|
) |
|
assert torch.equal(indices, all_entries) |
|
assert torch.equal( |
|
quantizer.convert_bits_to_indices(quantizer.codebook.reshape(1, 1, 1024, 10)), |
|
all_entries, |
|
) |
|
|