JustinLin610
update
8437114
raw history blame
No virus
7.01 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
def compute_cross_entropy_loss(logits, targets, ignore_index=-100):
"""
Function to compute the cross entropy loss. The default value of
ignore_index is the same as the default value for F.cross_entropy in
pytorch.
"""
assert logits.size(0) == targets.size(
-1
), "Logits and Targets tensor shapes don't match up"
loss = F.nll_loss(
F.log_softmax(logits, -1, dtype=torch.float32),
targets,
reduction="sum",
ignore_index=ignore_index,
)
return loss
@register_criterion("legacy_masked_lm_loss")
class LegacyMaskedLmLoss(FairseqCriterion):
"""
Implementation for the loss used in masked language model (MLM) training.
This optionally also computes the next sentence prediction (NSP) loss and
adds it to the overall loss based on the specified args. There are three
cases to consider:
1) Generic MLM training without NSP loss. In this case sentence_targets
and sentence_logits are both None.
2) BERT training without NSP loss. In this case sentence_targets is
not None but sentence_logits is None and we should not be computing
a sentence level loss.
3) BERT training with NSP loss. In this case both sentence_targets and
sentence_logits are not None and we should be computing a sentence
level loss. The weight of the sentence level loss is specified as
an argument.
"""
def __init__(self, task, masked_lm_only, nsp_loss_weight):
super().__init__(task)
self.masked_lm_only = masked_lm_only
self.nsp_loss_weight = nsp_loss_weight
@staticmethod
def add_args(parser):
"""Args for MaskedLM Loss"""
# Default for masked_lm_only is False so as to not break BERT training
parser.add_argument(
"--masked-lm-only",
default=False,
action="store_true",
help="compute MLM loss only",
)
parser.add_argument(
"--nsp-loss-weight",
default=1.0,
type=float,
help="weight for next sentence prediction" " loss (default 1)",
)
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
lm_logits, output_metadata = model(**sample["net_input"])
# reshape lm_logits from (N,T,C) to (N*T,C)
lm_logits = lm_logits.view(-1, lm_logits.size(-1))
lm_targets = sample["lm_target"].view(-1)
lm_loss = compute_cross_entropy_loss(lm_logits, lm_targets, self.padding_idx)
# compute the number of tokens for which loss is computed. This is used
# to normalize the loss
ntokens = utils.strip_pad(lm_targets, self.padding_idx).numel()
loss = lm_loss / ntokens
nsentences = sample["nsentences"]
# nsentences = 0
# Compute sentence loss if masked_lm_only is False
sentence_loss = None
if not self.masked_lm_only:
sentence_logits = output_metadata["sentence_logits"]
sentence_targets = sample["sentence_target"].view(-1)
# This needs to be recomputed due to some differences between
# TokenBlock and BlockPair dataset. This can be resolved with a
# refactor of BERTModel which we will do in the future.
# TODO: Remove this after refactor of BERTModel
nsentences = sentence_targets.size(0)
# Check for logits being none which can happen when remove_heads
# is set to true in the BERT model. Ideally we should set
# masked_lm_only to true in this case, but that requires some
# refactor in the BERT model.
if sentence_logits is not None:
sentence_loss = compute_cross_entropy_loss(
sentence_logits, sentence_targets
)
loss += self.nsp_loss_weight * (sentence_loss / nsentences)
# NOTE: as we are summing up per token mlm loss and per sentence nsp loss
# we don't need to use sample_size as denominator for the gradient
# here sample_size is just used for logging
sample_size = 1
logging_output = {
"loss": utils.item(loss.data) if reduce else loss.data,
"lm_loss": utils.item(lm_loss.data) if reduce else lm_loss.data,
# sentence loss is not always computed
"sentence_loss": (
(utils.item(sentence_loss.data) if reduce else sentence_loss.data)
if sentence_loss is not None
else 0.0
),
"ntokens": ntokens,
"nsentences": nsentences,
"sample_size": sample_size,
}
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
lm_loss_sum = sum(log.get("lm_loss", 0) for log in logging_outputs)
sentence_loss_sum = sum(log.get("sentence_loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
agg_loss = sum(log.get("loss", 0) for log in logging_outputs)
metrics.log_scalar(
"loss",
agg_loss / sample_size / math.log(2) if sample_size > 0 else 0.0,
sample_size,
round=3,
)
metrics.log_scalar(
"lm_loss",
lm_loss_sum / ntokens / math.log(2) if ntokens > 0 else 0.0,
ntokens,
round=3,
)
metrics.log_scalar(
"sentence_loss",
sentence_loss_sum / nsentences / math.log(2) if nsentences > 0 else 0.0,
nsentences,
round=3,
)
metrics.log_scalar(
"nll_loss",
lm_loss_sum / ntokens / math.log(2) if ntokens > 0 else 0.0,
ntokens,
round=3,
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True