Spaces:
Runtime error
Runtime error
# Copyright (c) 2017-present, Facebook, Inc. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the LICENSE file in | |
# the root directory of this source tree. An additional grant of patent rights | |
# can be found in the PATENTS file in the same directory. | |
import logging | |
from typing import Any, Dict, List | |
from functools import lru_cache | |
from dataclasses import dataclass, field | |
import torch | |
from omegaconf import II | |
from fairseq import metrics, utils | |
from fairseq.criterions import FairseqCriterion, register_criterion | |
from fairseq.dataclass import FairseqDataclass | |
from fairseq.data.data_utils import lengths_to_mask | |
import torch.nn.functional as F | |
logger = logging.getLogger(__name__) | |
class Tacotron2CriterionConfig(FairseqDataclass): | |
bce_pos_weight: float = field( | |
default=1.0, | |
metadata={"help": "weight of positive examples for BCE loss"}, | |
) | |
n_frames_per_step: int = field( | |
default=0, | |
metadata={"help": "Number of frames per decoding step"}, | |
) | |
use_guided_attention_loss: bool = field( | |
default=False, | |
metadata={"help": "use guided attention loss"}, | |
) | |
guided_attention_loss_sigma: float = field( | |
default=0.4, | |
metadata={"help": "weight of positive examples for BCE loss"}, | |
) | |
ctc_weight: float = field( | |
default=0.0, metadata={"help": "weight for CTC loss"} | |
) | |
sentence_avg: bool = II("optimization.sentence_avg") | |
class GuidedAttentionLoss(torch.nn.Module): | |
""" | |
Efficiently Trainable Text-to-Speech System Based on Deep Convolutional | |
Networks with Guided Attention (https://arxiv.org/abs/1710.08969) | |
""" | |
def __init__(self, sigma): | |
super().__init__() | |
self.sigma = sigma | |
def _get_weight(s_len, t_len, sigma): | |
grid_x, grid_y = torch.meshgrid(torch.arange(t_len), torch.arange(s_len)) | |
grid_x = grid_x.to(s_len.device) | |
grid_y = grid_y.to(s_len.device) | |
w = (grid_y.float() / s_len - grid_x.float() / t_len) ** 2 | |
return 1.0 - torch.exp(-w / (2 * (sigma ** 2))) | |
def _get_weights(self, src_lens, tgt_lens): | |
bsz, max_s_len, max_t_len = len(src_lens), max(src_lens), max(tgt_lens) | |
weights = torch.zeros((bsz, max_t_len, max_s_len)) | |
for i, (s_len, t_len) in enumerate(zip(src_lens, tgt_lens)): | |
weights[i, :t_len, :s_len] = self._get_weight(s_len, t_len, | |
self.sigma) | |
return weights | |
def _get_masks(src_lens, tgt_lens): | |
in_masks = lengths_to_mask(src_lens) | |
out_masks = lengths_to_mask(tgt_lens) | |
return out_masks.unsqueeze(2) & in_masks.unsqueeze(1) | |
def forward(self, attn, src_lens, tgt_lens, reduction="mean"): | |
weights = self._get_weights(src_lens, tgt_lens).to(attn.device) | |
masks = self._get_masks(src_lens, tgt_lens).to(attn.device) | |
loss = (weights * attn.transpose(1, 2)).masked_select(masks) | |
loss = torch.sum(loss) if reduction == "sum" else torch.mean(loss) | |
return loss | |
class Tacotron2Criterion(FairseqCriterion): | |
def __init__(self, task, sentence_avg, n_frames_per_step, | |
use_guided_attention_loss, guided_attention_loss_sigma, | |
bce_pos_weight, ctc_weight): | |
super().__init__(task) | |
self.sentence_avg = sentence_avg | |
self.n_frames_per_step = n_frames_per_step | |
self.bce_pos_weight = bce_pos_weight | |
self.guided_attn = None | |
if use_guided_attention_loss: | |
self.guided_attn = GuidedAttentionLoss(guided_attention_loss_sigma) | |
self.ctc_weight = ctc_weight | |
def forward(self, model, sample, reduction="mean"): | |
bsz, max_len, _ = sample["target"].size() | |
feat_tgt = sample["target"] | |
feat_len = sample["target_lengths"].view(bsz, 1).expand(-1, max_len) | |
eos_tgt = torch.arange(max_len).to(sample["target"].device) | |
eos_tgt = eos_tgt.view(1, max_len).expand(bsz, -1) | |
eos_tgt = (eos_tgt == (feat_len - 1)).float() | |
src_tokens = sample["net_input"]["src_tokens"] | |
src_lens = sample["net_input"]["src_lengths"] | |
tgt_lens = sample["target_lengths"] | |
feat_out, eos_out, extra = model( | |
src_tokens=src_tokens, | |
src_lengths=src_lens, | |
prev_output_tokens=sample["net_input"]["prev_output_tokens"], | |
incremental_state=None, | |
target_lengths=tgt_lens, | |
speaker=sample["speaker"] | |
) | |
l1_loss, mse_loss, eos_loss = self.compute_loss( | |
extra["feature_out"], feat_out, eos_out, feat_tgt, eos_tgt, | |
tgt_lens, reduction, | |
) | |
attn_loss = torch.tensor(0.).type_as(l1_loss) | |
if self.guided_attn is not None: | |
attn_loss = self.guided_attn(extra['attn'], src_lens, tgt_lens, reduction) | |
ctc_loss = torch.tensor(0.).type_as(l1_loss) | |
if self.ctc_weight > 0.: | |
net_output = (feat_out, eos_out, extra) | |
lprobs = model.get_normalized_probs(net_output, log_probs=True) | |
lprobs = lprobs.transpose(0, 1) # T x B x C | |
src_mask = lengths_to_mask(src_lens) | |
src_tokens_flat = src_tokens.masked_select(src_mask) | |
ctc_loss = F.ctc_loss( | |
lprobs, src_tokens_flat, tgt_lens, src_lens, | |
reduction=reduction, zero_infinity=True | |
) * self.ctc_weight | |
loss = l1_loss + mse_loss + eos_loss + attn_loss + ctc_loss | |
sample_size = sample["nsentences"] if self.sentence_avg \ | |
else sample["ntokens"] | |
logging_output = { | |
"loss": utils.item(loss.data), | |
"ntokens": sample["ntokens"], | |
"nsentences": sample["nsentences"], | |
"sample_size": sample_size, | |
"l1_loss": utils.item(l1_loss.data), | |
"mse_loss": utils.item(mse_loss.data), | |
"eos_loss": utils.item(eos_loss.data), | |
"attn_loss": utils.item(attn_loss.data), | |
"ctc_loss": utils.item(ctc_loss.data), | |
} | |
return loss, sample_size, logging_output | |
def compute_loss(self, feat_out, feat_out_post, eos_out, feat_tgt, | |
eos_tgt, tgt_lens, reduction="mean"): | |
mask = lengths_to_mask(tgt_lens) | |
_eos_out = eos_out[mask].squeeze() | |
_eos_tgt = eos_tgt[mask] | |
_feat_tgt = feat_tgt[mask] | |
_feat_out = feat_out[mask] | |
_feat_out_post = feat_out_post[mask] | |
l1_loss = ( | |
F.l1_loss(_feat_out, _feat_tgt, reduction=reduction) + | |
F.l1_loss(_feat_out_post, _feat_tgt, reduction=reduction) | |
) | |
mse_loss = ( | |
F.mse_loss(_feat_out, _feat_tgt, reduction=reduction) + | |
F.mse_loss(_feat_out_post, _feat_tgt, reduction=reduction) | |
) | |
eos_loss = F.binary_cross_entropy_with_logits( | |
_eos_out, _eos_tgt, pos_weight=torch.tensor(self.bce_pos_weight), | |
reduction=reduction | |
) | |
return l1_loss, mse_loss, eos_loss | |
def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None: | |
ns = [log.get("sample_size", 0) for log in logging_outputs] | |
ntot = sum(ns) | |
ws = [n / (ntot + 1e-8) for n in ns] | |
for key in ["loss", "l1_loss", "mse_loss", "eos_loss", "attn_loss", "ctc_loss"]: | |
vals = [log.get(key, 0) for log in logging_outputs] | |
val = sum(val * w for val, w in zip(vals, ws)) | |
metrics.log_scalar(key, val, ntot, round=3) | |
metrics.log_scalar("sample_size", ntot, len(logging_outputs)) | |
# inference metrics | |
if "targ_frames" not in logging_outputs[0]: | |
return | |
n = sum(log.get("targ_frames", 0) for log in logging_outputs) | |
for key, new_key in [ | |
("mcd_loss", "mcd_loss"), | |
("pred_frames", "pred_ratio"), | |
("nins", "ins_rate"), | |
("ndel", "del_rate"), | |
]: | |
val = sum(log.get(key, 0) for log in logging_outputs) | |
metrics.log_scalar(new_key, val / n, n, round=3) | |
def logging_outputs_can_be_summed() -> bool: | |
return False | |