# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from dataclasses import dataclass, field import torch from fairseq import metrics, utils from fairseq.criterions import register_criterion from fairseq.criterions.label_smoothed_cross_entropy import ( LabelSmoothedCrossEntropyCriterion, LabelSmoothedCrossEntropyCriterionConfig ) try: from simuleval.metrics.latency import ( AverageLagging, AverageProportion, DifferentiableAverageLagging ) LATENCY_METRICS = { "average_lagging": AverageLagging, "average_proportion": AverageProportion, "differentiable_average_lagging": DifferentiableAverageLagging, } except ImportError: LATENCY_METRICS = None @dataclass class LabelSmoothedCrossEntropyCriterionLatencyAugmentConfig( LabelSmoothedCrossEntropyCriterionConfig ): latency_avg_weight: float = field( default=0.0, metadata={"help": "weight fot average latency loss."}, ) latency_var_weight: float = field( default=0.0, metadata={"help": "weight fot variance latency loss."}, ) latency_avg_type: str = field( default="differentiable_average_lagging", metadata={"help": "latency type for average loss"}, ) latency_var_type: str = field( default="variance_delay", metadata={"help": "latency typ for variance loss"}, ) latency_gather_method: str = field( default="weighted_average", metadata={"help": "method to gather latency loss for all heads"}, ) latency_update_after: int = field( default=0, metadata={"help": "Add latency loss after certain steps"}, ) @register_criterion( "latency_augmented_label_smoothed_cross_entropy", dataclass=LabelSmoothedCrossEntropyCriterionLatencyAugmentConfig ) class LatencyAugmentedLabelSmoothedCrossEntropyCriterion( LabelSmoothedCrossEntropyCriterion ): def __init__( self, task, sentence_avg, label_smoothing, ignore_prefix_size, report_accuracy, latency_avg_weight, latency_var_weight, latency_avg_type, latency_var_type, latency_gather_method, latency_update_after, ): super().__init__( task, sentence_avg, label_smoothing, ignore_prefix_size, report_accuracy ) assert LATENCY_METRICS is not None, "Please make sure SimulEval is installed." self.latency_avg_weight = latency_avg_weight self.latency_var_weight = latency_var_weight self.latency_avg_type = latency_avg_type self.latency_var_type = latency_var_type self.latency_gather_method = latency_gather_method self.latency_update_after = latency_update_after def forward(self, model, sample, reduce=True): net_output = model(**sample["net_input"]) # 1. Compute cross entropy loss loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce) # 2. Compute cross latency loss latency_loss, expected_latency, expected_delays_var = self.compute_latency_loss( model, sample, net_output ) if self.latency_update_after > 0: num_updates = getattr(model.decoder, "num_updates", None) assert num_updates is not None, ( "model.decoder doesn't have attribute 'num_updates'" ) if num_updates <= self.latency_update_after: latency_loss = 0 loss += latency_loss sample_size = ( sample["target"].size(0) if self.sentence_avg else sample["ntokens"] ) logging_output = { "loss": loss.data, "nll_loss": nll_loss.data, "ntokens": sample["ntokens"], "nsentences": sample["target"].size(0), "sample_size": sample_size, "latency": expected_latency, "delays_var": expected_delays_var, "latency_loss": latency_loss, } if self.report_accuracy: n_correct, total = self.compute_accuracy(model, net_output, sample) logging_output["n_correct"] = utils.item(n_correct.data) logging_output["total"] = utils.item(total.data) return loss, sample_size, logging_output def compute_latency_loss(self, model, sample, net_output): assert ( net_output[-1].encoder_padding_mask is None or not net_output[-1].encoder_padding_mask[:, 0].any() ), ( "Only right padding on source is supported." ) # 1. Obtain the expected alignment alpha_list = [item["alpha"] for item in net_output[1].attn_list] num_layers = len(alpha_list) bsz, num_heads, tgt_len, src_len = alpha_list[0].size() # bsz * num_layers * num_heads, tgt_len, src_len alpha_all = torch.cat(alpha_list, dim=1).view(-1, tgt_len, src_len) # 2 compute expected delays # bsz * num_heads * num_layers, tgt_len, src_len for MMA steps = ( torch.arange(1, 1 + src_len) .unsqueeze(0) .unsqueeze(1) .expand_as(alpha_all) .type_as(alpha_all) ) expected_delays = torch.sum(steps * alpha_all, dim=-1) target_padding_mask = ( model.get_targets(sample, net_output) .eq(self.padding_idx) .unsqueeze(1) .expand(bsz, num_layers * num_heads, tgt_len) .contiguous() .view(-1, tgt_len) ) src_lengths = ( sample["net_input"]["src_lengths"] .unsqueeze(1) .expand(bsz, num_layers * num_heads) .contiguous() .view(-1) ) expected_latency = LATENCY_METRICS[self.latency_avg_type]( expected_delays, src_lengths, None, target_padding_mask=target_padding_mask ) # 2.1 average expected latency of heads # bsz, num_layers * num_heads expected_latency = expected_latency.view(bsz, -1) if self.latency_gather_method == "average": # bsz * tgt_len expected_latency = expected_delays.mean(dim=1) elif self.latency_gather_method == "weighted_average": weights = torch.nn.functional.softmax(expected_latency, dim=1) expected_latency = torch.sum(expected_latency * weights, dim=1) elif self.latency_gather_method == "max": expected_latency = expected_latency.max(dim=1)[0] else: raise NotImplementedError expected_latency = expected_latency.sum() avg_loss = self.latency_avg_weight * expected_latency # 2.2 variance of expected delays expected_delays_var = ( expected_delays.view(bsz, -1, tgt_len).var(dim=1).mean(dim=1) ) expected_delays_var = expected_delays_var.sum() var_loss = self.latency_avg_weight * expected_delays_var # 3. Final loss latency_loss = avg_loss + var_loss return latency_loss, expected_latency, expected_delays_var @classmethod def reduce_metrics(cls, logging_outputs) -> None: super().reduce_metrics(logging_outputs) latency = sum( log.get("latency", 0) for log in logging_outputs ) delays_var = sum( log.get("delays_var", 0) for log in logging_outputs ) latency_loss = sum( log.get("latency_loss", 0) for log in logging_outputs ) nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) metrics.log_scalar( "latency", latency.float() / nsentences, nsentences, round=3 ) metrics.log_scalar( "delays_var", delays_var / nsentences, nsentences, round=3 ) metrics.log_scalar( "latency_loss", latency_loss / nsentences, nsentences, round=3 )