# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import sys import torch from fairseq import utils class SequenceScorer(object): """Scores the target for a given source sentence.""" def __init__( self, tgt_dict, softmax_batch=None, compute_alignment=False, eos=None, symbols_to_strip_from_output=None, ): self.pad = tgt_dict.pad() self.eos = tgt_dict.eos() if eos is None else eos self.softmax_batch = softmax_batch or sys.maxsize assert self.softmax_batch > 0 self.compute_alignment = compute_alignment self.symbols_to_strip_from_output = ( symbols_to_strip_from_output.union({self.eos}) if symbols_to_strip_from_output is not None else {self.eos} ) @torch.no_grad() def generate(self, models, sample, **kwargs): """Score a batch of translations.""" net_input = sample["net_input"] def batch_for_softmax(dec_out, target): # assumes decoder_out[0] is the only thing needed (may not be correct for future models!) first, rest = dec_out[0], dec_out[1:] bsz, tsz, dim = first.shape if bsz * tsz < self.softmax_batch: yield dec_out, target, True else: flat = first.contiguous().view(1, -1, dim) flat_tgt = target.contiguous().view(flat.shape[:-1]) s = 0 while s < flat.size(1): e = s + self.softmax_batch yield (flat[:, s:e],) + rest, flat_tgt[:, s:e], False s = e def gather_target_probs(probs, target): probs = probs.gather( dim=2, index=target.unsqueeze(-1), ) return probs orig_target = sample["target"] # compute scores for each model in the ensemble avg_probs = None avg_attn = None for model in models: model.eval() decoder_out = model(**net_input) attn = decoder_out[1] if len(decoder_out) > 1 else None if type(attn) is dict: attn = attn.get("attn", None) batched = batch_for_softmax(decoder_out, orig_target) probs, idx = None, 0 for bd, tgt, is_single in batched: sample["target"] = tgt curr_prob = model.get_normalized_probs( bd, log_probs=len(models) == 1, sample=sample ).data if is_single: probs = gather_target_probs(curr_prob, orig_target) else: if probs is None: probs = curr_prob.new(orig_target.numel()) step = curr_prob.size(0) * curr_prob.size(1) end = step + idx tgt_probs = gather_target_probs( curr_prob.view(tgt.shape + (curr_prob.size(-1),)), tgt ) probs[idx:end] = tgt_probs.view(-1) idx = end sample["target"] = orig_target probs = probs.view(sample["target"].shape) if avg_probs is None: avg_probs = probs else: avg_probs.add_(probs) if attn is not None: if torch.is_tensor(attn): attn = attn.data else: attn = attn[0] if avg_attn is None: avg_attn = attn else: avg_attn.add_(attn) if len(models) > 1: avg_probs.div_(len(models)) avg_probs.log_() if avg_attn is not None: avg_attn.div_(len(models)) bsz = avg_probs.size(0) hypos = [] start_idxs = sample["start_indices"] if "start_indices" in sample else [0] * bsz for i in range(bsz): # remove padding from ref ref = ( utils.strip_pad(sample["target"][i, start_idxs[i] :], self.pad) if sample["target"] is not None else None ) tgt_len = ref.numel() avg_probs_i = avg_probs[i][start_idxs[i] : start_idxs[i] + tgt_len] score_i = avg_probs_i.sum() / tgt_len if avg_attn is not None: avg_attn_i = avg_attn[i] if self.compute_alignment: alignment = utils.extract_hard_alignment( avg_attn_i, sample["net_input"]["src_tokens"][i], sample["target"][i], self.pad, self.eos, ) else: alignment = None else: avg_attn_i = alignment = None hypos.append( [ { "tokens": ref, "score": score_i, "attention": avg_attn_i, "alignment": alignment, "positional_scores": avg_probs_i, } ] ) return hypos