File size: 6,298 Bytes
e34aada |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import torch
import torch.nn as nn
from scipy.cluster.vq import kmeans2
from torch.nn import functional as F
class VQEmbeddingEMA(nn.Module):
def __init__(self, n_embeddings, embedding_dim, commitment_cost=0.25, decay=0.999, epsilon=1e-5,
super(VQEmbeddingEMA, self).__init__()
self.commitment_cost = commitment_cost
self.n_embeddings = n_embeddings
self.decay = decay
self.epsilon = epsilon
self.print_vq_prob = print_vq_prob
self.register_buffer('data_initialized', torch.zeros(1))
init_bound = 1 / 512
embedding = torch.Tensor(n_embeddings, embedding_dim)
embedding.uniform_(-init_bound, init_bound)
self.register_buffer("embedding", embedding)
self.register_buffer("ema_count", torch.zeros(n_embeddings))
self.register_buffer("ema_weight", self.embedding.clone())
def encode(self, x):
B, T, _ = x.shape
M, D = self.embedding.size()
x_flat = x.detach().reshape(-1, D)
distances = torch.addmm(torch.sum(self.embedding ** 2, dim=1) +
torch.sum(x_flat ** 2, dim=1, keepdim=True),
x_flat, self.embedding.t(),
alpha=-2.0, beta=1.0) # [B*T_mel, N_vq]
indices = torch.argmin(distances.float(), dim=-1) # [B*T_mel]
quantized = F.embedding(indices, self.embedding)
quantized = quantized.view_as(x)
return x_flat, quantized, indices
def forward(self, x):
:param x: [B, T, D]
:return: [B, T, D]
B, T, _ = x.shape
M, D = self.embedding.size()
# if and self.data_initialized.item() == 0:
# print('| running kmeans in VQVAE') # data driven initialization for the embeddings
# x_flat = x.detach().reshape(-1, D)
# rp = torch.randperm(x_flat.size(0))
# kd = kmeans2(x_flat[rp].data.cpu().numpy(), self.n_embeddings, minit='points')
# self.embedding.copy_(torch.from_numpy(kd[0]))
# x_flat, quantized, indices = self.encode(x)
# encodings = F.one_hot(indices, M).float()
# self.ema_weight.copy_(torch.matmul(encodings.t(), x_flat))
# self.ema_count.copy_(torch.sum(encodings, dim=0))
x_flat, quantized, indices = self.encode(x)
encodings = F.one_hot(indices, M).float()
indices = indices.reshape(B, T)
if and self.data_initialized.item() != 0:
self.ema_count = self.decay * self.ema_count + (1 - self.decay) * torch.sum(encodings, dim=0)
n = torch.sum(self.ema_count)
self.ema_count = (self.ema_count + self.epsilon) / (n + M * self.epsilon) * n
dw = torch.matmul(encodings.t(), x_flat)
self.ema_weight = self.decay * self.ema_weight + (1 - self.decay) * dw
self.embedding = self.ema_weight / self.ema_count.unsqueeze(-1)
if and self.data_initialized.item() == 0:
e_latent_loss = F.mse_loss(x, quantized.detach(), reduction='none')
nonpadding = (x.abs().sum(-1) > 0).float()
e_latent_loss = (e_latent_loss.mean(-1) * nonpadding).sum() / nonpadding.sum()
loss = self.commitment_cost * e_latent_loss
quantized = x + (quantized - x).detach()
avg_probs = torch.mean(encodings, dim=0)
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
if self.print_vq_prob:
print("| VQ code avg_probs: ", avg_probs)
return quantized, loss, indices, perplexity
class VQEmbedding(nn.Module):
def __init__(self, n_embeddings, embedding_dim, commitment_cost=0.25, lambda_kl=1.0):
super(VQEmbedding, self).__init__()
self.commitment_cost = commitment_cost
self.lambda_kl = lambda_kl
self.n_embeddings = n_embeddings
embedding = torch.Tensor(n_embeddings, embedding_dim)
self.register_buffer("embedding", embedding)
self.register_buffer('data_initialized', torch.zeros(1))
def encode(self, x):
B, T, _ = x.shape
M, D = self.embedding.size()
x_flat = x.detach().reshape(-1, D)
distances = torch.addmm(torch.sum(self.embedding ** 2, dim=1) +
torch.sum(x_flat ** 2, dim=1, keepdim=True),
x_flat, self.embedding.t(),
alpha=-2.0, beta=1.0) # [B*T_mel, N_vq]
indices = torch.argmin(distances.float(), dim=-1) # [B*T_mel]
quantized = F.embedding(indices, self.embedding)
quantized = quantized.view_as(x)
return x_flat, quantized, indices
def forward(self, x):
:param x: [B, T, D]
:return: [B, T, D]
B, T, _ = x.shape
M, D = self.embedding.size()
x_flat, quantized, indices = self.encode(x)
encodings = F.one_hot(indices, M).float()
indices = indices.reshape(B, T)
# DeepMind def does not do this but I find I have to... ;\
if and self.data_initialized.item() == 0:
print('| running kmeans in VQVAE') # data driven initialization for the embeddings
rp = torch.randperm(x_flat.size(0))
kd = kmeans2(x_flat[rp].data.cpu().numpy(), self.n_embeddings, minit='points')
# TODO: this won't work in multi-GPU setups
x_flat, quantized, indices = self.encode(x)
encodings = F.one_hot(indices, M).float()
indices = indices.reshape(B, T)
# vector quantization cost that trains the embedding vectors
loss = self.commitment_cost * (x.detach() - quantized).pow(2).mean() + \
(quantized - x.detach()).pow(2).mean()
loss *= self.lambda_kl
quantized = x + (quantized - x).detach()
avg_probs = torch.mean(encodings, dim=0)
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
return quantized, loss, indices, perplexity