|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
from einops import rearrange, reduce |
|
|
|
_EPS = 1e-8 |
|
|
|
|
|
class DifferentiableEntropyFunction(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, zq, basis, K, eps): |
|
zb = (zq + 1) / 2 |
|
zi = ((zb * basis).sum(-1)).to(torch.int64) |
|
cnt = torch.scatter_reduce( |
|
torch.zeros(2**K, device=zq.device, dtype=zq.dtype), |
|
0, |
|
zi.flatten(), |
|
torch.ones_like(zi.flatten()).to(zq.dtype), |
|
"sum", |
|
) |
|
prob = (cnt + eps) / (cnt + eps).sum() |
|
H = torch.special.entr(prob).sum() |
|
ctx.save_for_backward(zq, zi, prob) |
|
ctx.K = K |
|
return H |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
zq, zi, prob = ctx.saved_tensors |
|
grad_array = -grad_output * (torch.log(prob) + 1) / zi.numel() / ctx.K |
|
reord_grad = grad_array[zi.flatten()].reshape(zi.shape) |
|
grad_input = reord_grad.unsqueeze(-1) * zq |
|
return grad_input, None, None, None, None |
|
|
|
|
|
def codebook_entropy(zq, basis, K, eps=1e-8): |
|
return DifferentiableEntropyFunction.apply(zq, basis, K, eps) |
|
|
|
|
|
class BinarySphericalQuantizer(nn.Module): |
|
def __init__( |
|
self, |
|
embed_dim: int = 18, |
|
group_size: int = 9, |
|
soft_entropy: bool = True, |
|
beta: float = 0.0, |
|
gamma_0: float = 1.0, |
|
gamma_1: float = 1.0, |
|
input_format: str = "bchw", |
|
persample_entropy_compute: str = "group", |
|
l2_norm: bool = True, |
|
inv_temperature: float = 100.0, |
|
): |
|
super().__init__() |
|
self.embed_dim = embed_dim |
|
self.group_size = group_size |
|
assert embed_dim % group_size == 0, "embed_dim must be divisible by group_size" |
|
self.soft_entropy = soft_entropy |
|
self.beta = beta |
|
self.gamma_0 = gamma_0 |
|
self.gamma_1 = gamma_1 |
|
assert input_format in ["bchw", "blc"] |
|
self.input_format = input_format |
|
assert persample_entropy_compute in [ |
|
"group", |
|
"analytical", |
|
], "persample_entropy_compute must be either 'group' or 'analytical'" |
|
self.persample_entropy_compute = persample_entropy_compute |
|
self.l2_norm = l2_norm |
|
self.inv_temperature = inv_temperature |
|
|
|
self.register_buffer("basis", 2 ** torch.arange(embed_dim - 1, -1, -1), persistent=False) |
|
self.register_buffer( |
|
"group_basis", 2 ** torch.arange(group_size - 1, -1, -1), persistent=False |
|
) |
|
|
|
group_codes = torch.arange(2**self.group_size) |
|
group_codebook = self.indexes_to_codes(group_codes).float()[:, -group_size:] |
|
self.register_buffer("group_codebook", group_codebook, persistent=False) |
|
|
|
def quantize(self, z): |
|
assert ( |
|
z.shape[-1] == self.embed_dim |
|
), f"Expected {self.embed_dim} dimensions, got {z.shape[-1]}" |
|
zhat = torch.where(z > 0, torch.ones_like(z), -torch.ones_like(z)) |
|
return z + (zhat - z).detach() |
|
|
|
def forward(self, z): |
|
if self.input_format == "bchw": |
|
z = rearrange(z, "b c h w -> b h w c") |
|
zq = self.quantize(z) |
|
|
|
indices = self.codes_to_indexes(zq.detach()) |
|
group_indices = self.codes_to_group_indexes(zq.detach()) |
|
|
|
if not self.training: |
|
used_codes = torch.unique(indices, return_counts=False) |
|
else: |
|
used_codes = None |
|
|
|
if self.soft_entropy: |
|
persample_entropy, cb_entropy = self.soft_entropy_loss(z) |
|
else: |
|
persample_entropy, cb_entropy = self.hard_entropy_loss(z) |
|
entropy_penalty = self.gamma_0 * persample_entropy - self.gamma_1 * cb_entropy |
|
|
|
q_scale = 1.0 / (self.embed_dim**0.5) if self.l2_norm else 1.0 |
|
zq = zq * q_scale |
|
commit_loss = self.beta * torch.mean(((zq.detach() - z) ** 2).sum(dim=-1)) |
|
|
|
if self.input_format == "bchw": |
|
zq = rearrange(zq, "b h w c -> b c h w") |
|
|
|
return ( |
|
zq, |
|
commit_loss + entropy_penalty / self.inv_temperature, |
|
{ |
|
"H": cb_entropy, |
|
"used_codes": used_codes, |
|
"indices": indices, |
|
"group_indices": group_indices, |
|
}, |
|
) |
|
|
|
def soft_entropy_loss(self, z): |
|
group_codebook = self.group_codebook / (self.embed_dim**0.5 if self.l2_norm else 1) |
|
divided_z = rearrange(z, "... (g c) -> ... g c", c=self.group_size) |
|
|
|
if self.persample_entropy_compute == "group": |
|
distance = -2 * torch.einsum("... g c, d c -> ... g d", divided_z, group_codebook) |
|
prob = (-distance * self.inv_temperature).softmax(dim=-1) |
|
persample_entropy = torch.special.entr(prob + _EPS).sum((-1, -2)).mean() |
|
else: |
|
p = torch.sigmoid( |
|
-4 * z / (self.embed_dim**0.5 if self.l2_norm else 1) * self.inv_temperature |
|
) |
|
prob = torch.stack([p, 1 - p], dim=-1) |
|
persample_entropy = torch.special.entr(prob + _EPS).sum((-1, -2)).mean() |
|
|
|
|
|
avg_prob = reduce(prob, "... g d -> g d", "mean") |
|
cb_entropy = torch.special.entr(avg_prob + _EPS).sum() |
|
|
|
return persample_entropy, cb_entropy |
|
|
|
def hard_entropy_loss(self, z): |
|
zb = ((z + 1) / 2).reshape(z.shape[0], -1, z.shape[-1]).to(torch.float32) |
|
prob_per_dim = zb.sum(1) / zb.shape[1] |
|
prob = torch.stack([prob_per_dim, 1 - prob_per_dim], dim=-1) |
|
persample_entropy = torch.special.entr(prob + _EPS).sum((-1, -2)).mean() |
|
cb_entropy = codebook_entropy(z, self.basis, self.embed_dim) |
|
|
|
return persample_entropy, cb_entropy |
|
|
|
def codes_to_indexes(self, zhat): |
|
"""Converts a `code` to an index in the codebook. |
|
Args: |
|
zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1} |
|
""" |
|
assert ( |
|
zhat.shape[-1] == self.embed_dim |
|
), f"Expected {self.embed_dim} dimensions, got {zhat.shape[-1]}" |
|
return ((zhat.int() + 1) / 2 * self.basis).sum(axis=-1).to(torch.int64) |
|
|
|
def codes_to_group_indexes(self, zhat): |
|
"""Converts a `code` to a list of indexes (in groups) in the codebook. |
|
Args: |
|
zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1} |
|
""" |
|
zhat_in_group = rearrange(zhat, "b ... (g c) -> b ... g c", c=self.group_size) |
|
return ((zhat_in_group.int() + 1) / 2 * self.group_basis).sum(axis=-1).to(torch.int64) |
|
|
|
def indexes_to_codes(self, indices): |
|
"""Inverse of `codes_to_indexes`.""" |
|
indices = indices.unsqueeze(-1) |
|
codes_non_centered = torch.remainder(torch.floor_divide(indices, self.basis), 2) |
|
return codes_non_centered * 2 - 1 |
|
|
|
def group_indexes_to_codes(self, group_indices): |
|
"""Inverse of `codes_to_group_indexes`.""" |
|
group_indices = group_indices.unsqueeze(-1) |
|
codes_non_centered = torch.remainder(torch.floor_divide(group_indices, self.group_basis), 2) |
|
codes_non_centered = rearrange(codes_non_centered, "b ... g c -> b ... (g c)") |
|
return codes_non_centered * 2 - 1 |
|
|
|
def get_group_codebook_entry(self, group_indices, one_hot=False): |
|
""" |
|
Args: |
|
group_indices: A tensor of shape (B, L, G, C) containing the group indices. |
|
""" |
|
if one_hot: |
|
z_q = group_indices @ self.group_codebook |
|
else: |
|
z_q = self.group_indexes_to_codes(group_indices) |
|
q_scale = 1.0 / (self.embed_dim**0.5) if self.l2_norm else 1.0 |
|
z_q = z_q * q_scale |
|
if self.input_format == "bchw": |
|
h, w = int(z_q.shape[1] ** 0.5) |
|
assert h * w == z_q.shape[1], "Invalid sequence length" |
|
z_q = rearrange(z_q, "b (h w) c -> b c h w", h=h) |
|
return z_q |
|
|
|
def get_codebook_entry(self, indices, one_hot=False): |
|
""" |
|
Args: |
|
group_indices: A tensor of shape (B, L, C) containing the indices. |
|
""" |
|
if one_hot: |
|
assert self.embed_dim == self.group_size, "one_hot is only supported for group_size == embed_dim" |
|
z_q = indices @ self.group_codebook |
|
else: |
|
z_q = self.indexes_to_codes(indices) |
|
q_scale = 1.0 / (self.embed_dim**0.5) if self.l2_norm else 1.0 |
|
z_q = z_q * q_scale |
|
if self.input_format == "bchw": |
|
h, w = int(z_q.shape[1] ** 0.5) |
|
assert h * w == z_q.shape[1], "Invalid sequence length" |
|
z_q = rearrange(z_q, "b (h w) c -> b c h w", h=h) |
|
return z_q |
|
|