Spaces:
Runtime error
Runtime error
File size: 5,450 Bytes
8437114 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import sys
import torch
from fairseq import utils
class SequenceScorer(object):
"""Scores the target for a given source sentence."""
def __init__(
self,
tgt_dict,
softmax_batch=None,
compute_alignment=False,
eos=None,
symbols_to_strip_from_output=None,
):
self.pad = tgt_dict.pad()
self.eos = tgt_dict.eos() if eos is None else eos
self.softmax_batch = softmax_batch or sys.maxsize
assert self.softmax_batch > 0
self.compute_alignment = compute_alignment
self.symbols_to_strip_from_output = (
symbols_to_strip_from_output.union({self.eos})
if symbols_to_strip_from_output is not None
else {self.eos}
)
@torch.no_grad()
def generate(self, models, sample, **kwargs):
"""Score a batch of translations."""
net_input = sample["net_input"]
def batch_for_softmax(dec_out, target):
# assumes decoder_out[0] is the only thing needed (may not be correct for future models!)
first, rest = dec_out[0], dec_out[1:]
bsz, tsz, dim = first.shape
if bsz * tsz < self.softmax_batch:
yield dec_out, target, True
else:
flat = first.contiguous().view(1, -1, dim)
flat_tgt = target.contiguous().view(flat.shape[:-1])
s = 0
while s < flat.size(1):
e = s + self.softmax_batch
yield (flat[:, s:e],) + rest, flat_tgt[:, s:e], False
s = e
def gather_target_probs(probs, target):
probs = probs.gather(
dim=2,
index=target.unsqueeze(-1),
)
return probs
orig_target = sample["target"]
# compute scores for each model in the ensemble
avg_probs = None
avg_attn = None
for model in models:
model.eval()
decoder_out = model(**net_input)
attn = decoder_out[1] if len(decoder_out) > 1 else None
if type(attn) is dict:
attn = attn.get("attn", None)
batched = batch_for_softmax(decoder_out, orig_target)
probs, idx = None, 0
for bd, tgt, is_single in batched:
sample["target"] = tgt
curr_prob = model.get_normalized_probs(
bd, log_probs=len(models) == 1, sample=sample
).data
if is_single:
probs = gather_target_probs(curr_prob, orig_target)
else:
if probs is None:
probs = curr_prob.new(orig_target.numel())
step = curr_prob.size(0) * curr_prob.size(1)
end = step + idx
tgt_probs = gather_target_probs(
curr_prob.view(tgt.shape + (curr_prob.size(-1),)), tgt
)
probs[idx:end] = tgt_probs.view(-1)
idx = end
sample["target"] = orig_target
probs = probs.view(sample["target"].shape)
if avg_probs is None:
avg_probs = probs
else:
avg_probs.add_(probs)
if attn is not None:
if torch.is_tensor(attn):
attn = attn.data
else:
attn = attn[0]
if avg_attn is None:
avg_attn = attn
else:
avg_attn.add_(attn)
if len(models) > 1:
avg_probs.div_(len(models))
avg_probs.log_()
if avg_attn is not None:
avg_attn.div_(len(models))
bsz = avg_probs.size(0)
hypos = []
start_idxs = sample["start_indices"] if "start_indices" in sample else [0] * bsz
for i in range(bsz):
# remove padding from ref
ref = (
utils.strip_pad(sample["target"][i, start_idxs[i] :], self.pad)
if sample["target"] is not None
else None
)
tgt_len = ref.numel()
avg_probs_i = avg_probs[i][start_idxs[i] : start_idxs[i] + tgt_len]
score_i = avg_probs_i.sum() / tgt_len
if avg_attn is not None:
avg_attn_i = avg_attn[i]
if self.compute_alignment:
alignment = utils.extract_hard_alignment(
avg_attn_i,
sample["net_input"]["src_tokens"][i],
sample["target"][i],
self.pad,
self.eos,
)
else:
alignment = None
else:
avg_attn_i = alignment = None
hypos.append(
[
{
"tokens": ref,
"score": score_i,
"attention": avg_attn_i,
"alignment": alignment,
"positional_scores": avg_probs_i,
}
]
)
return hypos
|