import torch import torch.nn as nn import hparams as hp class FastSpeech2Loss(nn.Module): """ FastSpeech2 Loss """ def __init__(self): super(FastSpeech2Loss, self).__init__() self.mse_loss = nn.MSELoss() self.mae_loss = nn.L1Loss() def forward(self, log_d_predicted, log_d_target, p_predicted, p_target, e_predicted, e_target, mel, mel_postnet, mel_target, src_mask, mel_mask): log_d_target.requires_grad = False p_target.requires_grad = False e_target.requires_grad = False mel_target.requires_grad = False log_d_predicted = log_d_predicted.masked_select(src_mask) log_d_target = log_d_target.masked_select(src_mask) p_predicted = p_predicted.masked_select(src_mask) p_target = p_target.masked_select(src_mask) e_predicted = e_predicted.masked_select(src_mask) e_target = e_target.masked_select(src_mask) mel = mel.masked_select(mel_mask.unsqueeze(-1)) mel_postnet = mel_postnet.masked_select(mel_mask.unsqueeze(-1)) mel_target = mel_target.masked_select(mel_mask.unsqueeze(-1)) mel_loss = self.mse_loss(mel, mel_target) mel_postnet_loss = self.mse_loss(mel_postnet, mel_target) d_loss = self.mae_loss(log_d_predicted, log_d_target) p_loss = self.mae_loss(p_predicted, p_target) e_loss = self.mae_loss(e_predicted, e_target) return mel_loss, mel_postnet_loss, d_loss, p_loss, e_loss