Florian Lux
add initial infrastructure
b3fa29f
raw history blame
No virus
4.43 kB
"""
Taken from ESPNet
"""
import torch
from Layers.DurationPredictor import DurationPredictorLoss
from Utility.utils import make_non_pad_mask
class FastSpeech2Loss(torch.nn.Module):
def __init__(self, use_masking=True, use_weighted_masking=False):
"""
use_masking (bool):
Whether to apply masking for padded part in loss calculation.
use_weighted_masking (bool):
Whether to weighted masking in loss calculation.
"""
super().__init__()
assert (use_masking != use_weighted_masking) or not use_masking
self.use_masking = use_masking
self.use_weighted_masking = use_weighted_masking
# define criterions
reduction = "none" if self.use_weighted_masking else "mean"
self.l1_criterion = torch.nn.L1Loss(reduction=reduction)
self.mse_criterion = torch.nn.MSELoss(reduction=reduction)
self.duration_criterion = DurationPredictorLoss(reduction=reduction)
def forward(self, after_outs, before_outs, d_outs, p_outs, e_outs, ys,
ds, ps, es, ilens, olens, ):
"""
Args:
after_outs (Tensor): Batch of outputs after postnets (B, Lmax, odim).
before_outs (Tensor): Batch of outputs before postnets (B, Lmax, odim).
d_outs (LongTensor): Batch of outputs of duration predictor (B, Tmax).
p_outs (Tensor): Batch of outputs of pitch predictor (B, Tmax, 1).
e_outs (Tensor): Batch of outputs of energy predictor (B, Tmax, 1).
ys (Tensor): Batch of target features (B, Lmax, odim).
ds (LongTensor): Batch of durations (B, Tmax).
ps (Tensor): Batch of target token-averaged pitch (B, Tmax, 1).
es (Tensor): Batch of target token-averaged energy (B, Tmax, 1).
ilens (LongTensor): Batch of the lengths of each input (B,).
olens (LongTensor): Batch of the lengths of each target (B,).
Returns:
Tensor: L1 loss value.
Tensor: Duration predictor loss value.
Tensor: Pitch predictor loss value.
Tensor: Energy predictor loss value.
"""
# apply mask to remove padded part
if self.use_masking:
out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device)
before_outs = before_outs.masked_select(out_masks)
if after_outs is not None:
after_outs = after_outs.masked_select(out_masks)
ys = ys.masked_select(out_masks)
duration_masks = make_non_pad_mask(ilens).to(ys.device)
d_outs = d_outs.masked_select(duration_masks)
ds = ds.masked_select(duration_masks)
pitch_masks = make_non_pad_mask(ilens).unsqueeze(-1).to(ys.device)
p_outs = p_outs.masked_select(pitch_masks)
e_outs = e_outs.masked_select(pitch_masks)
ps = ps.masked_select(pitch_masks)
es = es.masked_select(pitch_masks)
# calculate loss
l1_loss = self.l1_criterion(before_outs, ys)
if after_outs is not None:
l1_loss += self.l1_criterion(after_outs, ys)
duration_loss = self.duration_criterion(d_outs, ds)
pitch_loss = self.mse_criterion(p_outs, ps)
energy_loss = self.mse_criterion(e_outs, es)
# make weighted mask and apply it
if self.use_weighted_masking:
out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device)
out_weights = out_masks.float() / out_masks.sum(dim=1, keepdim=True).float()
out_weights /= ys.size(0) * ys.size(2)
duration_masks = make_non_pad_mask(ilens).to(ys.device)
duration_weights = (duration_masks.float() / duration_masks.sum(dim=1, keepdim=True).float())
duration_weights /= ds.size(0)
# apply weight
l1_loss = l1_loss.mul(out_weights).masked_select(out_masks).sum()
duration_loss = (duration_loss.mul(duration_weights).masked_select(duration_masks).sum())
pitch_masks = duration_masks.unsqueeze(-1)
pitch_weights = duration_weights.unsqueeze(-1)
pitch_loss = pitch_loss.mul(pitch_weights).masked_select(pitch_masks).sum()
energy_loss = (energy_loss.mul(pitch_weights).masked_select(pitch_masks).sum())
return l1_loss, duration_loss, pitch_loss, energy_loss