Spaces:
Build error
Build error
from typing import Dict | |
from typing import Tuple | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from ..utils.nets_utils import make_pad_mask | |
class MaskedMSELoss(nn.Module): | |
def __init__(self, frames_per_step): | |
super().__init__() | |
self.frames_per_step = frames_per_step | |
self.mel_loss_criterion = nn.MSELoss(reduction='none') | |
# self.loss = nn.MSELoss() | |
self.stop_loss_criterion = nn.BCEWithLogitsLoss(reduction='none') | |
def get_mask(self, lengths, max_len=None): | |
# lengths: [B,] | |
if max_len is None: | |
max_len = torch.max(lengths) | |
batch_size = lengths.size(0) | |
seq_range = torch.arange(0, max_len).long() | |
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len).to(lengths.device) | |
seq_length_expand = lengths.unsqueeze(1).expand_as(seq_range_expand) | |
return (seq_range_expand < seq_length_expand).float() | |
def forward(self, mel_pred, mel_pred_postnet, mel_trg, lengths, | |
stop_target, stop_pred): | |
## process stop_target | |
B = stop_target.size(0) | |
stop_target = stop_target.reshape(B, -1, self.frames_per_step)[:, :, 0] | |
stop_lengths = torch.ceil(lengths.float() / self.frames_per_step).long() | |
stop_mask = self.get_mask(stop_lengths, int(mel_trg.size(1)/self.frames_per_step)) | |
mel_trg.requires_grad = False | |
# (B, T, 1) | |
mel_mask = self.get_mask(lengths, mel_trg.size(1)).unsqueeze(-1) | |
# (B, T, D) | |
mel_mask = mel_mask.expand_as(mel_trg) | |
mel_loss_pre = (self.mel_loss_criterion(mel_pred, mel_trg) * mel_mask).sum() / mel_mask.sum() | |
mel_loss_post = (self.mel_loss_criterion(mel_pred_postnet, mel_trg) * mel_mask).sum() / mel_mask.sum() | |
mel_loss = mel_loss_pre + mel_loss_post | |
# stop token loss | |
stop_loss = torch.sum(self.stop_loss_criterion(stop_pred, stop_target) * stop_mask) / stop_mask.sum() | |
return mel_loss, stop_loss | |