Junyin's picture
Add files using upload-large-folder tool
05744dc verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from .layers import kmeans, sinkhorn_algorithm
class VectorQuantizer(nn.Module):
def __init__(self, n_e, e_dim,
beta = 0.25, kmeans_init = False, kmeans_iters = 10,
sk_epsilon=0.01, sk_iters=100):
super().__init__()
self.n_e = n_e
self.e_dim = e_dim
self.beta = beta
self.kmeans_init = kmeans_init
self.kmeans_iters = kmeans_iters
self.sk_epsilon = sk_epsilon
self.sk_iters = sk_iters
self.embedding = nn.Embedding(self.n_e, self.e_dim)
if not kmeans_init:
self.initted = True
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
else:
self.initted = False
self.embedding.weight.data.zero_()
def get_codebook(self):
return self.embedding.weight
def get_codebook_entry(self, indices, shape=None):
# get quantized latent vectors
z_q = self.embedding(indices)
if shape is not None:
z_q = z_q.view(shape)
return z_q
def init_emb(self, data):
centers = kmeans(
data,
self.n_e,
self.kmeans_iters,
)
self.embedding.weight.data.copy_(centers)
self.initted = True
@staticmethod
def center_distance_for_constraint(distances):
# distances: B, K
max_distance = distances.max()
min_distance = distances.min()
middle = (max_distance + min_distance) / 2
amplitude = max_distance - middle + 1e-5
assert amplitude > 0
centered_distances = (distances - middle) / amplitude
return centered_distances
def forward(self, x, use_sk=True):
# Flatten input
latent = x.view(-1, self.e_dim)
if not self.initted and self.training:
self.init_emb(latent)
# Calculate the L2 Norm between latent and Embedded weights
d = torch.sum(latent**2, dim=1, keepdim=True) + \
torch.sum(self.embedding.weight**2, dim=1, keepdim=True).t()- \
2 * torch.matmul(latent, self.embedding.weight.t())
if not use_sk or self.sk_epsilon <= 0:
indices = torch.argmin(d, dim=-1)
# print("=======",self.sk_epsilon)
else:
# print("++++++++",self.sk_epsilon)
d = self.center_distance_for_constraint(d)
d = d.double()
Q = sinkhorn_algorithm(d,self.sk_epsilon,self.sk_iters)
# print(Q.sum(0)[:10])
if torch.isnan(Q).any() or torch.isinf(Q).any():
print(f"Sinkhorn Algorithm returns nan/inf values.")
indices = torch.argmax(Q, dim=-1)
# indices = torch.argmin(d, dim=-1)
x_q = self.embedding(indices).view(x.shape)
# compute loss for embedding
commitment_loss = F.mse_loss(x_q.detach(), x)
codebook_loss = F.mse_loss(x_q, x.detach())
loss = codebook_loss + self.beta * commitment_loss
# preserve gradients
x_q = x + (x_q - x).detach()
indices = indices.view(x.shape[:-1])
return x_q, loss, indices