import torch import torch.nn as nn class FastSpeech2Loss(nn.Module): """ FastSpeech2 Loss """ def __init__(self, preprocess_config, model_config): super(FastSpeech2Loss, self).__init__() self.pitch_feature_level = preprocess_config["preprocessing"]["pitch"][ "feature" ] self.energy_feature_level = preprocess_config["preprocessing"]["energy"][ "feature" ] self.mse_loss = nn.MSELoss() self.mae_loss = nn.L1Loss() def forward(self, inputs, predictions): ( mel_targets, _, _, pitch_targets, energy_targets, duration_targets, ) = inputs[7:] ( mel_predictions, postnet_mel_predictions, pitch_predictions, energy_predictions, log_duration_predictions, _, src_masks, mel_masks, _, _, ) = predictions src_masks = ~src_masks mel_masks = ~mel_masks log_duration_targets = torch.log(duration_targets.float() + 1) mel_targets = mel_targets[:, : mel_masks.shape[1], :] mel_masks = mel_masks[:, :mel_masks.shape[1]] log_duration_targets.requires_grad = False pitch_targets.requires_grad = False energy_targets.requires_grad = False mel_targets.requires_grad = False if self.pitch_feature_level == "phoneme_level": pitch_predictions = pitch_predictions.masked_select(src_masks) pitch_targets = pitch_targets.masked_select(src_masks) elif self.pitch_feature_level == "frame_level": pitch_predictions = pitch_predictions.masked_select(mel_masks) pitch_targets = pitch_targets.masked_select(mel_masks) if self.energy_feature_level == "phoneme_level": energy_predictions = energy_predictions.masked_select(src_masks) energy_targets = energy_targets.masked_select(src_masks) if self.energy_feature_level == "frame_level": energy_predictions = energy_predictions.masked_select(mel_masks) energy_targets = energy_targets.masked_select(mel_masks) log_duration_predictions = log_duration_predictions.masked_select( src_masks) log_duration_targets = log_duration_targets.masked_select(src_masks) mel_predictions = mel_predictions.masked_select( mel_masks.unsqueeze(-1)) postnet_mel_predictions = postnet_mel_predictions.masked_select( mel_masks.unsqueeze(-1) ) mel_targets = mel_targets.masked_select(mel_masks.unsqueeze(-1)) mel_loss = self.mae_loss(mel_predictions, mel_targets) postnet_mel_loss = self.mae_loss(postnet_mel_predictions, mel_targets) pitch_loss = self.mse_loss(pitch_predictions, pitch_targets) energy_loss = self.mse_loss(energy_predictions, energy_targets) duration_loss = self.mse_loss( log_duration_predictions, log_duration_targets) total_loss = ( mel_loss + postnet_mel_loss + duration_loss + pitch_loss + energy_loss ) return ( total_loss, mel_loss, postnet_mel_loss, pitch_loss, energy_loss, duration_loss, )