|
import torch |
|
import torch.nn as nn |
|
import hparams as hp |
|
|
|
class FastSpeech2Loss(nn.Module): |
|
""" FastSpeech2 Loss """ |
|
|
|
def __init__(self): |
|
super(FastSpeech2Loss, self).__init__() |
|
self.mse_loss = nn.MSELoss() |
|
self.mae_loss = nn.L1Loss() |
|
|
|
def forward(self, log_d_predicted, log_d_target, p_predicted, p_target, e_predicted, e_target, mel, mel_postnet, mel_target, src_mask, mel_mask): |
|
log_d_target.requires_grad = False |
|
p_target.requires_grad = False |
|
e_target.requires_grad = False |
|
mel_target.requires_grad = False |
|
|
|
log_d_predicted = log_d_predicted.masked_select(src_mask) |
|
log_d_target = log_d_target.masked_select(src_mask) |
|
p_predicted = p_predicted.masked_select(src_mask) |
|
p_target = p_target.masked_select(src_mask) |
|
e_predicted = e_predicted.masked_select(src_mask) |
|
e_target = e_target.masked_select(src_mask) |
|
|
|
mel = mel.masked_select(mel_mask.unsqueeze(-1)) |
|
mel_postnet = mel_postnet.masked_select(mel_mask.unsqueeze(-1)) |
|
mel_target = mel_target.masked_select(mel_mask.unsqueeze(-1)) |
|
|
|
mel_loss = self.mse_loss(mel, mel_target) |
|
mel_postnet_loss = self.mse_loss(mel_postnet, mel_target) |
|
|
|
d_loss = self.mae_loss(log_d_predicted, log_d_target) |
|
p_loss = self.mae_loss(p_predicted, p_target) |
|
e_loss = self.mae_loss(e_predicted, e_target) |
|
|
|
return mel_loss, mel_postnet_loss, d_loss, p_loss, e_loss |
|
|