diff --git a/InferenceInterfaces/InferenceArchitectures/InferenceFastSpeech2.py b/InferenceInterfaces/InferenceArchitectures/InferenceFastSpeech2.py new file mode 100644 index 0000000000000000000000000000000000000000..683f0873a6bfa800929586724c0bb21dc126f0dd --- /dev/null +++ b/InferenceInterfaces/InferenceArchitectures/InferenceFastSpeech2.py @@ -0,0 +1,256 @@ +from abc import ABC + +import torch + +from Layers.Conformer import Conformer +from Layers.DurationPredictor import DurationPredictor +from Layers.LengthRegulator import LengthRegulator +from Layers.PostNet import PostNet +from Layers.VariancePredictor import VariancePredictor +from Utility.utils import make_non_pad_mask +from Utility.utils import make_pad_mask + + +class FastSpeech2(torch.nn.Module, ABC): + + def __init__(self, # network structure related + weights, + idim=66, + odim=80, + adim=384, + aheads=4, + elayers=6, + eunits=1536, + dlayers=6, + dunits=1536, + postnet_layers=5, + postnet_chans=256, + postnet_filts=5, + positionwise_conv_kernel_size=1, + use_scaled_pos_enc=True, + use_batch_norm=True, + encoder_normalize_before=True, + decoder_normalize_before=True, + encoder_concat_after=False, + decoder_concat_after=False, + reduction_factor=1, + # encoder / decoder + use_macaron_style_in_conformer=True, + use_cnn_in_conformer=True, + conformer_enc_kernel_size=7, + conformer_dec_kernel_size=31, + # duration predictor + duration_predictor_layers=2, + duration_predictor_chans=256, + duration_predictor_kernel_size=3, + # energy predictor + energy_predictor_layers=2, + energy_predictor_chans=256, + energy_predictor_kernel_size=3, + energy_predictor_dropout=0.5, + energy_embed_kernel_size=1, + energy_embed_dropout=0.0, + stop_gradient_from_energy_predictor=True, + # pitch predictor + pitch_predictor_layers=5, + pitch_predictor_chans=256, + pitch_predictor_kernel_size=5, + pitch_predictor_dropout=0.5, + pitch_embed_kernel_size=1, + pitch_embed_dropout=0.0, + stop_gradient_from_pitch_predictor=True, + # training related + transformer_enc_dropout_rate=0.2, + transformer_enc_positional_dropout_rate=0.2, + transformer_enc_attn_dropout_rate=0.2, + transformer_dec_dropout_rate=0.2, + transformer_dec_positional_dropout_rate=0.2, + transformer_dec_attn_dropout_rate=0.2, + duration_predictor_dropout_rate=0.2, + postnet_dropout_rate=0.5, + # additional features + utt_embed_dim=704, + connect_utt_emb_at_encoder_out=True, + lang_embs=100): + super().__init__() + self.idim = idim + self.odim = odim + self.reduction_factor = reduction_factor + self.stop_gradient_from_pitch_predictor = stop_gradient_from_pitch_predictor + self.stop_gradient_from_energy_predictor = stop_gradient_from_energy_predictor + self.use_scaled_pos_enc = use_scaled_pos_enc + embed = torch.nn.Sequential(torch.nn.Linear(idim, 100), + torch.nn.Tanh(), + torch.nn.Linear(100, adim)) + self.encoder = Conformer(idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, + input_layer=embed, dropout_rate=transformer_enc_dropout_rate, + positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, + normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, + positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer, + use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_enc_kernel_size, zero_triu=False, + utt_embed=utt_embed_dim, connect_utt_emb_at_encoder_out=connect_utt_emb_at_encoder_out, lang_embs=lang_embs) + self.duration_predictor = DurationPredictor(idim=adim, n_layers=duration_predictor_layers, + n_chans=duration_predictor_chans, + kernel_size=duration_predictor_kernel_size, + dropout_rate=duration_predictor_dropout_rate, ) + self.pitch_predictor = VariancePredictor(idim=adim, n_layers=pitch_predictor_layers, + n_chans=pitch_predictor_chans, + kernel_size=pitch_predictor_kernel_size, + dropout_rate=pitch_predictor_dropout) + self.pitch_embed = torch.nn.Sequential(torch.nn.Conv1d(in_channels=1, out_channels=adim, + kernel_size=pitch_embed_kernel_size, + padding=(pitch_embed_kernel_size - 1) // 2), + torch.nn.Dropout(pitch_embed_dropout)) + self.energy_predictor = VariancePredictor(idim=adim, n_layers=energy_predictor_layers, + n_chans=energy_predictor_chans, + kernel_size=energy_predictor_kernel_size, + dropout_rate=energy_predictor_dropout) + self.energy_embed = torch.nn.Sequential(torch.nn.Conv1d(in_channels=1, out_channels=adim, + kernel_size=energy_embed_kernel_size, + padding=(energy_embed_kernel_size - 1) // 2), + torch.nn.Dropout(energy_embed_dropout)) + self.length_regulator = LengthRegulator() + self.decoder = Conformer(idim=0, + attention_dim=adim, + attention_heads=aheads, + linear_units=dunits, + num_blocks=dlayers, + input_layer=None, + dropout_rate=transformer_dec_dropout_rate, + positional_dropout_rate=transformer_dec_positional_dropout_rate, + attention_dropout_rate=transformer_dec_attn_dropout_rate, + normalize_before=decoder_normalize_before, + concat_after=decoder_concat_after, + positionwise_conv_kernel_size=positionwise_conv_kernel_size, + macaron_style=use_macaron_style_in_conformer, + use_cnn_module=use_cnn_in_conformer, + cnn_module_kernel=conformer_dec_kernel_size) + self.feat_out = torch.nn.Linear(adim, odim * reduction_factor) + self.postnet = PostNet(idim=idim, + odim=odim, + n_layers=postnet_layers, + n_chans=postnet_chans, + n_filts=postnet_filts, + use_batch_norm=use_batch_norm, + dropout_rate=postnet_dropout_rate) + self.load_state_dict(weights) + + def _forward(self, text_tensors, text_lens, gold_speech=None, speech_lens=None, + gold_durations=None, gold_pitch=None, gold_energy=None, + is_inference=False, alpha=1.0, utterance_embedding=None, lang_ids=None): + # forward encoder + text_masks = self._source_mask(text_lens) + + encoded_texts, _ = self.encoder(text_tensors, text_masks, utterance_embedding=utterance_embedding, lang_ids=lang_ids) # (B, Tmax, adim) + + # forward duration predictor and variance predictors + duration_masks = make_pad_mask(text_lens, device=text_lens.device) + + if self.stop_gradient_from_pitch_predictor: + pitch_predictions = self.pitch_predictor(encoded_texts.detach(), duration_masks.unsqueeze(-1)) + else: + pitch_predictions = self.pitch_predictor(encoded_texts, duration_masks.unsqueeze(-1)) + + if self.stop_gradient_from_energy_predictor: + energy_predictions = self.energy_predictor(encoded_texts.detach(), duration_masks.unsqueeze(-1)) + else: + energy_predictions = self.energy_predictor(encoded_texts, duration_masks.unsqueeze(-1)) + + if is_inference: + if gold_durations is not None: + duration_predictions = gold_durations + else: + duration_predictions = self.duration_predictor.inference(encoded_texts, duration_masks) + if gold_pitch is not None: + pitch_predictions = gold_pitch + if gold_energy is not None: + energy_predictions = gold_energy + pitch_embeddings = self.pitch_embed(pitch_predictions.transpose(1, 2)).transpose(1, 2) + energy_embeddings = self.energy_embed(energy_predictions.transpose(1, 2)).transpose(1, 2) + encoded_texts = encoded_texts + energy_embeddings + pitch_embeddings + encoded_texts = self.length_regulator(encoded_texts, duration_predictions, alpha) + else: + duration_predictions = self.duration_predictor(encoded_texts, duration_masks) + + # use groundtruth in training + pitch_embeddings = self.pitch_embed(gold_pitch.transpose(1, 2)).transpose(1, 2) + energy_embeddings = self.energy_embed(gold_energy.transpose(1, 2)).transpose(1, 2) + encoded_texts = encoded_texts + energy_embeddings + pitch_embeddings + encoded_texts = self.length_regulator(encoded_texts, gold_durations) # (B, Lmax, adim) + + # forward decoder + if speech_lens is not None and not is_inference: + if self.reduction_factor > 1: + olens_in = speech_lens.new([olen // self.reduction_factor for olen in speech_lens]) + else: + olens_in = speech_lens + h_masks = self._source_mask(olens_in) + else: + h_masks = None + zs, _ = self.decoder(encoded_texts, h_masks) # (B, Lmax, adim) + before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) # (B, Lmax, odim) + + # postnet -> (B, Lmax//r * r, odim) + after_outs = before_outs + self.postnet(before_outs.transpose(1, 2)).transpose(1, 2) + + return before_outs, after_outs, duration_predictions, pitch_predictions, energy_predictions + + @torch.no_grad() + def forward(self, + text, + speech=None, + durations=None, + pitch=None, + energy=None, + utterance_embedding=None, + return_duration_pitch_energy=False, + lang_id=None): + """ + Generate the sequence of features given the sequences of characters. + + Args: + text: Input sequence of characters + speech: Feature sequence to extract style + durations: Groundtruth of duration + pitch: Groundtruth of token-averaged pitch + energy: Groundtruth of token-averaged energy + return_duration_pitch_energy: whether to return the list of predicted durations for nicer plotting + utterance_embedding: embedding of utterance wide parameters + + Returns: + Mel Spectrogram + + """ + self.eval() + # setup batch axis + ilens = torch.tensor([text.shape[0]], dtype=torch.long, device=text.device) + if speech is not None: + gold_speech = speech.unsqueeze(0) + else: + gold_speech = None + if durations is not None: + durations = durations.unsqueeze(0) + if pitch is not None: + pitch = pitch.unsqueeze(0) + if energy is not None: + energy = energy.unsqueeze(0) + if lang_id is not None: + lang_id = lang_id.unsqueeze(0) + + before_outs, after_outs, d_outs, pitch_predictions, energy_predictions = self._forward(text.unsqueeze(0), + ilens, + gold_speech=gold_speech, + gold_durations=durations, + is_inference=True, + gold_pitch=pitch, + gold_energy=energy, + utterance_embedding=utterance_embedding.unsqueeze(0), + lang_ids=lang_id) + self.train() + if return_duration_pitch_energy: + return after_outs[0], d_outs[0], pitch_predictions[0], energy_predictions[0] + return after_outs[0] + + def _source_mask(self, ilens): + x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) + return x_masks.unsqueeze(-2) diff --git a/InferenceInterfaces/InferenceArchitectures/InferenceHiFiGAN.py b/InferenceInterfaces/InferenceArchitectures/InferenceHiFiGAN.py new file mode 100644 index 0000000000000000000000000000000000000000..056b970b12d3c536a604e95aa9736d74cdf3e4fd --- /dev/null +++ b/InferenceInterfaces/InferenceArchitectures/InferenceHiFiGAN.py @@ -0,0 +1,91 @@ +import torch + +from Layers.ResidualBlock import HiFiGANResidualBlock as ResidualBlock + + +class HiFiGANGenerator(torch.nn.Module): + + def __init__(self, + path_to_weights, + in_channels=80, + out_channels=1, + channels=512, + kernel_size=7, + upsample_scales=(8, 6, 4, 4), + upsample_kernel_sizes=(16, 12, 8, 8), + resblock_kernel_sizes=(3, 7, 11), + resblock_dilations=[(1, 3, 5), (1, 3, 5), (1, 3, 5)], + use_additional_convs=True, + bias=True, + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.1}, + use_weight_norm=True, ): + super().__init__() + assert kernel_size % 2 == 1, "Kernal size must be odd number." + assert len(upsample_scales) == len(upsample_kernel_sizes) + assert len(resblock_dilations) == len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_kernel_sizes) + self.num_blocks = len(resblock_kernel_sizes) + self.input_conv = torch.nn.Conv1d(in_channels, + channels, + kernel_size, + 1, + padding=(kernel_size - 1) // 2, ) + self.upsamples = torch.nn.ModuleList() + self.blocks = torch.nn.ModuleList() + for i in range(len(upsample_kernel_sizes)): + self.upsamples += [ + torch.nn.Sequential(getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), + torch.nn.ConvTranspose1d(channels // (2 ** i), + channels // (2 ** (i + 1)), + upsample_kernel_sizes[i], + upsample_scales[i], + padding=(upsample_kernel_sizes[i] - upsample_scales[i]) // 2, ), )] + for j in range(len(resblock_kernel_sizes)): + self.blocks += [ResidualBlock(kernel_size=resblock_kernel_sizes[j], + channels=channels // (2 ** (i + 1)), + dilations=resblock_dilations[j], + bias=bias, + use_additional_convs=use_additional_convs, + nonlinear_activation=nonlinear_activation, + nonlinear_activation_params=nonlinear_activation_params, )] + self.output_conv = torch.nn.Sequential( + torch.nn.LeakyReLU(), + torch.nn.Conv1d(channels // (2 ** (i + 1)), + out_channels, + kernel_size, + 1, + padding=(kernel_size - 1) // 2, ), + torch.nn.Tanh(), ) + if use_weight_norm: + self.apply_weight_norm() + self.load_state_dict(torch.load(path_to_weights, map_location='cpu')["generator"]) + + def forward(self, c, normalize_before=False): + if normalize_before: + c = (c - self.mean) / self.scale + c = self.input_conv(c.unsqueeze(0)) + for i in range(self.num_upsamples): + c = self.upsamples[i](c) + cs = 0.0 # initialize + for j in range(self.num_blocks): + cs = cs + self.blocks[i * self.num_blocks + j](c) + c = cs / self.num_blocks + c = self.output_conv(c) + return c.squeeze(0).squeeze(0) + + def remove_weight_norm(self): + def _remove_weight_norm(m): + try: + torch.nn.utils.remove_weight_norm(m) + except ValueError: + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + def _apply_weight_norm(m): + if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d): + torch.nn.utils.weight_norm(m) + + self.apply(_apply_weight_norm) diff --git a/InferenceInterfaces/InferenceArchitectures/__init__.py b/InferenceInterfaces/InferenceArchitectures/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/InferenceInterfaces/Meta_FastSpeech2.py b/InferenceInterfaces/Meta_FastSpeech2.py new file mode 100644 index 0000000000000000000000000000000000000000..295e8aaf4253f2df0b724e207c3f12719e842a82 --- /dev/null +++ b/InferenceInterfaces/Meta_FastSpeech2.py @@ -0,0 +1,76 @@ +import os + +import librosa.display as lbd +import matplotlib.pyplot as plt +import soundfile +import torch + +from InferenceInterfaces.InferenceArchitectures.InferenceFastSpeech2 import FastSpeech2 +from InferenceInterfaces.InferenceArchitectures.InferenceHiFiGAN import HiFiGANGenerator +from Preprocessing.ArticulatoryCombinedTextFrontend import ArticulatoryCombinedTextFrontend +from Preprocessing.ArticulatoryCombinedTextFrontend import get_language_id +from Preprocessing.ProsodicConditionExtractor import ProsodicConditionExtractor + + +class Meta_FastSpeech2(torch.nn.Module): + + def __init__(self, device="cpu"): + super().__init__() + model_name = "Meta" + language = "en" + self.device = device + self.text2phone = ArticulatoryCombinedTextFrontend(language=language, add_silence_to_end=True) + checkpoint = torch.load(os.path.join("Models", f"FastSpeech2_{model_name}", "best.pt"), map_location='cpu') + self.phone2mel = FastSpeech2(weights=checkpoint["model"]).to(torch.device(device)) + self.mel2wav = HiFiGANGenerator(path_to_weights=os.path.join("Models", "HiFiGAN_combined", "best.pt")).to(torch.device(device)) + self.default_utterance_embedding = checkpoint["default_emb"].to(self.device) + self.phone2mel.eval() + self.mel2wav.eval() + self.lang_id = get_language_id(language) + self.to(torch.device(device)) + + def set_utterance_embedding(self, path_to_reference_audio): + wave, sr = soundfile.read(path_to_reference_audio) + self.default_utterance_embedding = ProsodicConditionExtractor(sr=sr).extract_condition_from_reference_wave(wave).to(self.device) + + def set_language(self, lang_id): + """ + The id parameter actually refers to the shorthand. This has become ambiguous with the introduction of the actual language IDs + """ + self.text2phone = ArticulatoryCombinedTextFrontend(language=lang_id, add_silence_to_end=True, silent=True) + self.lang_id = get_language_id(lang_id).to(self.device) + + def forward(self, text, view=False, durations=None, pitch=None, energy=None): + with torch.no_grad(): + phones = self.text2phone.string_to_tensor(text).to(torch.device(self.device)) + mel, durations, pitch, energy = self.phone2mel(phones, + return_duration_pitch_energy=True, + utterance_embedding=self.default_utterance_embedding, + durations=durations, + pitch=pitch, + energy=energy, + lang_id=self.lang_id) + mel = mel.transpose(0, 1) + wave = self.mel2wav(mel) + if view: + from Utility.utils import cumsum_durations + fig, ax = plt.subplots(nrows=2, ncols=1) + ax[0].plot(wave.cpu().numpy()) + lbd.specshow(mel.cpu().numpy(), + ax=ax[1], + sr=16000, + cmap='GnBu', + y_axis='mel', + x_axis=None, + hop_length=256) + ax[0].yaxis.set_visible(False) + ax[1].yaxis.set_visible(False) + duration_splits, label_positions = cumsum_durations(durations.cpu().numpy()) + ax[1].set_xticks(duration_splits, minor=True) + ax[1].xaxis.grid(True, which='minor') + ax[1].set_xticks(label_positions, minor=False) + ax[1].set_xticklabels(self.text2phone.get_phone_string(text)) + ax[0].set_title(text) + plt.subplots_adjust(left=0.05, bottom=0.1, right=0.95, top=.9, wspace=0.0, hspace=0.0) + plt.show() + return wave diff --git a/InferenceInterfaces/__init__.py b/InferenceInterfaces/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Layers/Attention.py b/Layers/Attention.py new file mode 100644 index 0000000000000000000000000000000000000000..eb241e315de718099901a075feae2ed0e31c7347 --- /dev/null +++ b/Layers/Attention.py @@ -0,0 +1,324 @@ +# Written by Shigeki Karita, 2019 +# Published under Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) +# Adapted by Florian Lux, 2021 + +"""Multi-Head Attention layer definition.""" + +import math + +import numpy +import torch +from torch import nn + +from Utility.utils import make_non_pad_mask + + +class MultiHeadedAttention(nn.Module): + """ + Multi-Head Attention layer. + + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + """ + + def __init__(self, n_head, n_feat, dropout_rate): + """ + Construct an MultiHeadedAttention object. + """ + super(MultiHeadedAttention, self).__init__() + assert n_feat % n_head == 0 + # We assume d_v always equals d_k + self.d_k = n_feat // n_head + self.h = n_head + self.linear_q = nn.Linear(n_feat, n_feat) + self.linear_k = nn.Linear(n_feat, n_feat) + self.linear_v = nn.Linear(n_feat, n_feat) + self.linear_out = nn.Linear(n_feat, n_feat) + self.attn = None + self.dropout = nn.Dropout(p=dropout_rate) + + def forward_qkv(self, query, key, value): + """ + Transform query, key and value. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + + Returns: + torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). + torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). + torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). + """ + n_batch = query.size(0) + q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) + k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) + v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) + q = q.transpose(1, 2) # (batch, head, time1, d_k) + k = k.transpose(1, 2) # (batch, head, time2, d_k) + v = v.transpose(1, 2) # (batch, head, time2, d_k) + + return q, k, v + + def forward_attention(self, value, scores, mask): + """ + Compute attention context vector. + + Args: + value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). + scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). + mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). + + Returns: + torch.Tensor: Transformed value (#batch, time1, d_model) + weighted by the attention score (#batch, time1, time2). + """ + n_batch = value.size(0) + if mask is not None: + mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) + min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min) + scores = scores.masked_fill(mask, min_value) + self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2) + else: + self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) + + p_attn = self.dropout(self.attn) + x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) + x = (x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)) # (batch, time1, d_model) + + return self.linear_out(x) # (batch, time1, d_model) + + def forward(self, query, key, value, mask): + """ + Compute scaled dot product attention. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + """ + q, k, v = self.forward_qkv(query, key, value) + scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) + return self.forward_attention(v, scores, mask) + + +class RelPositionMultiHeadedAttention(MultiHeadedAttention): + """ + Multi-Head Attention layer with relative position encoding. + Details can be found in https://github.com/espnet/espnet/pull/2816. + Paper: https://arxiv.org/abs/1901.02860 + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + zero_triu (bool): Whether to zero the upper triangular part of attention matrix. + """ + + def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False): + """Construct an RelPositionMultiHeadedAttention object.""" + super().__init__(n_head, n_feat, dropout_rate) + self.zero_triu = zero_triu + # linear transformation for positional encoding + self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) + self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) + torch.nn.init.xavier_uniform_(self.pos_bias_u) + torch.nn.init.xavier_uniform_(self.pos_bias_v) + + def rel_shift(self, x): + """ + Compute relative positional encoding. + Args: + x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1). + time1 means the length of query vector. + Returns: + torch.Tensor: Output tensor. + """ + zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype) + x_padded = torch.cat([zero_pad, x], dim=-1) + + x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) + x = x_padded[:, :, 1:].view_as(x)[:, :, :, : x.size(-1) // 2 + 1] # only keep the positions from 0 to time2 + + if self.zero_triu: + ones = torch.ones((x.size(2), x.size(3)), device=x.device) + x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] + + return x + + def forward(self, query, key, value, pos_emb, mask): + """ + Compute 'Scaled Dot Product Attention' with rel. positional encoding. + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + pos_emb (torch.Tensor): Positional embedding tensor + (#batch, 2*time1-1, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + """ + q, k, v = self.forward_qkv(query, key, value) + q = q.transpose(1, 2) # (batch, time1, head, d_k) + + n_batch_pos = pos_emb.size(0) + p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) + p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) + + # (batch, head, time1, d_k) + q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) + # (batch, head, time1, d_k) + q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) + + # compute attention score + # first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) + + # compute matrix b and matrix d + # (batch, head, time1, 2*time1-1) + matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) + matrix_bd = self.rel_shift(matrix_bd) + + scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) # (batch, head, time1, time2) + + return self.forward_attention(v, scores, mask) + + +class GuidedAttentionLoss(torch.nn.Module): + """ + Guided attention loss function module. + + This module calculates the guided attention loss described + in `Efficiently Trainable Text-to-Speech System Based + on Deep Convolutional Networks with Guided Attention`_, + which forces the attention to be diagonal. + + .. _`Efficiently Trainable Text-to-Speech System + Based on Deep Convolutional Networks with Guided Attention`: + https://arxiv.org/abs/1710.08969 + """ + + def __init__(self, sigma=0.4, alpha=1.0): + """ + Initialize guided attention loss module. + + Args: + sigma (float, optional): Standard deviation to control + how close attention to a diagonal. + alpha (float, optional): Scaling coefficient (lambda). + reset_always (bool, optional): Whether to always reset masks. + """ + super(GuidedAttentionLoss, self).__init__() + self.sigma = sigma + self.alpha = alpha + self.guided_attn_masks = None + self.masks = None + + def _reset_masks(self): + self.guided_attn_masks = None + self.masks = None + + def forward(self, att_ws, ilens, olens): + """ + Calculate forward propagation. + + Args: + att_ws (Tensor): Batch of attention weights (B, T_max_out, T_max_in). + ilens (LongTensor): Batch of input lenghts (B,). + olens (LongTensor): Batch of output lenghts (B,). + + Returns: + Tensor: Guided attention loss value. + """ + self._reset_masks() + self.guided_attn_masks = self._make_guided_attention_masks(ilens, olens).to(att_ws.device) + self.masks = self._make_masks(ilens, olens).to(att_ws.device) + losses = self.guided_attn_masks * att_ws + loss = torch.mean(losses.masked_select(self.masks)) + self._reset_masks() + return self.alpha * loss + + def _make_guided_attention_masks(self, ilens, olens): + n_batches = len(ilens) + max_ilen = max(ilens) + max_olen = max(olens) + guided_attn_masks = torch.zeros((n_batches, max_olen, max_ilen), device=ilens.device) + for idx, (ilen, olen) in enumerate(zip(ilens, olens)): + guided_attn_masks[idx, :olen, :ilen] = self._make_guided_attention_mask(ilen, olen, self.sigma) + return guided_attn_masks + + @staticmethod + def _make_guided_attention_mask(ilen, olen, sigma): + """ + Make guided attention mask. + """ + grid_x, grid_y = torch.meshgrid(torch.arange(olen, device=olen.device).float(), torch.arange(ilen, device=ilen.device).float()) + return 1.0 - torch.exp(-((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma ** 2))) + + @staticmethod + def _make_masks(ilens, olens): + """ + Make masks indicating non-padded part. + + Args: + ilens (LongTensor or List): Batch of lengths (B,). + olens (LongTensor or List): Batch of lengths (B,). + + Returns: + Tensor: Mask tensor indicating non-padded part. + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (including 1.2) + """ + in_masks = make_non_pad_mask(ilens, device=ilens.device) # (B, T_in) + out_masks = make_non_pad_mask(olens, device=olens.device) # (B, T_out) + return out_masks.unsqueeze(-1) & in_masks.unsqueeze(-2) # (B, T_out, T_in) + + +class GuidedMultiHeadAttentionLoss(GuidedAttentionLoss): + """ + Guided attention loss function module for multi head attention. + + Args: + sigma (float, optional): Standard deviation to control + how close attention to a diagonal. + alpha (float, optional): Scaling coefficient (lambda). + reset_always (bool, optional): Whether to always reset masks. + """ + + def forward(self, att_ws, ilens, olens): + """ + Calculate forward propagation. + + Args: + att_ws (Tensor): + Batch of multi head attention weights (B, H, T_max_out, T_max_in). + ilens (LongTensor): Batch of input lenghts (B,). + olens (LongTensor): Batch of output lenghts (B,). + + Returns: + Tensor: Guided attention loss value. + """ + if self.guided_attn_masks is None: + self.guided_attn_masks = (self._make_guided_attention_masks(ilens, olens).to(att_ws.device).unsqueeze(1)) + if self.masks is None: + self.masks = self._make_masks(ilens, olens).to(att_ws.device).unsqueeze(1) + losses = self.guided_attn_masks * att_ws + loss = torch.mean(losses.masked_select(self.masks)) + if self.reset_always: + self._reset_masks() + + return self.alpha * loss diff --git a/Layers/Conformer.py b/Layers/Conformer.py new file mode 100644 index 0000000000000000000000000000000000000000..5ca87bfbf18bcfb84830501dc3d00e3a38916966 --- /dev/null +++ b/Layers/Conformer.py @@ -0,0 +1,144 @@ +""" +Taken from ESPNet +""" + +import torch +import torch.nn.functional as F + +from Layers.Attention import RelPositionMultiHeadedAttention +from Layers.Convolution import ConvolutionModule +from Layers.EncoderLayer import EncoderLayer +from Layers.LayerNorm import LayerNorm +from Layers.MultiLayeredConv1d import MultiLayeredConv1d +from Layers.MultiSequential import repeat +from Layers.PositionalEncoding import RelPositionalEncoding +from Layers.Swish import Swish + + +class Conformer(torch.nn.Module): + """ + Conformer encoder module. + + Args: + idim (int): Input dimension. + attention_dim (int): Dimension of attention. + attention_heads (int): The number of heads of multi head attention. + linear_units (int): The number of units of position-wise feed forward. + num_blocks (int): The number of decoder blocks. + dropout_rate (float): Dropout rate. + positional_dropout_rate (float): Dropout rate after adding positional encoding. + attention_dropout_rate (float): Dropout rate in attention. + input_layer (Union[str, torch.nn.Module]): Input layer type. + normalize_before (bool): Whether to use layer_norm before the first block. + concat_after (bool): Whether to concat attention layer's input and output. + if True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. i.e. x -> x + att(x) + positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear". + positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer. + macaron_style (bool): Whether to use macaron style for positionwise layer. + pos_enc_layer_type (str): Conformer positional encoding layer type. + selfattention_layer_type (str): Conformer attention layer type. + activation_type (str): Conformer activation function type. + use_cnn_module (bool): Whether to use convolution module. + cnn_module_kernel (int): Kernerl size of convolution module. + padding_idx (int): Padding idx for input_layer=embed. + + """ + + def __init__(self, idim, attention_dim=256, attention_heads=4, linear_units=2048, num_blocks=6, dropout_rate=0.1, positional_dropout_rate=0.1, + attention_dropout_rate=0.0, input_layer="conv2d", normalize_before=True, concat_after=False, positionwise_conv_kernel_size=1, + macaron_style=False, use_cnn_module=False, cnn_module_kernel=31, zero_triu=False, utt_embed=None, connect_utt_emb_at_encoder_out=True, + spk_emb_bottleneck_size=128, lang_embs=None): + super(Conformer, self).__init__() + + activation = Swish() + self.conv_subsampling_factor = 1 + + if isinstance(input_layer, torch.nn.Module): + self.embed = input_layer + self.pos_enc = RelPositionalEncoding(attention_dim, positional_dropout_rate) + elif input_layer is None: + self.embed = None + self.pos_enc = torch.nn.Sequential(RelPositionalEncoding(attention_dim, positional_dropout_rate)) + else: + raise ValueError("unknown input_layer: " + input_layer) + + self.normalize_before = normalize_before + + self.connect_utt_emb_at_encoder_out = connect_utt_emb_at_encoder_out + if utt_embed is not None: + self.hs_emb_projection = torch.nn.Linear(attention_dim + spk_emb_bottleneck_size, attention_dim) + # embedding projection derived from https://arxiv.org/pdf/1705.08947.pdf + self.embedding_projection = torch.nn.Sequential(torch.nn.Linear(utt_embed, spk_emb_bottleneck_size), + torch.nn.Softsign()) + if lang_embs is not None: + self.language_embedding = torch.nn.Embedding(num_embeddings=lang_embs, embedding_dim=attention_dim) + + # self-attention module definition + encoder_selfattn_layer = RelPositionMultiHeadedAttention + encoder_selfattn_layer_args = (attention_heads, attention_dim, attention_dropout_rate, zero_triu) + + # feed-forward module definition + positionwise_layer = MultiLayeredConv1d + positionwise_layer_args = (attention_dim, linear_units, positionwise_conv_kernel_size, dropout_rate,) + + # convolution module definition + convolution_layer = ConvolutionModule + convolution_layer_args = (attention_dim, cnn_module_kernel, activation) + + self.encoders = repeat(num_blocks, lambda lnum: EncoderLayer(attention_dim, encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + positionwise_layer(*positionwise_layer_args) if macaron_style else None, + convolution_layer(*convolution_layer_args) if use_cnn_module else None, dropout_rate, + normalize_before, concat_after)) + if self.normalize_before: + self.after_norm = LayerNorm(attention_dim) + + def forward(self, xs, masks, utterance_embedding=None, lang_ids=None): + """ + Encode input sequence. + + Args: + utterance_embedding: embedding containing lots of conditioning signals + step: indicator for when to start updating the embedding function + xs (torch.Tensor): Input tensor (#batch, time, idim). + masks (torch.Tensor): Mask tensor (#batch, time). + + Returns: + torch.Tensor: Output tensor (#batch, time, attention_dim). + torch.Tensor: Mask tensor (#batch, time). + + """ + + if self.embed is not None: + xs = self.embed(xs) + + if lang_ids is not None: + lang_embs = self.language_embedding(lang_ids) + xs = xs + lang_embs # offset the phoneme distribution of a language + + if utterance_embedding is not None and not self.connect_utt_emb_at_encoder_out: + xs = self._integrate_with_utt_embed(xs, utterance_embedding) + + xs = self.pos_enc(xs) + + xs, masks = self.encoders(xs, masks) + if isinstance(xs, tuple): + xs = xs[0] + + if self.normalize_before: + xs = self.after_norm(xs) + + if utterance_embedding is not None and self.connect_utt_emb_at_encoder_out: + xs = self._integrate_with_utt_embed(xs, utterance_embedding) + + return xs, masks + + def _integrate_with_utt_embed(self, hs, utt_embeddings): + # project embedding into smaller space + speaker_embeddings_projected = self.embedding_projection(utt_embeddings) + # concat hidden states with spk embeds and then apply projection + speaker_embeddings_expanded = F.normalize(speaker_embeddings_projected).unsqueeze(1).expand(-1, hs.size(1), -1) + hs = self.hs_emb_projection(torch.cat([hs, speaker_embeddings_expanded], dim=-1)) + return hs diff --git a/Layers/Convolution.py b/Layers/Convolution.py new file mode 100644 index 0000000000000000000000000000000000000000..e6e56e85d5908b0db5fceaea1e701d197a824d4b --- /dev/null +++ b/Layers/Convolution.py @@ -0,0 +1,55 @@ +# Copyright 2020 Johns Hopkins University (Shinji Watanabe) +# Northwestern Polytechnical University (Pengcheng Guo) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) +# Adapted by Florian Lux 2021 + + +from torch import nn + + +class ConvolutionModule(nn.Module): + """ + ConvolutionModule in Conformer model. + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernel size of conv layers. + + """ + + def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True): + super(ConvolutionModule, self).__init__() + # kernel_size should be an odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + self.pointwise_conv1 = nn.Conv1d(channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=bias, ) + self.depthwise_conv = nn.Conv1d(channels, channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2, groups=channels, bias=bias, ) + self.norm = nn.GroupNorm(num_groups=32, num_channels=channels) + self.pointwise_conv2 = nn.Conv1d(channels, channels, kernel_size=1, stride=1, padding=0, bias=bias, ) + self.activation = activation + + def forward(self, x): + """ + Compute convolution module. + + Args: + x (torch.Tensor): Input tensor (#batch, time, channels). + + Returns: + torch.Tensor: Output tensor (#batch, time, channels). + + """ + # exchange the temporal dimension and the feature dimension + x = x.transpose(1, 2) + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channel, dim) + x = nn.functional.glu(x, dim=1) # (batch, channel, dim) + + # 1D Depthwise Conv + x = self.depthwise_conv(x) + x = self.activation(self.norm(x)) + + x = self.pointwise_conv2(x) + + return x.transpose(1, 2) diff --git a/Layers/DurationPredictor.py b/Layers/DurationPredictor.py new file mode 100644 index 0000000000000000000000000000000000000000..4ccfe1c4584a1de8f9f7b65fc7997885539976b1 --- /dev/null +++ b/Layers/DurationPredictor.py @@ -0,0 +1,139 @@ +# Copyright 2019 Tomoki Hayashi +# MIT License (https://opensource.org/licenses/MIT) +# Adapted by Florian Lux 2021 + + +import torch + +from Layers.LayerNorm import LayerNorm + + +class DurationPredictor(torch.nn.Module): + """ + Duration predictor module. + + This is a module of duration predictor described + in `FastSpeech: Fast, Robust and Controllable Text to Speech`_. + The duration predictor predicts a duration of each frame in log domain + from the hidden embeddings of encoder. + + .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`: + https://arxiv.org/pdf/1905.09263.pdf + + Note: + The calculation domain of outputs is different + between in `forward` and in `inference`. In `forward`, + the outputs are calculated in log domain but in `inference`, + those are calculated in linear domain. + + """ + + def __init__(self, idim, n_layers=2, n_chans=384, kernel_size=3, dropout_rate=0.1, offset=1.0): + """ + Initialize duration predictor module. + + Args: + idim (int): Input dimension. + n_layers (int, optional): Number of convolutional layers. + n_chans (int, optional): Number of channels of convolutional layers. + kernel_size (int, optional): Kernel size of convolutional layers. + dropout_rate (float, optional): Dropout rate. + offset (float, optional): Offset value to avoid nan in log domain. + + """ + super(DurationPredictor, self).__init__() + self.offset = offset + self.conv = torch.nn.ModuleList() + for idx in range(n_layers): + in_chans = idim if idx == 0 else n_chans + self.conv += [torch.nn.Sequential(torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=(kernel_size - 1) // 2, ), torch.nn.ReLU(), + LayerNorm(n_chans, dim=1), torch.nn.Dropout(dropout_rate), )] + self.linear = torch.nn.Linear(n_chans, 1) + + def _forward(self, xs, x_masks=None, is_inference=False): + xs = xs.transpose(1, -1) # (B, idim, Tmax) + for f in self.conv: + xs = f(xs) # (B, C, Tmax) + + # NOTE: calculate in log domain + xs = self.linear(xs.transpose(1, -1)).squeeze(-1) # (B, Tmax) + + if is_inference: + # NOTE: calculate in linear domain + xs = torch.clamp(torch.round(xs.exp() - self.offset), min=0).long() # avoid negative value + + if x_masks is not None: + xs = xs.masked_fill(x_masks, 0.0) + + return xs + + def forward(self, xs, x_masks=None): + """ + Calculate forward propagation. + + Args: + xs (Tensor): Batch of input sequences (B, Tmax, idim). + x_masks (ByteTensor, optional): + Batch of masks indicating padded part (B, Tmax). + + Returns: + Tensor: Batch of predicted durations in log domain (B, Tmax). + + """ + return self._forward(xs, x_masks, False) + + def inference(self, xs, x_masks=None): + """ + Inference duration. + + Args: + xs (Tensor): Batch of input sequences (B, Tmax, idim). + x_masks (ByteTensor, optional): + Batch of masks indicating padded part (B, Tmax). + + Returns: + LongTensor: Batch of predicted durations in linear domain (B, Tmax). + + """ + return self._forward(xs, x_masks, True) + + +class DurationPredictorLoss(torch.nn.Module): + """ + Loss function module for duration predictor. + + The loss value is Calculated in log domain to make it Gaussian. + + """ + + def __init__(self, offset=1.0, reduction="mean"): + """ + Args: + offset (float, optional): Offset value to avoid nan in log domain. + reduction (str): Reduction type in loss calculation. + + """ + super(DurationPredictorLoss, self).__init__() + self.criterion = torch.nn.MSELoss(reduction=reduction) + self.offset = offset + + def forward(self, outputs, targets): + """ + Calculate forward propagation. + + Args: + outputs (Tensor): Batch of prediction durations in log domain (B, T) + targets (LongTensor): Batch of groundtruth durations in linear domain (B, T) + + Returns: + Tensor: Mean squared error loss value. + + Note: + `outputs` is in log domain but `targets` is in linear domain. + + """ + # NOTE: outputs is in log domain while targets in linear + targets = torch.log(targets.float() + self.offset) + loss = self.criterion(outputs, targets) + + return loss diff --git a/Layers/EncoderLayer.py b/Layers/EncoderLayer.py new file mode 100644 index 0000000000000000000000000000000000000000..6ae91c25d7a88dfff0603c263b63b8bb0f05c80a --- /dev/null +++ b/Layers/EncoderLayer.py @@ -0,0 +1,144 @@ +# Copyright 2020 Johns Hopkins University (Shinji Watanabe) +# Northwestern Polytechnical University (Pengcheng Guo) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) +# Adapted by Florian Lux 2021 + + +import torch +from torch import nn + +from Layers.LayerNorm import LayerNorm + + +class EncoderLayer(nn.Module): + """ + Encoder layer module. + + Args: + size (int): Input dimension. + self_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance + can be used as the argument. + feed_forward (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance + can be used as the argument. + feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance. + `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance + can be used as the argument. + conv_module (torch.nn.Module): Convolution module instance. + `ConvlutionModule` instance can be used as the argument. + dropout_rate (float): Dropout rate. + normalize_before (bool): Whether to use layer_norm before the first block. + concat_after (bool): Whether to concat attention layer's input and output. + if True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. i.e. x -> x + att(x) + + """ + + def __init__(self, size, self_attn, feed_forward, feed_forward_macaron, conv_module, dropout_rate, normalize_before=True, concat_after=False, ): + super(EncoderLayer, self).__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.feed_forward_macaron = feed_forward_macaron + self.conv_module = conv_module + self.norm_ff = LayerNorm(size) # for the FNN module + self.norm_mha = LayerNorm(size) # for the MHA module + if feed_forward_macaron is not None: + self.norm_ff_macaron = LayerNorm(size) + self.ff_scale = 0.5 + else: + self.ff_scale = 1.0 + if self.conv_module is not None: + self.norm_conv = LayerNorm(size) # for the CNN module + self.norm_final = LayerNorm(size) # for the final output of the block + self.dropout = nn.Dropout(dropout_rate) + self.size = size + self.normalize_before = normalize_before + self.concat_after = concat_after + if self.concat_after: + self.concat_linear = nn.Linear(size + size, size) + + def forward(self, x_input, mask, cache=None): + """ + Compute encoded features. + + Args: + x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb. + - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)]. + - w/o pos emb: Tensor (#batch, time, size). + mask (torch.Tensor): Mask tensor for the input (#batch, time). + cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). + + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time). + + """ + if isinstance(x_input, tuple): + x, pos_emb = x_input[0], x_input[1] + else: + x, pos_emb = x_input, None + + # whether to use macaron style + if self.feed_forward_macaron is not None: + residual = x + if self.normalize_before: + x = self.norm_ff_macaron(x) + x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x)) + if not self.normalize_before: + x = self.norm_ff_macaron(x) + + # multi-headed self-attention module + residual = x + if self.normalize_before: + x = self.norm_mha(x) + + if cache is None: + x_q = x + else: + assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size) + x_q = x[:, -1:, :] + residual = residual[:, -1:, :] + mask = None if mask is None else mask[:, -1:, :] + + if pos_emb is not None: + x_att = self.self_attn(x_q, x, x, pos_emb, mask) + else: + x_att = self.self_attn(x_q, x, x, mask) + + if self.concat_after: + x_concat = torch.cat((x, x_att), dim=-1) + x = residual + self.concat_linear(x_concat) + else: + x = residual + self.dropout(x_att) + if not self.normalize_before: + x = self.norm_mha(x) + + # convolution module + if self.conv_module is not None: + residual = x + if self.normalize_before: + x = self.norm_conv(x) + x = residual + self.dropout(self.conv_module(x)) + if not self.normalize_before: + x = self.norm_conv(x) + + # feed forward module + residual = x + if self.normalize_before: + x = self.norm_ff(x) + x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm_ff(x) + + if self.conv_module is not None: + x = self.norm_final(x) + + if cache is not None: + x = torch.cat([cache, x], dim=1) + + if pos_emb is not None: + return (x, pos_emb), mask + + return x, mask diff --git a/Layers/LayerNorm.py b/Layers/LayerNorm.py new file mode 100644 index 0000000000000000000000000000000000000000..c4cb4c15df0ccc0195bc18e124f4b50fb6bcee80 --- /dev/null +++ b/Layers/LayerNorm.py @@ -0,0 +1,36 @@ +# Written by Shigeki Karita, 2019 +# Published under Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) +# Adapted by Florian Lux, 2021 + +import torch + + +class LayerNorm(torch.nn.LayerNorm): + """ + Layer normalization module. + + Args: + nout (int): Output dim size. + dim (int): Dimension to be normalized. + """ + + def __init__(self, nout, dim=-1): + """ + Construct an LayerNorm object. + """ + super(LayerNorm, self).__init__(nout, eps=1e-12) + self.dim = dim + + def forward(self, x): + """ + Apply layer normalization. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Normalized tensor. + """ + if self.dim == -1: + return super(LayerNorm, self).forward(x) + return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1) diff --git a/Layers/LengthRegulator.py b/Layers/LengthRegulator.py new file mode 100644 index 0000000000000000000000000000000000000000..e375cf18524e4695da5d0909b65a56a178696d40 --- /dev/null +++ b/Layers/LengthRegulator.py @@ -0,0 +1,62 @@ +# Copyright 2019 Tomoki Hayashi +# MIT License (https://opensource.org/licenses/MIT) +# Adapted by Florian Lux 2021 + +from abc import ABC + +import torch + +from Utility.utils import pad_list + + +class LengthRegulator(torch.nn.Module, ABC): + """ + Length regulator module for feed-forward Transformer. + + This is a module of length regulator described in + `FastSpeech: Fast, Robust and Controllable Text to Speech`_. + The length regulator expands char or + phoneme-level embedding features to frame-level by repeating each + feature based on the corresponding predicted durations. + + .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`: + https://arxiv.org/pdf/1905.09263.pdf + + """ + + def __init__(self, pad_value=0.0): + """ + Initialize length regulator module. + + Args: + pad_value (float, optional): Value used for padding. + """ + super(LengthRegulator, self).__init__() + self.pad_value = pad_value + + def forward(self, xs, ds, alpha=1.0): + """ + Calculate forward propagation. + + Args: + xs (Tensor): Batch of sequences of char or phoneme embeddings (B, Tmax, D). + ds (LongTensor): Batch of durations of each frame (B, T). + alpha (float, optional): Alpha value to control speed of speech. + + Returns: + Tensor: replicated input tensor based on durations (B, T*, D). + """ + if alpha != 1.0: + assert alpha > 0 + ds = torch.round(ds.float() * alpha).long() + + if ds.sum() == 0: + ds[ds.sum(dim=1).eq(0)] = 1 + + return pad_list([self._repeat_one_sequence(x, d) for x, d in zip(xs, ds)], self.pad_value) + + def _repeat_one_sequence(self, x, d): + """ + Repeat each frame according to duration + """ + return torch.repeat_interleave(x, d, dim=0) diff --git a/Layers/MultiLayeredConv1d.py b/Layers/MultiLayeredConv1d.py new file mode 100644 index 0000000000000000000000000000000000000000..f2de4a06a06d891fbaca726959b0f0d34d93d7cc --- /dev/null +++ b/Layers/MultiLayeredConv1d.py @@ -0,0 +1,87 @@ +# Copyright 2019 Tomoki Hayashi +# MIT License (https://opensource.org/licenses/MIT) +# Adapted by Florian Lux 2021 + +""" +Layer modules for FFT block in FastSpeech (Feed-forward Transformer). +""" + +import torch + + +class MultiLayeredConv1d(torch.nn.Module): + """ + Multi-layered conv1d for Transformer block. + + This is a module of multi-layered conv1d designed + to replace positionwise feed-forward network + in Transformer block, which is introduced in + `FastSpeech: Fast, Robust and Controllable Text to Speech`_. + + .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`: + https://arxiv.org/pdf/1905.09263.pdf + """ + + def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate): + """ + Initialize MultiLayeredConv1d module. + + Args: + in_chans (int): Number of input channels. + hidden_chans (int): Number of hidden channels. + kernel_size (int): Kernel size of conv1d. + dropout_rate (float): Dropout rate. + """ + super(MultiLayeredConv1d, self).__init__() + self.w_1 = torch.nn.Conv1d(in_chans, hidden_chans, kernel_size, stride=1, padding=(kernel_size - 1) // 2, ) + self.w_2 = torch.nn.Conv1d(hidden_chans, in_chans, kernel_size, stride=1, padding=(kernel_size - 1) // 2, ) + self.dropout = torch.nn.Dropout(dropout_rate) + + def forward(self, x): + """ + Calculate forward propagation. + + Args: + x (torch.Tensor): Batch of input tensors (B, T, in_chans). + + Returns: + torch.Tensor: Batch of output tensors (B, T, hidden_chans). + """ + x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1) + return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1) + + +class Conv1dLinear(torch.nn.Module): + """ + Conv1D + Linear for Transformer block. + + A variant of MultiLayeredConv1d, which replaces second conv-layer to linear. + """ + + def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate): + """ + Initialize Conv1dLinear module. + + Args: + in_chans (int): Number of input channels. + hidden_chans (int): Number of hidden channels. + kernel_size (int): Kernel size of conv1d. + dropout_rate (float): Dropout rate. + """ + super(Conv1dLinear, self).__init__() + self.w_1 = torch.nn.Conv1d(in_chans, hidden_chans, kernel_size, stride=1, padding=(kernel_size - 1) // 2, ) + self.w_2 = torch.nn.Linear(hidden_chans, in_chans) + self.dropout = torch.nn.Dropout(dropout_rate) + + def forward(self, x): + """ + Calculate forward propagation. + + Args: + x (torch.Tensor): Batch of input tensors (B, T, in_chans). + + Returns: + torch.Tensor: Batch of output tensors (B, T, hidden_chans). + """ + x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1) + return self.w_2(self.dropout(x)) diff --git a/Layers/MultiSequential.py b/Layers/MultiSequential.py new file mode 100644 index 0000000000000000000000000000000000000000..bccf8cd18bf94a42fcc1ef94f3fb23e86a114394 --- /dev/null +++ b/Layers/MultiSequential.py @@ -0,0 +1,33 @@ +# Written by Shigeki Karita, 2019 +# Published under Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) +# Adapted by Florian Lux, 2021 + +import torch + + +class MultiSequential(torch.nn.Sequential): + """ + Multi-input multi-output torch.nn.Sequential. + """ + + def forward(self, *args): + """ + Repeat. + """ + for m in self: + args = m(*args) + return args + + +def repeat(N, fn): + """ + Repeat module N times. + + Args: + N (int): Number of repeat time. + fn (Callable): Function to generate module. + + Returns: + MultiSequential: Repeated model instance. + """ + return MultiSequential(*[fn(n) for n in range(N)]) diff --git a/Layers/PositionalEncoding.py b/Layers/PositionalEncoding.py new file mode 100644 index 0000000000000000000000000000000000000000..8929a7fa6298f00e97fba1630524da014b738ace --- /dev/null +++ b/Layers/PositionalEncoding.py @@ -0,0 +1,166 @@ +""" +Taken from ESPNet +""" + +import math + +import torch + + +class PositionalEncoding(torch.nn.Module): + """ + Positional encoding. + + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int): Maximum input length. + reverse (bool): Whether to reverse the input position. + """ + + def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False): + """ + Construct an PositionalEncoding object. + """ + super(PositionalEncoding, self).__init__() + self.d_model = d_model + self.reverse = reverse + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0, device=d_model.device).expand(1, max_len)) + + def extend_pe(self, x): + """ + Reset the positional encodings. + """ + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.d_model) + if self.reverse: + position = torch.arange(x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1) + else: + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp(torch.arange(0, self.d_model, 2, dtype=torch.float32) * -(math.log(10000.0) / self.d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, x): + """ + Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + """ + self.extend_pe(x) + x = x * self.xscale + self.pe[:, : x.size(1)] + return self.dropout(x) + + +class RelPositionalEncoding(torch.nn.Module): + """ + Relative positional encoding module (new implementation). + Details can be found in https://github.com/espnet/espnet/pull/2816. + See : Appendix B in https://arxiv.org/abs/1901.02860 + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int): Maximum input length. + """ + + def __init__(self, d_model, dropout_rate, max_len=5000): + """ + Construct an PositionalEncoding object. + """ + super(RelPositionalEncoding, self).__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x): + """Reset the positional encodings.""" + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x.size(1) * 2 - 1: + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vecotr and `j` means the + # position of key vector. We use position relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i