JustinLin610
update
10b0761
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn.functional as F
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.criterions.label_smoothed_cross_entropy import label_smoothed_nll_loss
from fairseq import metrics, utils
@register_criterion("guided_label_smoothed_cross_entropy_with_accuracy")
class GuidedCrossEntAccCriterion(FairseqCriterion):
def __init__(
self,
task,
sentence_avg,
guide_alpha,
text_input_cost_ratio,
label_smoothing,
disable_text_guide_update_num=0,
attentive_cost_regularization=0,
):
"""
guide_alpha: alpha to inteplate nll and kd loss
text_input_cost_ratio: loss ratio for text only input data
label_smoothing: label smoothing ratio
disable_text_guide_update_num: only use nll loss for the first N updates
attentive_cost_regularization: ratio fo attentive cost
"""
super().__init__(task)
self.alpha = guide_alpha
self.attn_beta = attentive_cost_regularization
self.sentence_avg = sentence_avg
self.eps = label_smoothing
self.text_input_cost_ratio = text_input_cost_ratio
self.disable_update_num = disable_text_guide_update_num
assert self.alpha >= 0 and self.alpha <= 1.0
@staticmethod
def add_args(parser):
"""Add criterion-specific arguments to the parser."""
# fmt: off
parser.add_argument('--label-smoothing', default=0., type=float, metavar='D',
help='epsilon for label smoothing, 0 means no label smoothing')
# fmt: off
parser.add_argument('--guide-alpha', default=0., type=float, metavar='D',
help='alpha to merge kd cost from text to speech input with ce loss')
# fmt: off
parser.add_argument('--disable-text-guide-update-num', default=0, type=int, metavar='D',
help='disable guided target from text for the first N updates.')
parser.add_argument("--attentive-cost-regularization", default=0.0, type=float, metavar='D',
help="use encoder attentive loss regularization with cost ratio D")
parser.add_argument("--attentive-cost-without-normalize", action='store_true',
help="Don't do normalization during attentive cost computation")
def forward(self, model, sample, reduce=True):
reduction = 'sum' if reduce else 'none'
net_input = sample["net_input"]
net_output = model(**net_input)
attn_cost = None
lprobs = model.get_normalized_probs(net_output, log_probs=True)
is_dual_input = True if net_input['src_tokens'] is not None and net_input.get('src_txt_tokens') is not None else False
target = model.get_targets(sample, net_output)
src_token_num = 0
if is_dual_input:
# lprobs_spch from speech encoder and lprobs_text from text encoder
lprobs_spch, lprobs_text = torch.chunk(lprobs, 2)
lprobs_spch.batch_first = lprobs.batch_first
lprobs_text.batch_first = lprobs.batch_first
speech_loss, speech_nll_loss, speech_correct, speech_total = \
self.guide_loss_and_acc(model, lprobs_spch, lprobs_text, target, reduce=(reduction == 'sum'))
text_loss, text_nll_loss, text_correct, text_total = self.compute_loss_and_acc(model, lprobs_text, target, reduction=reduction)
loss = (speech_loss + text_loss)
nll_loss = (speech_nll_loss + text_nll_loss)
correct = speech_correct + text_correct
total = speech_total + text_total
attn_cost = net_output[1].get('attn_cost')
if attn_cost is not None:
# attn_cost is batch_first and padding tokens have been masked already
src_token_num = attn_cost.ne(0).sum()
attn_cost = attn_cost.sum()
loss = loss + attn_cost * self.attn_beta
else:
attn_cost = 0
else:
loss, nll_loss, correct, total = self.compute_loss_and_acc(model, lprobs, target, reduction=reduction)
if sample["net_input"]['src_tokens'] is None: # text input only
loss = loss * self.text_input_cost_ratio
speech_loss = None
speech_nll_loss = None
sample_size, logging_output = self.get_logging_output(
sample, loss, nll_loss, correct, total, src_token_num, speech_loss, speech_nll_loss, attn_cost, is_dual_input
)
return loss, sample_size, logging_output
def compute_loss_and_acc(self, model, lprobs, target, reduction='sum'):
if not lprobs.batch_first:
lprobs = lprobs.transpose(0, 1)
lprobs = lprobs.view(-1, lprobs.size(-1)) # -> (B x T) x C
target = target.view(-1)
loss, nll_loss = label_smoothed_nll_loss(
lprobs, target, self.eps, ignore_index=self.padding_idx, reduce=(reduction == 'sum'),
)
mask = target.ne(self.padding_idx)
correct = torch.sum(lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask)))
total = torch.sum(mask)
return loss, nll_loss, correct, total
def guide_loss_and_acc(self, model, lprobs, lprobs_teacher, target, reduce=True):
""" lprobs_teacher is used as guide for lprobs """
if self.alpha == 0.0 or model.num_updates < self.disable_update_num:
return self.compute_loss_and_acc(model, lprobs, target, reduction=('sum' if reduce else 'none'))
if not lprobs.batch_first:
lprobs = lprobs.transpose(0, 1)
lprobs_teacher = lprobs_teacher.transpose(0, 1)
lprobs = lprobs.view(-1, lprobs.size(-1)).float() # -> (B x T) x C
lprobs_teacher = lprobs_teacher.view(-1, lprobs_teacher.size(-1)).float() # -> (B x T) x C
target = target.view(-1)
loss = F.nll_loss(lprobs, target, ignore_index=self.padding_idx, reduction='sum' if reduce else 'none')
nll_loss = loss
probs_teacher = lprobs_teacher.exp().masked_fill_(target.unsqueeze(-1).eq(self.padding_idx), 0)
probs_teacher = probs_teacher.detach()
guide_loss = -(probs_teacher*lprobs).sum() if reduce else -(probs_teacher*lprobs).sum(-1, keepdim=True)
loss = self.alpha*guide_loss + (1.0 - self.alpha)*loss
mask = target.ne(self.padding_idx)
correct = torch.sum(lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask)))
total = torch.sum(mask)
return loss, nll_loss, correct, total
def get_logging_output(
self,
sample,
loss,
nll_loss,
correct,
total,
src_token_num=0,
speech_loss=None,
speech_nll_loss=None,
attn_cost=None,
is_dual_input=False,
):
sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)
mul_size = 2 if is_dual_input else 1
logging_output = {
"loss": utils.item(loss.data), # * sample['ntokens'],
"nll_loss": utils.item(nll_loss.data), # * sample['ntokens'],
"ntokens": sample["ntokens"]*mul_size,
"nsentences": sample["target"].size(0)*mul_size,
"sample_size": sample_size*mul_size,
"correct": utils.item(correct.data),
"total": utils.item(total.data),
"src_token_num": utils.item(src_token_num.data) if src_token_num > 0 else 0,
"nframes": torch.sum(sample["net_input"]["src_lengths"]).item(),
}
if speech_loss is not None:
logging_output["speech_loss"] = utils.item(speech_loss.data)
logging_output["speech_nll_loss"] = utils.item(speech_nll_loss.data)
logging_output["sample_size_speech_cost"] = sample_size
logging_output["speech_attn_loss"] = attn_cost
return sample_size*mul_size, logging_output
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
correct_sum = sum(log.get("correct", 0) for log in logging_outputs)
total_sum = sum(log.get("total", 0) for log in logging_outputs)
src_token_sum = sum(log.get("src_token_num", 0) for log in logging_outputs)
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
nframes = sum(log.get("nframes", 0) for log in logging_outputs)
speech_loss_sum = sum(log.get("speech_loss", 0) for log in logging_outputs)
speech_nll_loss_sum = sum(log.get("speech_nll_loss", 0) for log in logging_outputs)
speech_attn_loss_sum = sum(log.get("speech_attn_loss", 0) for log in logging_outputs)
sample_size_speech = sum(log.get("sample_size_speech_cost", 0) for log in logging_outputs)
agg_output = {
"loss": loss_sum / sample_size / math.log(2) if sample_size > 0 else 0.0,
"nll_loss": nll_loss_sum / sample_size / math.log(2) if sample_size > 0 else 0.0,
# if args.sentence_avg, then sample_size is nsentences, and loss
# is per-sentence loss; else sample_size is ntokens, and the loss
# becomes per-output token loss
"speech_loss": speech_loss_sum / sample_size_speech / math.log(2) if sample_size_speech > 0 else 0.0,
"speech_nll_loss": speech_nll_loss_sum / sample_size_speech / math.log(2) if sample_size_speech > 0 else 0.0,
"speech_attn_loss": speech_attn_loss_sum / src_token_sum / math.log(2) if src_token_sum > 0 else 0.0,
"ntokens": ntokens,
"nsentences": nsentences,
"nframes": nframes,
"sample_size": sample_size,
"acc": correct_sum * 100.0 / total_sum if total_sum > 0 else 0.0,
"correct": correct_sum,
"total": total_sum,
"src_token_num": src_token_sum,
# total is the number of validate tokens
}
return agg_output
@classmethod
def reduce_metrics(cls, logging_outputs):
"""Aggregate logging outputs from data parallel training."""
agg_logging_outputs = cls.aggregate_logging_outputs(logging_outputs)
for k, v in agg_logging_outputs.items():
if k in {'nsentences', 'ntokens', 'sample_size'}:
continue
metrics.log_scalar(k, v, round=3)