eriquesouza's picture
app v1
e831f85
raw history blame
No virus
19.1 kB
"""
Taken from ESPNet
"""
from abc import ABC
import torch
from Layers.Conformer import Conformer
from Layers.DurationPredictor import DurationPredictor
from Layers.LengthRegulator import LengthRegulator
from Layers.PostNet import PostNet
from Layers.VariancePredictor import VariancePredictor
from TrainingInterfaces.Text_to_Spectrogram.FastSpeech2.FastSpeech2Loss import FastSpeech2Loss
from Utility.SoftDTW.sdtw_cuda_loss import SoftDTW
from Utility.utils import initialize
from Utility.utils import make_non_pad_mask
from Utility.utils import make_pad_mask
class FastSpeech2(torch.nn.Module, ABC):
"""
FastSpeech 2 module.
This is a module of FastSpeech 2 described in FastSpeech 2: Fast and
High-Quality End-to-End Text to Speech. Instead of quantized pitch and
energy, we use token-averaged value introduced in FastPitch: Parallel
Text-to-speech with Pitch Prediction. The encoder and decoder are Conformers
instead of regular Transformers.
https://arxiv.org/abs/2006.04558
https://arxiv.org/abs/2006.06873
https://arxiv.org/pdf/2005.08100
"""
def __init__(self,
# network structure related
idim=66,
odim=80,
adim=384,
aheads=4,
elayers=6,
eunits=1536,
dlayers=6,
dunits=1536,
postnet_layers=5,
postnet_chans=256,
postnet_filts=5,
positionwise_layer_type="conv1d",
positionwise_conv_kernel_size=1,
use_scaled_pos_enc=True,
use_batch_norm=True,
encoder_normalize_before=True,
decoder_normalize_before=True,
encoder_concat_after=False,
decoder_concat_after=False,
reduction_factor=1,
# encoder / decoder
use_macaron_style_in_conformer=True,
use_cnn_in_conformer=True,
conformer_enc_kernel_size=7,
conformer_dec_kernel_size=31,
# duration predictor
duration_predictor_layers=2,
duration_predictor_chans=256,
duration_predictor_kernel_size=3,
# energy predictor
energy_predictor_layers=2,
energy_predictor_chans=256,
energy_predictor_kernel_size=3,
energy_predictor_dropout=0.5,
energy_embed_kernel_size=1,
energy_embed_dropout=0.0,
stop_gradient_from_energy_predictor=False,
# pitch predictor
pitch_predictor_layers=5,
pitch_predictor_chans=256,
pitch_predictor_kernel_size=5,
pitch_predictor_dropout=0.5,
pitch_embed_kernel_size=1,
pitch_embed_dropout=0.0,
stop_gradient_from_pitch_predictor=True,
# training related
transformer_enc_dropout_rate=0.2,
transformer_enc_positional_dropout_rate=0.2,
transformer_enc_attn_dropout_rate=0.2,
transformer_dec_dropout_rate=0.2,
transformer_dec_positional_dropout_rate=0.2,
transformer_dec_attn_dropout_rate=0.2,
duration_predictor_dropout_rate=0.2,
postnet_dropout_rate=0.5,
init_type="xavier_uniform",
init_enc_alpha=1.0,
init_dec_alpha=1.0,
use_masking=False,
use_weighted_masking=True,
# additional features
use_dtw_loss=False,
utt_embed_dim=704,
connect_utt_emb_at_encoder_out=True,
lang_embs=100):
super().__init__()
# store hyperparameters
self.idim = idim
self.odim = odim
self.use_dtw_loss = use_dtw_loss
self.eos = 1
self.reduction_factor = reduction_factor
self.stop_gradient_from_pitch_predictor = stop_gradient_from_pitch_predictor
self.stop_gradient_from_energy_predictor = stop_gradient_from_energy_predictor
self.use_scaled_pos_enc = use_scaled_pos_enc
self.multilingual_model = lang_embs is not None
self.multispeaker_model = utt_embed_dim is not None
# define encoder
embed = torch.nn.Sequential(torch.nn.Linear(idim, 100),
torch.nn.Tanh(),
torch.nn.Linear(100, adim))
self.encoder = Conformer(idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers,
input_layer=embed, dropout_rate=transformer_enc_dropout_rate,
positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate,
normalize_before=encoder_normalize_before, concat_after=encoder_concat_after,
positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer,
use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_enc_kernel_size, zero_triu=False,
utt_embed=utt_embed_dim, connect_utt_emb_at_encoder_out=connect_utt_emb_at_encoder_out, lang_embs=lang_embs)
# define duration predictor
self.duration_predictor = DurationPredictor(idim=adim, n_layers=duration_predictor_layers, n_chans=duration_predictor_chans,
kernel_size=duration_predictor_kernel_size, dropout_rate=duration_predictor_dropout_rate, )
# define pitch predictor
self.pitch_predictor = VariancePredictor(idim=adim, n_layers=pitch_predictor_layers, n_chans=pitch_predictor_chans,
kernel_size=pitch_predictor_kernel_size, dropout_rate=pitch_predictor_dropout)
# continuous pitch + FastPitch style avg
self.pitch_embed = torch.nn.Sequential(
torch.nn.Conv1d(in_channels=1, out_channels=adim, kernel_size=pitch_embed_kernel_size, padding=(pitch_embed_kernel_size - 1) // 2),
torch.nn.Dropout(pitch_embed_dropout))
# define energy predictor
self.energy_predictor = VariancePredictor(idim=adim, n_layers=energy_predictor_layers, n_chans=energy_predictor_chans,
kernel_size=energy_predictor_kernel_size, dropout_rate=energy_predictor_dropout)
# continuous energy + FastPitch style avg
self.energy_embed = torch.nn.Sequential(
torch.nn.Conv1d(in_channels=1, out_channels=adim, kernel_size=energy_embed_kernel_size, padding=(energy_embed_kernel_size - 1) // 2),
torch.nn.Dropout(energy_embed_dropout))
# define length regulator
self.length_regulator = LengthRegulator()
self.decoder = Conformer(idim=0, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, input_layer=None,
dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate,
attention_dropout_rate=transformer_dec_attn_dropout_rate, normalize_before=decoder_normalize_before,
concat_after=decoder_concat_after, positionwise_conv_kernel_size=positionwise_conv_kernel_size,
macaron_style=use_macaron_style_in_conformer, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_dec_kernel_size)
# define final projection
self.feat_out = torch.nn.Linear(adim, odim * reduction_factor)
# define postnet
self.postnet = PostNet(idim=idim, odim=odim, n_layers=postnet_layers, n_chans=postnet_chans, n_filts=postnet_filts, use_batch_norm=use_batch_norm,
dropout_rate=postnet_dropout_rate)
# initialize parameters
self._reset_parameters(init_type=init_type, init_enc_alpha=init_enc_alpha, init_dec_alpha=init_dec_alpha)
# define criterions
self.criterion = FastSpeech2Loss(use_masking=use_masking, use_weighted_masking=use_weighted_masking)
self.dtw_criterion = SoftDTW(use_cuda=True, gamma=0.1)
def forward(self,
text_tensors,
text_lengths,
gold_speech,
speech_lengths,
gold_durations,
gold_pitch,
gold_energy,
utterance_embedding,
return_mels=False,
lang_ids=None):
"""
Calculate forward propagation.
Args:
return_mels: whether to return the predicted spectrogram
text_tensors (LongTensor): Batch of padded text vectors (B, Tmax).
text_lengths (LongTensor): Batch of lengths of each input (B,).
gold_speech (Tensor): Batch of padded target features (B, Lmax, odim).
speech_lengths (LongTensor): Batch of the lengths of each target (B,).
gold_durations (LongTensor): Batch of padded durations (B, Tmax + 1).
gold_pitch (Tensor): Batch of padded token-averaged pitch (B, Tmax + 1, 1).
gold_energy (Tensor): Batch of padded token-averaged energy (B, Tmax + 1, 1).
Returns:
Tensor: Loss scalar value.
Dict: Statistics to be monitored.
Tensor: Weight value.
"""
# Texts include EOS token from the teacher model already in this version
# forward propagation
before_outs, after_outs, d_outs, p_outs, e_outs = self._forward(text_tensors, text_lengths, gold_speech, speech_lengths,
gold_durations, gold_pitch, gold_energy, utterance_embedding=utterance_embedding,
is_inference=False, lang_ids=lang_ids)
# modify mod part of groundtruth (speaking pace)
if self.reduction_factor > 1:
speech_lengths = speech_lengths.new([olen - olen % self.reduction_factor for olen in speech_lengths])
# calculate loss
l1_loss, duration_loss, pitch_loss, energy_loss = self.criterion(after_outs=after_outs, before_outs=before_outs, d_outs=d_outs, p_outs=p_outs,
e_outs=e_outs, ys=gold_speech, ds=gold_durations, ps=gold_pitch, es=gold_energy,
ilens=text_lengths, olens=speech_lengths)
loss = l1_loss + duration_loss + pitch_loss + energy_loss
if self.use_dtw_loss:
# print("Regular Loss: {}".format(loss))
dtw_loss = self.dtw_criterion(after_outs, gold_speech).mean() / 2000.0 # division to balance orders of magnitude
# print("DTW Loss: {}".format(dtw_loss))
loss = loss + dtw_loss
if return_mels:
return loss, after_outs
return loss
def _forward(self, text_tensors, text_lens, gold_speech=None, speech_lens=None,
gold_durations=None, gold_pitch=None, gold_energy=None,
is_inference=False, alpha=1.0, utterance_embedding=None, lang_ids=None):
if not self.multilingual_model:
lang_ids = None
if not self.multispeaker_model:
utterance_embedding = None
# forward encoder
text_masks = self._source_mask(text_lens)
encoded_texts, _ = self.encoder(text_tensors, text_masks, utterance_embedding=utterance_embedding, lang_ids=lang_ids) # (B, Tmax, adim)
# forward duration predictor and variance predictors
d_masks = make_pad_mask(text_lens, device=text_lens.device)
if self.stop_gradient_from_pitch_predictor:
pitch_predictions = self.pitch_predictor(encoded_texts.detach(), d_masks.unsqueeze(-1))
else:
pitch_predictions = self.pitch_predictor(encoded_texts, d_masks.unsqueeze(-1))
if self.stop_gradient_from_energy_predictor:
energy_predictions = self.energy_predictor(encoded_texts.detach(), d_masks.unsqueeze(-1))
else:
energy_predictions = self.energy_predictor(encoded_texts, d_masks.unsqueeze(-1))
if is_inference:
d_outs = self.duration_predictor.inference(encoded_texts, d_masks) # (B, Tmax)
# use prediction in inference
p_embs = self.pitch_embed(pitch_predictions.transpose(1, 2)).transpose(1, 2)
e_embs = self.energy_embed(energy_predictions.transpose(1, 2)).transpose(1, 2)
encoded_texts = encoded_texts + e_embs + p_embs
encoded_texts = self.length_regulator(encoded_texts, d_outs, alpha) # (B, Lmax, adim)
else:
d_outs = self.duration_predictor(encoded_texts, d_masks)
# use groundtruth in training
p_embs = self.pitch_embed(gold_pitch.transpose(1, 2)).transpose(1, 2)
e_embs = self.energy_embed(gold_energy.transpose(1, 2)).transpose(1, 2)
encoded_texts = encoded_texts + e_embs + p_embs
encoded_texts = self.length_regulator(encoded_texts, gold_durations) # (B, Lmax, adim)
# forward decoder
if speech_lens is not None and not is_inference:
if self.reduction_factor > 1:
olens_in = speech_lens.new([olen // self.reduction_factor for olen in speech_lens])
else:
olens_in = speech_lens
h_masks = self._source_mask(olens_in)
else:
h_masks = None
zs, _ = self.decoder(encoded_texts, h_masks) # (B, Lmax, adim)
before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) # (B, Lmax, odim)
# postnet -> (B, Lmax//r * r, odim)
after_outs = before_outs + self.postnet(before_outs.transpose(1, 2)).transpose(1, 2)
return before_outs, after_outs, d_outs, pitch_predictions, energy_predictions
def batch_inference(self, texts, text_lens, utt_emb):
_, after_outs, d_outs, _, _ = self._forward(texts,
text_lens,
None,
is_inference=True,
alpha=1.0)
return after_outs, d_outs
def inference(self,
text,
speech=None,
durations=None,
pitch=None,
energy=None,
alpha=1.0,
use_teacher_forcing=False,
utterance_embedding=None,
return_duration_pitch_energy=False,
lang_id=None):
"""
Generate the sequence of features given the sequences of characters.
Args:
text (LongTensor): Input sequence of characters (T,).
speech (Tensor, optional): Feature sequence to extract style (N, idim).
durations (LongTensor, optional): Groundtruth of duration (T + 1,).
pitch (Tensor, optional): Groundtruth of token-averaged pitch (T + 1, 1).
energy (Tensor, optional): Groundtruth of token-averaged energy (T + 1, 1).
alpha (float, optional): Alpha to control the speed.
use_teacher_forcing (bool, optional): Whether to use teacher forcing.
If true, groundtruth of duration, pitch and energy will be used.
return_duration_pitch_energy: whether to return the list of predicted durations for nicer plotting
Returns:
Tensor: Output sequence of features (L, odim).
"""
self.eval()
x, y = text, speech
d, p, e = durations, pitch, energy
# setup batch axis
ilens = torch.tensor([x.shape[0]], dtype=torch.long, device=x.device)
xs, ys = x.unsqueeze(0), None
if y is not None:
ys = y.unsqueeze(0)
if lang_id is not None:
lang_id = lang_id.unsqueeze(0)
if use_teacher_forcing:
# use groundtruth of duration, pitch, and energy
ds, ps, es = d.unsqueeze(0), p.unsqueeze(0), e.unsqueeze(0)
before_outs, after_outs, d_outs, pitch_predictions, energy_predictions = self._forward(xs,
ilens,
ys,
gold_durations=ds,
gold_pitch=ps,
gold_energy=es,
utterance_embedding=utterance_embedding.unsqueeze(0),
lang_ids=lang_id) # (1, L, odim)
else:
before_outs, after_outs, d_outs, pitch_predictions, energy_predictions = self._forward(xs,
ilens,
ys,
is_inference=True,
alpha=alpha,
utterance_embedding=utterance_embedding.unsqueeze(0),
lang_ids=lang_id) # (1, L, odim)
self.train()
if return_duration_pitch_energy:
return after_outs[0], d_outs[0], pitch_predictions[0], energy_predictions[0]
return after_outs[0]
def _source_mask(self, ilens):
"""
Make masks for self-attention.
Args:
ilens (LongTensor): Batch of lengths (B,).
Returns:
Tensor: Mask tensor for self-attention.
"""
x_masks = make_non_pad_mask(ilens, device=ilens.device)
return x_masks.unsqueeze(-2)
def _reset_parameters(self, init_type, init_enc_alpha, init_dec_alpha):
# initialize parameters
if init_type != "pytorch":
initialize(self, init_type)