|
|
|
|
|
|
|
"""Fastspeech2 related modules for ESPnet2.""" |
|
|
|
import logging |
|
|
|
from typing import Dict |
|
from typing import Sequence |
|
from typing import Tuple |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from typeguard import check_argument_types |
|
|
|
from espnet.nets.pytorch_backend.conformer.encoder import ( |
|
Encoder as ConformerEncoder, |
|
) |
|
from espnet.nets.pytorch_backend.fastspeech.duration_predictor import DurationPredictor |
|
from espnet.nets.pytorch_backend.fastspeech.duration_predictor import ( |
|
DurationPredictorLoss, |
|
) |
|
from espnet.nets.pytorch_backend.fastspeech.length_regulator import LengthRegulator |
|
from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask |
|
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask |
|
from espnet.nets.pytorch_backend.tacotron2.decoder import Postnet |
|
from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding |
|
from espnet.nets.pytorch_backend.transformer.embedding import ScaledPositionalEncoding |
|
from espnet.nets.pytorch_backend.transformer.encoder import ( |
|
Encoder as TransformerEncoder, |
|
) |
|
|
|
from espnet2.torch_utils.device_funcs import force_gatherable |
|
from espnet2.torch_utils.initialize import initialize |
|
from espnet2.tts.abs_tts import AbsTTS |
|
from espnet2.tts.gst.style_encoder import StyleEncoder |
|
from espnet2.tts.variance_predictor import VariancePredictor |
|
|
|
|
|
class FastSpeech2(AbsTTS): |
|
"""FastSpeech2 module. |
|
|
|
This is a module of FastSpeech2 described in `FastSpeech 2: Fast and |
|
High-Quality End-to-End Text to Speech`_. Instead of quantized pitch and |
|
energy, we use token-averaged value introduced in `FastPitch: Parallel |
|
Text-to-speech with Pitch Prediction`_. |
|
|
|
.. _`FastSpeech 2: Fast and High-Quality End-to-End Text to Speech`: |
|
https://arxiv.org/abs/2006.04558 |
|
.. _`FastPitch: Parallel Text-to-speech with Pitch Prediction`: |
|
https://arxiv.org/abs/2006.06873 |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
|
|
idim: int, |
|
odim: int, |
|
adim: int = 384, |
|
aheads: int = 4, |
|
elayers: int = 6, |
|
eunits: int = 1536, |
|
dlayers: int = 6, |
|
dunits: int = 1536, |
|
postnet_layers: int = 5, |
|
postnet_chans: int = 512, |
|
postnet_filts: int = 5, |
|
positionwise_layer_type: str = "conv1d", |
|
positionwise_conv_kernel_size: int = 1, |
|
use_scaled_pos_enc: bool = True, |
|
use_batch_norm: bool = True, |
|
encoder_normalize_before: bool = True, |
|
decoder_normalize_before: bool = True, |
|
encoder_concat_after: bool = False, |
|
decoder_concat_after: bool = False, |
|
reduction_factor: int = 1, |
|
encoder_type: str = "transformer", |
|
decoder_type: str = "transformer", |
|
|
|
conformer_rel_pos_type: str = "legacy", |
|
conformer_pos_enc_layer_type: str = "rel_pos", |
|
conformer_self_attn_layer_type: str = "rel_selfattn", |
|
conformer_activation_type: str = "swish", |
|
use_macaron_style_in_conformer: bool = True, |
|
use_cnn_in_conformer: bool = True, |
|
zero_triu: bool = False, |
|
conformer_enc_kernel_size: int = 7, |
|
conformer_dec_kernel_size: int = 31, |
|
|
|
duration_predictor_layers: int = 2, |
|
duration_predictor_chans: int = 384, |
|
duration_predictor_kernel_size: int = 3, |
|
|
|
energy_predictor_layers: int = 2, |
|
energy_predictor_chans: int = 384, |
|
energy_predictor_kernel_size: int = 3, |
|
energy_predictor_dropout: float = 0.5, |
|
energy_embed_kernel_size: int = 9, |
|
energy_embed_dropout: float = 0.5, |
|
stop_gradient_from_energy_predictor: bool = False, |
|
|
|
pitch_predictor_layers: int = 2, |
|
pitch_predictor_chans: int = 384, |
|
pitch_predictor_kernel_size: int = 3, |
|
pitch_predictor_dropout: float = 0.5, |
|
pitch_embed_kernel_size: int = 9, |
|
pitch_embed_dropout: float = 0.5, |
|
stop_gradient_from_pitch_predictor: bool = False, |
|
|
|
spk_embed_dim: int = None, |
|
spk_embed_integration_type: str = "add", |
|
|
|
use_gst: bool = False, |
|
gst_tokens: int = 10, |
|
gst_heads: int = 4, |
|
gst_conv_layers: int = 6, |
|
gst_conv_chans_list: Sequence[int] = (32, 32, 64, 64, 128, 128), |
|
gst_conv_kernel_size: int = 3, |
|
gst_conv_stride: int = 2, |
|
gst_gru_layers: int = 1, |
|
gst_gru_units: int = 128, |
|
|
|
transformer_enc_dropout_rate: float = 0.1, |
|
transformer_enc_positional_dropout_rate: float = 0.1, |
|
transformer_enc_attn_dropout_rate: float = 0.1, |
|
transformer_dec_dropout_rate: float = 0.1, |
|
transformer_dec_positional_dropout_rate: float = 0.1, |
|
transformer_dec_attn_dropout_rate: float = 0.1, |
|
duration_predictor_dropout_rate: float = 0.1, |
|
postnet_dropout_rate: float = 0.5, |
|
init_type: str = "xavier_uniform", |
|
init_enc_alpha: float = 1.0, |
|
init_dec_alpha: float = 1.0, |
|
use_masking: bool = False, |
|
use_weighted_masking: bool = False, |
|
): |
|
"""Initialize FastSpeech2 module.""" |
|
assert check_argument_types() |
|
super().__init__() |
|
|
|
|
|
self.idim = idim |
|
self.odim = odim |
|
self.eos = idim - 1 |
|
self.reduction_factor = reduction_factor |
|
self.encoder_type = encoder_type |
|
self.decoder_type = decoder_type |
|
self.stop_gradient_from_pitch_predictor = stop_gradient_from_pitch_predictor |
|
self.stop_gradient_from_energy_predictor = stop_gradient_from_energy_predictor |
|
self.use_scaled_pos_enc = use_scaled_pos_enc |
|
self.use_gst = use_gst |
|
self.spk_embed_dim = spk_embed_dim |
|
if self.spk_embed_dim is not None: |
|
self.spk_embed_integration_type = spk_embed_integration_type |
|
|
|
|
|
self.padding_idx = 0 |
|
|
|
|
|
pos_enc_class = ( |
|
ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding |
|
) |
|
|
|
|
|
if "conformer" in [encoder_type, decoder_type]: |
|
if conformer_rel_pos_type == "legacy": |
|
if conformer_pos_enc_layer_type == "rel_pos": |
|
conformer_pos_enc_layer_type = "legacy_rel_pos" |
|
logging.warning( |
|
"Fallback to conformer_pos_enc_layer_type = 'legacy_rel_pos' " |
|
"due to the compatibility. If you want to use the new one, " |
|
"please use conformer_pos_enc_layer_type = 'latest'." |
|
) |
|
if conformer_self_attn_layer_type == "rel_selfattn": |
|
conformer_self_attn_layer_type = "legacy_rel_selfattn" |
|
logging.warning( |
|
"Fallback to " |
|
"conformer_self_attn_layer_type = 'legacy_rel_selfattn' " |
|
"due to the compatibility. If you want to use the new one, " |
|
"please use conformer_pos_enc_layer_type = 'latest'." |
|
) |
|
elif conformer_rel_pos_type == "latest": |
|
assert conformer_pos_enc_layer_type != "legacy_rel_pos" |
|
assert conformer_self_attn_layer_type != "legacy_rel_selfattn" |
|
else: |
|
raise ValueError(f"Unknown rel_pos_type: {conformer_rel_pos_type}") |
|
|
|
|
|
encoder_input_layer = torch.nn.Embedding( |
|
num_embeddings=idim, embedding_dim=adim, padding_idx=self.padding_idx |
|
) |
|
if encoder_type == "transformer": |
|
self.encoder = TransformerEncoder( |
|
idim=idim, |
|
attention_dim=adim, |
|
attention_heads=aheads, |
|
linear_units=eunits, |
|
num_blocks=elayers, |
|
input_layer=encoder_input_layer, |
|
dropout_rate=transformer_enc_dropout_rate, |
|
positional_dropout_rate=transformer_enc_positional_dropout_rate, |
|
attention_dropout_rate=transformer_enc_attn_dropout_rate, |
|
pos_enc_class=pos_enc_class, |
|
normalize_before=encoder_normalize_before, |
|
concat_after=encoder_concat_after, |
|
positionwise_layer_type=positionwise_layer_type, |
|
positionwise_conv_kernel_size=positionwise_conv_kernel_size, |
|
) |
|
elif encoder_type == "conformer": |
|
self.encoder = ConformerEncoder( |
|
idim=idim, |
|
attention_dim=adim, |
|
attention_heads=aheads, |
|
linear_units=eunits, |
|
num_blocks=elayers, |
|
input_layer=encoder_input_layer, |
|
dropout_rate=transformer_enc_dropout_rate, |
|
positional_dropout_rate=transformer_enc_positional_dropout_rate, |
|
attention_dropout_rate=transformer_enc_attn_dropout_rate, |
|
normalize_before=encoder_normalize_before, |
|
concat_after=encoder_concat_after, |
|
positionwise_layer_type=positionwise_layer_type, |
|
positionwise_conv_kernel_size=positionwise_conv_kernel_size, |
|
macaron_style=use_macaron_style_in_conformer, |
|
pos_enc_layer_type=conformer_pos_enc_layer_type, |
|
selfattention_layer_type=conformer_self_attn_layer_type, |
|
activation_type=conformer_activation_type, |
|
use_cnn_module=use_cnn_in_conformer, |
|
cnn_module_kernel=conformer_enc_kernel_size, |
|
zero_triu=zero_triu, |
|
) |
|
else: |
|
raise ValueError(f"{encoder_type} is not supported.") |
|
|
|
|
|
if self.use_gst: |
|
self.gst = StyleEncoder( |
|
idim=odim, |
|
gst_tokens=gst_tokens, |
|
gst_token_dim=adim, |
|
gst_heads=gst_heads, |
|
conv_layers=gst_conv_layers, |
|
conv_chans_list=gst_conv_chans_list, |
|
conv_kernel_size=gst_conv_kernel_size, |
|
conv_stride=gst_conv_stride, |
|
gru_layers=gst_gru_layers, |
|
gru_units=gst_gru_units, |
|
) |
|
|
|
|
|
if self.spk_embed_dim is not None: |
|
if self.spk_embed_integration_type == "add": |
|
self.projection = torch.nn.Linear(self.spk_embed_dim, adim) |
|
else: |
|
self.projection = torch.nn.Linear(adim + self.spk_embed_dim, adim) |
|
|
|
|
|
self.duration_predictor = DurationPredictor( |
|
idim=adim, |
|
n_layers=duration_predictor_layers, |
|
n_chans=duration_predictor_chans, |
|
kernel_size=duration_predictor_kernel_size, |
|
dropout_rate=duration_predictor_dropout_rate, |
|
) |
|
|
|
|
|
self.pitch_predictor = VariancePredictor( |
|
idim=adim, |
|
n_layers=pitch_predictor_layers, |
|
n_chans=pitch_predictor_chans, |
|
kernel_size=pitch_predictor_kernel_size, |
|
dropout_rate=pitch_predictor_dropout, |
|
) |
|
|
|
self.pitch_embed = torch.nn.Sequential( |
|
torch.nn.Conv1d( |
|
in_channels=1, |
|
out_channels=adim, |
|
kernel_size=pitch_embed_kernel_size, |
|
padding=(pitch_embed_kernel_size - 1) // 2, |
|
), |
|
torch.nn.Dropout(pitch_embed_dropout), |
|
) |
|
|
|
|
|
self.energy_predictor = VariancePredictor( |
|
idim=adim, |
|
n_layers=energy_predictor_layers, |
|
n_chans=energy_predictor_chans, |
|
kernel_size=energy_predictor_kernel_size, |
|
dropout_rate=energy_predictor_dropout, |
|
) |
|
|
|
self.energy_embed = torch.nn.Sequential( |
|
torch.nn.Conv1d( |
|
in_channels=1, |
|
out_channels=adim, |
|
kernel_size=energy_embed_kernel_size, |
|
padding=(energy_embed_kernel_size - 1) // 2, |
|
), |
|
torch.nn.Dropout(energy_embed_dropout), |
|
) |
|
|
|
|
|
self.length_regulator = LengthRegulator() |
|
|
|
|
|
|
|
|
|
if decoder_type == "transformer": |
|
self.decoder = TransformerEncoder( |
|
idim=0, |
|
attention_dim=adim, |
|
attention_heads=aheads, |
|
linear_units=dunits, |
|
num_blocks=dlayers, |
|
input_layer=None, |
|
dropout_rate=transformer_dec_dropout_rate, |
|
positional_dropout_rate=transformer_dec_positional_dropout_rate, |
|
attention_dropout_rate=transformer_dec_attn_dropout_rate, |
|
pos_enc_class=pos_enc_class, |
|
normalize_before=decoder_normalize_before, |
|
concat_after=decoder_concat_after, |
|
positionwise_layer_type=positionwise_layer_type, |
|
positionwise_conv_kernel_size=positionwise_conv_kernel_size, |
|
) |
|
elif decoder_type == "conformer": |
|
self.decoder = ConformerEncoder( |
|
idim=0, |
|
attention_dim=adim, |
|
attention_heads=aheads, |
|
linear_units=dunits, |
|
num_blocks=dlayers, |
|
input_layer=None, |
|
dropout_rate=transformer_dec_dropout_rate, |
|
positional_dropout_rate=transformer_dec_positional_dropout_rate, |
|
attention_dropout_rate=transformer_dec_attn_dropout_rate, |
|
normalize_before=decoder_normalize_before, |
|
concat_after=decoder_concat_after, |
|
positionwise_layer_type=positionwise_layer_type, |
|
positionwise_conv_kernel_size=positionwise_conv_kernel_size, |
|
macaron_style=use_macaron_style_in_conformer, |
|
pos_enc_layer_type=conformer_pos_enc_layer_type, |
|
selfattention_layer_type=conformer_self_attn_layer_type, |
|
activation_type=conformer_activation_type, |
|
use_cnn_module=use_cnn_in_conformer, |
|
cnn_module_kernel=conformer_dec_kernel_size, |
|
) |
|
else: |
|
raise ValueError(f"{decoder_type} is not supported.") |
|
|
|
|
|
self.feat_out = torch.nn.Linear(adim, odim * reduction_factor) |
|
|
|
|
|
self.postnet = ( |
|
None |
|
if postnet_layers == 0 |
|
else Postnet( |
|
idim=idim, |
|
odim=odim, |
|
n_layers=postnet_layers, |
|
n_chans=postnet_chans, |
|
n_filts=postnet_filts, |
|
use_batch_norm=use_batch_norm, |
|
dropout_rate=postnet_dropout_rate, |
|
) |
|
) |
|
|
|
|
|
self._reset_parameters( |
|
init_type=init_type, |
|
init_enc_alpha=init_enc_alpha, |
|
init_dec_alpha=init_dec_alpha, |
|
) |
|
|
|
|
|
self.criterion = FastSpeech2Loss( |
|
use_masking=use_masking, use_weighted_masking=use_weighted_masking |
|
) |
|
|
|
def forward( |
|
self, |
|
text: torch.Tensor, |
|
text_lengths: torch.Tensor, |
|
speech: torch.Tensor, |
|
speech_lengths: torch.Tensor, |
|
durations: torch.Tensor, |
|
durations_lengths: torch.Tensor, |
|
pitch: torch.Tensor, |
|
pitch_lengths: torch.Tensor, |
|
energy: torch.Tensor, |
|
energy_lengths: torch.Tensor, |
|
spembs: torch.Tensor = None, |
|
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: |
|
"""Calculate forward propagation. |
|
|
|
Args: |
|
text (LongTensor): Batch of padded token ids (B, Tmax). |
|
text_lengths (LongTensor): Batch of lengths of each input (B,). |
|
speech (Tensor): Batch of padded target features (B, Lmax, odim). |
|
speech_lengths (LongTensor): Batch of the lengths of each target (B,). |
|
durations (LongTensor): Batch of padded durations (B, Tmax + 1). |
|
durations_lengths (LongTensor): Batch of duration lengths (B, Tmax + 1). |
|
pitch (Tensor): Batch of padded token-averaged pitch (B, Tmax + 1, 1). |
|
pitch_lengths (LongTensor): Batch of pitch lengths (B, Tmax + 1). |
|
energy (Tensor): Batch of padded token-averaged energy (B, Tmax + 1, 1). |
|
energy_lengths (LongTensor): Batch of energy lengths (B, Tmax + 1). |
|
spembs (Tensor, optional): Batch of speaker embeddings (B, spk_embed_dim). |
|
|
|
Returns: |
|
Tensor: Loss scalar value. |
|
Dict: Statistics to be monitored. |
|
Tensor: Weight value. |
|
|
|
""" |
|
text = text[:, : text_lengths.max()] |
|
speech = speech[:, : speech_lengths.max()] |
|
durations = durations[:, : durations_lengths.max()] |
|
pitch = pitch[:, : pitch_lengths.max()] |
|
energy = energy[:, : energy_lengths.max()] |
|
|
|
batch_size = text.size(0) |
|
|
|
|
|
xs = F.pad(text, [0, 1], "constant", self.padding_idx) |
|
for i, l in enumerate(text_lengths): |
|
xs[i, l] = self.eos |
|
ilens = text_lengths + 1 |
|
|
|
ys, ds, ps, es = speech, durations, pitch, energy |
|
olens = speech_lengths |
|
|
|
|
|
before_outs, after_outs, d_outs, p_outs, e_outs = self._forward( |
|
xs, ilens, ys, olens, ds, ps, es, spembs=spembs, is_inference=False |
|
) |
|
|
|
|
|
if self.reduction_factor > 1: |
|
olens = olens.new([olen - olen % self.reduction_factor for olen in olens]) |
|
max_olen = max(olens) |
|
ys = ys[:, :max_olen] |
|
|
|
|
|
if self.postnet is None: |
|
after_outs = None |
|
|
|
|
|
l1_loss, duration_loss, pitch_loss, energy_loss = self.criterion( |
|
after_outs=after_outs, |
|
before_outs=before_outs, |
|
d_outs=d_outs, |
|
p_outs=p_outs, |
|
e_outs=e_outs, |
|
ys=ys, |
|
ds=ds, |
|
ps=ps, |
|
es=es, |
|
ilens=ilens, |
|
olens=olens, |
|
) |
|
loss = l1_loss + duration_loss + pitch_loss + energy_loss |
|
|
|
stats = dict( |
|
l1_loss=l1_loss.item(), |
|
duration_loss=duration_loss.item(), |
|
pitch_loss=pitch_loss.item(), |
|
energy_loss=energy_loss.item(), |
|
loss=loss.item(), |
|
) |
|
|
|
|
|
if self.encoder_type == "transformer" and self.use_scaled_pos_enc: |
|
stats.update( |
|
encoder_alpha=self.encoder.embed[-1].alpha.data.item(), |
|
) |
|
if self.decoder_type == "transformer" and self.use_scaled_pos_enc: |
|
stats.update( |
|
decoder_alpha=self.decoder.embed[-1].alpha.data.item(), |
|
) |
|
|
|
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) |
|
return loss, stats, weight |
|
|
|
def _forward( |
|
self, |
|
xs: torch.Tensor, |
|
ilens: torch.Tensor, |
|
ys: torch.Tensor = None, |
|
olens: torch.Tensor = None, |
|
ds: torch.Tensor = None, |
|
ps: torch.Tensor = None, |
|
es: torch.Tensor = None, |
|
spembs: torch.Tensor = None, |
|
is_inference: bool = False, |
|
alpha: float = 1.0, |
|
) -> Sequence[torch.Tensor]: |
|
|
|
x_masks = self._source_mask(ilens) |
|
hs, _ = self.encoder(xs, x_masks) |
|
|
|
|
|
if self.use_gst: |
|
style_embs = self.gst(ys) |
|
hs = hs + style_embs.unsqueeze(1) |
|
|
|
|
|
if self.spk_embed_dim is not None: |
|
hs = self._integrate_with_spk_embed(hs, spembs) |
|
|
|
|
|
d_masks = make_pad_mask(ilens).to(xs.device) |
|
|
|
if self.stop_gradient_from_pitch_predictor: |
|
p_outs = self.pitch_predictor(hs.detach(), d_masks.unsqueeze(-1)) |
|
else: |
|
p_outs = self.pitch_predictor(hs, d_masks.unsqueeze(-1)) |
|
if self.stop_gradient_from_energy_predictor: |
|
e_outs = self.energy_predictor(hs.detach(), d_masks.unsqueeze(-1)) |
|
else: |
|
e_outs = self.energy_predictor(hs, d_masks.unsqueeze(-1)) |
|
|
|
if is_inference: |
|
d_outs = self.duration_predictor.inference(hs, d_masks) |
|
|
|
p_embs = self.pitch_embed(p_outs.transpose(1, 2)).transpose(1, 2) |
|
e_embs = self.energy_embed(e_outs.transpose(1, 2)).transpose(1, 2) |
|
hs = hs + e_embs + p_embs |
|
hs = self.length_regulator(hs, d_outs, alpha) |
|
else: |
|
d_outs = self.duration_predictor(hs, d_masks) |
|
|
|
p_embs = self.pitch_embed(ps.transpose(1, 2)).transpose(1, 2) |
|
e_embs = self.energy_embed(es.transpose(1, 2)).transpose(1, 2) |
|
hs = hs + e_embs + p_embs |
|
hs = self.length_regulator(hs, ds) |
|
|
|
|
|
if olens is not None and not is_inference: |
|
if self.reduction_factor > 1: |
|
olens_in = olens.new([olen // self.reduction_factor for olen in olens]) |
|
else: |
|
olens_in = olens |
|
h_masks = self._source_mask(olens_in) |
|
else: |
|
h_masks = None |
|
zs, _ = self.decoder(hs, h_masks) |
|
before_outs = self.feat_out(zs).view( |
|
zs.size(0), -1, self.odim |
|
) |
|
|
|
|
|
if self.postnet is None: |
|
after_outs = before_outs |
|
else: |
|
after_outs = before_outs + self.postnet( |
|
before_outs.transpose(1, 2) |
|
).transpose(1, 2) |
|
|
|
return before_outs, after_outs, d_outs, p_outs, e_outs |
|
|
|
def inference( |
|
self, |
|
text: torch.Tensor, |
|
speech: torch.Tensor = None, |
|
spembs: torch.Tensor = None, |
|
durations: torch.Tensor = None, |
|
pitch: torch.Tensor = None, |
|
energy: torch.Tensor = None, |
|
alpha: float = 1.0, |
|
use_teacher_forcing: bool = False, |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
"""Generate the sequence of features given the sequences of characters. |
|
|
|
Args: |
|
text (LongTensor): Input sequence of characters (T,). |
|
speech (Tensor, optional): Feature sequence to extract style (N, idim). |
|
spembs (Tensor, optional): Speaker embedding vector (spk_embed_dim,). |
|
durations (LongTensor, optional): Groundtruth of duration (T + 1,). |
|
pitch (Tensor, optional): Groundtruth of token-averaged pitch (T + 1, 1). |
|
energy (Tensor, optional): Groundtruth of token-averaged energy (T + 1, 1). |
|
alpha (float, optional): Alpha to control the speed. |
|
use_teacher_forcing (bool, optional): Whether to use teacher forcing. |
|
If true, groundtruth of duration, pitch and energy will be used. |
|
|
|
Returns: |
|
Tensor: Output sequence of features (L, odim). |
|
None: Dummy for compatibility. |
|
None: Dummy for compatibility. |
|
|
|
""" |
|
x, y = text, speech |
|
spemb, d, p, e = spembs, durations, pitch, energy |
|
|
|
|
|
x = F.pad(x, [0, 1], "constant", self.eos) |
|
|
|
|
|
ilens = torch.tensor([x.shape[0]], dtype=torch.long, device=x.device) |
|
xs, ys = x.unsqueeze(0), None |
|
if y is not None: |
|
ys = y.unsqueeze(0) |
|
if spemb is not None: |
|
spembs = spemb.unsqueeze(0) |
|
|
|
if use_teacher_forcing: |
|
|
|
ds, ps, es = d.unsqueeze(0), p.unsqueeze(0), e.unsqueeze(0) |
|
_, outs, *_ = self._forward( |
|
xs, |
|
ilens, |
|
ys, |
|
ds=ds, |
|
ps=ps, |
|
es=es, |
|
spembs=spembs, |
|
) |
|
else: |
|
_, outs, *_ = self._forward( |
|
xs, |
|
ilens, |
|
ys, |
|
spembs=spembs, |
|
is_inference=True, |
|
alpha=alpha, |
|
) |
|
|
|
return outs[0], None, None |
|
|
|
def _integrate_with_spk_embed( |
|
self, hs: torch.Tensor, spembs: torch.Tensor |
|
) -> torch.Tensor: |
|
"""Integrate speaker embedding with hidden states. |
|
|
|
Args: |
|
hs (Tensor): Batch of hidden state sequences (B, Tmax, adim). |
|
spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim). |
|
|
|
Returns: |
|
Tensor: Batch of integrated hidden state sequences (B, Tmax, adim). |
|
|
|
""" |
|
if self.spk_embed_integration_type == "add": |
|
|
|
spembs = self.projection(F.normalize(spembs)) |
|
hs = hs + spembs.unsqueeze(1) |
|
elif self.spk_embed_integration_type == "concat": |
|
|
|
spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1) |
|
hs = self.projection(torch.cat([hs, spembs], dim=-1)) |
|
else: |
|
raise NotImplementedError("support only add or concat.") |
|
|
|
return hs |
|
|
|
def _source_mask(self, ilens: torch.Tensor) -> torch.Tensor: |
|
"""Make masks for self-attention. |
|
|
|
Args: |
|
ilens (LongTensor): Batch of lengths (B,). |
|
|
|
Returns: |
|
Tensor: Mask tensor for self-attention. |
|
dtype=torch.uint8 in PyTorch 1.2- |
|
dtype=torch.bool in PyTorch 1.2+ (including 1.2) |
|
|
|
Examples: |
|
>>> ilens = [5, 3] |
|
>>> self._source_mask(ilens) |
|
tensor([[[1, 1, 1, 1, 1], |
|
[1, 1, 1, 0, 0]]], dtype=torch.uint8) |
|
|
|
""" |
|
x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) |
|
return x_masks.unsqueeze(-2) |
|
|
|
def _reset_parameters( |
|
self, init_type: str, init_enc_alpha: float, init_dec_alpha: float |
|
): |
|
|
|
if init_type != "pytorch": |
|
initialize(self, init_type) |
|
|
|
|
|
if self.encoder_type == "transformer" and self.use_scaled_pos_enc: |
|
self.encoder.embed[-1].alpha.data = torch.tensor(init_enc_alpha) |
|
if self.decoder_type == "transformer" and self.use_scaled_pos_enc: |
|
self.decoder.embed[-1].alpha.data = torch.tensor(init_dec_alpha) |
|
|
|
|
|
class FastSpeech2Loss(torch.nn.Module): |
|
"""Loss function module for FastSpeech2.""" |
|
|
|
def __init__(self, use_masking: bool = True, use_weighted_masking: bool = False): |
|
"""Initialize feed-forward Transformer loss module. |
|
|
|
Args: |
|
use_masking (bool): |
|
Whether to apply masking for padded part in loss calculation. |
|
use_weighted_masking (bool): |
|
Whether to weighted masking in loss calculation. |
|
|
|
""" |
|
assert check_argument_types() |
|
super().__init__() |
|
|
|
assert (use_masking != use_weighted_masking) or not use_masking |
|
self.use_masking = use_masking |
|
self.use_weighted_masking = use_weighted_masking |
|
|
|
|
|
reduction = "none" if self.use_weighted_masking else "mean" |
|
self.l1_criterion = torch.nn.L1Loss(reduction=reduction) |
|
self.mse_criterion = torch.nn.MSELoss(reduction=reduction) |
|
self.duration_criterion = DurationPredictorLoss(reduction=reduction) |
|
|
|
def forward( |
|
self, |
|
after_outs: torch.Tensor, |
|
before_outs: torch.Tensor, |
|
d_outs: torch.Tensor, |
|
p_outs: torch.Tensor, |
|
e_outs: torch.Tensor, |
|
ys: torch.Tensor, |
|
ds: torch.Tensor, |
|
ps: torch.Tensor, |
|
es: torch.Tensor, |
|
ilens: torch.Tensor, |
|
olens: torch.Tensor, |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
|
"""Calculate forward propagation. |
|
|
|
Args: |
|
after_outs (Tensor): Batch of outputs after postnets (B, Lmax, odim). |
|
before_outs (Tensor): Batch of outputs before postnets (B, Lmax, odim). |
|
d_outs (LongTensor): Batch of outputs of duration predictor (B, Tmax). |
|
p_outs (Tensor): Batch of outputs of pitch predictor (B, Tmax, 1). |
|
e_outs (Tensor): Batch of outputs of energy predictor (B, Tmax, 1). |
|
ys (Tensor): Batch of target features (B, Lmax, odim). |
|
ds (LongTensor): Batch of durations (B, Tmax). |
|
ps (Tensor): Batch of target token-averaged pitch (B, Tmax, 1). |
|
es (Tensor): Batch of target token-averaged energy (B, Tmax, 1). |
|
ilens (LongTensor): Batch of the lengths of each input (B,). |
|
olens (LongTensor): Batch of the lengths of each target (B,). |
|
|
|
Returns: |
|
Tensor: L1 loss value. |
|
Tensor: Duration predictor loss value. |
|
Tensor: Pitch predictor loss value. |
|
Tensor: Energy predictor loss value. |
|
|
|
""" |
|
|
|
if self.use_masking: |
|
out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device) |
|
before_outs = before_outs.masked_select(out_masks) |
|
if after_outs is not None: |
|
after_outs = after_outs.masked_select(out_masks) |
|
ys = ys.masked_select(out_masks) |
|
duration_masks = make_non_pad_mask(ilens).to(ys.device) |
|
d_outs = d_outs.masked_select(duration_masks) |
|
ds = ds.masked_select(duration_masks) |
|
pitch_masks = make_non_pad_mask(ilens).unsqueeze(-1).to(ys.device) |
|
p_outs = p_outs.masked_select(pitch_masks) |
|
e_outs = e_outs.masked_select(pitch_masks) |
|
ps = ps.masked_select(pitch_masks) |
|
es = es.masked_select(pitch_masks) |
|
|
|
|
|
l1_loss = self.l1_criterion(before_outs, ys) |
|
if after_outs is not None: |
|
l1_loss += self.l1_criterion(after_outs, ys) |
|
duration_loss = self.duration_criterion(d_outs, ds) |
|
pitch_loss = self.mse_criterion(p_outs, ps) |
|
energy_loss = self.mse_criterion(e_outs, es) |
|
|
|
|
|
if self.use_weighted_masking: |
|
out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device) |
|
out_weights = out_masks.float() / out_masks.sum(dim=1, keepdim=True).float() |
|
out_weights /= ys.size(0) * ys.size(2) |
|
duration_masks = make_non_pad_mask(ilens).to(ys.device) |
|
duration_weights = ( |
|
duration_masks.float() / duration_masks.sum(dim=1, keepdim=True).float() |
|
) |
|
duration_weights /= ds.size(0) |
|
|
|
|
|
l1_loss = l1_loss.mul(out_weights).masked_select(out_masks).sum() |
|
duration_loss = ( |
|
duration_loss.mul(duration_weights).masked_select(duration_masks).sum() |
|
) |
|
pitch_masks = duration_masks.unsqueeze(-1) |
|
pitch_weights = duration_weights.unsqueeze(-1) |
|
pitch_loss = pitch_loss.mul(pitch_weights).masked_select(pitch_masks).sum() |
|
energy_loss = ( |
|
energy_loss.mul(pitch_weights).masked_select(pitch_masks).sum() |
|
) |
|
|
|
return l1_loss, duration_loss, pitch_loss, energy_loss |
|
|