Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import argparse | |
import unittest | |
import tests.utils as test_utils | |
import torch | |
from fairseq.sequence_scorer import SequenceScorer | |
class TestSequenceScorer(unittest.TestCase): | |
def test_sequence_scorer(self): | |
# construct dummy dictionary | |
d = test_utils.dummy_dictionary(vocab_size=2) | |
self.assertEqual(d.pad(), 1) | |
self.assertEqual(d.eos(), 2) | |
self.assertEqual(d.unk(), 3) | |
eos = d.eos() | |
w1 = 4 | |
w2 = 5 | |
# construct dataloader | |
data = [ | |
{ | |
"source": torch.LongTensor([w1, w2, eos]), | |
"target": torch.LongTensor([w1, w2, w1, eos]), | |
}, | |
{ | |
"source": torch.LongTensor([w2, eos]), | |
"target": torch.LongTensor([w2, w1, eos]), | |
}, | |
{ | |
"source": torch.LongTensor([w2, eos]), | |
"target": torch.LongTensor([w2, eos]), | |
}, | |
] | |
data_itr = test_utils.dummy_dataloader(data) | |
# specify expected output probabilities | |
args = argparse.Namespace() | |
unk = 0.0 | |
args.beam_probs = [ | |
# step 0: | |
torch.FloatTensor( | |
[ | |
# eos w1 w2 | |
[0.0, unk, 0.6, 0.4], # sentence 1 | |
[0.0, unk, 0.4, 0.6], # sentence 2 | |
[0.0, unk, 0.7, 0.3], # sentence 3 | |
] | |
), | |
# step 1: | |
torch.FloatTensor( | |
[ | |
# eos w1 w2 | |
[0.0, unk, 0.2, 0.7], # sentence 1 | |
[0.0, unk, 0.8, 0.2], # sentence 2 | |
[0.7, unk, 0.1, 0.2], # sentence 3 | |
] | |
), | |
# step 2: | |
torch.FloatTensor( | |
[ | |
# eos w1 w2 | |
[0.10, unk, 0.50, 0.4], # sentence 1 | |
[0.15, unk, 0.15, 0.7], # sentence 2 | |
[0.00, unk, 0.00, 0.0], # sentence 3 | |
] | |
), | |
# step 3: | |
torch.FloatTensor( | |
[ | |
# eos w1 w2 | |
[0.9, unk, 0.05, 0.05], # sentence 1 | |
[0.0, unk, 0.00, 0.0], # sentence 2 | |
[0.0, unk, 0.00, 0.0], # sentence 3 | |
] | |
), | |
] | |
expected_scores = [ | |
[0.6, 0.7, 0.5, 0.9], # sentence 1 | |
[0.6, 0.8, 0.15], # sentence 2 | |
[0.3, 0.7], # sentence 3 | |
] | |
task = test_utils.TestTranslationTask.setup_task(args, d, d) | |
model = task.build_model(args) | |
scorer = SequenceScorer(task.target_dictionary) | |
for sample in data_itr: | |
hypos = task.inference_step(scorer, [model], sample) | |
for id, hypos_id in zip(sample["id"].tolist(), hypos): | |
self.assertHypoTokens(hypos_id[0], data[id]["target"]) | |
self.assertHypoScore(hypos_id[0], expected_scores[id]) | |
def assertHypoTokens(self, hypo, tokens): | |
self.assertTensorEqual(hypo["tokens"], torch.LongTensor(tokens)) | |
def assertHypoScore(self, hypo, pos_probs, normalized=True, lenpen=1.0): | |
pos_scores = torch.FloatTensor(pos_probs).log() | |
self.assertAlmostEqual(hypo["positional_scores"], pos_scores) | |
self.assertEqual(pos_scores.numel(), hypo["tokens"].numel()) | |
score = pos_scores.sum() | |
if normalized: | |
score /= pos_scores.numel() ** lenpen | |
self.assertLess(abs(score - hypo["score"]), 1e-6) | |
def assertAlmostEqual(self, t1, t2): | |
self.assertEqual(t1.size(), t2.size(), "size mismatch") | |
self.assertLess((t1 - t2).abs().max(), 1e-4) | |
def assertTensorEqual(self, t1, t2): | |
self.assertEqual(t1.size(), t2.size(), "size mismatch") | |
self.assertEqual(t1.ne(t2).long().sum(), 0) | |
if __name__ == "__main__": | |
unittest.main() | |