import torch import torch.nn as nn import torch.nn.functional as F import math class ConvNorm(torch.nn.Module): def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=None, dilation=1, bias=True, w_init_gain='linear'): super(ConvNorm, self).__init__() if padding is None: assert(kernel_size % 2 == 1) padding = int(dilation * (kernel_size - 1) / 2) self.conv = torch.nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias) torch.nn.init.xavier_uniform_( self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) def forward(self, signal): conv_signal = self.conv(signal) return conv_signal class Encoder_lf0(nn.Module): def __init__(self, typ='no_emb'): super(Encoder_lf0, self).__init__() self.type = typ if typ != 'no_emb': convolutions = [] for i in range(3): conv_layer = nn.Sequential( ConvNorm(1 if i==0 else 256, 256, kernel_size=5, stride=2 if i==2 else 1, padding=2, dilation=1, w_init_gain='relu'), nn.GroupNorm(256//16, 256), nn.ReLU()) convolutions.append(conv_layer) self.convolutions = nn.ModuleList(convolutions) self.lstm = nn.LSTM(256, 32, 1, batch_first=True, bidirectional=True) def forward(self, lf0): if self.type != 'no_emb': if len(lf0.shape) == 2: lf0 = lf0.unsqueeze(1) # bz x 1 x 128 for conv in self.convolutions: lf0 = conv(lf0) # bz x 256 x 128 lf0 = lf0.transpose(1,2) # bz x 64 x 256 self.lstm.flatten_parameters() lf0, _ = self.lstm(lf0) # bz x 64 x 64 else: if len(lf0.shape) == 2: lf0 = lf0.unsqueeze(-1) # bz x 128 x 1 # no downsampling return lf0 def pad_layer(inp, layer, pad_type='reflect'): kernel_size = layer.kernel_size[0] if kernel_size % 2 == 0: pad = (kernel_size//2, kernel_size//2 - 1) else: pad = (kernel_size//2, kernel_size//2) # padding inp = F.pad(inp, pad=pad, mode=pad_type) out = layer(inp) return out def conv_bank(x, module_list, act, pad_type='reflect'): outs = [] for layer in module_list: out = act(pad_layer(x, layer, pad_type)) outs.append(out) out = torch.cat(outs + [x], dim=1) return out def get_act(act): if act == 'relu': return nn.ReLU() elif act == 'lrelu': return nn.LeakyReLU() else: return nn.ReLU() class SpeakerEncoder(nn.Module): ''' reference from speaker-encoder of AdaIN-VC: https://github.com/jjery2243542/adaptive_voice_conversion/blob/master/model.py ''' def __init__(self, c_in=80, c_h=128, c_out=256, kernel_size=5, bank_size=8, bank_scale=1, c_bank=128, n_conv_blocks=6, n_dense_blocks=6, subsample=[1, 2, 1, 2, 1, 2], act='relu', dropout_rate=0): super(SpeakerEncoder, self).__init__() self.c_in = c_in self.c_h = c_h self.c_out = c_out self.kernel_size = kernel_size self.n_conv_blocks = n_conv_blocks self.n_dense_blocks = n_dense_blocks self.subsample = subsample self.act = get_act(act) self.conv_bank = nn.ModuleList( [nn.Conv1d(c_in, c_bank, kernel_size=k) for k in range(bank_scale, bank_size + 1, bank_scale)]) in_channels = c_bank * (bank_size // bank_scale) + c_in self.in_conv_layer = nn.Conv1d(in_channels, c_h, kernel_size=1) self.first_conv_layers = nn.ModuleList([nn.Conv1d(c_h, c_h, kernel_size=kernel_size) for _ \ in range(n_conv_blocks)]) self.second_conv_layers = nn.ModuleList([nn.Conv1d(c_h, c_h, kernel_size=kernel_size, stride=sub) for sub, _ in zip(subsample, range(n_conv_blocks))]) self.pooling_layer = nn.AdaptiveAvgPool1d(1) self.first_dense_layers = nn.ModuleList([nn.Linear(c_h, c_h) for _ in range(n_dense_blocks)]) self.second_dense_layers = nn.ModuleList([nn.Linear(c_h, c_h) for _ in range(n_dense_blocks)]) self.output_layer = nn.Linear(c_h, c_out) self.dropout_layer = nn.Dropout(p=dropout_rate) def conv_blocks(self, inp): out = inp # convolution blocks for l in range(self.n_conv_blocks): y = pad_layer(out, self.first_conv_layers[l]) y = self.act(y) y = self.dropout_layer(y) y = pad_layer(y, self.second_conv_layers[l]) y = self.act(y) y = self.dropout_layer(y) if self.subsample[l] > 1: out = F.avg_pool1d(out, kernel_size=self.subsample[l], ceil_mode=True) out = y + out return out def dense_blocks(self, inp): out = inp # dense layers for l in range(self.n_dense_blocks): y = self.first_dense_layers[l](out) y = self.act(y) y = self.dropout_layer(y) y = self.second_dense_layers[l](y) y = self.act(y) y = self.dropout_layer(y) out = y + out return out def forward(self, x): out = conv_bank(x, self.conv_bank, act=self.act) # dimension reduction layer out = pad_layer(out, self.in_conv_layer) out = self.act(out) # conv blocks out = self.conv_blocks(out) # avg pooling out = self.pooling_layer(out).squeeze(2) # dense blocks out = self.dense_blocks(out) out = self.output_layer(out) return out class Encoder(nn.Module): ''' reference from: https://github.com/bshall/VectorQuantizedCPC/blob/master/model.py ''' def __init__(self, in_channels, channels, n_embeddings, z_dim, c_dim): super(Encoder, self).__init__() self.conv = nn.Conv1d(in_channels, channels, 4, 2, 1, bias=False) self.encoder = nn.Sequential( nn.LayerNorm(channels), nn.ReLU(True), nn.Linear(channels, channels, bias=False), nn.LayerNorm(channels), nn.ReLU(True), nn.Linear(channels, channels, bias=False), nn.LayerNorm(channels), nn.ReLU(True), nn.Linear(channels, channels, bias=False), nn.LayerNorm(channels), nn.ReLU(True), nn.Linear(channels, channels, bias=False), nn.LayerNorm(channels), nn.ReLU(True), nn.Linear(channels, z_dim), ) self.codebook = VQEmbeddingEMA(n_embeddings, z_dim) self.rnn = nn.LSTM(z_dim, c_dim, batch_first=True) def encode(self, mel): z = self.conv(mel) z_beforeVQ = self.encoder(z.transpose(1, 2)) z, r, indices = self.codebook.encode(z_beforeVQ) c, _ = self.rnn(z) return z, c, z_beforeVQ, indices def forward(self, mels): z = self.conv(mels.float()) # (bz, 80, 128) -> (bz, 512, 128/2) z_beforeVQ = self.encoder(z.transpose(1, 2)) # (bz, 512, 128/2) -> (bz, 128/2, 512) -> (bz, 128/2, 64) z, r, loss, perplexity = self.codebook(z_beforeVQ) # z: (bz, 128/2, 64) c, _ = self.rnn(z) # (64, 140/2, 64) -> (64, 140/2, 256) return z, c, z_beforeVQ, loss, perplexity class VQEmbeddingEMA(nn.Module): ''' reference from: https://github.com/bshall/VectorQuantizedCPC/blob/master/model.py ''' def __init__(self, n_embeddings, embedding_dim, commitment_cost=0.25, decay=0.999, epsilon=1e-5): super(VQEmbeddingEMA, self).__init__() self.commitment_cost = commitment_cost self.decay = decay self.epsilon = epsilon init_bound = 1 / 512 embedding = torch.Tensor(n_embeddings, embedding_dim) embedding.uniform_(-init_bound, init_bound) self.register_buffer("embedding", embedding) # only change during forward self.register_buffer("ema_count", torch.zeros(n_embeddings)) self.register_buffer("ema_weight", self.embedding.clone()) def encode(self, x): M, D = self.embedding.size() x_flat = x.detach().reshape(-1, D) distances = torch.addmm(torch.sum(self.embedding ** 2, dim=1) + torch.sum(x_flat ** 2, dim=1, keepdim=True), x_flat, self.embedding.t(), alpha=-2.0, beta=1.0) indices = torch.argmin(distances.float(), dim=-1) quantized = F.embedding(indices, self.embedding) quantized = quantized.view_as(x) residual = x - quantized return quantized, residual, indices.view(x.size(0), x.size(1)) def forward(self, x): M, D = self.embedding.size() x_flat = x.detach().reshape(-1, D) distances = torch.addmm(torch.sum(self.embedding ** 2, dim=1) + torch.sum(x_flat ** 2, dim=1, keepdim=True), x_flat, self.embedding.t(), alpha=-2.0, beta=1.0) # calculate the distance between each ele in embedding and x indices = torch.argmin(distances.float(), dim=-1) encodings = F.one_hot(indices, M).float() quantized = F.embedding(indices, self.embedding) quantized = quantized.view_as(x) if self.training: # EMA based codebook learning self.ema_count = self.decay * self.ema_count + (1 - self.decay) * torch.sum(encodings, dim=0) n = torch.sum(self.ema_count) self.ema_count = (self.ema_count + self.epsilon) / (n + M * self.epsilon) * n dw = torch.matmul(encodings.t(), x_flat) self.ema_weight = self.decay * self.ema_weight + (1 - self.decay) * dw self.embedding = self.ema_weight / self.ema_count.unsqueeze(-1) e_latent_loss = F.mse_loss(x, quantized.detach()) loss = self.commitment_cost * e_latent_loss residual = x - quantized quantized = x + (quantized - x).detach() avg_probs = torch.mean(encodings, dim=0) perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) return quantized, residual, loss, perplexity class CPCLoss(nn.Module): ''' CPC-loss calculation: negative samples are drawn within-speaker reference from: https://github.com/bshall/VectorQuantizedCPC/blob/master/model.py ''' def __init__(self, n_speakers_per_batch, n_utterances_per_speaker, n_prediction_steps, n_negatives, z_dim, c_dim): super(CPCLoss, self).__init__() self.n_speakers_per_batch = n_speakers_per_batch self.n_utterances_per_speaker = n_utterances_per_speaker self.n_prediction_steps = n_prediction_steps // 2 self.n_negatives = n_negatives self.z_dim = z_dim self.c_dim = c_dim self.predictors = nn.ModuleList([ nn.Linear(c_dim, z_dim) for _ in range(n_prediction_steps) ]) def forward(self, z, c): # z:(64, 70, 64), c:(64, 70, 256) length = z.size(1) - self.n_prediction_steps # 64 z = z.reshape( self.n_speakers_per_batch, self.n_utterances_per_speaker, -1, self.z_dim ) # (64, 70, 64) -> (8, 8, 70, 64) c = c[:, :-self.n_prediction_steps, :] # (64, 64, 256) losses, accuracies = list(), list() for k in range(1, self.n_prediction_steps+1): z_shift = z[:, :, k:length + k, :] # (8, 8, 64, 64), positive samples Wc = self.predictors[k-1](c) # (64, 64, 256) -> (64, 64, 64) Wc = Wc.view( self.n_speakers_per_batch, self.n_utterances_per_speaker, -1, self.z_dim ) # (64, 64, 64) -> (8, 8, 64, 64) batch_index = torch.randint( 0, self.n_utterances_per_speaker, size=( self.n_utterances_per_speaker, self.n_negatives ), device=z.device ) batch_index = batch_index.view( 1, self.n_utterances_per_speaker, self.n_negatives, 1 ) # (1, 8, 17, 1) # seq_index: (8, 8, 17, 64) seq_index = torch.randint( 1, length, size=( self.n_speakers_per_batch, self.n_utterances_per_speaker, self.n_negatives, length ), device=z.device ) seq_index += torch.arange(length, device=z.device) #(1) seq_index = torch.remainder(seq_index, length) #(2) (1)+(2) ensures that the current positive frame will not be selected as negative sample... speaker_index = torch.arange(self.n_speakers_per_batch, device=z.device) # within-speaker sampling speaker_index = speaker_index.view(-1, 1, 1, 1) # z_negatives: (8,8,17,64,64); z_negatives[0,0,:,0,:] is (17, 64) that is negative samples for first frame of first utterance of first speaker... z_negatives = z_shift[speaker_index, batch_index, seq_index, :] # speaker_index has the original order (within-speaker sampling) # batch_index is randomly sampled from 0~7, each point has 17 negative samples # seq_index is randomly sampled from 0~115 # so for each positive frame with time-id as t, the negative samples will be selected from # another or the current utterance and the seq-index (frame-index) will not conclude t zs = torch.cat((z_shift.unsqueeze(2), z_negatives), dim=2) # (8, 8, 1+17, 64, 64) f = torch.sum(zs * Wc.unsqueeze(2) / math.sqrt(self.z_dim), dim=-1) # (8, 8, 1+17, 64), vector product in fact... f = f.view( self.n_speakers_per_batch * self.n_utterances_per_speaker, self.n_negatives + 1, -1 ) # (64, 1+17, 64) labels = torch.zeros( self.n_speakers_per_batch * self.n_utterances_per_speaker, length, dtype=torch.long, device=z.device ) # (64, 64) loss = F.cross_entropy(f, labels) accuracy = f.argmax(dim=1) == labels # (64, 116) accuracy = torch.mean(accuracy.float()) losses.append(loss) accuracies.append(accuracy.item()) loss = torch.stack(losses).mean() return loss, accuracies class CPCLoss_sameSeq(nn.Module): ''' CPC-loss calculation: negative samples are drawn within-sequence/utterance ''' def __init__(self, n_speakers_per_batch, n_utterances_per_speaker, n_prediction_steps, n_negatives, z_dim, c_dim): super(CPCLoss_sameSeq, self).__init__() self.n_speakers_per_batch = n_speakers_per_batch self.n_utterances_per_speaker = n_utterances_per_speaker self.n_prediction_steps = n_prediction_steps self.n_negatives = n_negatives self.z_dim = z_dim self.c_dim = c_dim self.predictors = nn.ModuleList([ nn.Linear(c_dim, z_dim) for _ in range(n_prediction_steps) ]) def forward(self, z, c): # z:(256, 64, 64), c:(256, 64, 256) length = z.size(1) - self.n_prediction_steps # 64-6=58, length is the total time-steps of each utterance used for calculated cpc loss n_speakers_per_batch = z.shape[0] # each utterance is treated as a speaker c = c[:, :-self.n_prediction_steps, :] # (256, 58, 256) losses, accuracies = list(), list() for k in range(1, self.n_prediction_steps+1): z_shift = z[:, k:length + k, :] # (256, 58, 64), positive samples Wc = self.predictors[k-1](c) # (256, 58, 256) -> (256, 58, 64) # seq_index: (256, 10, 58) seq_index = torch.randint( 1, length, size=( n_speakers_per_batch, self.n_negatives, length ), device=z.device ) seq_index += torch.arange(length, device=z.device) #(1) seq_index = torch.remainder(seq_index, length) #(2) (1)+(2) ensures that the current positive frame will not be selected as negative sample... speaker_index = torch.arange(n_speakers_per_batch, device=z.device) # within-utterance sampling speaker_index = speaker_index.view(-1, 1, 1) z_negatives = z_shift[speaker_index, seq_index, :] # (256,10,58,64), z_negatives[i,:,j,:] is the negative samples set for ith utterance and jth time-step zs = torch.cat((z_shift.unsqueeze(1), z_negatives), dim=1) # (256,11,58,64) f = torch.sum(zs * Wc.unsqueeze(1) / math.sqrt(self.z_dim), dim=-1) # (256,11,58), vector product in fact... labels = torch.zeros( n_speakers_per_batch, length, dtype=torch.long, device=z.device ) loss = F.cross_entropy(f, labels) accuracy = f.argmax(dim=1) == labels # (256, 58) accuracy = torch.mean(accuracy.float()) losses.append(loss) accuracies.append(accuracy.item()) loss = torch.stack(losses).mean() return loss, accuracies