Spaces:
Running
Running
""" | |
Taken from ESPNet | |
""" | |
from abc import ABC | |
import torch | |
from Layers.Conformer import Conformer | |
from Layers.DurationPredictor import DurationPredictor | |
from Layers.LengthRegulator import LengthRegulator | |
from Layers.PostNet import PostNet | |
from Layers.VariancePredictor import VariancePredictor | |
from TrainingInterfaces.Text_to_Spectrogram.FastSpeech2.FastSpeech2Loss import FastSpeech2Loss | |
from Utility.SoftDTW.sdtw_cuda_loss import SoftDTW | |
from Utility.utils import initialize | |
from Utility.utils import make_non_pad_mask | |
from Utility.utils import make_pad_mask | |
class FastSpeech2(torch.nn.Module, ABC): | |
""" | |
FastSpeech 2 module. | |
This is a module of FastSpeech 2 described in FastSpeech 2: Fast and | |
High-Quality End-to-End Text to Speech. Instead of quantized pitch and | |
energy, we use token-averaged value introduced in FastPitch: Parallel | |
Text-to-speech with Pitch Prediction. The encoder and decoder are Conformers | |
instead of regular Transformers. | |
https://arxiv.org/abs/2006.04558 | |
https://arxiv.org/abs/2006.06873 | |
https://arxiv.org/pdf/2005.08100 | |
""" | |
def __init__(self, | |
# network structure related | |
idim=66, | |
odim=80, | |
adim=384, | |
aheads=4, | |
elayers=6, | |
eunits=1536, | |
dlayers=6, | |
dunits=1536, | |
postnet_layers=5, | |
postnet_chans=256, | |
postnet_filts=5, | |
positionwise_layer_type="conv1d", | |
positionwise_conv_kernel_size=1, | |
use_scaled_pos_enc=True, | |
use_batch_norm=True, | |
encoder_normalize_before=True, | |
decoder_normalize_before=True, | |
encoder_concat_after=False, | |
decoder_concat_after=False, | |
reduction_factor=1, | |
# encoder / decoder | |
use_macaron_style_in_conformer=True, | |
use_cnn_in_conformer=True, | |
conformer_enc_kernel_size=7, | |
conformer_dec_kernel_size=31, | |
# duration predictor | |
duration_predictor_layers=2, | |
duration_predictor_chans=256, | |
duration_predictor_kernel_size=3, | |
# energy predictor | |
energy_predictor_layers=2, | |
energy_predictor_chans=256, | |
energy_predictor_kernel_size=3, | |
energy_predictor_dropout=0.5, | |
energy_embed_kernel_size=1, | |
energy_embed_dropout=0.0, | |
stop_gradient_from_energy_predictor=False, | |
# pitch predictor | |
pitch_predictor_layers=5, | |
pitch_predictor_chans=256, | |
pitch_predictor_kernel_size=5, | |
pitch_predictor_dropout=0.5, | |
pitch_embed_kernel_size=1, | |
pitch_embed_dropout=0.0, | |
stop_gradient_from_pitch_predictor=True, | |
# training related | |
transformer_enc_dropout_rate=0.2, | |
transformer_enc_positional_dropout_rate=0.2, | |
transformer_enc_attn_dropout_rate=0.2, | |
transformer_dec_dropout_rate=0.2, | |
transformer_dec_positional_dropout_rate=0.2, | |
transformer_dec_attn_dropout_rate=0.2, | |
duration_predictor_dropout_rate=0.2, | |
postnet_dropout_rate=0.5, | |
init_type="xavier_uniform", | |
init_enc_alpha=1.0, | |
init_dec_alpha=1.0, | |
use_masking=False, | |
use_weighted_masking=True, | |
# additional features | |
use_dtw_loss=False, | |
utt_embed_dim=704, | |
connect_utt_emb_at_encoder_out=True, | |
lang_embs=100): | |
super().__init__() | |
# store hyperparameters | |
self.idim = idim | |
self.odim = odim | |
self.use_dtw_loss = use_dtw_loss | |
self.eos = 1 | |
self.reduction_factor = reduction_factor | |
self.stop_gradient_from_pitch_predictor = stop_gradient_from_pitch_predictor | |
self.stop_gradient_from_energy_predictor = stop_gradient_from_energy_predictor | |
self.use_scaled_pos_enc = use_scaled_pos_enc | |
self.multilingual_model = lang_embs is not None | |
self.multispeaker_model = utt_embed_dim is not None | |
# define encoder | |
embed = torch.nn.Sequential(torch.nn.Linear(idim, 100), | |
torch.nn.Tanh(), | |
torch.nn.Linear(100, adim)) | |
self.encoder = Conformer(idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, | |
input_layer=embed, dropout_rate=transformer_enc_dropout_rate, | |
positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, | |
normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, | |
positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer, | |
use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_enc_kernel_size, zero_triu=False, | |
utt_embed=utt_embed_dim, connect_utt_emb_at_encoder_out=connect_utt_emb_at_encoder_out, lang_embs=lang_embs) | |
# define duration predictor | |
self.duration_predictor = DurationPredictor(idim=adim, n_layers=duration_predictor_layers, n_chans=duration_predictor_chans, | |
kernel_size=duration_predictor_kernel_size, dropout_rate=duration_predictor_dropout_rate, ) | |
# define pitch predictor | |
self.pitch_predictor = VariancePredictor(idim=adim, n_layers=pitch_predictor_layers, n_chans=pitch_predictor_chans, | |
kernel_size=pitch_predictor_kernel_size, dropout_rate=pitch_predictor_dropout) | |
# continuous pitch + FastPitch style avg | |
self.pitch_embed = torch.nn.Sequential( | |
torch.nn.Conv1d(in_channels=1, out_channels=adim, kernel_size=pitch_embed_kernel_size, padding=(pitch_embed_kernel_size - 1) // 2), | |
torch.nn.Dropout(pitch_embed_dropout)) | |
# define energy predictor | |
self.energy_predictor = VariancePredictor(idim=adim, n_layers=energy_predictor_layers, n_chans=energy_predictor_chans, | |
kernel_size=energy_predictor_kernel_size, dropout_rate=energy_predictor_dropout) | |
# continuous energy + FastPitch style avg | |
self.energy_embed = torch.nn.Sequential( | |
torch.nn.Conv1d(in_channels=1, out_channels=adim, kernel_size=energy_embed_kernel_size, padding=(energy_embed_kernel_size - 1) // 2), | |
torch.nn.Dropout(energy_embed_dropout)) | |
# define length regulator | |
self.length_regulator = LengthRegulator() | |
self.decoder = Conformer(idim=0, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, input_layer=None, | |
dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, | |
attention_dropout_rate=transformer_dec_attn_dropout_rate, normalize_before=decoder_normalize_before, | |
concat_after=decoder_concat_after, positionwise_conv_kernel_size=positionwise_conv_kernel_size, | |
macaron_style=use_macaron_style_in_conformer, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_dec_kernel_size) | |
# define final projection | |
self.feat_out = torch.nn.Linear(adim, odim * reduction_factor) | |
# define postnet | |
self.postnet = PostNet(idim=idim, odim=odim, n_layers=postnet_layers, n_chans=postnet_chans, n_filts=postnet_filts, use_batch_norm=use_batch_norm, | |
dropout_rate=postnet_dropout_rate) | |
# initialize parameters | |
self._reset_parameters(init_type=init_type, init_enc_alpha=init_enc_alpha, init_dec_alpha=init_dec_alpha) | |
# define criterions | |
self.criterion = FastSpeech2Loss(use_masking=use_masking, use_weighted_masking=use_weighted_masking) | |
self.dtw_criterion = SoftDTW(use_cuda=True, gamma=0.1) | |
def forward(self, | |
text_tensors, | |
text_lengths, | |
gold_speech, | |
speech_lengths, | |
gold_durations, | |
gold_pitch, | |
gold_energy, | |
utterance_embedding, | |
return_mels=False, | |
lang_ids=None): | |
""" | |
Calculate forward propagation. | |
Args: | |
return_mels: whether to return the predicted spectrogram | |
text_tensors (LongTensor): Batch of padded text vectors (B, Tmax). | |
text_lengths (LongTensor): Batch of lengths of each input (B,). | |
gold_speech (Tensor): Batch of padded target features (B, Lmax, odim). | |
speech_lengths (LongTensor): Batch of the lengths of each target (B,). | |
gold_durations (LongTensor): Batch of padded durations (B, Tmax + 1). | |
gold_pitch (Tensor): Batch of padded token-averaged pitch (B, Tmax + 1, 1). | |
gold_energy (Tensor): Batch of padded token-averaged energy (B, Tmax + 1, 1). | |
Returns: | |
Tensor: Loss scalar value. | |
Dict: Statistics to be monitored. | |
Tensor: Weight value. | |
""" | |
# Texts include EOS token from the teacher model already in this version | |
# forward propagation | |
before_outs, after_outs, d_outs, p_outs, e_outs = self._forward(text_tensors, text_lengths, gold_speech, speech_lengths, | |
gold_durations, gold_pitch, gold_energy, utterance_embedding=utterance_embedding, | |
is_inference=False, lang_ids=lang_ids) | |
# modify mod part of groundtruth (speaking pace) | |
if self.reduction_factor > 1: | |
speech_lengths = speech_lengths.new([olen - olen % self.reduction_factor for olen in speech_lengths]) | |
# calculate loss | |
l1_loss, duration_loss, pitch_loss, energy_loss = self.criterion(after_outs=after_outs, before_outs=before_outs, d_outs=d_outs, p_outs=p_outs, | |
e_outs=e_outs, ys=gold_speech, ds=gold_durations, ps=gold_pitch, es=gold_energy, | |
ilens=text_lengths, olens=speech_lengths) | |
loss = l1_loss + duration_loss + pitch_loss + energy_loss | |
if self.use_dtw_loss: | |
# print("Regular Loss: {}".format(loss)) | |
dtw_loss = self.dtw_criterion(after_outs, gold_speech).mean() / 2000.0 # division to balance orders of magnitude | |
# print("DTW Loss: {}".format(dtw_loss)) | |
loss = loss + dtw_loss | |
if return_mels: | |
return loss, after_outs | |
return loss | |
def _forward(self, text_tensors, text_lens, gold_speech=None, speech_lens=None, | |
gold_durations=None, gold_pitch=None, gold_energy=None, | |
is_inference=False, alpha=1.0, utterance_embedding=None, lang_ids=None): | |
if not self.multilingual_model: | |
lang_ids = None | |
if not self.multispeaker_model: | |
utterance_embedding = None | |
# forward encoder | |
text_masks = self._source_mask(text_lens) | |
encoded_texts, _ = self.encoder(text_tensors, text_masks, utterance_embedding=utterance_embedding, lang_ids=lang_ids) # (B, Tmax, adim) | |
# forward duration predictor and variance predictors | |
d_masks = make_pad_mask(text_lens, device=text_lens.device) | |
if self.stop_gradient_from_pitch_predictor: | |
pitch_predictions = self.pitch_predictor(encoded_texts.detach(), d_masks.unsqueeze(-1)) | |
else: | |
pitch_predictions = self.pitch_predictor(encoded_texts, d_masks.unsqueeze(-1)) | |
if self.stop_gradient_from_energy_predictor: | |
energy_predictions = self.energy_predictor(encoded_texts.detach(), d_masks.unsqueeze(-1)) | |
else: | |
energy_predictions = self.energy_predictor(encoded_texts, d_masks.unsqueeze(-1)) | |
if is_inference: | |
d_outs = self.duration_predictor.inference(encoded_texts, d_masks) # (B, Tmax) | |
# use prediction in inference | |
p_embs = self.pitch_embed(pitch_predictions.transpose(1, 2)).transpose(1, 2) | |
e_embs = self.energy_embed(energy_predictions.transpose(1, 2)).transpose(1, 2) | |
encoded_texts = encoded_texts + e_embs + p_embs | |
encoded_texts = self.length_regulator(encoded_texts, d_outs, alpha) # (B, Lmax, adim) | |
else: | |
d_outs = self.duration_predictor(encoded_texts, d_masks) | |
# use groundtruth in training | |
p_embs = self.pitch_embed(gold_pitch.transpose(1, 2)).transpose(1, 2) | |
e_embs = self.energy_embed(gold_energy.transpose(1, 2)).transpose(1, 2) | |
encoded_texts = encoded_texts + e_embs + p_embs | |
encoded_texts = self.length_regulator(encoded_texts, gold_durations) # (B, Lmax, adim) | |
# forward decoder | |
if speech_lens is not None and not is_inference: | |
if self.reduction_factor > 1: | |
olens_in = speech_lens.new([olen // self.reduction_factor for olen in speech_lens]) | |
else: | |
olens_in = speech_lens | |
h_masks = self._source_mask(olens_in) | |
else: | |
h_masks = None | |
zs, _ = self.decoder(encoded_texts, h_masks) # (B, Lmax, adim) | |
before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) # (B, Lmax, odim) | |
# postnet -> (B, Lmax//r * r, odim) | |
after_outs = before_outs + self.postnet(before_outs.transpose(1, 2)).transpose(1, 2) | |
return before_outs, after_outs, d_outs, pitch_predictions, energy_predictions | |
def batch_inference(self, texts, text_lens, utt_emb): | |
_, after_outs, d_outs, _, _ = self._forward(texts, | |
text_lens, | |
None, | |
is_inference=True, | |
alpha=1.0) | |
return after_outs, d_outs | |
def inference(self, | |
text, | |
speech=None, | |
durations=None, | |
pitch=None, | |
energy=None, | |
alpha=1.0, | |
use_teacher_forcing=False, | |
utterance_embedding=None, | |
return_duration_pitch_energy=False, | |
lang_id=None): | |
""" | |
Generate the sequence of features given the sequences of characters. | |
Args: | |
text (LongTensor): Input sequence of characters (T,). | |
speech (Tensor, optional): Feature sequence to extract style (N, idim). | |
durations (LongTensor, optional): Groundtruth of duration (T + 1,). | |
pitch (Tensor, optional): Groundtruth of token-averaged pitch (T + 1, 1). | |
energy (Tensor, optional): Groundtruth of token-averaged energy (T + 1, 1). | |
alpha (float, optional): Alpha to control the speed. | |
use_teacher_forcing (bool, optional): Whether to use teacher forcing. | |
If true, groundtruth of duration, pitch and energy will be used. | |
return_duration_pitch_energy: whether to return the list of predicted durations for nicer plotting | |
Returns: | |
Tensor: Output sequence of features (L, odim). | |
""" | |
self.eval() | |
x, y = text, speech | |
d, p, e = durations, pitch, energy | |
# setup batch axis | |
ilens = torch.tensor([x.shape[0]], dtype=torch.long, device=x.device) | |
xs, ys = x.unsqueeze(0), None | |
if y is not None: | |
ys = y.unsqueeze(0) | |
if lang_id is not None: | |
lang_id = lang_id.unsqueeze(0) | |
if use_teacher_forcing: | |
# use groundtruth of duration, pitch, and energy | |
ds, ps, es = d.unsqueeze(0), p.unsqueeze(0), e.unsqueeze(0) | |
before_outs, after_outs, d_outs, pitch_predictions, energy_predictions = self._forward(xs, | |
ilens, | |
ys, | |
gold_durations=ds, | |
gold_pitch=ps, | |
gold_energy=es, | |
utterance_embedding=utterance_embedding.unsqueeze(0), | |
lang_ids=lang_id) # (1, L, odim) | |
else: | |
before_outs, after_outs, d_outs, pitch_predictions, energy_predictions = self._forward(xs, | |
ilens, | |
ys, | |
is_inference=True, | |
alpha=alpha, | |
utterance_embedding=utterance_embedding.unsqueeze(0), | |
lang_ids=lang_id) # (1, L, odim) | |
self.train() | |
if return_duration_pitch_energy: | |
return after_outs[0], d_outs[0], pitch_predictions[0], energy_predictions[0] | |
return after_outs[0] | |
def _source_mask(self, ilens): | |
""" | |
Make masks for self-attention. | |
Args: | |
ilens (LongTensor): Batch of lengths (B,). | |
Returns: | |
Tensor: Mask tensor for self-attention. | |
""" | |
x_masks = make_non_pad_mask(ilens, device=ilens.device) | |
return x_masks.unsqueeze(-2) | |
def _reset_parameters(self, init_type, init_enc_alpha, init_dec_alpha): | |
# initialize parameters | |
if init_type != "pytorch": | |
initialize(self, init_type) | |