import os import json import torch import torch.nn as nn import torch.nn.functional as F from transformer import Encoder, Decoder, PostNet from .modules import VarianceAdaptor from utils.tools import get_mask_from_lengths, get_roberta_emotion_embeddings class FastSpeech2(nn.Module): """ FastSpeech2 """ def __init__(self, preprocess_config, model_config): super(FastSpeech2, self).__init__() self.model_config = model_config self.encoder = Encoder(model_config) self.variance_adaptor = VarianceAdaptor( preprocess_config, model_config) self.decoder = Decoder(model_config) self.mel_linear = nn.Linear( model_config["transformer"]["decoder_hidden"], preprocess_config["preprocessing"]["mel"]["n_mel_channels"], ) self.postnet = PostNet() self.speaker_emb = None if model_config["multi_speaker"]: with open( os.path.join( preprocess_config["path"]["preprocessed_path"], "speakers.json" ), "r", ) as f: n_speaker = len(json.load(f)) self.speaker_emb = nn.Embedding( n_speaker, model_config["transformer"]["encoder_hidden"], ) self.emotion_emb = None if model_config["multi_emotion"]: with open( os.path.join( preprocess_config["path"]["preprocessed_path"], "emotions.json" ), "r", ) as f: n_emotion = len(json.load(f)) self.emotion_emb = nn.Embedding( n_emotion, model_config["transformer"]["encoder_hidden"], ) self.emotion_linear = nn.Sequential( nn.Linear(model_config["transformer"]["encoder_hidden"], model_config["transformer"]["encoder_hidden"]), nn.ReLU() ) def forward( self, speakers, texts, src_lens, max_src_len, emotions, mels=None, mel_lens=None, max_mel_len=None, p_targets=None, e_targets=None, d_targets=None, p_control=1.0, e_control=1.0, d_control=1.0, ): src_masks = get_mask_from_lengths(src_lens, max_src_len) mel_masks = ( get_mask_from_lengths(mel_lens, max_mel_len) if mel_lens is not None else None ) output = self.encoder(texts, src_masks) if self.speaker_emb is not None: output = output + self.speaker_emb(speakers).unsqueeze(1).expand( -1, max_src_len, -1 ) if self.emotion_emb is not None: output = output + self.emotion_linear(self.emotion_emb(emotions)).unsqueeze(1).expand( -1, max_src_len, -1 ) ( output, p_predictions, e_predictions, log_d_predictions, d_rounded, mel_lens, mel_masks, ) = self.variance_adaptor( output, src_masks, mel_masks, max_mel_len, p_targets, e_targets, d_targets, p_control, e_control, d_control, ) output, mel_masks = self.decoder(output, mel_masks) output = self.mel_linear(output) postnet_output = self.postnet(output) + output return ( output, postnet_output, p_predictions, e_predictions, log_d_predictions, d_rounded, src_masks, mel_masks, src_lens, mel_lens, )