OFA-Image_Caption / fairseq /fairseq /criterions /label_smoothed_cross_entropy_with_alignment.py
JustinLin610
update
8437114
raw history blame
No virus
4.75 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from fairseq import metrics, utils
from fairseq.criterions import register_criterion
from .label_smoothed_cross_entropy import (
LabelSmoothedCrossEntropyCriterion,
LabelSmoothedCrossEntropyCriterionConfig,
)
from dataclasses import dataclass, field
@dataclass
class LabelSmoothedCrossEntropyCriterionWithAlignmentConfig(
LabelSmoothedCrossEntropyCriterionConfig
):
alignment_lambda: float = field(
default=0.05, metadata={"help": "weight for the alignment loss"}
)
@register_criterion(
"label_smoothed_cross_entropy_with_alignment",
dataclass=LabelSmoothedCrossEntropyCriterionWithAlignmentConfig,
)
class LabelSmoothedCrossEntropyCriterionWithAlignment(
LabelSmoothedCrossEntropyCriterion
):
def __init__(self, task, sentence_avg, label_smoothing, alignment_lambda):
super().__init__(task, sentence_avg, label_smoothing)
self.alignment_lambda = alignment_lambda
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample["net_input"])
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)
logging_output = {
"loss": utils.item(loss.data) if reduce else loss.data,
"nll_loss": utils.item(nll_loss.data) if reduce else nll_loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["target"].size(0),
"sample_size": sample_size,
}
alignment_loss = None
# Compute alignment loss only for training set and non dummy batches.
if "alignments" in sample and sample["alignments"] is not None:
alignment_loss = self.compute_alignment_loss(sample, net_output)
if alignment_loss is not None:
logging_output["alignment_loss"] = utils.item(alignment_loss.data)
loss += self.alignment_lambda * alignment_loss
return loss, sample_size, logging_output
def compute_alignment_loss(self, sample, net_output):
attn_prob = net_output[1]["attn"][0]
bsz, tgt_sz, src_sz = attn_prob.shape
attn = attn_prob.view(bsz * tgt_sz, src_sz)
align = sample["alignments"]
align_weights = sample["align_weights"].float()
if len(align) > 0:
# Alignment loss computation. align (shape [:, 2]) contains the src-tgt index pairs corresponding to
# the alignments. align_weights (shape [:]) contains the 1 / frequency of a tgt index for normalizing.
loss = -(
(attn[align[:, 1][:, None], align[:, 0][:, None]]).log()
* align_weights[:, None]
).sum()
else:
return None
return loss
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
nll_loss_sum = utils.item(
sum(log.get("nll_loss", 0) for log in logging_outputs)
)
alignment_loss_sum = utils.item(
sum(log.get("alignment_loss", 0) for log in logging_outputs)
)
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs)
)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
metrics.log_scalar(
"nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=3
)
metrics.log_scalar(
"alignment_loss",
alignment_loss_sum / sample_size / math.log(2),
sample_size,
round=3,
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True