VQMIVC / model_encoder.py
akhaliq3
spaces demo
2b7bf83
raw history blame
No virus
18.1 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class ConvNorm(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
padding=None, dilation=1, bias=True, w_init_gain='linear'):
super(ConvNorm, self).__init__()
if padding is None:
assert(kernel_size % 2 == 1)
padding = int(dilation * (kernel_size - 1) / 2)
self.conv = torch.nn.Conv1d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation,
bias=bias)
torch.nn.init.xavier_uniform_(
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
def forward(self, signal):
conv_signal = self.conv(signal)
return conv_signal
class Encoder_lf0(nn.Module):
def __init__(self, typ='no_emb'):
super(Encoder_lf0, self).__init__()
self.type = typ
if typ != 'no_emb':
convolutions = []
for i in range(3):
conv_layer = nn.Sequential(
ConvNorm(1 if i==0 else 256, 256,
kernel_size=5, stride=2 if i==2 else 1,
padding=2,
dilation=1, w_init_gain='relu'),
nn.GroupNorm(256//16, 256),
nn.ReLU())
convolutions.append(conv_layer)
self.convolutions = nn.ModuleList(convolutions)
self.lstm = nn.LSTM(256, 32, 1, batch_first=True, bidirectional=True)
def forward(self, lf0):
if self.type != 'no_emb':
if len(lf0.shape) == 2:
lf0 = lf0.unsqueeze(1) # bz x 1 x 128
for conv in self.convolutions:
lf0 = conv(lf0) # bz x 256 x 128
lf0 = lf0.transpose(1,2) # bz x 64 x 256
self.lstm.flatten_parameters()
lf0, _ = self.lstm(lf0) # bz x 64 x 64
else:
if len(lf0.shape) == 2:
lf0 = lf0.unsqueeze(-1) # bz x 128 x 1 # no downsampling
return lf0
def pad_layer(inp, layer, pad_type='reflect'):
kernel_size = layer.kernel_size[0]
if kernel_size % 2 == 0:
pad = (kernel_size//2, kernel_size//2 - 1)
else:
pad = (kernel_size//2, kernel_size//2)
# padding
inp = F.pad(inp,
pad=pad,
mode=pad_type)
out = layer(inp)
return out
def conv_bank(x, module_list, act, pad_type='reflect'):
outs = []
for layer in module_list:
out = act(pad_layer(x, layer, pad_type))
outs.append(out)
out = torch.cat(outs + [x], dim=1)
return out
def get_act(act):
if act == 'relu':
return nn.ReLU()
elif act == 'lrelu':
return nn.LeakyReLU()
else:
return nn.ReLU()
class SpeakerEncoder(nn.Module):
'''
reference from speaker-encoder of AdaIN-VC: https://github.com/jjery2243542/adaptive_voice_conversion/blob/master/model.py
'''
def __init__(self, c_in=80, c_h=128, c_out=256, kernel_size=5,
bank_size=8, bank_scale=1, c_bank=128,
n_conv_blocks=6, n_dense_blocks=6,
subsample=[1, 2, 1, 2, 1, 2], act='relu', dropout_rate=0):
super(SpeakerEncoder, self).__init__()
self.c_in = c_in
self.c_h = c_h
self.c_out = c_out
self.kernel_size = kernel_size
self.n_conv_blocks = n_conv_blocks
self.n_dense_blocks = n_dense_blocks
self.subsample = subsample
self.act = get_act(act)
self.conv_bank = nn.ModuleList(
[nn.Conv1d(c_in, c_bank, kernel_size=k) for k in range(bank_scale, bank_size + 1, bank_scale)])
in_channels = c_bank * (bank_size // bank_scale) + c_in
self.in_conv_layer = nn.Conv1d(in_channels, c_h, kernel_size=1)
self.first_conv_layers = nn.ModuleList([nn.Conv1d(c_h, c_h, kernel_size=kernel_size) for _ \
in range(n_conv_blocks)])
self.second_conv_layers = nn.ModuleList([nn.Conv1d(c_h, c_h, kernel_size=kernel_size, stride=sub)
for sub, _ in zip(subsample, range(n_conv_blocks))])
self.pooling_layer = nn.AdaptiveAvgPool1d(1)
self.first_dense_layers = nn.ModuleList([nn.Linear(c_h, c_h) for _ in range(n_dense_blocks)])
self.second_dense_layers = nn.ModuleList([nn.Linear(c_h, c_h) for _ in range(n_dense_blocks)])
self.output_layer = nn.Linear(c_h, c_out)
self.dropout_layer = nn.Dropout(p=dropout_rate)
def conv_blocks(self, inp):
out = inp
# convolution blocks
for l in range(self.n_conv_blocks):
y = pad_layer(out, self.first_conv_layers[l])
y = self.act(y)
y = self.dropout_layer(y)
y = pad_layer(y, self.second_conv_layers[l])
y = self.act(y)
y = self.dropout_layer(y)
if self.subsample[l] > 1:
out = F.avg_pool1d(out, kernel_size=self.subsample[l], ceil_mode=True)
out = y + out
return out
def dense_blocks(self, inp):
out = inp
# dense layers
for l in range(self.n_dense_blocks):
y = self.first_dense_layers[l](out)
y = self.act(y)
y = self.dropout_layer(y)
y = self.second_dense_layers[l](y)
y = self.act(y)
y = self.dropout_layer(y)
out = y + out
return out
def forward(self, x):
out = conv_bank(x, self.conv_bank, act=self.act)
# dimension reduction layer
out = pad_layer(out, self.in_conv_layer)
out = self.act(out)
# conv blocks
out = self.conv_blocks(out)
# avg pooling
out = self.pooling_layer(out).squeeze(2)
# dense blocks
out = self.dense_blocks(out)
out = self.output_layer(out)
return out
class Encoder(nn.Module):
'''
reference from: https://github.com/bshall/VectorQuantizedCPC/blob/master/model.py
'''
def __init__(self, in_channels, channels, n_embeddings, z_dim, c_dim):
super(Encoder, self).__init__()
self.conv = nn.Conv1d(in_channels, channels, 4, 2, 1, bias=False)
self.encoder = nn.Sequential(
nn.LayerNorm(channels),
nn.ReLU(True),
nn.Linear(channels, channels, bias=False),
nn.LayerNorm(channels),
nn.ReLU(True),
nn.Linear(channels, channels, bias=False),
nn.LayerNorm(channels),
nn.ReLU(True),
nn.Linear(channels, channels, bias=False),
nn.LayerNorm(channels),
nn.ReLU(True),
nn.Linear(channels, channels, bias=False),
nn.LayerNorm(channels),
nn.ReLU(True),
nn.Linear(channels, z_dim),
)
self.codebook = VQEmbeddingEMA(n_embeddings, z_dim)
self.rnn = nn.LSTM(z_dim, c_dim, batch_first=True)
def encode(self, mel):
z = self.conv(mel)
z_beforeVQ = self.encoder(z.transpose(1, 2))
z, r, indices = self.codebook.encode(z_beforeVQ)
c, _ = self.rnn(z)
return z, c, z_beforeVQ, indices
def forward(self, mels):
z = self.conv(mels.float()) # (bz, 80, 128) -> (bz, 512, 128/2)
z_beforeVQ = self.encoder(z.transpose(1, 2)) # (bz, 512, 128/2) -> (bz, 128/2, 512) -> (bz, 128/2, 64)
z, r, loss, perplexity = self.codebook(z_beforeVQ) # z: (bz, 128/2, 64)
c, _ = self.rnn(z) # (64, 140/2, 64) -> (64, 140/2, 256)
return z, c, z_beforeVQ, loss, perplexity
class VQEmbeddingEMA(nn.Module):
'''
reference from: https://github.com/bshall/VectorQuantizedCPC/blob/master/model.py
'''
def __init__(self, n_embeddings, embedding_dim, commitment_cost=0.25, decay=0.999, epsilon=1e-5):
super(VQEmbeddingEMA, self).__init__()
self.commitment_cost = commitment_cost
self.decay = decay
self.epsilon = epsilon
init_bound = 1 / 512
embedding = torch.Tensor(n_embeddings, embedding_dim)
embedding.uniform_(-init_bound, init_bound)
self.register_buffer("embedding", embedding) # only change during forward
self.register_buffer("ema_count", torch.zeros(n_embeddings))
self.register_buffer("ema_weight", self.embedding.clone())
def encode(self, x):
M, D = self.embedding.size()
x_flat = x.detach().reshape(-1, D)
distances = torch.addmm(torch.sum(self.embedding ** 2, dim=1) +
torch.sum(x_flat ** 2, dim=1, keepdim=True),
x_flat, self.embedding.t(),
alpha=-2.0, beta=1.0)
indices = torch.argmin(distances.float(), dim=-1)
quantized = F.embedding(indices, self.embedding)
quantized = quantized.view_as(x)
residual = x - quantized
return quantized, residual, indices.view(x.size(0), x.size(1))
def forward(self, x):
M, D = self.embedding.size()
x_flat = x.detach().reshape(-1, D)
distances = torch.addmm(torch.sum(self.embedding ** 2, dim=1) +
torch.sum(x_flat ** 2, dim=1, keepdim=True),
x_flat, self.embedding.t(),
alpha=-2.0, beta=1.0) # calculate the distance between each ele in embedding and x
indices = torch.argmin(distances.float(), dim=-1)
encodings = F.one_hot(indices, M).float()
quantized = F.embedding(indices, self.embedding)
quantized = quantized.view_as(x)
if self.training: # EMA based codebook learning
self.ema_count = self.decay * self.ema_count + (1 - self.decay) * torch.sum(encodings, dim=0)
n = torch.sum(self.ema_count)
self.ema_count = (self.ema_count + self.epsilon) / (n + M * self.epsilon) * n
dw = torch.matmul(encodings.t(), x_flat)
self.ema_weight = self.decay * self.ema_weight + (1 - self.decay) * dw
self.embedding = self.ema_weight / self.ema_count.unsqueeze(-1)
e_latent_loss = F.mse_loss(x, quantized.detach())
loss = self.commitment_cost * e_latent_loss
residual = x - quantized
quantized = x + (quantized - x).detach()
avg_probs = torch.mean(encodings, dim=0)
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
return quantized, residual, loss, perplexity
class CPCLoss(nn.Module):
'''
CPC-loss calculation: negative samples are drawn within-speaker
reference from: https://github.com/bshall/VectorQuantizedCPC/blob/master/model.py
'''
def __init__(self, n_speakers_per_batch, n_utterances_per_speaker, n_prediction_steps, n_negatives, z_dim, c_dim):
super(CPCLoss, self).__init__()
self.n_speakers_per_batch = n_speakers_per_batch
self.n_utterances_per_speaker = n_utterances_per_speaker
self.n_prediction_steps = n_prediction_steps // 2
self.n_negatives = n_negatives
self.z_dim = z_dim
self.c_dim = c_dim
self.predictors = nn.ModuleList([
nn.Linear(c_dim, z_dim) for _ in range(n_prediction_steps)
])
def forward(self, z, c): # z:(64, 70, 64), c:(64, 70, 256)
length = z.size(1) - self.n_prediction_steps # 64
z = z.reshape(
self.n_speakers_per_batch,
self.n_utterances_per_speaker,
-1,
self.z_dim
) # (64, 70, 64) -> (8, 8, 70, 64)
c = c[:, :-self.n_prediction_steps, :] # (64, 64, 256)
losses, accuracies = list(), list()
for k in range(1, self.n_prediction_steps+1):
z_shift = z[:, :, k:length + k, :] # (8, 8, 64, 64), positive samples
Wc = self.predictors[k-1](c) # (64, 64, 256) -> (64, 64, 64)
Wc = Wc.view(
self.n_speakers_per_batch,
self.n_utterances_per_speaker,
-1,
self.z_dim
) # (64, 64, 64) -> (8, 8, 64, 64)
batch_index = torch.randint(
0, self.n_utterances_per_speaker,
size=(
self.n_utterances_per_speaker,
self.n_negatives
),
device=z.device
)
batch_index = batch_index.view(
1, self.n_utterances_per_speaker, self.n_negatives, 1
) # (1, 8, 17, 1)
# seq_index: (8, 8, 17, 64)
seq_index = torch.randint(
1, length,
size=(
self.n_speakers_per_batch,
self.n_utterances_per_speaker,
self.n_negatives,
length
),
device=z.device
)
seq_index += torch.arange(length, device=z.device) #(1)
seq_index = torch.remainder(seq_index, length) #(2) (1)+(2) ensures that the current positive frame will not be selected as negative sample...
speaker_index = torch.arange(self.n_speakers_per_batch, device=z.device) # within-speaker sampling
speaker_index = speaker_index.view(-1, 1, 1, 1)
# z_negatives: (8,8,17,64,64); z_negatives[0,0,:,0,:] is (17, 64) that is negative samples for first frame of first utterance of first speaker...
z_negatives = z_shift[speaker_index, batch_index, seq_index, :] # speaker_index has the original order (within-speaker sampling)
# batch_index is randomly sampled from 0~7, each point has 17 negative samples
# seq_index is randomly sampled from 0~115
# so for each positive frame with time-id as t, the negative samples will be selected from
# another or the current utterance and the seq-index (frame-index) will not conclude t
zs = torch.cat((z_shift.unsqueeze(2), z_negatives), dim=2) # (8, 8, 1+17, 64, 64)
f = torch.sum(zs * Wc.unsqueeze(2) / math.sqrt(self.z_dim), dim=-1) # (8, 8, 1+17, 64), vector product in fact...
f = f.view(
self.n_speakers_per_batch * self.n_utterances_per_speaker,
self.n_negatives + 1,
-1
) # (64, 1+17, 64)
labels = torch.zeros(
self.n_speakers_per_batch * self.n_utterances_per_speaker, length,
dtype=torch.long, device=z.device
) # (64, 64)
loss = F.cross_entropy(f, labels)
accuracy = f.argmax(dim=1) == labels # (64, 116)
accuracy = torch.mean(accuracy.float())
losses.append(loss)
accuracies.append(accuracy.item())
loss = torch.stack(losses).mean()
return loss, accuracies
class CPCLoss_sameSeq(nn.Module):
'''
CPC-loss calculation: negative samples are drawn within-sequence/utterance
'''
def __init__(self, n_speakers_per_batch, n_utterances_per_speaker, n_prediction_steps, n_negatives, z_dim, c_dim):
super(CPCLoss_sameSeq, self).__init__()
self.n_speakers_per_batch = n_speakers_per_batch
self.n_utterances_per_speaker = n_utterances_per_speaker
self.n_prediction_steps = n_prediction_steps
self.n_negatives = n_negatives
self.z_dim = z_dim
self.c_dim = c_dim
self.predictors = nn.ModuleList([
nn.Linear(c_dim, z_dim) for _ in range(n_prediction_steps)
])
def forward(self, z, c): # z:(256, 64, 64), c:(256, 64, 256)
length = z.size(1) - self.n_prediction_steps # 64-6=58, length is the total time-steps of each utterance used for calculated cpc loss
n_speakers_per_batch = z.shape[0] # each utterance is treated as a speaker
c = c[:, :-self.n_prediction_steps, :] # (256, 58, 256)
losses, accuracies = list(), list()
for k in range(1, self.n_prediction_steps+1):
z_shift = z[:, k:length + k, :] # (256, 58, 64), positive samples
Wc = self.predictors[k-1](c) # (256, 58, 256) -> (256, 58, 64)
# seq_index: (256, 10, 58)
seq_index = torch.randint(
1, length,
size=(
n_speakers_per_batch,
self.n_negatives,
length
),
device=z.device
)
seq_index += torch.arange(length, device=z.device) #(1)
seq_index = torch.remainder(seq_index, length) #(2) (1)+(2) ensures that the current positive frame will not be selected as negative sample...
speaker_index = torch.arange(n_speakers_per_batch, device=z.device) # within-utterance sampling
speaker_index = speaker_index.view(-1, 1, 1)
z_negatives = z_shift[speaker_index, seq_index, :] # (256,10,58,64), z_negatives[i,:,j,:] is the negative samples set for ith utterance and jth time-step
zs = torch.cat((z_shift.unsqueeze(1), z_negatives), dim=1) # (256,11,58,64)
f = torch.sum(zs * Wc.unsqueeze(1) / math.sqrt(self.z_dim), dim=-1) # (256,11,58), vector product in fact...
labels = torch.zeros(
n_speakers_per_batch, length,
dtype=torch.long, device=z.device
)
loss = F.cross_entropy(f, labels)
accuracy = f.argmax(dim=1) == labels # (256, 58)
accuracy = torch.mean(accuracy.float())
losses.append(loss)
accuracies.append(accuracy.item())
loss = torch.stack(losses).mean()
return loss, accuracies