QLIP-B-16-256 / bsq.py
zhaoyue-zephyrus's picture
first commit
ecd1674
# Copyright (c) 2024, NVIDIA Corporation & Affiliates. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, visit
# https://github.com/NVlabs/QLIP/blob/main/LICENSE
# MIT License
# Based on https://github.com/zhaoyue-zephyrus/bsq-vit/blob/main/transcoder/models/quantizer/bsq.py
import torch
import torch.nn as nn
from einops import rearrange, reduce
_EPS = 1e-8
class DifferentiableEntropyFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, zq, basis, K, eps):
zb = (zq + 1) / 2
zi = ((zb * basis).sum(-1)).to(torch.int64)
cnt = torch.scatter_reduce(
torch.zeros(2**K, device=zq.device, dtype=zq.dtype),
0,
zi.flatten(),
torch.ones_like(zi.flatten()).to(zq.dtype),
"sum",
)
prob = (cnt + eps) / (cnt + eps).sum()
H = torch.special.entr(prob).sum()
ctx.save_for_backward(zq, zi, prob)
ctx.K = K
return H
@staticmethod
def backward(ctx, grad_output):
zq, zi, prob = ctx.saved_tensors
grad_array = -grad_output * (torch.log(prob) + 1) / zi.numel() / ctx.K
reord_grad = grad_array[zi.flatten()].reshape(zi.shape)
grad_input = reord_grad.unsqueeze(-1) * zq
return grad_input, None, None, None, None
def codebook_entropy(zq, basis, K, eps=1e-8):
return DifferentiableEntropyFunction.apply(zq, basis, K, eps)
class BinarySphericalQuantizer(nn.Module):
def __init__(
self,
embed_dim: int = 18,
group_size: int = 9,
soft_entropy: bool = True,
beta: float = 0.0, # commit loss
gamma_0: float = 1.0, # entropy loss (E[H(q)])
gamma_1: float = 1.0, # entropy loss (H[E[q]])
input_format: str = "bchw",
persample_entropy_compute: str = "group",
l2_norm: bool = True,
inv_temperature: float = 100.0,
):
super().__init__()
self.embed_dim = embed_dim
self.group_size = group_size
assert embed_dim % group_size == 0, "embed_dim must be divisible by group_size"
self.soft_entropy = soft_entropy
self.beta = beta
self.gamma_0 = gamma_0
self.gamma_1 = gamma_1
assert input_format in ["bchw", "blc"]
self.input_format = input_format
assert persample_entropy_compute in [
"group",
"analytical",
], "persample_entropy_compute must be either 'group' or 'analytical'"
self.persample_entropy_compute = persample_entropy_compute
self.l2_norm = l2_norm
self.inv_temperature = inv_temperature
self.register_buffer("basis", 2 ** torch.arange(embed_dim - 1, -1, -1), persistent=False)
self.register_buffer(
"group_basis", 2 ** torch.arange(group_size - 1, -1, -1), persistent=False
)
group_codes = torch.arange(2**self.group_size)
group_codebook = self.indexes_to_codes(group_codes).float()[:, -group_size:]
self.register_buffer("group_codebook", group_codebook, persistent=False)
def quantize(self, z):
assert (
z.shape[-1] == self.embed_dim
), f"Expected {self.embed_dim} dimensions, got {z.shape[-1]}"
zhat = torch.where(z > 0, torch.ones_like(z), -torch.ones_like(z))
return z + (zhat - z).detach()
def forward(self, z):
if self.input_format == "bchw":
z = rearrange(z, "b c h w -> b h w c")
zq = self.quantize(z)
indices = self.codes_to_indexes(zq.detach())
group_indices = self.codes_to_group_indexes(zq.detach())
if not self.training:
used_codes = torch.unique(indices, return_counts=False)
else:
used_codes = None
if self.soft_entropy:
persample_entropy, cb_entropy = self.soft_entropy_loss(z)
else:
persample_entropy, cb_entropy = self.hard_entropy_loss(z)
entropy_penalty = self.gamma_0 * persample_entropy - self.gamma_1 * cb_entropy
q_scale = 1.0 / (self.embed_dim**0.5) if self.l2_norm else 1.0
zq = zq * q_scale
commit_loss = self.beta * torch.mean(((zq.detach() - z) ** 2).sum(dim=-1))
if self.input_format == "bchw":
zq = rearrange(zq, "b h w c -> b c h w")
return (
zq,
commit_loss + entropy_penalty / self.inv_temperature,
{
"H": cb_entropy,
"used_codes": used_codes,
"indices": indices,
"group_indices": group_indices,
},
)
def soft_entropy_loss(self, z):
group_codebook = self.group_codebook / (self.embed_dim**0.5 if self.l2_norm else 1)
divided_z = rearrange(z, "... (g c) -> ... g c", c=self.group_size)
if self.persample_entropy_compute == "group":
distance = -2 * torch.einsum("... g c, d c -> ... g d", divided_z, group_codebook)
prob = (-distance * self.inv_temperature).softmax(dim=-1)
persample_entropy = torch.special.entr(prob + _EPS).sum((-1, -2)).mean()
else:
p = torch.sigmoid(
-4 * z / (self.embed_dim**0.5 if self.l2_norm else 1) * self.inv_temperature
)
prob = torch.stack([p, 1 - p], dim=-1)
persample_entropy = torch.special.entr(prob + _EPS).sum((-1, -2)).mean()
# macro average of the probability of each subgroup
avg_prob = reduce(prob, "... g d -> g d", "mean")
cb_entropy = torch.special.entr(avg_prob + _EPS).sum()
return persample_entropy, cb_entropy
def hard_entropy_loss(self, z):
zb = ((z + 1) / 2).reshape(z.shape[0], -1, z.shape[-1]).to(torch.float32)
prob_per_dim = zb.sum(1) / zb.shape[1]
prob = torch.stack([prob_per_dim, 1 - prob_per_dim], dim=-1)
persample_entropy = torch.special.entr(prob + _EPS).sum((-1, -2)).mean()
cb_entropy = codebook_entropy(z, self.basis, self.embed_dim)
return persample_entropy, cb_entropy
def codes_to_indexes(self, zhat):
"""Converts a `code` to an index in the codebook.
Args:
zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1}
"""
assert (
zhat.shape[-1] == self.embed_dim
), f"Expected {self.embed_dim} dimensions, got {zhat.shape[-1]}"
return ((zhat.int() + 1) / 2 * self.basis).sum(axis=-1).to(torch.int64)
def codes_to_group_indexes(self, zhat):
"""Converts a `code` to a list of indexes (in groups) in the codebook.
Args:
zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1}
"""
zhat_in_group = rearrange(zhat, "b ... (g c) -> b ... g c", c=self.group_size)
return ((zhat_in_group.int() + 1) / 2 * self.group_basis).sum(axis=-1).to(torch.int64)
def indexes_to_codes(self, indices):
"""Inverse of `codes_to_indexes`."""
indices = indices.unsqueeze(-1)
codes_non_centered = torch.remainder(torch.floor_divide(indices, self.basis), 2)
return codes_non_centered * 2 - 1
def group_indexes_to_codes(self, group_indices):
"""Inverse of `codes_to_group_indexes`."""
group_indices = group_indices.unsqueeze(-1)
codes_non_centered = torch.remainder(torch.floor_divide(group_indices, self.group_basis), 2)
codes_non_centered = rearrange(codes_non_centered, "b ... g c -> b ... (g c)")
return codes_non_centered * 2 - 1
def get_group_codebook_entry(self, group_indices, one_hot=False):
"""
Args:
group_indices: A tensor of shape (B, L, G, C) containing the group indices.
"""
if one_hot:
z_q = group_indices @ self.group_codebook
else:
z_q = self.group_indexes_to_codes(group_indices)
q_scale = 1.0 / (self.embed_dim**0.5) if self.l2_norm else 1.0
z_q = z_q * q_scale
if self.input_format == "bchw":
h, w = int(z_q.shape[1] ** 0.5)
assert h * w == z_q.shape[1], "Invalid sequence length"
z_q = rearrange(z_q, "b (h w) c -> b c h w", h=h)
return z_q
def get_codebook_entry(self, indices, one_hot=False):
"""
Args:
group_indices: A tensor of shape (B, L, C) containing the indices.
"""
if one_hot:
assert self.embed_dim == self.group_size, "one_hot is only supported for group_size == embed_dim"
z_q = indices @ self.group_codebook
else:
z_q = self.indexes_to_codes(indices)
q_scale = 1.0 / (self.embed_dim**0.5) if self.l2_norm else 1.0
z_q = z_q * q_scale
if self.input_format == "bchw":
h, w = int(z_q.shape[1] ** 0.5)
assert h * w == z_q.shape[1], "Invalid sequence length"
z_q = rearrange(z_q, "b (h w) c -> b c h w", h=h)
return z_q