# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import inspect from typing import Any, Dict, List from fairseq import metrics, utils from fairseq.dataclass import FairseqDataclass from fairseq.dataclass.utils import gen_parser_from_dataclass from torch.nn.modules.loss import _Loss class FairseqCriterion(_Loss): def __init__(self, task): super().__init__() self.task = task if hasattr(task, "target_dictionary"): tgt_dict = task.target_dictionary self.padding_idx = tgt_dict.pad() if tgt_dict is not None else -100 @classmethod def add_args(cls, parser): """Add criterion-specific arguments to the parser.""" dc = getattr(cls, "__dataclass", None) if dc is not None: gen_parser_from_dataclass(parser, dc()) @classmethod def build_criterion(cls, cfg: FairseqDataclass, task): """Construct a criterion from command-line args.""" # arguments in the __init__. init_args = {} for p in inspect.signature(cls).parameters.values(): if ( p.kind == p.POSITIONAL_ONLY or p.kind == p.VAR_POSITIONAL or p.kind == p.VAR_KEYWORD ): # we haven't implemented inference for these argument types, # but PRs welcome :) raise NotImplementedError("{} not supported".format(p.kind)) assert p.kind in {p.POSITIONAL_OR_KEYWORD, p.KEYWORD_ONLY} if p.name == "task": init_args["task"] = task elif p.name == "cfg": init_args["cfg"] = cfg elif hasattr(cfg, p.name): init_args[p.name] = getattr(cfg, p.name) elif p.default != p.empty: pass # we'll use the default value else: raise NotImplementedError( "Unable to infer Criterion arguments, please implement " "{}.build_criterion".format(cls.__name__) ) return cls(**init_args) def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. Returns a tuple with three elements: 1) the loss 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ raise NotImplementedError @staticmethod def aggregate_logging_outputs( logging_outputs: List[Dict[str, Any]] ) -> Dict[str, Any]: """Aggregate logging outputs from data parallel training.""" utils.deprecation_warning( "The aggregate_logging_outputs API is deprecated. " "Please use the reduce_metrics API instead." ) raise NotImplementedError @classmethod def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None: """Aggregate logging outputs from data parallel training.""" utils.deprecation_warning( "Criterions should implement the reduce_metrics API. " "Falling back to deprecated aggregate_logging_outputs API." ) agg_logging_outputs = cls.aggregate_logging_outputs(logging_outputs) for k, v in agg_logging_outputs.items(): if k in {"nsentences", "ntokens", "sample_size"}: continue metrics.log_scalar(k, v) @staticmethod def logging_outputs_can_be_summed() -> bool: """ Whether the logging outputs returned by `forward` can be summed across workers prior to calling `reduce_metrics`. Setting this to True will improves distributed training speed. """ return False class LegacyFairseqCriterion(FairseqCriterion): def __init__(self, args, task): super().__init__(task=task) self.args = args utils.deprecation_warning( "Criterions should take explicit arguments instead of an " "argparse.Namespace object, please update your criterion by " "extending FairseqCriterion instead of LegacyFairseqCriterion." ) @classmethod def build_criterion(cls, args, task): """Construct a criterion from command-line args.""" return cls(args, task)