# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from dataclasses import dataclass import math from omegaconf import II import torch from fairseq import metrics, modules, utils from fairseq.criterions import FairseqCriterion, register_criterion from fairseq.dataclass import FairseqDataclass @dataclass class MaskedLmConfig(FairseqDataclass): tpu: bool = II("common.tpu") @register_criterion("masked_lm", dataclass=MaskedLmConfig) class MaskedLmLoss(FairseqCriterion): """ Implementation for the loss used in masked language model (MLM) training. """ def __init__(self, cfg: MaskedLmConfig, task): super().__init__(task) self.tpu = cfg.tpu def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. Returns a tuple with three elements: 1) the loss 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ masked_tokens = sample["target"].ne(self.padding_idx) sample_size = masked_tokens.int().sum() # Rare: when all tokens are masked, project all tokens. # We use torch.where to avoid device-to-host transfers, # except on CPU where torch.where is not well supported # (see github.com/pytorch/pytorch/issues/26247). if self.tpu: masked_tokens = None # always project all tokens on TPU elif masked_tokens.device == torch.device("cpu"): if not masked_tokens.any(): masked_tokens = None else: masked_tokens = torch.where( masked_tokens.any(), masked_tokens, masked_tokens.new([True]), ) logits = model(**sample["net_input"], masked_tokens=masked_tokens)[0] targets = model.get_targets(sample, [logits]) if masked_tokens is not None: targets = targets[masked_tokens] loss = modules.cross_entropy( logits.view(-1, logits.size(-1)), targets.view(-1), reduction="sum", ignore_index=self.padding_idx, ) logging_output = { "loss": loss if self.tpu else loss.data, "ntokens": sample["ntokens"], "nsentences": sample["nsentences"], "sample_size": sample_size, } return loss, sample_size, logging_output @staticmethod def reduce_metrics(logging_outputs) -> None: """Aggregate logging outputs from data parallel training.""" loss_sum = sum(log.get("loss", 0) for log in logging_outputs) sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) metrics.log_scalar( "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 ) metrics.log_derived( "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg) ) @staticmethod def logging_outputs_can_be_summed() -> bool: """ Whether the logging outputs returned by `forward` can be summed across workers prior to calling `reduce_metrics`. Setting this to True will improves distributed training speed. """ return True