|
import os |
|
import json |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from transformer import Encoder, Decoder, PostNet |
|
from .modules import VarianceAdaptor |
|
from utils.tools import get_mask_from_lengths, get_roberta_emotion_embeddings |
|
|
|
|
|
class FastSpeech2(nn.Module): |
|
""" FastSpeech2 """ |
|
|
|
def __init__(self, preprocess_config, model_config): |
|
super(FastSpeech2, self).__init__() |
|
self.model_config = model_config |
|
|
|
self.encoder = Encoder(model_config) |
|
self.variance_adaptor = VarianceAdaptor( |
|
preprocess_config, model_config) |
|
self.decoder = Decoder(model_config) |
|
self.mel_linear = nn.Linear( |
|
model_config["transformer"]["decoder_hidden"], |
|
preprocess_config["preprocessing"]["mel"]["n_mel_channels"], |
|
) |
|
self.postnet = PostNet() |
|
|
|
self.speaker_emb = None |
|
if model_config["multi_speaker"]: |
|
with open( |
|
os.path.join( |
|
preprocess_config["path"]["preprocessed_path"], "speakers.json" |
|
), |
|
"r", |
|
) as f: |
|
n_speaker = len(json.load(f)) |
|
self.speaker_emb = nn.Embedding( |
|
n_speaker, |
|
model_config["transformer"]["encoder_hidden"], |
|
) |
|
|
|
self.emotion_emb = None |
|
|
|
if model_config["multi_emotion"]: |
|
with open( |
|
os.path.join( |
|
preprocess_config["path"]["preprocessed_path"], "emotions.json" |
|
), |
|
"r", |
|
) as f: |
|
n_emotion = len(json.load(f)) |
|
self.emotion_emb = nn.Embedding( |
|
n_emotion, |
|
model_config["transformer"]["encoder_hidden"], |
|
) |
|
self.emotion_linear = nn.Sequential( |
|
nn.Linear(model_config["transformer"]["encoder_hidden"], |
|
model_config["transformer"]["encoder_hidden"]), |
|
nn.ReLU() |
|
) |
|
|
|
def forward( |
|
self, |
|
speakers, |
|
texts, |
|
src_lens, |
|
max_src_len, |
|
emotions, |
|
mels=None, |
|
mel_lens=None, |
|
max_mel_len=None, |
|
p_targets=None, |
|
e_targets=None, |
|
d_targets=None, |
|
p_control=1.0, |
|
e_control=1.0, |
|
d_control=1.0, |
|
): |
|
src_masks = get_mask_from_lengths(src_lens, max_src_len) |
|
mel_masks = ( |
|
get_mask_from_lengths(mel_lens, max_mel_len) |
|
if mel_lens is not None |
|
else None |
|
) |
|
|
|
output = self.encoder(texts, src_masks) |
|
|
|
if self.speaker_emb is not None: |
|
output = output + self.speaker_emb(speakers).unsqueeze(1).expand( |
|
-1, max_src_len, -1 |
|
) |
|
|
|
|
|
if self.emotion_emb is not None: |
|
|
|
output = output + self.emotion_linear(self.emotion_emb(emotions)).unsqueeze(1).expand( |
|
-1, max_src_len, -1 |
|
) |
|
|
|
( |
|
output, |
|
p_predictions, |
|
e_predictions, |
|
log_d_predictions, |
|
d_rounded, |
|
mel_lens, |
|
mel_masks, |
|
) = self.variance_adaptor( |
|
output, |
|
src_masks, |
|
mel_masks, |
|
max_mel_len, |
|
p_targets, |
|
e_targets, |
|
d_targets, |
|
p_control, |
|
e_control, |
|
d_control, |
|
) |
|
|
|
output, mel_masks = self.decoder(output, mel_masks) |
|
output = self.mel_linear(output) |
|
|
|
postnet_output = self.postnet(output) + output |
|
|
|
return ( |
|
output, |
|
postnet_output, |
|
p_predictions, |
|
e_predictions, |
|
log_d_predictions, |
|
d_rounded, |
|
src_masks, |
|
mel_masks, |
|
src_lens, |
|
mel_lens, |
|
) |
|
|